diff --git a/.gitignore b/.gitignore index ea8c4bf7..05923927 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +.DS_Store diff --git a/src/drivers.rs b/src/drivers.rs index b16513d9..933a6d2f 100644 --- a/src/drivers.rs +++ b/src/drivers.rs @@ -1,126 +1,12 @@ -use std::collections::BTreeSet; use std::path::{Path, PathBuf}; -use std::time::Duration; -use crates_io_api::AsyncClient; -use log::debug; -use semver::{Version, VersionReq}; -use url::Url; +use crate::BinstallError; -use crate::{helpers::*, BinstallError, PkgFmt}; +mod version; +use version::find_version; -fn find_version<'a, V: Iterator>( - requirement: &str, - version_iter: V, -) -> Result { - // Parse version requirement - let version_req = VersionReq::parse(requirement).map_err(|err| BinstallError::VersionReq { - req: requirement.into(), - err, - })?; - - // Filter for matching versions - let filtered: BTreeSet<_> = version_iter - .filter_map(|v| { - // Remove leading `v` for git tags - let ver_str = match v.strip_prefix('s') { - Some(v) => v, - None => v, - }; - - // Parse out version - let ver = Version::parse(ver_str).ok()?; - debug!("Version: {:?}", ver); - - // Filter by version match - if version_req.matches(&ver) { - Some(ver) - } else { - None - } - }) - .collect(); - - debug!("Filtered: {:?}", filtered); - - // Return highest version - filtered - .iter() - .max() - .cloned() - .ok_or(BinstallError::VersionMismatch { req: version_req }) -} - -/// Fetch a crate Cargo.toml by name and version from crates.io -pub async fn fetch_crate_cratesio( - name: &str, - version_req: &str, - temp_dir: &Path, -) -> Result { - // Fetch / update index - debug!("Looking up crate information"); - - // Build crates.io api client - let api_client = AsyncClient::new( - "cargo-binstall (https://github.com/ryankurte/cargo-binstall)", - Duration::from_millis(100), - ) - .expect("bug: invalid user agent"); - - // Fetch online crate information - let base_info = - api_client - .get_crate(name.as_ref()) - .await - .map_err(|err| BinstallError::CratesIoApi { - crate_name: name.into(), - err, - })?; - - // Locate matching version - let version_iter = - base_info - .versions - .iter() - .filter_map(|v| if !v.yanked { Some(&v.num) } else { None }); - let version_name = find_version(version_req, version_iter)?; - - // Fetch information for the filtered version - let version = base_info - .versions - .iter() - .find(|v| v.num == version_name.to_string()) - .ok_or_else(|| BinstallError::VersionUnavailable { - crate_name: name.into(), - v: version_name.clone(), - })?; - - debug!("Found information for crate version: '{}'", version.num); - - // Download crate to temporary dir (crates.io or git?) - let crate_url = format!("https://crates.io/{}", version.dl_path); - - debug!("Fetching crate from: {crate_url} and extracting Cargo.toml from it"); - - let crate_dir: PathBuf = format!("{name}-{version_name}").into(); - let crate_path = temp_dir.join(&crate_dir); - - let cargo_toml = crate_dir.join("Cargo.toml"); - let src = crate_dir.join("src"); - let main = src.join("main.rs"); - let bin = src.join("bin"); - - download_and_extract_with_filter( - Url::parse(&crate_url)?, - PkgFmt::Tgz, - &temp_dir, - Some(move |path: &Path| path == cargo_toml || path == main || path.starts_with(&bin)), - ) - .await?; - - // Return crate directory - Ok(crate_path) -} +mod crates_io; +pub use crates_io::fetch_crate_cratesio; /// Fetch a crate by name and version from github /// TODO: implement this diff --git a/src/drivers/crates_io.rs b/src/drivers/crates_io.rs new file mode 100644 index 00000000..68f04e82 --- /dev/null +++ b/src/drivers/crates_io.rs @@ -0,0 +1,76 @@ +use std::path::PathBuf; +use std::time::Duration; + +use cargo_toml::Manifest; +use crates_io_api::AsyncClient; +use log::debug; +use url::Url; + +use super::find_version; +use crate::{helpers::*, BinstallError, Meta, TarBasedFmt}; + +mod vfs; + +mod visitor; +use visitor::ManifestVisitor; + +/// Fetch a crate Cargo.toml by name and version from crates.io +pub async fn fetch_crate_cratesio( + name: &str, + version_req: &str, +) -> Result, BinstallError> { + // Fetch / update index + debug!("Looking up crate information"); + + // Build crates.io api client + let api_client = AsyncClient::new( + "cargo-binstall (https://github.com/ryankurte/cargo-binstall)", + Duration::from_millis(100), + ) + .expect("bug: invalid user agent"); + + // Fetch online crate information + let base_info = + api_client + .get_crate(name.as_ref()) + .await + .map_err(|err| BinstallError::CratesIoApi { + crate_name: name.into(), + err, + })?; + + // Locate matching version + let version_iter = + base_info + .versions + .iter() + .filter_map(|v| if !v.yanked { Some(&v.num) } else { None }); + let version_name = find_version(version_req, version_iter)?; + + // Fetch information for the filtered version + let version = base_info + .versions + .iter() + .find(|v| v.num == version_name.to_string()) + .ok_or_else(|| BinstallError::VersionUnavailable { + crate_name: name.into(), + v: version_name.clone(), + })?; + + debug!("Found information for crate version: '{}'", version.num); + + // Download crate to temporary dir (crates.io or git?) + let crate_url = format!("https://crates.io/{}", version.dl_path); + + debug!("Fetching crate from: {crate_url} and extracting Cargo.toml from it"); + + let manifest_dir_path: PathBuf = format!("{name}-{version_name}").into(); + + download_tar_based_and_visit( + Url::parse(&crate_url)?, + TarBasedFmt::Tgz, + ManifestVisitor::new(manifest_dir_path), + ) + .await? + .load_manifest() +} diff --git a/src/drivers/crates_io/vfs.rs b/src/drivers/crates_io/vfs.rs new file mode 100644 index 00000000..66e4875e --- /dev/null +++ b/src/drivers/crates_io/vfs.rs @@ -0,0 +1,52 @@ +use std::collections::{hash_map::HashMap, hash_set::HashSet}; +use std::io; +use std::path::Path; + +use cargo_toml::AbstractFilesystem; + +use crate::helpers::PathExt; + +/// This type stores the filesystem structure for the crate tarball +/// extracted in memory and can be passed to +/// `cargo_toml::Manifest::complete_from_abstract_filesystem`. +#[derive(Debug)] +pub(super) struct Vfs(HashMap, HashSet>>); + +impl Vfs { + pub(super) fn new() -> Self { + Self(HashMap::with_capacity(16)) + } + + /// * `path` - must be canonical, must not be empty. + pub(super) fn add_path(&mut self, mut path: &Path) { + while let Some(parent) = path.parent() { + // Since path has parent, it must have a filename + let filename = path.file_name().unwrap(); + + // `cargo_toml`'s implementation does the same thing. + // https://docs.rs/cargo_toml/0.11.5/src/cargo_toml/afs.rs.html#24 + let filename = filename.to_string_lossy(); + + self.0 + .entry(parent.into()) + .or_insert_with(|| HashSet::with_capacity(4)) + .insert(filename.into()); + + path = parent; + } + } +} + +impl AbstractFilesystem for Vfs { + fn file_names_in(&self, rel_path: &str) -> io::Result>> { + let rel_path = Path::new(rel_path).normalize_path(); + + Ok(self.0.get(&*rel_path).map(Clone::clone).unwrap_or_default()) + } +} + +impl AbstractFilesystem for &Vfs { + fn file_names_in(&self, rel_path: &str) -> io::Result>> { + (*self).file_names_in(rel_path) + } +} diff --git a/src/drivers/crates_io/visitor.rs b/src/drivers/crates_io/visitor.rs new file mode 100644 index 00000000..3f354771 --- /dev/null +++ b/src/drivers/crates_io/visitor.rs @@ -0,0 +1,81 @@ +use std::io::Read; +use std::path::{Path, PathBuf}; + +use cargo_toml::Manifest; +use log::debug; +use tar::Entries; + +use super::vfs::Vfs; +use crate::{ + helpers::{PathExt, TarEntriesVisitor}, + BinstallError, Meta, +}; + +#[derive(Debug)] +pub(super) struct ManifestVisitor { + cargo_toml_content: Vec, + /// manifest_dir_path is treated as the current dir. + manifest_dir_path: PathBuf, + + vfs: Vfs, +} + +impl ManifestVisitor { + pub(super) fn new(manifest_dir_path: PathBuf) -> Self { + Self { + // Cargo.toml is quite large usually. + cargo_toml_content: Vec::with_capacity(2000), + manifest_dir_path, + vfs: Vfs::new(), + } + } + + /// Load binstall metadata using the extracted information stored in memory. + pub(super) fn load_manifest(&self) -> Result, BinstallError> { + debug!("Loading manifest directly from extracted file"); + + // Load and parse manifest + let mut manifest = Manifest::::from_slice_with_metadata(&self.cargo_toml_content)?; + + // Checks vfs for binary output names + manifest.complete_from_abstract_filesystem(&self.vfs)?; + + // Return metadata + Ok(manifest) + } +} + +impl TarEntriesVisitor for ManifestVisitor { + fn visit(&mut self, entries: Entries<'_, R>) -> Result<(), BinstallError> { + for res in entries { + let mut entry = res?; + let path = entry.path()?; + let path = path.normalize_path(); + + let path = if let Ok(path) = path.strip_prefix(&self.manifest_dir_path) { + path + } else { + // The path is outside of the curr dir (manifest dir), + // ignore it. + continue; + }; + + if path == Path::new("Cargo.toml") + || path == Path::new("src/main.rs") + || path.starts_with("src/bin") + { + self.vfs.add_path(path); + } + + if path == Path::new("Cargo.toml") { + // Since it is possible for the same Cargo.toml to appear + // multiple times using `tar --keep-old-files`, here we + // clear the buffer first before reading into it. + self.cargo_toml_content.clear(); + entry.read_to_end(&mut self.cargo_toml_content)?; + } + } + + Ok(()) + } +} diff --git a/src/drivers/version.rs b/src/drivers/version.rs new file mode 100644 index 00000000..7d5f4a74 --- /dev/null +++ b/src/drivers/version.rs @@ -0,0 +1,48 @@ +use std::collections::BTreeSet; + +use log::debug; +use semver::{Version, VersionReq}; + +use crate::BinstallError; + +pub(super) fn find_version<'a, V: Iterator>( + requirement: &str, + version_iter: V, +) -> Result { + // Parse version requirement + let version_req = VersionReq::parse(requirement).map_err(|err| BinstallError::VersionReq { + req: requirement.into(), + err, + })?; + + // Filter for matching versions + let filtered: BTreeSet<_> = version_iter + .filter_map(|v| { + // Remove leading `v` for git tags + let ver_str = match v.strip_prefix('s') { + Some(v) => v, + None => v, + }; + + // Parse out version + let ver = Version::parse(ver_str).ok()?; + debug!("Version: {:?}", ver); + + // Filter by version match + if version_req.matches(&ver) { + Some(ver) + } else { + None + } + }) + .collect(); + + debug!("Filtered: {:?}", filtered); + + // Return highest version + filtered + .iter() + .max() + .cloned() + .ok_or(BinstallError::VersionMismatch { req: version_req }) +} diff --git a/src/errors.rs b/src/errors.rs index 070c91ce..4c7e5cad 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -75,7 +75,7 @@ pub enum BinstallError { /// - Exit: 74 #[error(transparent)] #[diagnostic(severity(error), code(binstall::io))] - Io(#[from] std::io::Error), + Io(std::io::Error), /// An error interacting with the crates.io API. /// @@ -231,3 +231,22 @@ impl Termination for BinstallError { code } } + +impl From for BinstallError { + fn from(err: std::io::Error) -> Self { + if err.get_ref().is_some() { + let kind = err.kind(); + + let inner = err + .into_inner() + .expect("err.get_ref() returns Some, so err.into_inner() should also return Some"); + + inner + .downcast() + .map(|b| *b) + .unwrap_or_else(|err| BinstallError::Io(std::io::Error::new(kind, err))) + } else { + BinstallError::Io(err) + } + } +} diff --git a/src/format.rs b/src/format.rs new file mode 100644 index 00000000..840ffc6a --- /dev/null +++ b/src/format.rs @@ -0,0 +1,74 @@ +use serde::{Deserialize, Serialize}; +use strum_macros::{Display, EnumString, EnumVariantNames}; + +/// Binary format enumeration +#[derive( + Debug, Copy, Clone, PartialEq, Serialize, Deserialize, Display, EnumString, EnumVariantNames, +)] +#[strum(serialize_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum PkgFmt { + /// Download format is TAR (uncompressed) + Tar, + /// Download format is TGZ (TAR + GZip) + Tgz, + /// Download format is TAR + XZ + Txz, + /// Download format is TAR + Zstd + Tzstd, + /// Download format is Zip + Zip, + /// Download format is raw / binary + Bin, +} + +impl Default for PkgFmt { + fn default() -> Self { + Self::Tgz + } +} + +impl PkgFmt { + /// If self is one of the tar based formats, + /// return Some. + pub fn decompose(self) -> PkgFmtDecomposed { + match self { + PkgFmt::Tar => PkgFmtDecomposed::Tar(TarBasedFmt::Tar), + PkgFmt::Tgz => PkgFmtDecomposed::Tar(TarBasedFmt::Tgz), + PkgFmt::Txz => PkgFmtDecomposed::Tar(TarBasedFmt::Txz), + PkgFmt::Tzstd => PkgFmtDecomposed::Tar(TarBasedFmt::Tzstd), + PkgFmt::Bin => PkgFmtDecomposed::Bin, + PkgFmt::Zip => PkgFmtDecomposed::Zip, + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum PkgFmtDecomposed { + Tar(TarBasedFmt), + Bin, + Zip, +} + +#[derive(Debug, Display, Copy, Clone, PartialEq)] +pub enum TarBasedFmt { + /// Download format is TAR (uncompressed) + Tar, + /// Download format is TGZ (TAR + GZip) + Tgz, + /// Download format is TAR + XZ + Txz, + /// Download format is TAR + Zstd + Tzstd, +} + +impl From for PkgFmt { + fn from(fmt: TarBasedFmt) -> Self { + match fmt { + TarBasedFmt::Tar => PkgFmt::Tar, + TarBasedFmt::Tgz => PkgFmt::Tgz, + TarBasedFmt::Txz => PkgFmt::Txz, + TarBasedFmt::Tzstd => PkgFmt::Tzstd, + } + } +} diff --git a/src/helpers.rs b/src/helpers.rs index 5d6d350a..55aee08c 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -1,18 +1,19 @@ -use std::{ - path::{Path, PathBuf}, -}; +use std::fmt::Debug; +use std::path::{Path, PathBuf}; +use bytes::Bytes; use cargo_toml::Manifest; +use futures_util::stream::Stream; use log::debug; -use reqwest::Method; +use reqwest::{Method, Response}; use serde::Serialize; use tinytemplate::TinyTemplate; use url::Url; -use crate::{BinstallError, Meta, PkgFmt}; +use crate::{BinstallError, Meta, PkgFmt, PkgFmtDecomposed, TarBasedFmt}; mod async_extracter; -pub use async_extracter::extract_archive_stream; +pub use async_extracter::*; mod auto_abort_join_handle; pub use auto_abort_join_handle::AutoAbortJoinHandle; @@ -21,7 +22,10 @@ mod ui_thread; pub use ui_thread::UIThread; mod extracter; -mod readable_rx; +mod stream_readable; + +mod path_ext; +pub use path_ext::*; /// Load binstall metadata from the crate `Cargo.toml` at the provided path pub fn load_manifest_path>( @@ -45,13 +49,42 @@ pub async fn remote_exists(url: Url, method: Method) -> Result Result>, BinstallError> { + debug!("Downloading from: '{url}'"); + + reqwest::get(url.clone()) + .await + .and_then(|r| r.error_for_status()) + .map_err(|err| BinstallError::Http { + method: Method::GET, + url, + err, + }) + .map(Response::bytes_stream) +} + /// Download a file from the provided URL and extract it to the provided path. pub async fn download_and_extract>( url: Url, fmt: PkgFmt, path: P, ) -> Result<(), BinstallError> { - download_and_extract_with_filter:: bool, _>(url, fmt, path.as_ref(), None).await + let stream = create_request(url).await?; + + let path = path.as_ref(); + debug!("Downloading and extracting to: '{}'", path.display()); + + match fmt.decompose() { + PkgFmtDecomposed::Tar(fmt) => extract_tar_based_stream(stream, path, fmt).await?, + PkgFmtDecomposed::Bin => extract_bin(stream, path).await?, + PkgFmtDecomposed::Zip => extract_zip(stream, path).await?, + } + + debug!("Download OK, extracted to: '{}'", path.display()); + + Ok(()) } /// Download a file from the provided URL and extract part of it to @@ -59,36 +92,20 @@ pub async fn download_and_extract>( /// /// * `filter` - If Some, then it will pass the path of the file to it /// and only extract ones which filter returns `true`. -/// Note that this is a best-effort and it only works when `fmt` -/// is not `PkgFmt::Bin` or `PkgFmt::Zip`. -pub async fn download_and_extract_with_filter< - Filter: FnMut(&Path) -> bool + Send + 'static, - P: AsRef, ->( +pub async fn download_tar_based_and_visit( url: Url, - fmt: PkgFmt, - path: P, - filter: Option, -) -> Result<(), BinstallError> { - debug!("Downloading from: '{url}'"); + fmt: TarBasedFmt, + visitor: V, +) -> Result { + let stream = create_request(url).await?; - let resp = reqwest::get(url.clone()) - .await - .and_then(|r| r.error_for_status()) - .map_err(|err| BinstallError::Http { - method: Method::GET, - url, - err, - })?; + debug!("Downloading and extracting then in-memory processing"); - let path = path.as_ref(); - debug!("Downloading to file: '{}'", path.display()); + let visitor = extract_tar_based_stream_and_visit(stream, fmt, visitor).await?; - extract_archive_stream(resp.bytes_stream(), path, fmt, filter).await?; + debug!("Download, extraction and in-memory procession OK"); - debug!("Download OK, written to file: '{}'", path.display()); - - Ok(()) + Ok(visitor) } /// Fetch install path from environment diff --git a/src/helpers/async_extracter.rs b/src/helpers/async_extracter.rs index 7f858a1b..5a7ef9c1 100644 --- a/src/helpers/async_extracter.rs +++ b/src/helpers/async_extracter.rs @@ -1,226 +1,129 @@ +//! # Advantages +//! +//! Using this mod has the following advantages over downloading +//! to file then extracting: +//! +//! - The code is pipelined instead of storing the downloaded file in memory +//! and extract it, except for `PkgFmt::Zip`, since `ZipArchiver::new` +//! requires `std::io::Seek`, so it fallbacks to writing the a file then +//! unzip it. +//! - Compressing/writing which takes a lot of CPU time will not block +//! the runtime anymore. +//! - For all `tar` based formats, it can extract only specified files and +//! process them in memory, without any disk I/O. + +use std::fmt::Debug; use std::fs; -use std::io::{self, Seek, Write}; +use std::io::{copy, Read, Seek}; use std::path::Path; use bytes::Bytes; -use futures_util::stream::{Stream, StreamExt}; -use scopeguard::{guard, Always, ScopeGuard}; +use futures_util::stream::Stream; +use log::debug; +use scopeguard::{guard, ScopeGuard}; +use tar::Entries; use tempfile::tempfile; -use tokio::{ - sync::mpsc, - task::{spawn_blocking, JoinHandle}, -}; +use tokio::task::block_in_place; -use super::{extracter::*, readable_rx::*}; -use crate::{BinstallError, PkgFmt}; +use super::{extracter::*, stream_readable::StreamReadable}; +use crate::{BinstallError, TarBasedFmt}; -pub(crate) enum Content { - /// Data to write to file - Data(Bytes), - - /// Abort the writing and remove the file. - Abort, -} - -#[derive(Debug)] -struct AsyncExtracterInner { - /// Use AutoAbortJoinHandle so that the task - /// will be cancelled on failure. - handle: JoinHandle>, - tx: mpsc::Sender, -} - -impl AsyncExtracterInner { - /// * `filter` - If Some, then it will pass the path of the file to it - /// and only extract ones which filter returns `true`. - /// Note that this is a best-effort and it only works when `fmt` - /// is not `PkgFmt::Bin` or `PkgFmt::Zip`. - fn new bool + Send + 'static>( - path: &Path, - fmt: PkgFmt, - filter: Option, - ) -> Self { - let path = path.to_owned(); - let (tx, mut rx) = mpsc::channel::(100); - - let handle = spawn_blocking(move || { - fs::create_dir_all(path.parent().unwrap())?; - - match fmt { - PkgFmt::Bin => { - let mut file = fs::File::create(&path)?; - - // remove it unless the operation isn't aborted and no write - // fails. - let remove_guard = guard(&path, |path| { - fs::remove_file(path).ok(); - }); - - Self::read_into_file(&mut file, &mut rx)?; - - // Operation isn't aborted and all writes succeed, - // disarm the remove_guard. - ScopeGuard::into_inner(remove_guard); - } - PkgFmt::Zip => { - let mut file = tempfile()?; - - Self::read_into_file(&mut file, &mut rx)?; - - // rewind it so that we can pass it to unzip - file.rewind()?; - - unzip(file, &path)?; - } - _ => { - extract_compressed_from_readable(ReadableRx::new(&mut rx), fmt, &path, filter)? - } - } - - Ok(()) - }); - - Self { handle, tx } - } - - fn read_into_file( - file: &mut fs::File, - rx: &mut mpsc::Receiver, - ) -> Result<(), BinstallError> { - while let Some(content) = rx.blocking_recv() { - match content { - Content::Data(bytes) => file.write_all(&*bytes)?, - Content::Abort => { - return Err(io::Error::new(io::ErrorKind::Other, "Aborted").into()) - } - } - } - - file.flush()?; - - Ok(()) - } - - /// Upon error, this extracter shall not be reused. - /// Otherwise, `Self::done` would panic. - async fn feed(&mut self, bytes: Bytes) -> Result<(), BinstallError> { - if self.tx.send(Content::Data(bytes)).await.is_err() { - // task failed - Err(Self::wait(&mut self.handle).await.expect_err( - "Implementation bug: write task finished successfully before all writes are done", - )) - } else { - Ok(()) - } - } - - async fn done(mut self) -> Result<(), BinstallError> { - // Drop tx as soon as possible so that the task would wrap up what it - // was doing and flush out all the pending data. - drop(self.tx); - - Self::wait(&mut self.handle).await - } - - async fn wait(handle: &mut JoinHandle>) -> Result<(), BinstallError> { - match handle.await { - Ok(res) => res, - Err(join_err) => Err(io::Error::new(io::ErrorKind::Other, join_err).into()), - } - } - - fn abort(self) { - let tx = self.tx; - // If Self::write fail, then the task is already tear down, - // tx closed and no need to abort. - if !tx.is_closed() { - // Use send here because blocking_send would panic if used - // in async threads. - tokio::spawn(async move { - tx.send(Content::Abort).await.ok(); - }); - } - } -} - -/// AsyncExtracter will pass the `Bytes` you give to another thread via -/// a `mpsc` and decompress and unpack it if needed. -/// -/// After all write is done, you must call `AsyncExtracter::done`, -/// otherwise the extracted content will be removed on drop. -/// -/// # Advantages -/// -/// `download_and_extract` has the following advantages over downloading -/// plus extracting in on the same thread: -/// -/// - The code is pipelined instead of storing the downloaded file in memory -/// and extract it, except for `PkgFmt::Zip`, since `ZipArchiver::new` -/// requires `std::io::Seek`, so it fallbacks to writing the a file then -/// unzip it. -/// - The async part (downloading) and the extracting part runs in parallel -/// using `tokio::spawn_nonblocking`. -/// - Compressing/writing which takes a lot of CPU time will not block -/// the runtime anymore. -/// - For any PkgFmt except for `PkgFmt::Zip` and `PkgFmt::Bin` (basically -/// all `tar` based formats), it can extract only specified files. -/// This means that `super::drivers::fetch_crate_cratesio` no longer need to -/// extract the whole crate and write them to disk, it now only extract the -/// relevant part (`Cargo.toml`) out to disk and open it. -#[derive(Debug)] -struct AsyncExtracter(ScopeGuard); - -impl AsyncExtracter { - /// * `path` - If `fmt` is `PkgFmt::Bin`, then this is the filename - /// for the bin. - /// Otherwise, it is the directory where the extracted content will be put. - /// * `fmt` - The format of the archive to feed in. - /// * `filter` - If Some, then it will pass the path of the file to it - /// and only extract ones which filter returns `true`. - /// Note that this is a best-effort and it only works when `fmt` - /// is not `PkgFmt::Bin` or `PkgFmt::Zip`. - fn new bool + Send + 'static>( - path: &Path, - fmt: PkgFmt, - filter: Option, - ) -> Self { - let inner = AsyncExtracterInner::new(path, fmt, filter); - Self(guard(inner, AsyncExtracterInner::abort)) - } - - /// Upon error, this extracter shall not be reused. - /// Otherwise, `Self::done` would panic. - async fn feed(&mut self, bytes: Bytes) -> Result<(), BinstallError> { - self.0.feed(bytes).await - } - - async fn done(self) -> Result<(), BinstallError> { - ScopeGuard::into_inner(self.0).done().await - } -} - -/// * `output` - If `fmt` is `PkgFmt::Bin`, then this is the filename -/// for the bin. -/// Otherwise, it is the directory where the extracted content will be put. -/// * `fmt` - The format of the archive to feed in. -/// * `filter` - If Some, then it will pass the path of the file to it -/// and only extract ones which filter returns `true`. -/// Note that this is a best-effort and it only works when `fmt` -/// is not `PkgFmt::Bin` or `PkgFmt::Zip`. -pub async fn extract_archive_stream bool + Send + 'static, E>( - mut stream: impl Stream> + Unpin, - output: &Path, - fmt: PkgFmt, - filter: Option, -) -> Result<(), BinstallError> +pub async fn extract_bin(stream: S, path: &Path) -> Result<(), BinstallError> where + S: Stream> + Unpin + 'static, BinstallError: From, { - let mut extracter = AsyncExtracter::new(output, fmt, filter); + let mut reader = StreamReadable::new(stream).await; + block_in_place(move || { + fs::create_dir_all(path.parent().unwrap())?; - while let Some(res) = stream.next().await { - extracter.feed(res?).await?; - } + let mut file = fs::File::create(&path)?; - extracter.done().await + // remove it unless the operation isn't aborted and no write + // fails. + let remove_guard = guard(&path, |path| { + fs::remove_file(path).ok(); + }); + + copy(&mut reader, &mut file)?; + + // Operation isn't aborted and all writes succeed, + // disarm the remove_guard. + ScopeGuard::into_inner(remove_guard); + + Ok(()) + }) +} + +pub async fn extract_zip(stream: S, path: &Path) -> Result<(), BinstallError> +where + S: Stream> + Unpin + 'static, + BinstallError: From, +{ + let mut reader = StreamReadable::new(stream).await; + block_in_place(move || { + fs::create_dir_all(path.parent().unwrap())?; + + let mut file = tempfile()?; + + copy(&mut reader, &mut file)?; + + // rewind it so that we can pass it to unzip + file.rewind()?; + + unzip(file, path) + }) +} + +pub async fn extract_tar_based_stream( + stream: S, + path: &Path, + fmt: TarBasedFmt, +) -> Result<(), BinstallError> +where + S: Stream> + Unpin + 'static, + BinstallError: From, +{ + let reader = StreamReadable::new(stream).await; + block_in_place(move || { + fs::create_dir_all(path.parent().unwrap())?; + + debug!("Extracting from {fmt} archive to {path:#?}"); + + create_tar_decoder(reader, fmt)?.unpack(path)?; + + Ok(()) + }) +} + +/// Visitor must iterate over all entries. +/// Entires can be in arbitary order. +pub trait TarEntriesVisitor { + fn visit(&mut self, entries: Entries<'_, R>) -> Result<(), BinstallError>; +} + +impl TarEntriesVisitor for &mut V { + fn visit(&mut self, entries: Entries<'_, R>) -> Result<(), BinstallError> { + (*self).visit(entries) + } +} + +pub async fn extract_tar_based_stream_and_visit( + stream: S, + fmt: TarBasedFmt, + mut visitor: V, +) -> Result +where + S: Stream> + Unpin + 'static, + V: TarEntriesVisitor + Debug + Send + 'static, + BinstallError: From, +{ + let reader = StreamReadable::new(stream).await; + block_in_place(move || { + debug!("Extracting from {fmt} archive to process it in memory"); + + let mut tar = create_tar_decoder(reader, fmt)?; + visitor.visit(tar.entries()?)?; + Ok(visitor) + }) } diff --git a/src/helpers/extracter.rs b/src/helpers/extracter.rs index 42426693..13f018b6 100644 --- a/src/helpers/extracter.rs +++ b/src/helpers/extracter.rs @@ -1,5 +1,5 @@ -use std::fs::{self, File}; -use std::io::{BufRead, Read}; +use std::fs::File; +use std::io::{self, BufRead, Read}; use std::path::Path; use flate2::bufread::GzDecoder; @@ -9,99 +9,31 @@ use xz2::bufread::XzDecoder; use zip::read::ZipArchive; use zstd::stream::Decoder as ZstdDecoder; -use crate::{BinstallError, PkgFmt}; +use crate::{BinstallError, TarBasedFmt}; -/// * `filter` - If Some, then it will pass the path of the file to it -/// and only extract ones which filter returns `true`. -/// Note that this is a best-effort and it only works when `fmt` -/// is not `PkgFmt::Bin` or `PkgFmt::Zip`. -fn untar bool>( - dat: impl Read, - path: &Path, - filter: Option, -) -> Result<(), BinstallError> { - let mut tar = Archive::new(dat); - - if let Some(mut filter) = filter { - debug!("Untaring with filter"); - - for res in tar.entries()? { - let mut entry = res?; - let entry_path = entry.path()?; - - if filter(&entry_path) { - debug!("Extracting {entry_path:#?}"); - - let dst = path.join(entry_path); - - fs::create_dir_all(dst.parent().unwrap())?; - - entry.unpack(dst)?; - } - } - } else { - debug!("Untaring entire tar"); - tar.unpack(path)?; - } - - debug!("Untaring completed"); - - Ok(()) -} - -/// Extract files from the specified source onto the specified path. -/// -/// * `fmt` - must not be `PkgFmt::Bin` or `PkgFmt::Zip`. -/// * `filter` - If Some, then it will pass the path of the file to it -/// and only extract ones which filter returns `true`. -/// Note that this is a best-effort and it only works when `fmt` -/// is not `PkgFmt::Bin` or `PkgFmt::Zip`. -pub(crate) fn extract_compressed_from_readable bool>( - dat: impl BufRead, - fmt: PkgFmt, - path: &Path, - filter: Option, -) -> Result<(), BinstallError> { - match fmt { - PkgFmt::Tar => { - // Extract to install dir - debug!("Extracting from tar archive to `{path:?}`"); - - untar(dat, path, filter)? - } - PkgFmt::Tgz => { - // Extract to install dir - debug!("Decompressing from tgz archive to `{path:?}`"); - - let tar = GzDecoder::new(dat); - untar(tar, path, filter)?; - } - PkgFmt::Txz => { - // Extract to install dir - debug!("Decompressing from txz archive to `{path:?}`"); - - let tar = XzDecoder::new(dat); - untar(tar, path, filter)?; - } - PkgFmt::Tzstd => { - // Extract to install dir - debug!("Decompressing from tzstd archive to `{path:?}`"); +pub(super) fn create_tar_decoder( + dat: impl BufRead + 'static, + fmt: TarBasedFmt, +) -> io::Result>> { + use TarBasedFmt::*; + let r: Box = match fmt { + Tar => Box::new(dat), + Tgz => Box::new(GzDecoder::new(dat)), + Txz => Box::new(XzDecoder::new(dat)), + Tzstd => { // The error can only come from raw::Decoder::with_dictionary // as of zstd 0.10.2 and 0.11.2, which is specified // as &[] by ZstdDecoder::new, thus ZstdDecoder::new // should not return any error. - let tar = ZstdDecoder::with_buffer(dat)?; - untar(tar, path, filter)?; + Box::new(ZstdDecoder::with_buffer(dat)?) } - PkgFmt::Zip => panic!("Unexpected PkgFmt::Zip!"), - PkgFmt::Bin => panic!("Unexpected PkgFmt::Bin!"), }; - Ok(()) + Ok(Archive::new(r)) } -pub(crate) fn unzip(dat: File, dst: &Path) -> Result<(), BinstallError> { +pub(super) fn unzip(dat: File, dst: &Path) -> Result<(), BinstallError> { debug!("Decompressing from zip archive to `{dst:?}`"); let mut zip = ZipArchive::new(dat)?; diff --git a/src/helpers/path_ext.rs b/src/helpers/path_ext.rs new file mode 100644 index 00000000..78f26687 --- /dev/null +++ b/src/helpers/path_ext.rs @@ -0,0 +1,58 @@ +//! Shamelessly adapted from: +//! https://github.com/rust-lang/cargo/blob/fede83ccf973457de319ba6fa0e36ead454d2e20/src/cargo/util/paths.rs#L61 + +use std::borrow::Cow; +use std::path::{Component, Path, PathBuf}; + +pub trait PathExt { + /// Similiar to `os.path.normpath`: It does not perform + /// any fs operation. + fn normalize_path(&self) -> Cow<'_, Path>; +} + +fn is_normalized(path: &Path) -> bool { + for component in path.components() { + match component { + Component::CurDir | Component::ParentDir => { + return false; + } + _ => continue, + } + } + + true +} + +impl PathExt for Path { + fn normalize_path(&self) -> Cow<'_, Path> { + if is_normalized(self) { + return Cow::Borrowed(self); + } + + let mut components = self.components().peekable(); + let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek() { + let buf = PathBuf::from(c.as_os_str()); + components.next(); + buf + } else { + PathBuf::new() + }; + + for component in components { + match component { + Component::Prefix(..) => unreachable!(), + Component::RootDir => { + ret.push(component.as_os_str()); + } + Component::CurDir => {} + Component::ParentDir => { + ret.pop(); + } + Component::Normal(c) => { + ret.push(c); + } + } + } + Cow::Owned(ret) + } +} diff --git a/src/helpers/readable_rx.rs b/src/helpers/stream_readable.rs similarity index 50% rename from src/helpers/readable_rx.rs rename to src/helpers/stream_readable.rs index 545bc176..17113591 100644 --- a/src/helpers/readable_rx.rs +++ b/src/helpers/stream_readable.rs @@ -2,26 +2,38 @@ use std::cmp::min; use std::io::{self, BufRead, Read}; use bytes::{Buf, Bytes}; -use tokio::sync::mpsc::Receiver; +use futures_util::stream::{Stream, StreamExt}; +use tokio::runtime::Handle; -use super::async_extracter::Content; +use super::BinstallError; +/// This wraps an AsyncIterator as a `Read`able. +/// It must be used in non-async context only, +/// meaning you have to use it with +/// `tokio::task::{block_in_place, spawn_blocking}` or +/// `std::thread::spawn`. #[derive(Debug)] -pub(crate) struct ReadableRx<'a> { - rx: &'a mut Receiver, +pub(super) struct StreamReadable { + stream: S, + handle: Handle, bytes: Bytes, } -impl<'a> ReadableRx<'a> { - pub(crate) fn new(rx: &'a mut Receiver) -> Self { +impl StreamReadable { + pub(super) async fn new(stream: S) -> Self { Self { - rx, + stream, + handle: Handle::current(), bytes: Bytes::new(), } } } -impl Read for ReadableRx<'_> { +impl Read for StreamReadable +where + S: Stream> + Unpin, + BinstallError: From, +{ fn read(&mut self, buf: &mut [u8]) -> io::Result { if buf.is_empty() { return Ok(0); @@ -42,15 +54,20 @@ impl Read for ReadableRx<'_> { Ok(n) } } - -impl BufRead for ReadableRx<'_> { +impl BufRead for StreamReadable +where + S: Stream> + Unpin, + BinstallError: From, +{ fn fill_buf(&mut self) -> io::Result<&[u8]> { let bytes = &mut self.bytes; + if !bytes.has_remaining() { - match self.rx.blocking_recv() { - Some(Content::Data(new_bytes)) => *bytes = new_bytes, - Some(Content::Abort) => { - return Err(io::Error::new(io::ErrorKind::Other, "Aborted")) + match self.handle.block_on(async { self.stream.next().await }) { + Some(Ok(new_bytes)) => *bytes = new_bytes, + Some(Err(e)) => { + let e: BinstallError = e.into(); + return Err(io::Error::new(io::ErrorKind::Other, e)); } None => (), } diff --git a/src/lib.rs b/src/lib.rs index 0ad83f42..5725995d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use strum_macros::{Display, EnumString, EnumVariantNames}; pub mod drivers; pub use drivers::*; @@ -18,6 +17,9 @@ pub mod fetchers; mod target; pub use target::*; +mod format; +pub use format::*; + /// Default package path template (may be overridden in package Cargo.toml) pub const DEFAULT_PKG_URL: &str = "{ repo }/releases/download/v{ version }/{ name }-{ target }-v{ version }.{ archive-format }"; @@ -25,33 +27,6 @@ pub const DEFAULT_PKG_URL: &str = /// Default binary name template (may be overridden in package Cargo.toml) pub const DEFAULT_BIN_DIR: &str = "{ name }-{ target }-v{ version }/{ bin }{ binary-ext }"; -/// Binary format enumeration -#[derive( - Debug, Copy, Clone, PartialEq, Serialize, Deserialize, Display, EnumString, EnumVariantNames, -)] -#[strum(serialize_all = "snake_case")] -#[serde(rename_all = "snake_case")] -pub enum PkgFmt { - /// Download format is TAR (uncompressed) - Tar, - /// Download format is TGZ (TAR + GZip) - Tgz, - /// Download format is TAR + XZ - Txz, - /// Download format is TAR + Zstd - Tzstd, - /// Download format is Zip - Zip, - /// Download format is raw / binary - Bin, -} - -impl Default for PkgFmt { - fn default() -> Self { - Self::Tgz - } -} - /// `binstall` metadata container /// /// Required to nest metadata under `package.metadata.binstall` diff --git a/src/main.rs b/src/main.rs index 5b89353c..0d8748cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -210,13 +210,11 @@ async fn entry() -> Result<()> { // Fetch crate via crates.io, git, or use a local manifest path // TODO: work out which of these to do based on `opts.name` // TODO: support git-based fetches (whole repo name rather than just crate name) - let manifest_path = match opts.manifest_path.clone() { - Some(p) => p, - None => fetch_crate_cratesio(&opts.name, &opts.version, temp_dir.path()).await?, + let manifest = match opts.manifest_path.clone() { + Some(manifest_path) => load_manifest_path(manifest_path.join("Cargo.toml"))?, + None => fetch_crate_cratesio(&opts.name, &opts.version).await?, }; - debug!("Reading manifest: {}", manifest_path.display()); - let manifest = load_manifest_path(manifest_path.join("Cargo.toml"))?; let package = manifest.package.unwrap(); let is_plain_version = semver::Version::from_str(&opts.version).is_ok();