Disable tcp_nodelay for reqwest::Client and add rate limiting for https requests (#458)

This commit is contained in:
Jiahao XU 2022-10-07 15:51:34 +11:00 committed by GitHub
parent 8398ec2d4b
commit 76bc030f90
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 243 additions and 99 deletions

View file

@ -1,69 +1,109 @@
use std::env;
use std::{env, num::NonZeroU64, sync::Arc, time::Duration};
use bytes::Bytes;
use futures_util::stream::Stream;
use log::debug;
use reqwest::{tls, Client, ClientBuilder, Method, Response};
use url::Url;
use reqwest::{Request, Response};
use tokio::sync::Mutex;
use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt};
use crate::errors::BinstallError;
pub fn create_reqwest_client(min_tls: Option<tls::Version>) -> Result<Client, BinstallError> {
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
pub use reqwest::{tls, Method};
pub use url::Url;
let mut builder = ClientBuilder::new()
.user_agent(USER_AGENT)
.https_only(true)
.min_tls_version(tls::Version::TLS_1_2);
#[derive(Clone, Debug)]
pub struct Client {
client: reqwest::Client,
rate_limit: Arc<Mutex<RateLimit<reqwest::Client>>>,
}
if let Some(ver) = min_tls {
builder = builder.min_tls_version(ver);
impl Client {
/// * `per` - must not be 0.
pub fn new(
min_tls: Option<tls::Version>,
per: Duration,
num_request: NonZeroU64,
) -> Result<Self, BinstallError> {
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
let mut builder = reqwest::ClientBuilder::new()
.user_agent(USER_AGENT)
.https_only(true)
.min_tls_version(tls::Version::TLS_1_2)
.tcp_nodelay(false);
if let Some(ver) = min_tls {
builder = builder.min_tls_version(ver);
}
let client = builder.build()?;
Ok(Self {
client: client.clone(),
rate_limit: Arc::new(Mutex::new(
ServiceBuilder::new()
.rate_limit(num_request.get(), per)
.service(client),
)),
})
}
Ok(builder.build()?)
}
pub async fn remote_exists(
client: Client,
url: Url,
method: Method,
) -> Result<bool, BinstallError> {
let req = client
.request(method.clone(), url.clone())
.send()
.await
.map_err(|err| BinstallError::Http { method, url, err })?;
Ok(req.status().is_success())
}
pub async fn get_redirected_final_url(client: &Client, url: Url) -> Result<Url, BinstallError> {
let method = Method::HEAD;
let req = client
.request(method.clone(), url.clone())
.send()
.await
.and_then(Response::error_for_status)
.map_err(|err| BinstallError::Http { method, url, err })?;
Ok(req.url().clone())
}
pub(crate) async fn create_request(
client: Client,
url: Url,
) -> Result<impl Stream<Item = reqwest::Result<Bytes>>, BinstallError> {
debug!("Downloading from: '{url}'");
client
.get(url.clone())
.send()
.await
.and_then(|r| r.error_for_status())
.map_err(|err| BinstallError::Http {
method: Method::GET,
url,
err,
})
.map(Response::bytes_stream)
pub fn get_inner(&self) -> &reqwest::Client {
&self.client
}
async fn send_request(
&self,
method: Method,
url: Url,
error_for_status: bool,
) -> Result<Response, BinstallError> {
let request = Request::new(method.clone(), url.clone());
// Reduce critical section:
// - Construct the request before locking
// - Once the rate_limit is ready, call it and obtain
// the future, then release the lock before
// polling the future.
let future = self.rate_limit.lock().await.ready().await?.call(request);
future
.await
.and_then(|response| {
if error_for_status {
response.error_for_status()
} else {
Ok(response)
}
})
.map_err(|err| BinstallError::Http { method, url, err })
}
pub async fn remote_exists(&self, url: Url, method: Method) -> Result<bool, BinstallError> {
Ok(self
.send_request(method, url, false)
.await?
.status()
.is_success())
}
pub async fn get_redirected_final_url(&self, url: Url) -> Result<Url, BinstallError> {
Ok(self
.send_request(Method::HEAD, url, true)
.await?
.url()
.clone())
}
pub(crate) async fn create_request(
&self,
url: Url,
) -> Result<impl Stream<Item = reqwest::Result<Bytes>>, BinstallError> {
debug!("Downloading from: '{url}'");
self.send_request(Method::GET, url, true)
.await
.map(Response::bytes_stream)
}
}