Make extraction cancellable for bin and tar based formats (#481)

Extraction wasn't cancellable by `cancel_on_user_sig_term` used in `entry` since it calls `block_in_place`.

This PR adds cancellation support to it by adding a `static` variable `OnceCell` to `wait_on_cancellation_signal` so that once it returns `Ok(())`, all other calls to it after that point also returns `Ok(())` immediately.

`StreamReadable`, which is used in cancellation process, then stores a boxed future of `wait_on_cancellation_signal` and polled it in `BufReader::fill_buf`.

Note that for zip, the extraction process takes `File` instead of `StreamReadable` due to `io::Seek` requirement, so it cancelling during extraction for zip is still not possible.

This PR also optimized `extract_bin` and `extract_zip` by using `StreamReadable::copy` introduced to this PR instead of `io::copy`, which allocates an internal buffer on stack, which imposes extra copy.

It also fixed `StreamReadable::fill_buf` by ensuring that empty buffer is only returned on eof.

* Make `wait_on_cancellation_signal` pub
* Enable feature `parking_lot` of dep tokio
* Mod `wait_on_cancellation_signal`: Use `OnceCell` internally
   to archive the effect that once call to it return `Ok(())`, all calls to
   it after that also returns `Ok(())`.
* Impl `From<BinstallError>` for `io::Error`
* Impl cancellation on user signal in `StreamReadable`
* Fix err msg when cancelling during extraction in `ops::resolve`
* Optimize: Impl & use `StreamReadable::copy`
   which is same as `io::copy` but does not allocate any internal buffer
   since `StreamReadable` is buffered.
* Fix `next_stream`: Return non-empty bytes on `Ok(Some(bytes))`

Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
This commit is contained in:
Jiahao XU 2022-10-13 11:31:13 +11:00 committed by GitHub
parent fdc617d870
commit aa6012baae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 109 additions and 20 deletions

1
Cargo.lock generated
View file

@ -2042,6 +2042,7 @@ dependencies = [
"memchr", "memchr",
"mio", "mio",
"num_cpus", "num_cpus",
"parking_lot",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry", "signal-hook-registry",
"socket2", "socket2",

View file

@ -46,7 +46,8 @@ tar = { package = "binstall-tar", version = "0.4.39" }
tempfile = "3.3.0" tempfile = "3.3.0"
thiserror = "1.0.37" thiserror = "1.0.37"
tinytemplate = "1.2.1" tinytemplate = "1.2.1"
tokio = { version = "1.21.2", features = ["macros", "rt", "process", "sync", "signal", "time"], default-features = false } # parking_lot - for OnceCell::const_new
tokio = { version = "1.21.2", features = ["macros", "rt", "process", "sync", "signal", "time", "parking_lot"], default-features = false }
toml_edit = { version = "0.14.4", features = ["easy"] } toml_edit = { version = "0.14.4", features = ["easy"] }
tower = { version = "0.4.13", features = ["limit", "util"] } tower = { version = "0.4.13", features = ["limit", "util"] }
trust-dns-resolver = { version = "0.21.2", optional = true, default-features = false, features = ["dnssec-ring"] } trust-dns-resolver = { version = "0.21.2", optional = true, default-features = false, features = ["dnssec-ring"] }

View file

@ -1,4 +1,5 @@
use std::{ use std::{
io,
path::PathBuf, path::PathBuf,
process::{ExitCode, ExitStatus, Termination}, process::{ExitCode, ExitStatus, Termination},
}; };
@ -99,7 +100,7 @@ pub enum BinstallError {
/// - Exit: 74 /// - Exit: 74
#[error(transparent)] #[error(transparent)]
#[diagnostic(severity(error), code(binstall::io))] #[diagnostic(severity(error), code(binstall::io))]
Io(std::io::Error), Io(io::Error),
/// An error interacting with the crates.io API. /// An error interacting with the crates.io API.
/// ///
@ -392,8 +393,8 @@ impl Termination for BinstallError {
} }
} }
impl From<std::io::Error> for BinstallError { impl From<io::Error> for BinstallError {
fn from(err: std::io::Error) -> Self { fn from(err: io::Error) -> Self {
if err.get_ref().is_some() { if err.get_ref().is_some() {
let kind = err.kind(); let kind = err.kind();
@ -404,9 +405,18 @@ impl From<std::io::Error> for BinstallError {
inner inner
.downcast() .downcast()
.map(|b| *b) .map(|b| *b)
.unwrap_or_else(|err| BinstallError::Io(std::io::Error::new(kind, err))) .unwrap_or_else(|err| BinstallError::Io(io::Error::new(kind, err)))
} else { } else {
BinstallError::Io(err) BinstallError::Io(err)
} }
} }
} }
impl From<BinstallError> for io::Error {
fn from(e: BinstallError) -> io::Error {
match e {
BinstallError::Io(io_error) => io_error,
e => io::Error::new(io::ErrorKind::Other, e),
}
}
}

View file

@ -1,7 +1,7 @@
use std::{ use std::{
fmt::Debug, fmt::Debug,
fs, fs,
io::{copy, Read, Seek}, io::{Read, Seek},
path::Path, path::Path,
}; };
@ -33,7 +33,7 @@ where
fs::remove_file(path).ok(); fs::remove_file(path).ok();
}); });
copy(&mut reader, &mut file)?; reader.copy(&mut file)?;
// Operation isn't aborted and all writes succeed, // Operation isn't aborted and all writes succeed,
// disarm the remove_guard. // disarm the remove_guard.
@ -54,7 +54,7 @@ where
let mut file = tempfile()?; let mut file = tempfile()?;
copy(&mut reader, &mut file)?; reader.copy(&mut file)?;
// rewind it so that we can pass it to unzip // rewind it so that we can pass it to unzip
file.rewind()?; file.rewind()?;

View file

@ -1,24 +1,26 @@
use std::{ use std::{
cmp::min, cmp::min,
io::{self, BufRead, Read}, future::Future,
io::{self, BufRead, Read, Write},
pin::Pin,
}; };
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use futures_util::stream::{Stream, StreamExt}; use futures_util::stream::{Stream, StreamExt};
use tokio::runtime::Handle; use tokio::runtime::Handle;
use crate::errors::BinstallError; use crate::{errors::BinstallError, helpers::signal::wait_on_cancellation_signal};
/// This wraps an AsyncIterator as a `Read`able. /// This wraps an AsyncIterator as a `Read`able.
/// It must be used in non-async context only, /// It must be used in non-async context only,
/// meaning you have to use it with /// meaning you have to use it with
/// `tokio::task::{block_in_place, spawn_blocking}` or /// `tokio::task::{block_in_place, spawn_blocking}` or
/// `std::thread::spawn`. /// `std::thread::spawn`.
#[derive(Debug)]
pub struct StreamReadable<S> { pub struct StreamReadable<S> {
stream: S, stream: S,
handle: Handle, handle: Handle,
bytes: Bytes, bytes: Bytes,
cancellation_future: Pin<Box<dyn Future<Output = Result<(), io::Error>> + Send>>,
} }
impl<S> StreamReadable<S> { impl<S> StreamReadable<S> {
@ -27,6 +29,39 @@ impl<S> StreamReadable<S> {
stream, stream,
handle: Handle::current(), handle: Handle::current(),
bytes: Bytes::new(), bytes: Bytes::new(),
cancellation_future: Box::pin(wait_on_cancellation_signal()),
}
}
}
impl<S, E> StreamReadable<S>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
BinstallError: From<E>,
{
/// Copies from `self` to `writer`.
///
/// Same as `io::copy` but does not allocate any internal buffer
/// since `self` is buffered.
pub(super) fn copy<W>(&mut self, mut writer: W) -> io::Result<()>
where
W: Write,
{
self.copy_inner(&mut writer)
}
fn copy_inner(&mut self, writer: &mut dyn Write) -> io::Result<()> {
loop {
let buf = self.fill_buf()?;
if buf.is_empty() {
// Eof
break Ok(());
}
writer.write_all(buf)?;
let n = buf.len();
self.consume(n);
} }
} }
} }
@ -56,6 +91,27 @@ where
Ok(n) Ok(n)
} }
} }
/// If `Ok(Some(bytes))` if returned, then `bytes.is_empty() == false`.
async fn next_stream<S, E>(stream: &mut S) -> io::Result<Option<Bytes>>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
BinstallError: From<E>,
{
loop {
let option = stream
.next()
.await
.transpose()
.map_err(BinstallError::from)?;
match option {
Some(bytes) if bytes.is_empty() => continue,
option => break Ok(option),
}
}
}
impl<S, E> BufRead for StreamReadable<S> impl<S, E> BufRead for StreamReadable<S>
where where
S: Stream<Item = Result<Bytes, E>> + Unpin, S: Stream<Item = Result<Bytes, E>> + Unpin,
@ -65,13 +121,18 @@ where
let bytes = &mut self.bytes; let bytes = &mut self.bytes;
if !bytes.has_remaining() { if !bytes.has_remaining() {
match self.handle.block_on(async { self.stream.next().await }) { let option = self.handle.block_on(async {
Some(Ok(new_bytes)) => *bytes = new_bytes, tokio::select! {
Some(Err(e)) => { res = next_stream(&mut self.stream) => res,
let e: BinstallError = e.into(); res = self.cancellation_future.as_mut() => {
return Err(io::Error::new(io::ErrorKind::Other, e)); Err(res.err().unwrap_or_else(|| io::Error::from(BinstallError::UserAbort)))
},
} }
None => (), })?;
if let Some(new_bytes) = option {
// new_bytes are guaranteed to be non-empty.
*bytes = new_bytes;
} }
} }
Ok(&*bytes) Ok(&*bytes)

View file

@ -1,7 +1,7 @@
use std::io; use std::io;
use futures_util::future::pending; use futures_util::future::pending;
use tokio::signal; use tokio::{signal, sync::OnceCell};
use super::tasks::AutoAbortJoinHandle; use super::tasks::AutoAbortJoinHandle;
use crate::errors::BinstallError; use crate::errors::BinstallError;
@ -24,12 +24,25 @@ pub async fn cancel_on_user_sig_term<T>(
tokio::select! { tokio::select! {
res = handle => res, res = handle => res,
res = wait_on_cancellation_signal() => { res = wait_on_cancellation_signal() => {
res.map_err(BinstallError::Io).and(Err(BinstallError::UserAbort)) res
.map_err(BinstallError::Io)
.and(Err(BinstallError::UserAbort))
} }
} }
} }
async fn wait_on_cancellation_signal() -> Result<(), io::Error> { /// If call to it returns `Ok(())`, then all calls to this function after
/// that also returns `Ok(())`.
pub async fn wait_on_cancellation_signal() -> Result<(), io::Error> {
static CANCELLED: OnceCell<()> = OnceCell::const_new();
CANCELLED
.get_or_try_init(wait_on_cancellation_signal_inner)
.await
.copied()
}
async fn wait_on_cancellation_signal_inner() -> Result<(), io::Error> {
#[cfg(unix)] #[cfg(unix)]
async fn inner() -> Result<(), io::Error> { async fn inner() -> Result<(), io::Error> {
unix::wait_on_cancellation_signal_unix().await unix::wait_on_cancellation_signal_unix().await

View file

@ -260,6 +260,9 @@ async fn resolve_inner(
} }
} }
Err(err) => { Err(err) => {
if let BinstallError::UserAbort = err {
return Err(err);
}
warn!( warn!(
"Error while downloading and extracting from fetcher {}: {}", "Error while downloading and extracting from fetcher {}: {}",
fetcher.source_name(), fetcher.source_name(),