From 90a96cabc98cea1ba539164987cad65051b0ec3b Mon Sep 17 00:00:00 2001 From: Jiahao XU Date: Sat, 11 Jun 2022 20:31:46 +1000 Subject: [PATCH] Rewrite `untar` to take a visitor & simplify signature of `download_and_extract_with_filter` Signed-off-by: Jiahao XU --- src/drivers.rs | 2 +- src/helpers.rs | 2 +- src/helpers/async_extracter.rs | 48 ++++++++++++++++++++++++++++---- src/helpers/extracter.rs | 50 ++++++++++++++-------------------- 4 files changed, 64 insertions(+), 38 deletions(-) diff --git a/src/drivers.rs b/src/drivers.rs index db520d4e..1c6a5eba 100644 --- a/src/drivers.rs +++ b/src/drivers.rs @@ -114,7 +114,7 @@ pub async fn fetch_crate_cratesio( Url::parse(&crate_url)?, TarBasedFmt::Tgz, &temp_dir, - Some(move |path: &Path| path == cargo_toml || path == main || path.starts_with(&bin)), + move |path: &Path| path == cargo_toml || path == main || path.starts_with(&bin), ) .await?; diff --git a/src/helpers.rs b/src/helpers.rs index 12cf35c7..bd2d904f 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -92,7 +92,7 @@ pub async fn download_and_extract_with_filter< url: Url, fmt: TarBasedFmt, path: P, - filter: Option, + filter: Filter, ) -> Result<(), BinstallError> { debug!("Downloading from: '{url}'"); diff --git a/src/helpers/async_extracter.rs b/src/helpers/async_extracter.rs index ac84ab16..5a4051ac 100644 --- a/src/helpers/async_extracter.rs +++ b/src/helpers/async_extracter.rs @@ -1,11 +1,14 @@ use std::fmt::Debug; use std::fs; -use std::io::{self, Seek, Write}; -use std::path::Path; +use std::io::{self, Read, Seek, Write}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; use bytes::Bytes; use futures_util::stream::{Stream, StreamExt}; +use log::debug; use scopeguard::{guard, ScopeGuard}; +use tar::Entries; use tempfile::tempfile; use tokio::{ sync::mpsc, @@ -209,17 +212,42 @@ pub async fn extract_tar_based_stream_with_filter< stream: impl Stream> + Unpin, output: &Path, fmt: TarBasedFmt, - filter: Option, + filter: Filter, ) -> Result<(), BinstallError> where BinstallError: From, { - let path = output.to_owned(); + struct Visitor(F, Arc); + + impl bool + Send + 'static> TarEntriesVisitor for Visitor { + fn visit(&mut self, entries: Entries<'_, R>) -> Result<(), BinstallError> { + for res in entries { + let mut entry = res?; + let entry_path = entry.path()?; + + if self.0(&entry_path) { + debug!("Extracting {entry_path:#?}"); + + let dst = self.1.join(entry_path); + + fs::create_dir_all(dst.parent().unwrap())?; + + entry.unpack(dst)?; + } + } + + Ok(()) + } + } + + let path = Arc::new(output.to_owned()); + + let visitor = Visitor(filter, path.clone()); extract_impl(stream, move |mut rx| { fs::create_dir_all(path.parent().unwrap())?; - extract_compressed_from_readable(ReadableRx::new(&mut rx), fmt, &path, filter) + extract_compressed_from_readable(ReadableRx::new(&mut rx), fmt, &*path, Some(visitor)) }) .await } @@ -232,12 +260,20 @@ pub async fn extract_tar_based_stream( where BinstallError: From, { + struct DummyVisitor; + + impl TarEntriesVisitor for DummyVisitor { + fn visit(&mut self, _entries: Entries<'_, R>) -> Result<(), BinstallError> { + unimplemented!() + } + } + let path = output.to_owned(); extract_impl(stream, move |mut rx| { fs::create_dir_all(path.parent().unwrap())?; - extract_compressed_from_readable:: bool, _>( + extract_compressed_from_readable::( ReadableRx::new(&mut rx), fmt, &path, diff --git a/src/helpers/extracter.rs b/src/helpers/extracter.rs index cda13ef8..969a1556 100644 --- a/src/helpers/extracter.rs +++ b/src/helpers/extracter.rs @@ -1,44 +1,34 @@ -use std::fs::{self, File}; +use std::fs::File; use std::io::{BufRead, Read}; use std::path::Path; use flate2::bufread::GzDecoder; use log::debug; -use tar::Archive; +use tar::{Archive, Entries}; use xz2::bufread::XzDecoder; use zip::read::ZipArchive; use zstd::stream::Decoder as ZstdDecoder; use crate::{BinstallError, TarBasedFmt}; -/// * `filter` - If Some, then it will pass the path of the file to it -/// and only extract ones which filter returns `true`. -/// Note that this is a best-effort and it only works when `fmt` -/// is not `PkgFmt::Bin` or `PkgFmt::Zip`. -fn untar bool>( - dat: impl Read, +pub trait TarEntriesVisitor { + fn visit(&mut self, entries: Entries<'_, R>) -> Result<(), BinstallError>; +} + +/// * `f` - If Some, then this function will pass +/// the entries of the `dat` to it and let it decides +/// what to do with the tar. +fn untar( + dat: R, path: &Path, - filter: Option, + visitor: Option, ) -> Result<(), BinstallError> { let mut tar = Archive::new(dat); - if let Some(mut filter) = filter { + if let Some(mut visitor) = visitor { debug!("Untaring with filter"); - for res in tar.entries()? { - let mut entry = res?; - let entry_path = entry.path()?; - - if filter(&entry_path) { - debug!("Extracting {entry_path:#?}"); - - let dst = path.join(entry_path); - - fs::create_dir_all(dst.parent().unwrap())?; - - entry.unpack(dst)?; - } - } + visitor.visit(tar.entries()?)?; } else { debug!("Untaring entire tar"); tar.unpack(path)?; @@ -56,11 +46,11 @@ fn untar bool>( /// and only extract ones which filter returns `true`. /// Note that this is a best-effort and it only works when `fmt` /// is not `PkgFmt::Bin` or `PkgFmt::Zip`. -pub(crate) fn extract_compressed_from_readable bool, R: BufRead>( +pub(crate) fn extract_compressed_from_readable( dat: R, fmt: TarBasedFmt, path: &Path, - filter: Option, + visitor: Option, ) -> Result<(), BinstallError> { use TarBasedFmt::*; @@ -69,21 +59,21 @@ pub(crate) fn extract_compressed_from_readable bool, R: // Extract to install dir debug!("Extracting from tar archive to `{path:?}`"); - untar(dat, path, filter)? + untar(dat, path, visitor)? } Tgz => { // Extract to install dir debug!("Decompressing from tgz archive to `{path:?}`"); let tar = GzDecoder::new(dat); - untar(tar, path, filter)?; + untar(tar, path, visitor)?; } Txz => { // Extract to install dir debug!("Decompressing from txz archive to `{path:?}`"); let tar = XzDecoder::new(dat); - untar(tar, path, filter)?; + untar(tar, path, visitor)?; } Tzstd => { // Extract to install dir @@ -94,7 +84,7 @@ pub(crate) fn extract_compressed_from_readable bool, R: // as &[] by ZstdDecoder::new, thus ZstdDecoder::new // should not return any error. let tar = ZstdDecoder::with_buffer(dat)?; - untar(tar, path, filter)?; + untar(tar, path, visitor)?; } };