From 03ae2b78d0e34a62b73249e6e8a54e1d33eaec66 Mon Sep 17 00:00:00 2001 From: Jiahao XU Date: Mon, 3 Jun 2024 23:40:55 +1000 Subject: [PATCH] Impl new API `GhApiClient::download_artifact` Signed-off-by: Jiahao XU --- Cargo.lock | 1 + crates/binstalk-downloader/src/download.rs | 69 ++++++++------ crates/binstalk-git-repo-api/Cargo.toml | 1 + .../src/gh_api_client.rs | 94 ++++++++++++++++--- .../src/gh_api_client/common.rs | 12 ++- 5 files changed, 133 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 828d35b6..24e727ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -354,6 +354,7 @@ version = "0.1.0" dependencies = [ "binstalk-downloader", "compact_str", + "futures-core", "percent-encoding", "serde", "serde-tuple-vec-map", diff --git a/crates/binstalk-downloader/src/download.rs b/crates/binstalk-downloader/src/download.rs index c1b7f59e..57250b94 100644 --- a/crates/binstalk-downloader/src/download.rs +++ b/crates/binstalk-downloader/src/download.rs @@ -1,4 +1,4 @@ -use std::{fmt, io, marker::PhantomData, path::Path}; +use std::{fmt, io, path::Path}; use binstalk_types::cargo_toml_binstall::PkgFmtDecomposed; use bytes::Bytes; @@ -8,7 +8,7 @@ use tracing::{debug, error, instrument}; pub use binstalk_types::cargo_toml_binstall::{PkgFmt, TarBasedFmt}; -use crate::remote::{Client, Error as RemoteError, Url}; +use crate::remote::{Client, Error as RemoteError, Response, Url}; mod async_extracter; use async_extracter::*; @@ -90,38 +90,43 @@ impl DataVerifier for () { } } +#[derive(Debug)] +enum DownloadContent { + ToIssue { client: Client, url: Url }, + Response(Response), +} + +impl DownloadContent { + async fn to_response(self) -> Result { + Ok(match self { + DownloadContent::ToIssue { client, url } => client.get(url).send(true).await?, + DownloadContent::Response(response) => response, + }) + } +} + pub struct Download<'a> { - client: Client, - url: Url, + content: DownloadContent, data_verifier: Option<&'a mut dyn DataVerifier>, } impl fmt::Debug for Download<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - #[allow(dead_code, clippy::type_complexity)] - #[derive(Debug)] - struct Download<'a> { - client: &'a Client, - url: &'a Url, - data_verifier: Option>, - } - - fmt::Debug::fmt( - &Download { - client: &self.client, - url: &self.url, - data_verifier: self.data_verifier.as_ref().map(|_| PhantomData), - }, - f, - ) + fmt::Debug::fmt(&self.content, f) } } impl Download<'static> { pub fn new(client: Client, url: Url) -> Self { Self { - client, - url, + content: DownloadContent::ToIssue { client, url }, + data_verifier: None, + } + } + + pub fn from_response(response: Response) -> Self { + Self { + content: DownloadContent::Response(response), data_verifier: None, } } @@ -134,8 +139,17 @@ impl<'a> Download<'a> { data_verifier: &'a mut dyn DataVerifier, ) -> Self { Self { - client, - url, + content: DownloadContent::ToIssue { client, url }, + data_verifier: Some(data_verifier), + } + } + + pub fn from_response_with_data_verifier( + response: Response, + data_verifier: &'a mut dyn DataVerifier, + ) -> Self { + Self { + content: DownloadContent::Response(response), data_verifier: Some(data_verifier), } } @@ -148,9 +162,10 @@ impl<'a> Download<'a> { > { let mut data_verifier = self.data_verifier; Ok(self - .client - .get_stream(self.url) + .content + .to_response() .await? + .bytes_stream() .map(move |res| { let bytes = res?; @@ -257,7 +272,7 @@ impl Download<'_> { #[instrument] pub async fn into_bytes(self) -> Result { - let bytes = self.client.get(self.url).send(true).await?.bytes().await?; + let bytes = self.content.to_response().await?.bytes().await?; if let Some(verifier) = self.data_verifier { verifier.update(&bytes); } diff --git a/crates/binstalk-git-repo-api/Cargo.toml b/crates/binstalk-git-repo-api/Cargo.toml index 71b0844c..20bb85a2 100644 --- a/crates/binstalk-git-repo-api/Cargo.toml +++ b/crates/binstalk-git-repo-api/Cargo.toml @@ -14,6 +14,7 @@ binstalk-downloader = { version = "0.10.3", path = "../binstalk-downloader", def "json", ] } compact_str = "0.7.0" +futures-core = "0.3.30" percent-encoding = "2.2.0" serde = { version = "1.0.163", features = ["derive"] } serde-tuple-vec-map = "1.0.1" diff --git a/crates/binstalk-git-repo-api/src/gh_api_client.rs b/crates/binstalk-git-repo-api/src/gh_api_client.rs index f7d6a6ef..86a2e7b9 100644 --- a/crates/binstalk-git-repo-api/src/gh_api_client.rs +++ b/crates/binstalk-git-repo-api/src/gh_api_client.rs @@ -9,7 +9,7 @@ use std::{ time::{Duration, Instant}, }; -use binstalk_downloader::remote; +use binstalk_downloader::{download::Download, remote}; use compact_str::{format_compact, CompactString}; use tokio::sync::OnceCell; use url::Url; @@ -19,7 +19,7 @@ mod error; mod release_artifacts; mod repo_info; -use common::percent_decode_http_url_path; +use common::{check_http_status_and_header, percent_decode_http_url_path}; pub use error::{GhApiContextError, GhApiError, GhGraphQLErrors}; pub use repo_info::RepoInfo; @@ -201,7 +201,12 @@ impl GhApiClient { Err(err) => Err(err), } } +} +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct ReleaseArtifactUrl(Url); + +impl GhApiClient { /// Return `Ok(Some(api_artifact_url))` if exists. /// /// The returned future is guaranteed to be pointer size. @@ -211,7 +216,7 @@ impl GhApiClient { release, artifact_name, }: GhReleaseArtifact, - ) -> Result, GhApiError> { + ) -> Result, GhApiError> { let once_cell = self.0.release_artifacts.get(release.clone()); let res = once_cell .get_or_try_init(|| { @@ -233,7 +238,9 @@ impl GhApiClient { .await; match res { - Ok(Some(artifacts)) => Ok(artifacts.get_artifact_url(&artifact_name)), + Ok(Some(artifacts)) => Ok(artifacts + .get_artifact_url(&artifact_name) + .map(ReleaseArtifactUrl)), Ok(None) => Ok(None), Err(GhApiError::RateLimit { retry_after }) => { *self.0.retry_after.lock().unwrap() = @@ -244,6 +251,35 @@ impl GhApiClient { Err(err) => Err(err), } } + + pub async fn download_artifact( + &self, + artifact_url: ReleaseArtifactUrl, + ) -> Result, GhApiError> { + self.check_retry_after()?; + + let Some(auth_token) = self.get_auth_token() else { + return Err(GhApiError::Unauthorized); + }; + + let response = self + .0 + .client + .get(artifact_url.0) + .header("Accept", "application/octet-stream") + .bearer_auth(&auth_token) + .send(false) + .await?; + + match check_http_status_and_header(&response) { + Err(GhApiError::Unauthorized) => { + self.0.is_auth_token_valid.store(false, Relaxed); + } + res => res?, + } + + Ok(Download::from_response(response)) + } } #[cfg(test)] @@ -485,11 +521,11 @@ mod test { } #[tokio::test] - async fn test_has_release_artifact() { + async fn test_has_release_artifact_and_download_artifacts() { const RELEASES: [(GhRelease, &[&str]); 2] = [ ( - cargo_audit_v_0_17_6::RELEASE, - cargo_audit_v_0_17_6::ARTIFACTS, + cargo_binstall_v0_20_1::RELEASE, + cargo_binstall_v0_20_1::ARTIFACTS, ), ( cargo_audit_v_0_17_6::RELEASE, @@ -516,14 +552,48 @@ mod test { let client = client.clone(); let release = release.clone(); tasks.push(tokio::spawn(async move { - client - .has_release_artifact(GhReleaseArtifact { - release, - artifact_name: artifact_name.to_compact_string(), - }) + let artifact = GhReleaseArtifact { + release, + artifact_name: artifact_name.to_compact_string(), + }; + + let browser_download_task = client.get_auth_token().map(|_| { + tokio::spawn( + Download::new( + client.0.client.clone(), + Url::parse(&format!( + "https://github.com/{}/{}/releases/download/{}/{}", + artifact.release.repo.owner, + artifact.release.repo.repo, + artifact.release.tag, + artifact.artifact_name, + )) + .unwrap(), + ) + .into_bytes(), + ) + }); + + let artifact_url = client + .has_release_artifact(artifact) .await .unwrap() .unwrap(); + + if let Some(browser_download_task) = browser_download_task { + let artifact_download_data = client + .download_artifact(artifact_url) + .await + .unwrap() + .into_bytes() + .await + .unwrap(); + + let browser_download_data = + browser_download_task.await.unwrap().unwrap(); + + assert_eq!(artifact_download_data, browser_download_data); + } })); } diff --git a/crates/binstalk-git-repo-api/src/gh_api_client/common.rs b/crates/binstalk-git-repo-api/src/gh_api_client/common.rs index b405f2eb..e8c293f3 100644 --- a/crates/binstalk-git-repo-api/src/gh_api_client/common.rs +++ b/crates/binstalk-git-repo-api/src/gh_api_client/common.rs @@ -1,6 +1,6 @@ use std::{future::Future, sync::OnceLock, time::Duration}; -use binstalk_downloader::remote::{self, header::HeaderMap, StatusCode, Url}; +use binstalk_downloader::remote::{self, Response, Url}; use compact_str::CompactString; use percent_encoding::percent_decode_str; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -18,8 +18,10 @@ pub(super) fn percent_decode_http_url_path(input: &str) -> CompactString { } } -fn check_http_status_and_header(status: StatusCode, headers: &HeaderMap) -> Result<(), GhApiError> { - match status { +pub(super) fn check_http_status_and_header(response: &Response) -> Result<(), GhApiError> { + let headers = response.headers(); + + match response.status() { remote::StatusCode::FORBIDDEN if headers .get("x-ratelimit-remaining") @@ -73,7 +75,7 @@ where async move { let response = future.await?; - check_http_status_and_header(response.status(), response.headers())?; + check_http_status_and_header(&response)?; Ok(response.json().await?) } @@ -126,7 +128,7 @@ where async move { let response = res?.await?; - check_http_status_and_header(response.status(), response.headers())?; + check_http_status_and_header(&response)?; let mut response: GraphQLResponse = response.json().await?;