Fix binstalk_downloader::Download for data-verifier (#1313)

To make sure the `data_verifier` consumes the entire file and produces
the correct checksum.

Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
This commit is contained in:
Jiahao XU 2023-08-24 10:04:57 +10:00 committed by GitHub
parent b9adaa006f
commit cb9cb0e937
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,9 +2,9 @@ use std::{fmt, io, marker::PhantomData, path::Path};
use binstalk_types::cargo_toml_binstall::PkgFmtDecomposed;
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use futures_util::{stream::FusedStream, Stream, StreamExt};
use thiserror::Error as ThisError;
use tracing::{debug, instrument};
use tracing::{debug, error, instrument};
pub use binstalk_types::cargo_toml_binstall::{PkgFmt, TarBasedFmt};
@ -142,19 +142,42 @@ impl<'a> Download<'a> {
async fn get_stream(
self,
) -> Result<
impl Stream<Item = Result<Bytes, DownloadError>> + Send + Sync + Unpin + 'a,
impl Stream<Item = Result<Bytes, DownloadError>> + FusedStream + Send + Sync + Unpin + 'a,
DownloadError,
> {
let mut data_verifier = self.data_verifier;
Ok(self.client.get_stream(self.url).await?.map(move |res| {
let bytes = res?;
Ok(self
.client
.get_stream(self.url)
.await?
.map(move |res| {
let bytes = res?;
if let Some(data_verifier) = &mut data_verifier {
data_verifier.update(&bytes);
}
if let Some(data_verifier) = &mut data_verifier {
data_verifier.update(&bytes);
}
Ok(bytes)
}))
Ok(bytes)
})
// Call `fuse` at the end to make sure `data_verifier` is only
// called when the stream still has elements left.
.fuse())
}
}
/// Make sure `stream` is an alias instead of taking the value to avoid
/// exploding size of the future generated.
///
/// Accept `FusedStream` only since the `stream` could be already consumed.
async fn consume_stream<S>(stream: &mut S)
where
S: Stream<Item = Result<Bytes, DownloadError>> + FusedStream + Unpin,
{
while let Some(res) = stream.next().await {
if let Err(err) = res {
error!(?err, "failed to consume stream");
break;
}
}
}
@ -172,15 +195,23 @@ impl Download<'_> {
fmt: TarBasedFmt,
visitor: &mut dyn TarEntriesVisitor,
) -> Result<(), DownloadError> {
let stream = self.get_stream().await?;
let has_data_verifier = self.data_verifier.is_some();
let mut stream = self.get_stream().await?;
debug!("Downloading and extracting then in-memory processing");
extract_tar_based_stream_and_visit(stream, fmt, visitor).await?;
debug!("Download, extraction and in-memory procession OK");
Ok(())
match extract_tar_based_stream_and_visit(&mut stream, fmt, visitor).await {
Ok(()) => {
debug!("Download, extraction and in-memory procession OK");
Ok(())
}
Err(err) => {
if has_data_verifier {
consume_stream(&mut stream).await;
}
Err(err)
}
}
}
/// Download a file from the provided URL and extract it to the provided path.
@ -197,19 +228,31 @@ impl Download<'_> {
fmt: PkgFmt,
path: &Path,
) -> Result<ExtractedFiles, DownloadError> {
let stream = this.get_stream().await?;
let has_data_verifier = this.data_verifier.is_some();
let mut stream = this.get_stream().await?;
debug!("Downloading and extracting to: '{}'", path.display());
let extracted_files = match fmt.decompose() {
PkgFmtDecomposed::Tar(fmt) => extract_tar_based_stream(stream, path, fmt).await?,
PkgFmtDecomposed::Bin => extract_bin(stream, path).await?,
PkgFmtDecomposed::Zip => extract_zip(stream, path).await?,
let res = match fmt.decompose() {
PkgFmtDecomposed::Tar(fmt) => {
extract_tar_based_stream(&mut stream, path, fmt).await
}
PkgFmtDecomposed::Bin => extract_bin(&mut stream, path).await,
PkgFmtDecomposed::Zip => extract_zip(&mut stream, path).await,
};
debug!("Download OK, extracted to: '{}'", path.display());
Ok(extracted_files)
match res {
Ok(extracted_files) => {
debug!("Download OK, extracted to: '{}'", path.display());
Ok(extracted_files)
}
Err(err) => {
if has_data_verifier {
consume_stream(&mut stream).await;
}
Err(err)
}
}
}
inner(self, fmt, path.as_ref()).await