Impl new API GhApiClient::download_artifact

Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
This commit is contained in:
Jiahao XU 2024-06-03 23:40:55 +10:00
parent 4da2f0e64f
commit 03ae2b78d0
No known key found for this signature in database
GPG key ID: 76D1E687CA3C4928
5 changed files with 133 additions and 44 deletions

1
Cargo.lock generated
View file

@ -354,6 +354,7 @@ version = "0.1.0"
dependencies = [
"binstalk-downloader",
"compact_str",
"futures-core",
"percent-encoding",
"serde",
"serde-tuple-vec-map",

View file

@ -1,4 +1,4 @@
use std::{fmt, io, marker::PhantomData, path::Path};
use std::{fmt, io, path::Path};
use binstalk_types::cargo_toml_binstall::PkgFmtDecomposed;
use bytes::Bytes;
@ -8,7 +8,7 @@ use tracing::{debug, error, instrument};
pub use binstalk_types::cargo_toml_binstall::{PkgFmt, TarBasedFmt};
use crate::remote::{Client, Error as RemoteError, Url};
use crate::remote::{Client, Error as RemoteError, Response, Url};
mod async_extracter;
use async_extracter::*;
@ -90,38 +90,43 @@ impl DataVerifier for () {
}
}
#[derive(Debug)]
enum DownloadContent {
ToIssue { client: Client, url: Url },
Response(Response),
}
impl DownloadContent {
async fn to_response(self) -> Result<Response, DownloadError> {
Ok(match self {
DownloadContent::ToIssue { client, url } => client.get(url).send(true).await?,
DownloadContent::Response(response) => response,
})
}
}
pub struct Download<'a> {
client: Client,
url: Url,
content: DownloadContent,
data_verifier: Option<&'a mut dyn DataVerifier>,
}
impl fmt::Debug for Download<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[allow(dead_code, clippy::type_complexity)]
#[derive(Debug)]
struct Download<'a> {
client: &'a Client,
url: &'a Url,
data_verifier: Option<PhantomData<&'a mut dyn DataVerifier>>,
}
fmt::Debug::fmt(
&Download {
client: &self.client,
url: &self.url,
data_verifier: self.data_verifier.as_ref().map(|_| PhantomData),
},
f,
)
fmt::Debug::fmt(&self.content, f)
}
}
impl Download<'static> {
pub fn new(client: Client, url: Url) -> Self {
Self {
client,
url,
content: DownloadContent::ToIssue { client, url },
data_verifier: None,
}
}
pub fn from_response(response: Response) -> Self {
Self {
content: DownloadContent::Response(response),
data_verifier: None,
}
}
@ -134,8 +139,17 @@ impl<'a> Download<'a> {
data_verifier: &'a mut dyn DataVerifier,
) -> Self {
Self {
client,
url,
content: DownloadContent::ToIssue { client, url },
data_verifier: Some(data_verifier),
}
}
pub fn from_response_with_data_verifier(
response: Response,
data_verifier: &'a mut dyn DataVerifier,
) -> Self {
Self {
content: DownloadContent::Response(response),
data_verifier: Some(data_verifier),
}
}
@ -148,9 +162,10 @@ impl<'a> Download<'a> {
> {
let mut data_verifier = self.data_verifier;
Ok(self
.client
.get_stream(self.url)
.content
.to_response()
.await?
.bytes_stream()
.map(move |res| {
let bytes = res?;
@ -257,7 +272,7 @@ impl Download<'_> {
#[instrument]
pub async fn into_bytes(self) -> Result<Bytes, DownloadError> {
let bytes = self.client.get(self.url).send(true).await?.bytes().await?;
let bytes = self.content.to_response().await?.bytes().await?;
if let Some(verifier) = self.data_verifier {
verifier.update(&bytes);
}

View file

@ -14,6 +14,7 @@ binstalk-downloader = { version = "0.10.3", path = "../binstalk-downloader", def
"json",
] }
compact_str = "0.7.0"
futures-core = "0.3.30"
percent-encoding = "2.2.0"
serde = { version = "1.0.163", features = ["derive"] }
serde-tuple-vec-map = "1.0.1"

View file

@ -9,7 +9,7 @@ use std::{
time::{Duration, Instant},
};
use binstalk_downloader::remote;
use binstalk_downloader::{download::Download, remote};
use compact_str::{format_compact, CompactString};
use tokio::sync::OnceCell;
use url::Url;
@ -19,7 +19,7 @@ mod error;
mod release_artifacts;
mod repo_info;
use common::percent_decode_http_url_path;
use common::{check_http_status_and_header, percent_decode_http_url_path};
pub use error::{GhApiContextError, GhApiError, GhGraphQLErrors};
pub use repo_info::RepoInfo;
@ -201,7 +201,12 @@ impl GhApiClient {
Err(err) => Err(err),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ReleaseArtifactUrl(Url);
impl GhApiClient {
/// Return `Ok(Some(api_artifact_url))` if exists.
///
/// The returned future is guaranteed to be pointer size.
@ -211,7 +216,7 @@ impl GhApiClient {
release,
artifact_name,
}: GhReleaseArtifact,
) -> Result<Option<Url>, GhApiError> {
) -> Result<Option<ReleaseArtifactUrl>, GhApiError> {
let once_cell = self.0.release_artifacts.get(release.clone());
let res = once_cell
.get_or_try_init(|| {
@ -233,7 +238,9 @@ impl GhApiClient {
.await;
match res {
Ok(Some(artifacts)) => Ok(artifacts.get_artifact_url(&artifact_name)),
Ok(Some(artifacts)) => Ok(artifacts
.get_artifact_url(&artifact_name)
.map(ReleaseArtifactUrl)),
Ok(None) => Ok(None),
Err(GhApiError::RateLimit { retry_after }) => {
*self.0.retry_after.lock().unwrap() =
@ -244,6 +251,35 @@ impl GhApiClient {
Err(err) => Err(err),
}
}
pub async fn download_artifact(
&self,
artifact_url: ReleaseArtifactUrl,
) -> Result<Download<'static>, GhApiError> {
self.check_retry_after()?;
let Some(auth_token) = self.get_auth_token() else {
return Err(GhApiError::Unauthorized);
};
let response = self
.0
.client
.get(artifact_url.0)
.header("Accept", "application/octet-stream")
.bearer_auth(&auth_token)
.send(false)
.await?;
match check_http_status_and_header(&response) {
Err(GhApiError::Unauthorized) => {
self.0.is_auth_token_valid.store(false, Relaxed);
}
res => res?,
}
Ok(Download::from_response(response))
}
}
#[cfg(test)]
@ -485,11 +521,11 @@ mod test {
}
#[tokio::test]
async fn test_has_release_artifact() {
async fn test_has_release_artifact_and_download_artifacts() {
const RELEASES: [(GhRelease, &[&str]); 2] = [
(
cargo_audit_v_0_17_6::RELEASE,
cargo_audit_v_0_17_6::ARTIFACTS,
cargo_binstall_v0_20_1::RELEASE,
cargo_binstall_v0_20_1::ARTIFACTS,
),
(
cargo_audit_v_0_17_6::RELEASE,
@ -516,14 +552,48 @@ mod test {
let client = client.clone();
let release = release.clone();
tasks.push(tokio::spawn(async move {
client
.has_release_artifact(GhReleaseArtifact {
release,
artifact_name: artifact_name.to_compact_string(),
})
let artifact = GhReleaseArtifact {
release,
artifact_name: artifact_name.to_compact_string(),
};
let browser_download_task = client.get_auth_token().map(|_| {
tokio::spawn(
Download::new(
client.0.client.clone(),
Url::parse(&format!(
"https://github.com/{}/{}/releases/download/{}/{}",
artifact.release.repo.owner,
artifact.release.repo.repo,
artifact.release.tag,
artifact.artifact_name,
))
.unwrap(),
)
.into_bytes(),
)
});
let artifact_url = client
.has_release_artifact(artifact)
.await
.unwrap()
.unwrap();
if let Some(browser_download_task) = browser_download_task {
let artifact_download_data = client
.download_artifact(artifact_url)
.await
.unwrap()
.into_bytes()
.await
.unwrap();
let browser_download_data =
browser_download_task.await.unwrap().unwrap();
assert_eq!(artifact_download_data, browser_download_data);
}
}));
}

View file

@ -1,6 +1,6 @@
use std::{future::Future, sync::OnceLock, time::Duration};
use binstalk_downloader::remote::{self, header::HeaderMap, StatusCode, Url};
use binstalk_downloader::remote::{self, Response, Url};
use compact_str::CompactString;
use percent_encoding::percent_decode_str;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -18,8 +18,10 @@ pub(super) fn percent_decode_http_url_path(input: &str) -> CompactString {
}
}
fn check_http_status_and_header(status: StatusCode, headers: &HeaderMap) -> Result<(), GhApiError> {
match status {
pub(super) fn check_http_status_and_header(response: &Response) -> Result<(), GhApiError> {
let headers = response.headers();
match response.status() {
remote::StatusCode::FORBIDDEN
if headers
.get("x-ratelimit-remaining")
@ -73,7 +75,7 @@ where
async move {
let response = future.await?;
check_http_status_and_header(response.status(), response.headers())?;
check_http_status_and_header(&response)?;
Ok(response.json().await?)
}
@ -126,7 +128,7 @@ where
async move {
let response = res?.await?;
check_http_status_and_header(response.status(), response.headers())?;
check_http_status_and_header(&response)?;
let mut response: GraphQLResponse<T> = response.json().await?;