Impl new API GhApiClient::download_artifact

Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
This commit is contained in:
Jiahao XU 2024-06-03 23:40:55 +10:00
parent 4da2f0e64f
commit 03ae2b78d0
No known key found for this signature in database
GPG key ID: 76D1E687CA3C4928
5 changed files with 133 additions and 44 deletions

View file

@ -9,7 +9,7 @@ use std::{
time::{Duration, Instant},
};
use binstalk_downloader::remote;
use binstalk_downloader::{download::Download, remote};
use compact_str::{format_compact, CompactString};
use tokio::sync::OnceCell;
use url::Url;
@ -19,7 +19,7 @@ mod error;
mod release_artifacts;
mod repo_info;
use common::percent_decode_http_url_path;
use common::{check_http_status_and_header, percent_decode_http_url_path};
pub use error::{GhApiContextError, GhApiError, GhGraphQLErrors};
pub use repo_info::RepoInfo;
@ -201,7 +201,12 @@ impl GhApiClient {
Err(err) => Err(err),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ReleaseArtifactUrl(Url);
impl GhApiClient {
/// Return `Ok(Some(api_artifact_url))` if exists.
///
/// The returned future is guaranteed to be pointer size.
@ -211,7 +216,7 @@ impl GhApiClient {
release,
artifact_name,
}: GhReleaseArtifact,
) -> Result<Option<Url>, GhApiError> {
) -> Result<Option<ReleaseArtifactUrl>, GhApiError> {
let once_cell = self.0.release_artifacts.get(release.clone());
let res = once_cell
.get_or_try_init(|| {
@ -233,7 +238,9 @@ impl GhApiClient {
.await;
match res {
Ok(Some(artifacts)) => Ok(artifacts.get_artifact_url(&artifact_name)),
Ok(Some(artifacts)) => Ok(artifacts
.get_artifact_url(&artifact_name)
.map(ReleaseArtifactUrl)),
Ok(None) => Ok(None),
Err(GhApiError::RateLimit { retry_after }) => {
*self.0.retry_after.lock().unwrap() =
@ -244,6 +251,35 @@ impl GhApiClient {
Err(err) => Err(err),
}
}
pub async fn download_artifact(
&self,
artifact_url: ReleaseArtifactUrl,
) -> Result<Download<'static>, GhApiError> {
self.check_retry_after()?;
let Some(auth_token) = self.get_auth_token() else {
return Err(GhApiError::Unauthorized);
};
let response = self
.0
.client
.get(artifact_url.0)
.header("Accept", "application/octet-stream")
.bearer_auth(&auth_token)
.send(false)
.await?;
match check_http_status_and_header(&response) {
Err(GhApiError::Unauthorized) => {
self.0.is_auth_token_valid.store(false, Relaxed);
}
res => res?,
}
Ok(Download::from_response(response))
}
}
#[cfg(test)]
@ -485,11 +521,11 @@ mod test {
}
#[tokio::test]
async fn test_has_release_artifact() {
async fn test_has_release_artifact_and_download_artifacts() {
const RELEASES: [(GhRelease, &[&str]); 2] = [
(
cargo_audit_v_0_17_6::RELEASE,
cargo_audit_v_0_17_6::ARTIFACTS,
cargo_binstall_v0_20_1::RELEASE,
cargo_binstall_v0_20_1::ARTIFACTS,
),
(
cargo_audit_v_0_17_6::RELEASE,
@ -516,14 +552,48 @@ mod test {
let client = client.clone();
let release = release.clone();
tasks.push(tokio::spawn(async move {
client
.has_release_artifact(GhReleaseArtifact {
release,
artifact_name: artifact_name.to_compact_string(),
})
let artifact = GhReleaseArtifact {
release,
artifact_name: artifact_name.to_compact_string(),
};
let browser_download_task = client.get_auth_token().map(|_| {
tokio::spawn(
Download::new(
client.0.client.clone(),
Url::parse(&format!(
"https://github.com/{}/{}/releases/download/{}/{}",
artifact.release.repo.owner,
artifact.release.repo.repo,
artifact.release.tag,
artifact.artifact_name,
))
.unwrap(),
)
.into_bytes(),
)
});
let artifact_url = client
.has_release_artifact(artifact)
.await
.unwrap()
.unwrap();
if let Some(browser_download_task) = browser_download_task {
let artifact_download_data = client
.download_artifact(artifact_url)
.await
.unwrap()
.into_bytes()
.await
.unwrap();
let browser_download_data =
browser_download_task.await.unwrap().unwrap();
assert_eq!(artifact_download_data, browser_download_data);
}
}));
}

View file

@ -1,6 +1,6 @@
use std::{future::Future, sync::OnceLock, time::Duration};
use binstalk_downloader::remote::{self, header::HeaderMap, StatusCode, Url};
use binstalk_downloader::remote::{self, Response, Url};
use compact_str::CompactString;
use percent_encoding::percent_decode_str;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -18,8 +18,10 @@ pub(super) fn percent_decode_http_url_path(input: &str) -> CompactString {
}
}
fn check_http_status_and_header(status: StatusCode, headers: &HeaderMap) -> Result<(), GhApiError> {
match status {
pub(super) fn check_http_status_and_header(response: &Response) -> Result<(), GhApiError> {
let headers = response.headers();
match response.status() {
remote::StatusCode::FORBIDDEN
if headers
.get("x-ratelimit-remaining")
@ -73,7 +75,7 @@ where
async move {
let response = future.await?;
check_http_status_and_header(response.status(), response.headers())?;
check_http_status_and_header(&response)?;
Ok(response.json().await?)
}
@ -126,7 +128,7 @@ where
async move {
let response = res?.await?;
check_http_status_and_header(response.status(), response.headers())?;
check_http_status_and_header(&response)?;
let mut response: GraphQLResponse<T> = response.json().await?;