mirror of
https://github.com/cargo-bins/cargo-binstall.git
synced 2025-04-24 14:28:42 +00:00
Refactor: Extract new crate binstalk-{signal, downloader} (#518)
* Refactor: Extract new crate binstalk-downloader * Re-export `PkgFmt` from `binstalk_manifests` * Update release-pr.yml * Update dependabot Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
This commit is contained in:
parent
3841762a5b
commit
89fa5b1769
21 changed files with 456 additions and 260 deletions
70
crates/binstalk-downloader/Cargo.toml
Normal file
70
crates/binstalk-downloader/Cargo.toml
Normal file
|
@ -0,0 +1,70 @@
|
|||
[package]
|
||||
name = "binstalk-downloader"
|
||||
description = "The binstall toolkit for downloading and extracting file"
|
||||
repository = "https://github.com/cargo-bins/cargo-binstall"
|
||||
documentation = "https://docs.rs/binstalk-downloader"
|
||||
version = "0.1.0"
|
||||
rust-version = "1.61.0"
|
||||
authors = ["ryan <ryan@kurte.nz>"]
|
||||
edition = "2021"
|
||||
license = "GPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
binstalk-manifests = { version = "0.1.0", path = "../binstalk-manifests" }
|
||||
bytes = "1.2.1"
|
||||
bzip2 = "0.4.3"
|
||||
digest = "0.10.5"
|
||||
flate2 = { version = "1.0.24", default-features = false }
|
||||
futures-util = { version = "0.3.25", default-features = false, features = ["std"] }
|
||||
generic-array = "0.14.6"
|
||||
httpdate = "1.0.2"
|
||||
log = { version = "0.4.17", features = ["std"] }
|
||||
reqwest = { version = "0.11.12", features = ["stream", "gzip", "brotli", "deflate"], default-features = false }
|
||||
scopeguard = "1.1.0"
|
||||
# Use a fork here since we need PAX support, but the upstream
|
||||
# does not hav the PR merged yet.
|
||||
#
|
||||
#tar = "0.4.38"
|
||||
tar = { package = "binstall-tar", version = "0.4.39" }
|
||||
tempfile = "3.3.0"
|
||||
thiserror = "1.0.37"
|
||||
tokio = { version = "1.21.2", features = ["macros", "rt-multi-thread", "sync", "time"], default-features = false }
|
||||
tower = { version = "0.4.13", features = ["limit", "util"] }
|
||||
trust-dns-resolver = { version = "0.21.2", optional = true, default-features = false, features = ["dnssec-ring"] }
|
||||
url = "2.3.1"
|
||||
|
||||
xz2 = "0.1.7"
|
||||
|
||||
# Disable all features of zip except for features of compression algorithms:
|
||||
# Disabled features include:
|
||||
# - aes-crypto: Enables decryption of files which were encrypted with AES, absolutely zero use for
|
||||
# this crate.
|
||||
# - time: Enables features using the [time](https://github.com/time-rs/time) crate,
|
||||
# which is not used by this crate.
|
||||
zip = { version = "0.6.3", default-features = false, features = ["deflate", "bzip2", "zstd"] }
|
||||
|
||||
# zstd is also depended by zip.
|
||||
# Since zip 0.6.3 depends on zstd 0.11, we also have to use 0.11 here,
|
||||
# otherwise there will be a link conflict.
|
||||
zstd = { version = "0.11.2", default-features = false }
|
||||
|
||||
[features]
|
||||
default = ["static", "rustls"]
|
||||
|
||||
static = ["bzip2/static", "xz2/static"]
|
||||
pkg-config = ["zstd/pkg-config"]
|
||||
|
||||
zlib-ng = ["flate2/zlib-ng"]
|
||||
|
||||
rustls = [
|
||||
"reqwest/rustls-tls",
|
||||
|
||||
# Enable the following features only if trust-dns-resolver is enabled.
|
||||
"trust-dns-resolver?/dns-over-rustls",
|
||||
# trust-dns-resolver currently supports https with rustls
|
||||
"trust-dns-resolver?/dns-over-https-rustls",
|
||||
]
|
||||
native-tls = ["reqwest/native-tls", "trust-dns-resolver?/dns-over-native-tls"]
|
||||
|
||||
# Enable trust-dns-resolver so that features on it will also be enabled.
|
||||
trust-dns = ["trust-dns-resolver", "reqwest/trust-dns"]
|
170
crates/binstalk-downloader/src/download.rs
Normal file
170
crates/binstalk-downloader/src/download.rs
Normal file
|
@ -0,0 +1,170 @@
|
|||
use std::{fmt::Debug, future::Future, io, marker::PhantomData, path::Path, pin::Pin};
|
||||
|
||||
use binstalk_manifests::cargo_toml_binstall::{PkgFmtDecomposed, TarBasedFmt};
|
||||
use digest::{Digest, FixedOutput, HashMarker, Output, OutputSizeUser, Update};
|
||||
use log::debug;
|
||||
use thiserror::Error as ThisError;
|
||||
|
||||
pub use binstalk_manifests::cargo_toml_binstall::PkgFmt;
|
||||
pub use tar::Entries;
|
||||
pub use zip::result::ZipError;
|
||||
|
||||
use crate::remote::{Client, Error as RemoteError, Url};
|
||||
|
||||
mod async_extracter;
|
||||
pub use async_extracter::TarEntriesVisitor;
|
||||
use async_extracter::*;
|
||||
|
||||
mod extracter;
|
||||
mod stream_readable;
|
||||
|
||||
pub type CancellationFuture = Option<Pin<Box<dyn Future<Output = Result<(), io::Error>> + Send>>>;
|
||||
|
||||
#[derive(Debug, ThisError)]
|
||||
pub enum DownloadError {
|
||||
#[error(transparent)]
|
||||
Unzip(#[from] ZipError),
|
||||
|
||||
#[error(transparent)]
|
||||
Remote(#[from] RemoteError),
|
||||
|
||||
/// A generic I/O error.
|
||||
///
|
||||
/// - Code: `binstall::io`
|
||||
/// - Exit: 74
|
||||
#[error(transparent)]
|
||||
Io(io::Error),
|
||||
|
||||
#[error("installation cancelled by user")]
|
||||
UserAbort,
|
||||
}
|
||||
|
||||
impl From<io::Error> for DownloadError {
|
||||
fn from(err: io::Error) -> Self {
|
||||
if err.get_ref().is_some() {
|
||||
let kind = err.kind();
|
||||
|
||||
let inner = err
|
||||
.into_inner()
|
||||
.expect("err.get_ref() returns Some, so err.into_inner() should also return Some");
|
||||
|
||||
inner
|
||||
.downcast()
|
||||
.map(|b| *b)
|
||||
.unwrap_or_else(|err| DownloadError::Io(io::Error::new(kind, err)))
|
||||
} else {
|
||||
DownloadError::Io(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DownloadError> for io::Error {
|
||||
fn from(e: DownloadError) -> io::Error {
|
||||
match e {
|
||||
DownloadError::Io(io_error) => io_error,
|
||||
e => io::Error::new(io::ErrorKind::Other, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Download<D: Digest = NoDigest> {
|
||||
client: Client,
|
||||
url: Url,
|
||||
_digest: PhantomData<D>,
|
||||
_checksum: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Download {
|
||||
pub fn new(client: Client, url: Url) -> Self {
|
||||
Self {
|
||||
client,
|
||||
url,
|
||||
_digest: PhantomData::default(),
|
||||
_checksum: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Download a file from the provided URL and process them in memory.
|
||||
///
|
||||
/// This does not support verifying a checksum due to the partial extraction
|
||||
/// and will ignore one if specified.
|
||||
///
|
||||
/// `cancellation_future` can be used to cancel the extraction and return
|
||||
/// [`DownloadError::UserAbort`] error.
|
||||
pub async fn and_visit_tar<V: TarEntriesVisitor + Debug + Send + 'static>(
|
||||
self,
|
||||
fmt: TarBasedFmt,
|
||||
visitor: V,
|
||||
cancellation_future: CancellationFuture,
|
||||
) -> Result<V::Target, DownloadError> {
|
||||
let stream = self.client.get_stream(self.url).await?;
|
||||
|
||||
debug!("Downloading and extracting then in-memory processing");
|
||||
|
||||
let ret =
|
||||
extract_tar_based_stream_and_visit(stream, fmt, visitor, cancellation_future).await?;
|
||||
|
||||
debug!("Download, extraction and in-memory procession OK");
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Download a file from the provided URL and extract it to the provided path.
|
||||
///
|
||||
/// `cancellation_future` can be used to cancel the extraction and return
|
||||
/// [`DownloadError::UserAbort`] error.
|
||||
pub async fn and_extract(
|
||||
self,
|
||||
fmt: PkgFmt,
|
||||
path: impl AsRef<Path>,
|
||||
cancellation_future: CancellationFuture,
|
||||
) -> Result<(), DownloadError> {
|
||||
let stream = self.client.get_stream(self.url).await?;
|
||||
|
||||
let path = path.as_ref();
|
||||
debug!("Downloading and extracting to: '{}'", path.display());
|
||||
|
||||
match fmt.decompose() {
|
||||
PkgFmtDecomposed::Tar(fmt) => {
|
||||
extract_tar_based_stream(stream, path, fmt, cancellation_future).await?
|
||||
}
|
||||
PkgFmtDecomposed::Bin => extract_bin(stream, path, cancellation_future).await?,
|
||||
PkgFmtDecomposed::Zip => extract_zip(stream, path, cancellation_future).await?,
|
||||
}
|
||||
|
||||
debug!("Download OK, extracted to: '{}'", path.display());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Digest> Download<D> {
|
||||
pub fn new_with_checksum(client: Client, url: Url, checksum: Vec<u8>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
url,
|
||||
_digest: PhantomData::default(),
|
||||
_checksum: checksum,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement checking the sum, may involve bringing (parts of) and_extract() back in here
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct NoDigest;
|
||||
|
||||
impl FixedOutput for NoDigest {
|
||||
fn finalize_into(self, _out: &mut Output<Self>) {}
|
||||
}
|
||||
|
||||
impl OutputSizeUser for NoDigest {
|
||||
type OutputSize = generic_array::typenum::U0;
|
||||
}
|
||||
|
||||
impl Update for NoDigest {
|
||||
fn update(&mut self, _data: &[u8]) {}
|
||||
}
|
||||
|
||||
impl HashMarker for NoDigest {}
|
125
crates/binstalk-downloader/src/download/async_extracter.rs
Normal file
125
crates/binstalk-downloader/src/download/async_extracter.rs
Normal file
|
@ -0,0 +1,125 @@
|
|||
use std::{
|
||||
fmt::Debug,
|
||||
fs,
|
||||
io::{Read, Seek},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::stream::Stream;
|
||||
use log::debug;
|
||||
use scopeguard::{guard, ScopeGuard};
|
||||
use tar::Entries;
|
||||
use tempfile::tempfile;
|
||||
use tokio::task::block_in_place;
|
||||
|
||||
use super::{
|
||||
extracter::*, stream_readable::StreamReadable, CancellationFuture, DownloadError, TarBasedFmt,
|
||||
};
|
||||
|
||||
pub async fn extract_bin<S, E>(
|
||||
stream: S,
|
||||
path: &Path,
|
||||
cancellation_future: CancellationFuture,
|
||||
) -> Result<(), DownloadError>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
let mut reader = StreamReadable::new(stream, cancellation_future).await;
|
||||
block_in_place(move || {
|
||||
fs::create_dir_all(path.parent().unwrap())?;
|
||||
|
||||
let mut file = fs::File::create(path)?;
|
||||
|
||||
// remove it unless the operation isn't aborted and no write
|
||||
// fails.
|
||||
let remove_guard = guard(&path, |path| {
|
||||
fs::remove_file(path).ok();
|
||||
});
|
||||
|
||||
reader.copy(&mut file)?;
|
||||
|
||||
// Operation isn't aborted and all writes succeed,
|
||||
// disarm the remove_guard.
|
||||
ScopeGuard::into_inner(remove_guard);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn extract_zip<S, E>(
|
||||
stream: S,
|
||||
path: &Path,
|
||||
cancellation_future: CancellationFuture,
|
||||
) -> Result<(), DownloadError>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
let mut reader = StreamReadable::new(stream, cancellation_future).await;
|
||||
block_in_place(move || {
|
||||
fs::create_dir_all(path.parent().unwrap())?;
|
||||
|
||||
let mut file = tempfile()?;
|
||||
|
||||
reader.copy(&mut file)?;
|
||||
|
||||
// rewind it so that we can pass it to unzip
|
||||
file.rewind()?;
|
||||
|
||||
unzip(file, path)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn extract_tar_based_stream<S, E>(
|
||||
stream: S,
|
||||
path: &Path,
|
||||
fmt: TarBasedFmt,
|
||||
cancellation_future: CancellationFuture,
|
||||
) -> Result<(), DownloadError>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
let reader = StreamReadable::new(stream, cancellation_future).await;
|
||||
block_in_place(move || {
|
||||
fs::create_dir_all(path.parent().unwrap())?;
|
||||
|
||||
debug!("Extracting from {fmt} archive to {path:#?}");
|
||||
|
||||
create_tar_decoder(reader, fmt)?.unpack(path)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Visitor must iterate over all entries.
|
||||
/// Entires can be in arbitary order.
|
||||
pub trait TarEntriesVisitor {
|
||||
type Target;
|
||||
|
||||
fn visit<R: Read>(&mut self, entries: Entries<'_, R>) -> Result<(), DownloadError>;
|
||||
fn finish(self) -> Result<Self::Target, DownloadError>;
|
||||
}
|
||||
|
||||
pub async fn extract_tar_based_stream_and_visit<S, V, E>(
|
||||
stream: S,
|
||||
fmt: TarBasedFmt,
|
||||
mut visitor: V,
|
||||
cancellation_future: CancellationFuture,
|
||||
) -> Result<V::Target, DownloadError>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
|
||||
V: TarEntriesVisitor + Debug + Send + 'static,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
let reader = StreamReadable::new(stream, cancellation_future).await;
|
||||
block_in_place(move || {
|
||||
debug!("Extracting from {fmt} archive to process it in memory");
|
||||
|
||||
let mut tar = create_tar_decoder(reader, fmt)?;
|
||||
visitor.visit(tar.entries()?)?;
|
||||
visitor.finish()
|
||||
})
|
||||
}
|
46
crates/binstalk-downloader/src/download/extracter.rs
Normal file
46
crates/binstalk-downloader/src/download/extracter.rs
Normal file
|
@ -0,0 +1,46 @@
|
|||
use std::{
|
||||
fs::File,
|
||||
io::{self, BufRead, Read},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use bzip2::bufread::BzDecoder;
|
||||
use flate2::bufread::GzDecoder;
|
||||
use log::debug;
|
||||
use tar::Archive;
|
||||
use xz2::bufread::XzDecoder;
|
||||
use zip::read::ZipArchive;
|
||||
use zstd::stream::Decoder as ZstdDecoder;
|
||||
|
||||
use super::{DownloadError, TarBasedFmt};
|
||||
|
||||
pub fn create_tar_decoder(
|
||||
dat: impl BufRead + 'static,
|
||||
fmt: TarBasedFmt,
|
||||
) -> io::Result<Archive<Box<dyn Read>>> {
|
||||
use TarBasedFmt::*;
|
||||
|
||||
let r: Box<dyn Read> = match fmt {
|
||||
Tar => Box::new(dat),
|
||||
Tbz2 => Box::new(BzDecoder::new(dat)),
|
||||
Tgz => Box::new(GzDecoder::new(dat)),
|
||||
Txz => Box::new(XzDecoder::new(dat)),
|
||||
Tzstd => {
|
||||
// The error can only come from raw::Decoder::with_dictionary as of zstd 0.10.2 and
|
||||
// 0.11.2, which is specified as `&[]` by `ZstdDecoder::new`, thus `ZstdDecoder::new`
|
||||
// should not return any error.
|
||||
Box::new(ZstdDecoder::with_buffer(dat)?)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Archive::new(r))
|
||||
}
|
||||
|
||||
pub fn unzip(dat: File, dst: &Path) -> Result<(), DownloadError> {
|
||||
debug!("Decompressing from zip archive to `{dst:?}`");
|
||||
|
||||
let mut zip = ZipArchive::new(dat)?;
|
||||
zip.extract(dst)?;
|
||||
|
||||
Ok(())
|
||||
}
|
146
crates/binstalk-downloader/src/download/stream_readable.rs
Normal file
146
crates/binstalk-downloader/src/download/stream_readable.rs
Normal file
|
@ -0,0 +1,146 @@
|
|||
use std::{
|
||||
cmp::min,
|
||||
io::{self, BufRead, Read, Write},
|
||||
};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures_util::stream::{Stream, StreamExt};
|
||||
use tokio::runtime::Handle;
|
||||
|
||||
use super::{CancellationFuture, DownloadError};
|
||||
|
||||
/// This wraps an AsyncIterator as a `Read`able.
|
||||
/// It must be used in non-async context only,
|
||||
/// meaning you have to use it with
|
||||
/// `tokio::task::{block_in_place, spawn_blocking}` or
|
||||
/// `std::thread::spawn`.
|
||||
pub struct StreamReadable<S> {
|
||||
stream: S,
|
||||
handle: Handle,
|
||||
bytes: Bytes,
|
||||
cancellation_future: CancellationFuture,
|
||||
}
|
||||
|
||||
impl<S> StreamReadable<S> {
|
||||
pub(super) async fn new(stream: S, cancellation_future: CancellationFuture) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
handle: Handle::current(),
|
||||
bytes: Bytes::new(),
|
||||
cancellation_future,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, E> StreamReadable<S>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
/// Copies from `self` to `writer`.
|
||||
///
|
||||
/// Same as `io::copy` but does not allocate any internal buffer
|
||||
/// since `self` is buffered.
|
||||
pub(super) fn copy<W>(&mut self, mut writer: W) -> io::Result<()>
|
||||
where
|
||||
W: Write,
|
||||
{
|
||||
self.copy_inner(&mut writer)
|
||||
}
|
||||
|
||||
fn copy_inner(&mut self, writer: &mut dyn Write) -> io::Result<()> {
|
||||
loop {
|
||||
let buf = self.fill_buf()?;
|
||||
if buf.is_empty() {
|
||||
// Eof
|
||||
break Ok(());
|
||||
}
|
||||
|
||||
writer.write_all(buf)?;
|
||||
|
||||
let n = buf.len();
|
||||
self.consume(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, E> Read for StreamReadable<S>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
if buf.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
if self.fill_buf()?.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let bytes = &mut self.bytes;
|
||||
|
||||
// copy_to_slice requires the bytes to have enough remaining bytes
|
||||
// to fill buf.
|
||||
let n = min(buf.len(), bytes.remaining());
|
||||
|
||||
bytes.copy_to_slice(&mut buf[..n]);
|
||||
|
||||
Ok(n)
|
||||
}
|
||||
}
|
||||
|
||||
/// If `Ok(Some(bytes))` if returned, then `bytes.is_empty() == false`.
|
||||
async fn next_stream<S, E>(stream: &mut S) -> io::Result<Option<Bytes>>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
loop {
|
||||
let option = stream
|
||||
.next()
|
||||
.await
|
||||
.transpose()
|
||||
.map_err(DownloadError::from)?;
|
||||
|
||||
match option {
|
||||
Some(bytes) if bytes.is_empty() => continue,
|
||||
option => break Ok(option),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, E> BufRead for StreamReadable<S>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, E>> + Unpin,
|
||||
DownloadError: From<E>,
|
||||
{
|
||||
fn fill_buf(&mut self) -> io::Result<&[u8]> {
|
||||
let bytes = &mut self.bytes;
|
||||
|
||||
if !bytes.has_remaining() {
|
||||
let option = self.handle.block_on(async {
|
||||
if let Some(cancellation_future) = self.cancellation_future.as_mut() {
|
||||
tokio::select! {
|
||||
res = next_stream(&mut self.stream) => res,
|
||||
res = cancellation_future => {
|
||||
Err(res.err().unwrap_or_else(|| io::Error::from(DownloadError::UserAbort)))
|
||||
},
|
||||
}
|
||||
} else {
|
||||
next_stream(&mut self.stream).await
|
||||
}
|
||||
})?;
|
||||
|
||||
if let Some(new_bytes) = option {
|
||||
// new_bytes are guaranteed to be non-empty.
|
||||
*bytes = new_bytes;
|
||||
}
|
||||
}
|
||||
Ok(&*bytes)
|
||||
}
|
||||
|
||||
fn consume(&mut self, amt: usize) {
|
||||
self.bytes.advance(amt);
|
||||
}
|
||||
}
|
2
crates/binstalk-downloader/src/lib.rs
Normal file
2
crates/binstalk-downloader/src/lib.rs
Normal file
|
@ -0,0 +1,2 @@
|
|||
pub mod download;
|
||||
pub mod remote;
|
201
crates/binstalk-downloader/src/remote.rs
Normal file
201
crates/binstalk-downloader/src/remote.rs
Normal file
|
@ -0,0 +1,201 @@
|
|||
use std::{
|
||||
env,
|
||||
num::NonZeroU64,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::stream::{Stream, StreamExt};
|
||||
use httpdate::parse_http_date;
|
||||
use log::{debug, info};
|
||||
use reqwest::{
|
||||
header::{HeaderMap, RETRY_AFTER},
|
||||
Request, Response, StatusCode,
|
||||
};
|
||||
use thiserror::Error as ThisError;
|
||||
use tokio::{sync::Mutex, time::sleep};
|
||||
use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt};
|
||||
|
||||
pub use reqwest::{tls, Error as ReqwestError, Method};
|
||||
pub use url::Url;
|
||||
|
||||
const MAX_RETRY_DURATION: Duration = Duration::from_secs(120);
|
||||
const MAX_RETRY_COUNT: u8 = 3;
|
||||
|
||||
#[derive(Debug, ThisError)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Http(HttpError),
|
||||
}
|
||||
|
||||
#[derive(Debug, ThisError)]
|
||||
#[error("could not {method} {url}: {err}")]
|
||||
pub struct HttpError {
|
||||
method: reqwest::Method,
|
||||
url: url::Url,
|
||||
#[source]
|
||||
err: reqwest::Error,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Client {
|
||||
client: reqwest::Client,
|
||||
rate_limit: Arc<Mutex<RateLimit<reqwest::Client>>>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// * `per` - must not be 0.
|
||||
/// * `num_request` - maximum number of requests to be processed for
|
||||
/// each `per` duration.
|
||||
pub fn new(
|
||||
min_tls: Option<tls::Version>,
|
||||
per: Duration,
|
||||
num_request: NonZeroU64,
|
||||
) -> Result<Self, Error> {
|
||||
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
let mut builder = reqwest::ClientBuilder::new()
|
||||
.user_agent(USER_AGENT)
|
||||
.https_only(true)
|
||||
.min_tls_version(tls::Version::TLS_1_2)
|
||||
.tcp_nodelay(false);
|
||||
|
||||
if let Some(ver) = min_tls {
|
||||
builder = builder.min_tls_version(ver);
|
||||
}
|
||||
|
||||
let client = builder.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client: client.clone(),
|
||||
rate_limit: Arc::new(Mutex::new(
|
||||
ServiceBuilder::new()
|
||||
.rate_limit(num_request.get(), per)
|
||||
.service(client),
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Return inner reqwest client.
|
||||
pub fn get_inner(&self) -> &reqwest::Client {
|
||||
&self.client
|
||||
}
|
||||
|
||||
async fn send_request_inner(
|
||||
&self,
|
||||
method: &Method,
|
||||
url: &Url,
|
||||
) -> Result<Response, ReqwestError> {
|
||||
let mut count = 0;
|
||||
|
||||
loop {
|
||||
let request = Request::new(method.clone(), url.clone());
|
||||
|
||||
// Reduce critical section:
|
||||
// - Construct the request before locking
|
||||
// - Once the rate_limit is ready, call it and obtain
|
||||
// the future, then release the lock before
|
||||
// polling the future, which performs network I/O that could
|
||||
// take really long.
|
||||
let future = self.rate_limit.lock().await.ready().await?.call(request);
|
||||
|
||||
let response = future.await?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
match (status, parse_header_retry_after(response.headers())) {
|
||||
(
|
||||
// 503 429
|
||||
StatusCode::SERVICE_UNAVAILABLE | StatusCode::TOO_MANY_REQUESTS,
|
||||
Some(duration),
|
||||
) if duration <= MAX_RETRY_DURATION && count < MAX_RETRY_COUNT => {
|
||||
info!("Receiver status code {status}, will wait for {duration:#?} and retry");
|
||||
sleep(duration).await
|
||||
}
|
||||
_ => break Ok(response),
|
||||
}
|
||||
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&self,
|
||||
method: Method,
|
||||
url: Url,
|
||||
error_for_status: bool,
|
||||
) -> Result<Response, Error> {
|
||||
self.send_request_inner(&method, &url)
|
||||
.await
|
||||
.and_then(|response| {
|
||||
if error_for_status {
|
||||
response.error_for_status()
|
||||
} else {
|
||||
Ok(response)
|
||||
}
|
||||
})
|
||||
.map_err(|err| Error::Http(HttpError { method, url, err }))
|
||||
}
|
||||
|
||||
/// Check if remote exists using `method`.
|
||||
pub async fn remote_exists(&self, url: Url, method: Method) -> Result<bool, Error> {
|
||||
Ok(self
|
||||
.send_request(method, url, false)
|
||||
.await?
|
||||
.status()
|
||||
.is_success())
|
||||
}
|
||||
|
||||
/// Attempt to get final redirected url.
|
||||
pub async fn get_redirected_final_url(&self, url: Url) -> Result<Url, Error> {
|
||||
Ok(self
|
||||
.send_request(Method::HEAD, url, true)
|
||||
.await?
|
||||
.url()
|
||||
.clone())
|
||||
}
|
||||
|
||||
/// Create `GET` request to `url` and return a stream of the response data.
|
||||
/// On status code other than 200, it will return an error.
|
||||
pub async fn get_stream(
|
||||
&self,
|
||||
url: Url,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
|
||||
debug!("Downloading from: '{url}'");
|
||||
|
||||
self.send_request(Method::GET, url, true)
|
||||
.await
|
||||
.map(|response| response.bytes_stream().map(|res| res.map_err(Error::from)))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_header_retry_after(headers: &HeaderMap) -> Option<Duration> {
|
||||
let header = headers
|
||||
.get_all(RETRY_AFTER)
|
||||
.into_iter()
|
||||
.last()?
|
||||
.to_str()
|
||||
.ok()?;
|
||||
|
||||
match header.parse::<u64>() {
|
||||
Ok(dur) => Some(Duration::from_secs(dur)),
|
||||
Err(_) => {
|
||||
let system_time = parse_http_date(header).ok()?;
|
||||
|
||||
let retry_after_unix_timestamp =
|
||||
system_time.duration_since(SystemTime::UNIX_EPOCH).ok()?;
|
||||
|
||||
let curr_time_unix_timestamp = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.expect("SystemTime before UNIX EPOCH!");
|
||||
|
||||
// retry_after_unix_timestamp - curr_time_unix_timestamp
|
||||
// If underflows, returns Duration::ZERO.
|
||||
Some(retry_after_unix_timestamp.saturating_sub(curr_time_unix_timestamp))
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue