Rewrite AsyncExtracter: Extract fmt logic as callback fn

Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
This commit is contained in:
Jiahao XU 2022-06-11 20:10:46 +10:00
parent d1033758a7
commit 7b52eaad5b
No known key found for this signature in database
GPG key ID: 591C0B03040416D6
2 changed files with 128 additions and 108 deletions

View file

@ -1,3 +1,4 @@
use std::fmt::Debug;
use std::fs; use std::fs;
use std::io::{self, Seek, Write}; use std::io::{self, Seek, Write};
use std::path::Path; use std::path::Path;
@ -22,28 +23,47 @@ pub(crate) enum Content {
Abort, Abort,
} }
/// AsyncExtracter will pass the `Bytes` you give to another thread via
/// a `mpsc` and decompress and unpack it if needed.
///
/// After all write is done, you must call `AsyncExtracter::done`,
/// otherwise the extracted content will be removed on drop.
///
/// # Advantages
///
/// `download_and_extract` has the following advantages over downloading
/// plus extracting in on the same thread:
///
/// - The code is pipelined instead of storing the downloaded file in memory
/// and extract it, except for `PkgFmt::Zip`, since `ZipArchiver::new`
/// requires `std::io::Seek`, so it fallbacks to writing the a file then
/// unzip it.
/// - The async part (downloading) and the extracting part runs in parallel
/// using `tokio::spawn_nonblocking`.
/// - Compressing/writing which takes a lot of CPU time will not block
/// the runtime anymore.
/// - For any PkgFmt except for `PkgFmt::Zip` and `PkgFmt::Bin` (basically
/// all `tar` based formats), it can extract only specified files.
/// This means that `super::drivers::fetch_crate_cratesio` no longer need to
/// extract the whole crate and write them to disk, it now only extract the
/// relevant part (`Cargo.toml`) out to disk and open it.
#[derive(Debug)] #[derive(Debug)]
struct AsyncExtracterInner { struct AsyncExtracterInner<T> {
/// Use AutoAbortJoinHandle so that the task /// Use AutoAbortJoinHandle so that the task
/// will be cancelled on failure. /// will be cancelled on failure.
handle: JoinHandle<Result<(), BinstallError>>, handle: JoinHandle<Result<T, BinstallError>>,
tx: mpsc::Sender<Content>, tx: mpsc::Sender<Content>,
} }
impl AsyncExtracterInner { impl<T: Debug + Send + 'static> AsyncExtracterInner<T> {
/// * `filter` - If Some, then it will pass the path of the file to it fn new<F: FnOnce(mpsc::Receiver<Content>) -> Result<T, BinstallError> + Send + 'static>(
/// and only extract ones which filter returns `true`. f: F,
/// Note that this is a best-effort and it only works when `fmt`
/// is not `PkgFmt::Bin` or `PkgFmt::Zip`.
fn new<Filter: FnMut(&Path) -> bool + Send + 'static>(
path: &Path,
fmt: PkgFmt,
filter: Option<Filter>,
) -> Self { ) -> Self {
let path = path.to_owned(); let (tx, rx) = mpsc::channel::<Content>(100);
let (tx, mut rx) = mpsc::channel::<Content>(100);
let handle = spawn_blocking(move || { let handle = spawn_blocking(move || {
f(rx)
/*
fs::create_dir_all(path.parent().unwrap())?; fs::create_dir_all(path.parent().unwrap())?;
match fmt { match fmt {
@ -78,29 +98,12 @@ impl AsyncExtracterInner {
} }
Ok(()) Ok(())
*/
}); });
Self { handle, tx } Self { handle, tx }
} }
fn read_into_file(
file: &mut fs::File,
rx: &mut mpsc::Receiver<Content>,
) -> Result<(), BinstallError> {
while let Some(content) = rx.blocking_recv() {
match content {
Content::Data(bytes) => file.write_all(&*bytes)?,
Content::Abort => {
return Err(io::Error::new(io::ErrorKind::Other, "Aborted").into())
}
}
}
file.flush()?;
Ok(())
}
/// Upon error, this extracter shall not be reused. /// Upon error, this extracter shall not be reused.
/// Otherwise, `Self::done` would panic. /// Otherwise, `Self::done` would panic.
async fn feed(&mut self, bytes: Bytes) -> Result<(), BinstallError> { async fn feed(&mut self, bytes: Bytes) -> Result<(), BinstallError> {
@ -114,7 +117,7 @@ impl AsyncExtracterInner {
} }
} }
async fn done(mut self) -> Result<(), BinstallError> { async fn done(mut self) -> Result<T, BinstallError> {
// Drop tx as soon as possible so that the task would wrap up what it // Drop tx as soon as possible so that the task would wrap up what it
// was doing and flush out all the pending data. // was doing and flush out all the pending data.
drop(self.tx); drop(self.tx);
@ -122,7 +125,7 @@ impl AsyncExtracterInner {
Self::wait(&mut self.handle).await Self::wait(&mut self.handle).await
} }
async fn wait(handle: &mut JoinHandle<Result<(), BinstallError>>) -> Result<(), BinstallError> { async fn wait(handle: &mut JoinHandle<Result<T, BinstallError>>) -> Result<T, BinstallError> {
match handle.await { match handle.await {
Ok(res) => res, Ok(res) => res,
Err(join_err) => Err(io::Error::new(io::ErrorKind::Other, join_err).into()), Err(join_err) => Err(io::Error::new(io::ErrorKind::Other, join_err).into()),
@ -143,92 +146,98 @@ impl AsyncExtracterInner {
} }
} }
/// AsyncExtracter will pass the `Bytes` you give to another thread via async fn extract_impl<
/// a `mpsc` and decompress and unpack it if needed. F: FnOnce(mpsc::Receiver<Content>) -> Result<T, BinstallError> + Send + 'static,
/// T: Debug + Send + 'static,
/// After all write is done, you must call `AsyncExtracter::done`, S: Stream<Item = Result<Bytes, E>> + Unpin,
/// otherwise the extracted content will be removed on drop. E,
/// >(
/// # Advantages mut stream: S,
/// f: F,
/// `download_and_extract` has the following advantages over downloading ) -> Result<T, BinstallError>
/// plus extracting in on the same thread: where
/// BinstallError: From<E>,
/// - The code is pipelined instead of storing the downloaded file in memory {
/// and extract it, except for `PkgFmt::Zip`, since `ZipArchiver::new` let mut extracter = guard(AsyncExtracterInner::new(f), AsyncExtracterInner::abort);
/// requires `std::io::Seek`, so it fallbacks to writing the a file then
/// unzip it.
/// - The async part (downloading) and the extracting part runs in parallel
/// using `tokio::spawn_nonblocking`.
/// - Compressing/writing which takes a lot of CPU time will not block
/// the runtime anymore.
/// - For any PkgFmt except for `PkgFmt::Zip` and `PkgFmt::Bin` (basically
/// all `tar` based formats), it can extract only specified files.
/// This means that `super::drivers::fetch_crate_cratesio` no longer need to
/// extract the whole crate and write them to disk, it now only extract the
/// relevant part (`Cargo.toml`) out to disk and open it.
#[derive(Debug)]
struct AsyncExtracter(ScopeGuard<AsyncExtracterInner, fn(AsyncExtracterInner), Always>);
impl AsyncExtracter { while let Some(res) = stream.next().await {
/// * `path` - If `fmt` is `PkgFmt::Bin`, then this is the filename extracter.feed(res?).await?;
/// for the bin.
/// Otherwise, it is the directory where the extracted content will be put.
/// * `fmt` - The format of the archive to feed in.
/// * `filter` - If Some, then it will pass the path of the file to it
/// and only extract ones which filter returns `true`.
/// Note that this is a best-effort and it only works when `fmt`
/// is not `PkgFmt::Bin` or `PkgFmt::Zip`.
fn new<Filter: FnMut(&Path) -> bool + Send + 'static>(
path: &Path,
fmt: PkgFmt,
filter: Option<Filter>,
) -> Self {
let inner = AsyncExtracterInner::new(path, fmt, filter);
Self(guard(inner, AsyncExtracterInner::abort))
} }
/// Upon error, this extracter shall not be reused. ScopeGuard::into_inner(extracter).done().await
/// Otherwise, `Self::done` would panic. }
async fn feed(&mut self, bytes: Bytes) -> Result<(), BinstallError> {
self.0.feed(bytes).await fn read_into_file(
file: &mut fs::File,
rx: &mut mpsc::Receiver<Content>,
) -> Result<(), BinstallError> {
while let Some(content) = rx.blocking_recv() {
match content {
Content::Data(bytes) => file.write_all(&*bytes)?,
Content::Abort => return Err(io::Error::new(io::ErrorKind::Other, "Aborted").into()),
}
} }
async fn done(self) -> Result<(), BinstallError> { file.flush()?;
ScopeGuard::into_inner(self.0).done().await
} Ok(())
} }
pub async fn extract_bin<E>( pub async fn extract_bin<E>(
mut stream: impl Stream<Item = Result<Bytes, E>> + Unpin, stream: impl Stream<Item = Result<Bytes, E>> + Unpin,
output: &Path, output: &Path,
) -> Result<(), BinstallError> ) -> Result<(), BinstallError>
where where
BinstallError: From<E>, BinstallError: From<E>,
{ {
let mut extracter = AsyncExtracter::new::<fn(&Path) -> bool>(output, PkgFmt::Bin, None); let path = output.to_owned();
while let Some(res) = stream.next().await { extract_impl(stream, move |mut rx| {
extracter.feed(res?).await?; fs::create_dir_all(path.parent().unwrap())?;
}
extracter.done().await let mut file = fs::File::create(&path)?;
// remove it unless the operation isn't aborted and no write
// fails.
let remove_guard = guard(&path, |path| {
fs::remove_file(path).ok();
});
read_into_file(&mut file, &mut rx)?;
// Operation isn't aborted and all writes succeed,
// disarm the remove_guard.
ScopeGuard::into_inner(remove_guard);
Ok(())
})
.await
} }
pub async fn extract_zip<E>( pub async fn extract_zip<E>(
mut stream: impl Stream<Item = Result<Bytes, E>> + Unpin, stream: impl Stream<Item = Result<Bytes, E>> + Unpin,
output: &Path, output: &Path,
) -> Result<(), BinstallError> ) -> Result<(), BinstallError>
where where
BinstallError: From<E>, BinstallError: From<E>,
{ {
let mut extracter = AsyncExtracter::new::<fn(&Path) -> bool>(output, PkgFmt::Zip, None); let path = output.to_owned();
while let Some(res) = stream.next().await { extract_impl(stream, move |mut rx| {
extracter.feed(res?).await?; fs::create_dir_all(path.parent().unwrap())?;
}
extracter.done().await let mut file = tempfile()?;
read_into_file(&mut file, &mut rx)?;
// rewind it so that we can pass it to unzip
file.rewind()?;
unzip(file, &path)?;
Ok(())
})
.await
} }
/// * `filter` - If Some, then it will pass the path of the file to it /// * `filter` - If Some, then it will pass the path of the file to it
@ -237,7 +246,7 @@ pub async fn extract_tar_based_stream_with_filter<
Filter: FnMut(&Path) -> bool + Send + 'static, Filter: FnMut(&Path) -> bool + Send + 'static,
E, E,
>( >(
mut stream: impl Stream<Item = Result<Bytes, E>> + Unpin, stream: impl Stream<Item = Result<Bytes, E>> + Unpin,
output: &Path, output: &Path,
fmt: TarBasedFmt, fmt: TarBasedFmt,
filter: Option<Filter>, filter: Option<Filter>,
@ -245,28 +254,39 @@ pub async fn extract_tar_based_stream_with_filter<
where where
BinstallError: From<E>, BinstallError: From<E>,
{ {
let mut extracter = AsyncExtracter::new(output, fmt.into(), filter); let path = output.to_owned();
while let Some(res) = stream.next().await { extract_impl(stream, move |mut rx| {
extracter.feed(res?).await?; fs::create_dir_all(path.parent().unwrap())?;
}
extracter.done().await extract_compressed_from_readable(ReadableRx::new(&mut rx), fmt.into(), &path, filter)?;
Ok(())
})
.await
} }
pub async fn extract_tar_based_stream<E>( pub async fn extract_tar_based_stream<E>(
mut stream: impl Stream<Item = Result<Bytes, E>> + Unpin, stream: impl Stream<Item = Result<Bytes, E>> + Unpin,
output: &Path, output: &Path,
fmt: TarBasedFmt, fmt: TarBasedFmt,
) -> Result<(), BinstallError> ) -> Result<(), BinstallError>
where where
BinstallError: From<E>, BinstallError: From<E>,
{ {
let mut extracter = AsyncExtracter::new::<fn(&Path) -> bool>(output, fmt.into(), None); let path = output.to_owned();
while let Some(res) = stream.next().await { extract_impl(stream, move |mut rx| {
extracter.feed(res?).await?; fs::create_dir_all(path.parent().unwrap())?;
}
extracter.done().await extract_compressed_from_readable::<fn(&Path) -> bool, _>(
ReadableRx::new(&mut rx),
fmt.into(),
&path,
None,
)?;
Ok(())
})
.await
} }

View file

@ -56,8 +56,8 @@ fn untar<Filter: FnMut(&Path) -> bool>(
/// and only extract ones which filter returns `true`. /// and only extract ones which filter returns `true`.
/// Note that this is a best-effort and it only works when `fmt` /// Note that this is a best-effort and it only works when `fmt`
/// is not `PkgFmt::Bin` or `PkgFmt::Zip`. /// is not `PkgFmt::Bin` or `PkgFmt::Zip`.
pub(crate) fn extract_compressed_from_readable<Filter: FnMut(&Path) -> bool>( pub(crate) fn extract_compressed_from_readable<Filter: FnMut(&Path) -> bool, R: BufRead>(
dat: impl BufRead, dat: R,
fmt: PkgFmt, fmt: PkgFmt,
path: &Path, path: &Path,
filter: Option<Filter>, filter: Option<Filter>,