diff --git a/crates/binstalk-git-repo-api/src/gh_api_client.rs b/crates/binstalk-git-repo-api/src/gh_api_client.rs index 01313311..213ffb73 100644 --- a/crates/binstalk-git-repo-api/src/gh_api_client.rs +++ b/crates/binstalk-git-repo-api/src/gh_api_client.rs @@ -9,44 +9,29 @@ use std::{ }; use binstalk_downloader::remote; -use compact_str::CompactString; -use percent_encoding::{ - percent_decode_str, utf8_percent_encode, AsciiSet, PercentEncode, CONTROLS, -}; +use compact_str::{format_compact, CompactString}; use tokio::sync::OnceCell; -mod release_artifacts; - +mod common; mod error; +mod release_artifacts; +mod repo_info; + +use common::percent_decode_http_url_path; pub use error::{GhApiContextError, GhApiError, GhGraphQLErrors}; +pub use repo_info::RepoInfo; /// default retry duration if x-ratelimit-reset is not found in response header const DEFAULT_RETRY_DURATION: Duration = Duration::from_secs(10 * 60); -fn percent_encode_http_url_path(path: &str) -> PercentEncode<'_> { - /// https://url.spec.whatwg.org/#fragment-percent-encode-set - const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); - - /// https://url.spec.whatwg.org/#path-percent-encode-set - const PATH: &AsciiSet = &FRAGMENT.add(b'#').add(b'?').add(b'{').add(b'}'); - - const PATH_SEGMENT: &AsciiSet = &PATH.add(b'/').add(b'%'); - - // The backslash (\) character is treated as a path separator in special URLs - // so it needs to be additionally escaped in that case. - // - // http is considered to have special path. - const SPECIAL_PATH_SEGMENT: &AsciiSet = &PATH_SEGMENT.add(b'\\'); - - utf8_percent_encode(path, SPECIAL_PATH_SEGMENT) +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub struct GhRepo { + pub owner: CompactString, + pub repo: CompactString, } - -fn percent_decode_http_url_path(input: &str) -> CompactString { - if input.contains('%') { - percent_decode_str(input).decode_utf8_lossy().into() - } else { - // No '%', no need to decode. - CompactString::new(input) +impl GhRepo { + pub fn repo_url(&self) -> CompactString { + format_compact!("https://github.com/{}/{}", self.owner, self.repo) } } @@ -157,13 +142,13 @@ impl GhApiClient { release: &GhRelease, auth_token: Option<&str>, ) -> Result, FetchReleaseArtifactError> { - use release_artifacts::FetchReleaseRet::*; + use common::GhApiRet::*; use FetchReleaseArtifactError as Error; match release_artifacts::fetch_release_artifacts(&self.0.client, release, auth_token).await { - Ok(ReleaseNotFound) => Ok(None), - Ok(Artifacts(artifacts)) => Ok(Some(artifacts)), + Ok(NotFound) => Ok(None), + Ok(Success(artifacts)) => Ok(Some(artifacts)), Ok(ReachedRateLimit { retry_after }) => { let retry_after = retry_after.unwrap_or(DEFAULT_RETRY_DURATION); diff --git a/crates/binstalk-git-repo-api/src/gh_api_client/common.rs b/crates/binstalk-git-repo-api/src/gh_api_client/common.rs new file mode 100644 index 00000000..60041488 --- /dev/null +++ b/crates/binstalk-git-repo-api/src/gh_api_client/common.rs @@ -0,0 +1,134 @@ +use std::{sync::OnceLock, time::Duration}; + +use binstalk_downloader::remote::{self, header::HeaderMap, StatusCode, Url}; +use compact_str::CompactString; +use percent_encoding::{ + percent_decode_str, utf8_percent_encode, AsciiSet, PercentEncode, CONTROLS, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::to_string as to_json_string; +use tracing::debug; + +use super::{GhApiError, GhGraphQLErrors}; + +pub(super) fn percent_encode_http_url_path(path: &str) -> PercentEncode<'_> { + /// https://url.spec.whatwg.org/#fragment-percent-encode-set + const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); + + /// https://url.spec.whatwg.org/#path-percent-encode-set + const PATH: &AsciiSet = &FRAGMENT.add(b'#').add(b'?').add(b'{').add(b'}'); + + const PATH_SEGMENT: &AsciiSet = &PATH.add(b'/').add(b'%'); + + // The backslash (\) character is treated as a path separator in special URLs + // so it needs to be additionally escaped in that case. + // + // http is considered to have special path. + const SPECIAL_PATH_SEGMENT: &AsciiSet = &PATH_SEGMENT.add(b'\\'); + + utf8_percent_encode(path, SPECIAL_PATH_SEGMENT) +} + +pub(super) fn percent_decode_http_url_path(input: &str) -> CompactString { + if input.contains('%') { + percent_decode_str(input).decode_utf8_lossy().into() + } else { + // No '%', no need to decode. + CompactString::new(input) + } +} + +pub(super) enum GhApiRet { + ReachedRateLimit { retry_after: Option }, + NotFound, + Success(T), + Unauthorized, +} + +pub(super) fn check_for_status(status: StatusCode, headers: &HeaderMap) -> Option> { + match status { + remote::StatusCode::FORBIDDEN + if headers + .get("x-ratelimit-remaining") + .map(|val| val == "0") + .unwrap_or(false) => + { + Some(GhApiRet::ReachedRateLimit { + retry_after: headers.get("x-ratelimit-reset").and_then(|value| { + let secs = value.to_str().ok()?.parse().ok()?; + Some(Duration::from_secs(secs)) + }), + }) + } + + remote::StatusCode::UNAUTHORIZED => Some(GhApiRet::Unauthorized), + remote::StatusCode::NOT_FOUND => Some(GhApiRet::NotFound), + + _ => None, + } +} + +#[derive(Deserialize)] +enum GraphQLResponse { + #[serde(rename = "data")] + Data(T), + + #[serde(rename = "errors")] + Errors(GhGraphQLErrors), +} + +#[derive(Serialize)] +struct GraphQLQuery { + query: String, +} + +fn get_graphql_endpoint() -> &'static Url { + static GRAPHQL_ENDPOINT: OnceLock = OnceLock::new(); + + GRAPHQL_ENDPOINT.get_or_init(|| { + Url::parse("https://api.github.com/graphql").expect("Literal provided must be a valid url") + }) +} + +pub(super) enum GraphQLResult { + Data(T), + Else(GhApiRet), +} + +pub(super) async fn issue_graphql_query( + client: &remote::Client, + query: String, + auth_token: &str, +) -> Result, GhApiError> +where + T: DeserializeOwned, +{ + let graphql_endpoint = get_graphql_endpoint(); + + let graphql_query = to_json_string(&GraphQLQuery { query }).map_err(remote::Error::from)?; + + debug!("Sending graphql query to {graphql_endpoint}: '{graphql_query}'"); + + let request_builder = client + .post(graphql_endpoint.clone(), graphql_query) + .header("Accept", "application/vnd.github+json") + .bearer_auth(&auth_token); + + let response = request_builder.send(false).await?; + + if let Some(ret) = check_for_status(response.status(), response.headers()) { + return Ok(GraphQLResult::Else(ret)); + } + + let response: GraphQLResponse = response.json().await?; + + match response { + GraphQLResponse::Data(data) => Ok(GraphQLResult::Data(data)), + GraphQLResponse::Errors(errors) if errors.is_rate_limited() => { + Ok(GraphQLResult::Else(GhApiRet::ReachedRateLimit { + retry_after: None, + })) + } + GraphQLResponse::Errors(errors) => Err(errors.into()), + } +} diff --git a/crates/binstalk-git-repo-api/src/gh_api_client/release_artifacts.rs b/crates/binstalk-git-repo-api/src/gh_api_client/release_artifacts.rs index ec390aae..9c64775e 100644 --- a/crates/binstalk-git-repo-api/src/gh_api_client/release_artifacts.rs +++ b/crates/binstalk-git-repo-api/src/gh_api_client/release_artifacts.rs @@ -3,17 +3,16 @@ use std::{ collections::HashSet, fmt, hash::{Hash, Hasher}, - sync::OnceLock, - time::Duration, }; -use binstalk_downloader::remote::{header::HeaderMap, StatusCode, Url}; +use binstalk_downloader::remote::{self, header::HeaderMap, StatusCode, Url}; use compact_str::CompactString; -use serde::{Deserialize, Serialize}; -use serde_json::to_string as to_json_string; -use tracing::debug; +use serde::Deserialize; -use super::{percent_encode_http_url_path, remote, GhApiError, GhGraphQLErrors, GhRelease}; +use super::{ + common::{self, issue_graphql_query, percent_encode_http_url_path, GraphQLResult}, + GhApiError, GhRelease, +}; // Only include fields we do care about @@ -66,34 +65,10 @@ impl Artifacts { } } -pub(super) enum FetchReleaseRet { - ReachedRateLimit { retry_after: Option }, - ReleaseNotFound, - Artifacts(Artifacts), - Unauthorized, -} +pub(super) type FetchReleaseRet = common::GhApiRet; fn check_for_status(status: StatusCode, headers: &HeaderMap) -> Option { - match status { - remote::StatusCode::FORBIDDEN - if headers - .get("x-ratelimit-remaining") - .map(|val| val == "0") - .unwrap_or(false) => - { - Some(FetchReleaseRet::ReachedRateLimit { - retry_after: headers.get("x-ratelimit-reset").and_then(|value| { - let secs = value.to_str().ok()?.parse().ok()?; - Some(Duration::from_secs(secs)) - }), - }) - } - - remote::StatusCode::UNAUTHORIZED => Some(FetchReleaseRet::Unauthorized), - remote::StatusCode::NOT_FOUND => Some(FetchReleaseRet::ReleaseNotFound), - - _ => None, - } + common::check_for_status(status, headers) } async fn fetch_release_artifacts_restful_api( @@ -120,19 +95,10 @@ async fn fetch_release_artifacts_restful_api( if let Some(ret) = check_for_status(response.status(), response.headers()) { Ok(ret) } else { - Ok(FetchReleaseRet::Artifacts(response.json().await?)) + Ok(FetchReleaseRet::Success(response.json().await?)) } } -#[derive(Deserialize)] -enum GraphQLResponse { - #[serde(rename = "data")] - Data(GraphQLData), - - #[serde(rename = "errors")] - Errors(GhGraphQLErrors), -} - #[derive(Deserialize)] struct GraphQLData { repository: Option, @@ -179,22 +145,11 @@ impl fmt::Display for FilterCondition { } } -#[derive(Serialize)] -struct GraphQLQuery { - query: String, -} - async fn fetch_release_artifacts_graphql_api( client: &remote::Client, GhRelease { owner, repo, tag }: &GhRelease, auth_token: &str, ) -> Result { - static GRAPHQL_ENDPOINT: OnceLock = OnceLock::new(); - - let graphql_endpoint = GRAPHQL_ENDPOINT.get_or_init(|| { - Url::parse("https://api.github.com/graphql").expect("Literal provided must be a valid url") - }); - let mut artifacts = Artifacts::default(); let mut cond = FilterCondition::Init; @@ -216,29 +171,9 @@ query {{ }}"# ); - let graphql_query = to_json_string(&GraphQLQuery { query }).map_err(remote::Error::from)?; - - debug!("Sending graphql query to https://api.github.com/graphql: '{graphql_query}'"); - - let request_builder = client - .post(graphql_endpoint.clone(), graphql_query) - .header("Accept", "application/vnd.github+json") - .bearer_auth(&auth_token); - - let response = request_builder.send(false).await?; - - if let Some(ret) = check_for_status(response.status(), response.headers()) { - return Ok(ret); - } - - let response: GraphQLResponse = response.json().await?; - - let data = match response { - GraphQLResponse::Data(data) => data, - GraphQLResponse::Errors(errors) if errors.is_rate_limited() => { - return Ok(FetchReleaseRet::ReachedRateLimit { retry_after: None }) - } - GraphQLResponse::Errors(errors) => return Err(errors.into()), + let data: GraphQLData = match issue_graphql_query(client, query, auth_token).await? { + GraphQLResult::Data(data) => data, + GraphQLResult::Else(ret) => return Ok(ret), }; let assets = data @@ -256,10 +191,10 @@ query {{ } => { cond = FilterCondition::After(end_cursor); } - _ => break Ok(FetchReleaseRet::Artifacts(artifacts)), + _ => break Ok(FetchReleaseRet::Success(artifacts)), } } else { - break Ok(FetchReleaseRet::ReleaseNotFound); + break Ok(FetchReleaseRet::NotFound); } } } diff --git a/crates/binstalk-git-repo-api/src/gh_api_client/repo_info.rs b/crates/binstalk-git-repo-api/src/gh_api_client/repo_info.rs new file mode 100644 index 00000000..ec144398 --- /dev/null +++ b/crates/binstalk-git-repo-api/src/gh_api_client/repo_info.rs @@ -0,0 +1,117 @@ +use binstalk_downloader::remote::{header::HeaderMap, StatusCode, Url}; +use compact_str::CompactString; +use serde::Deserialize; + +use super::{ + common::{self, issue_graphql_query, percent_encode_http_url_path, GraphQLResult}, + remote, GhApiError, GhRepo, +}; + +#[derive(Debug, Deserialize)] +struct Owner { + login: CompactString, +} + +#[derive(Debug, Deserialize)] +pub struct RepoInfo { + owner: Owner, + name: CompactString, + private: bool, +} + +impl RepoInfo { + pub fn repo(&self) -> GhRepo { + GhRepo { + owner: self.owner.login.clone(), + repo: self.name.clone(), + } + } + + pub fn is_private(&self) -> bool { + self.private + } +} + +pub(super) type FetchRepoInfoRet = common::GhApiRet; + +fn check_for_status(status: StatusCode, headers: &HeaderMap) -> Option { + common::check_for_status(status, headers) +} + +async fn fetch_repo_info_restful_api( + client: &remote::Client, + GhRepo { owner, repo }: &GhRepo, + auth_token: Option<&str>, +) -> Result { + let mut request_builder = client + .get(Url::parse(&format!( + "https://api.github.com/repos/{owner}/{repo}", + owner = percent_encode_http_url_path(owner), + repo = percent_encode_http_url_path(repo), + ))?) + .header("Accept", "application/vnd.github+json") + .header("X-GitHub-Api-Version", "2022-11-28"); + + if let Some(auth_token) = auth_token { + request_builder = request_builder.bearer_auth(&auth_token); + } + + let response = request_builder.send(false).await?; + + if let Some(ret) = check_for_status(response.status(), response.headers()) { + Ok(ret) + } else { + Ok(FetchRepoInfoRet::Success(response.json().await?)) + } +} + +#[derive(Deserialize)] +struct GraphQLData { + repository: Option, +} + +async fn fetch_repo_info_graphql_api( + client: &remote::Client, + GhRepo { owner, repo }: &GhRepo, + auth_token: &str, +) -> Result { + let query = format!( + r#" +query {{ + repository(owner:"{owner}",name:"{repo}") {{ + owner {{ + login + }} + name + private: isPrivate + }} +}}"# + ); + + match issue_graphql_query(client, query, auth_token).await? { + GraphQLResult::Data(repo_info) => Ok(common::GhApiRet::Success(repo_info)), + GraphQLResult::Else(ret) => Ok(ret), + } +} + +pub(super) async fn fetch_repo_info( + client: &remote::Client, + repo: &GhRepo, + auth_token: Option<&str>, +) -> Result { + if let Some(auth_token) = auth_token { + let res = fetch_repo_info_graphql_api(client, repo, auth_token) + .await + .map_err(|err| err.context("GraphQL API")); + + match res { + // Fallback to Restful API + Ok(FetchRepoInfoRet::Unauthorized) => (), + res => return res, + } + } + + fetch_repo_info_restful_api(client, repo, auth_token) + .await + .map_err(|err| err.context("Restful API")) +}