Commit a268da36 authored by Ben Boeckel's avatar Ben Boeckel Committed by Kitware Robot
Browse files

Merge topic 'async-client'

9cf36abe Add async API additions to the changelog.
ad08a844 gitlab-ci: bump minimum version
d3fd1108 Add async tests to Endpoint, Ignore, Paged, Raw.
7294cae9 Add GitlabBuilder::build_async and async gitlab client.
482e147b Implement AsyncQuery for Endpoint, Ignore, Paged, Raw.
7fb170ef Add traits for async query and client.
695a470b

 rustfmt: specify that we want the 2018 edition
Acked-by: Kitware Robot's avatarKitware Robot <kwrobot@kitware.com>
Acked-by: Ben Boeckel's avatarBen Boeckel <ben.boeckel@kitware.com>
Reviewed-by: Ben Boeckel's avatarBen Boeckel <ben.boeckel@kitware.com>
Merge-request: !284
parents 9e698263 9cf36abe
...@@ -50,7 +50,7 @@ before_script: ...@@ -50,7 +50,7 @@ before_script:
- cargo tarpaulin --frozen $CARGO_FEATURES --exclude-files vendor --ignore-panics --all --verbose --out Html - cargo tarpaulin --frozen $CARGO_FEATURES --exclude-files vendor --ignore-panics --all --verbose --out Html
.rust_minimum: &rust_minimum .rust_minimum: &rust_minimum
image: "rust:1.42.0" image: "rust:1.43.0"
variables: variables:
CARGO_UPDATE_POLICY: newest CARGO_UPDATE_POLICY: newest
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
* Added the `api::projects::releases::ProjectReleases` endpoint to list all * Added the `api::projects::releases::ProjectReleases` endpoint to list all
releases for a project. releases for a project.
* Listing commits in a repository can now be done via `Commits` * Listing commits in a repository can now be done via `Commits`
* Added asynchronous API for query `api::AsyncQuery` and client `api::AsyncClient`.
* Added asynchronous client `AsyncGitlab` (created by `GitlabBuilder::build_async`).
# v0.1310.0 # v0.1310.0
......
...@@ -15,7 +15,15 @@ edition = "2018" ...@@ -15,7 +15,15 @@ edition = "2018"
[features] [features]
default = ["client_api"] default = ["client_api"]
client_api = ["itertools", "percent-encoding", "reqwest", "thiserror", "graphql_client"] client_api = [
"itertools",
"percent-encoding",
"reqwest",
"thiserror",
"graphql_client",
"async-trait",
"futures-util",
]
[dependencies] [dependencies]
base64 = "~0.12" base64 = "~0.12"
...@@ -25,6 +33,8 @@ log = "~0.4" ...@@ -25,6 +33,8 @@ log = "~0.4"
percent-encoding = { version = "^2.0", optional = true } percent-encoding = { version = "^2.0", optional = true }
reqwest = { version = "~0.10.5", features = ["blocking", "json"], optional = true } reqwest = { version = "~0.10.5", features = ["blocking", "json"], optional = true }
thiserror = { version = "^1.0.2", optional = true } thiserror = { version = "^1.0.2", optional = true }
async-trait = { version = "~0.1", optional = true }
futures-util = { version = "0.3.13", default-features = false, optional = true }
bytes = "~0.5" bytes = "~0.5"
chrono = { version = "~0.4", features = ["serde"] } chrono = { version = "~0.4", features = ["serde"] }
...@@ -37,3 +47,4 @@ url = "^2.1" ...@@ -37,3 +47,4 @@ url = "^2.1"
[dev-dependencies] [dev-dependencies]
itertools = { version = "~0.8" } itertools = { version = "~0.8" }
tokio = { version = "1.4.0", features = ["macros", "rt-multi-thread"] }
edition = "2018"
match_block_trailing_comma = true match_block_trailing_comma = true
force_multiline_blocks = true force_multiline_blocks = true
report_todo = "Unnumbered" report_todo = "Unnumbered"
......
...@@ -76,6 +76,7 @@ pub mod users; ...@@ -76,6 +76,7 @@ pub mod users;
pub(crate) mod helpers; pub(crate) mod helpers;
pub use self::client::AsyncClient;
pub use self::client::Client; pub use self::client::Client;
pub use self::endpoint::Endpoint; pub use self::endpoint::Endpoint;
...@@ -97,6 +98,7 @@ pub use self::params::FormParams; ...@@ -97,6 +98,7 @@ pub use self::params::FormParams;
pub use self::params::ParamValue; pub use self::params::ParamValue;
pub use self::params::QueryParams; pub use self::params::QueryParams;
pub use self::query::AsyncQuery;
pub use self::query::Query; pub use self::query::Query;
pub use self::raw::raw; pub use self::raw::raw;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
use std::error::Error; use std::error::Error;
use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use http::request::Builder as RequestBuilder; use http::request::Builder as RequestBuilder;
use http::Response; use http::Response;
...@@ -30,3 +31,22 @@ pub trait Client { ...@@ -30,3 +31,22 @@ pub trait Client {
body: Vec<u8>, body: Vec<u8>,
) -> Result<Response<Bytes>, ApiError<Self::Error>>; ) -> Result<Response<Bytes>, ApiError<Self::Error>>;
} }
/// A trait representing an asynchronous client which can communicate with a GitLab instance.
#[async_trait]
pub trait AsyncClient {
/// The errors which may occur for this client.
type Error: Error + Send + Sync + 'static;
/// Get the URL for the endpoint for the client.
///
/// This method adds the hostname for the client's target instance.
fn rest_endpoint(&self, endpoint: &str) -> Result<Url, ApiError<Self::Error>>;
/// Send a REST query asynchronously.
async fn rest_async(
&self,
request: RequestBuilder,
body: Vec<u8>,
) -> Result<Response<Bytes>, ApiError<Self::Error>>;
}
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
use std::borrow::Cow; use std::borrow::Cow;
use async_trait::async_trait;
use http::{self, header, Method, Request}; use http::{self, header, Method, Request};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::api::{query, ApiError, BodyError, Client, Query, QueryParams}; use crate::api::{query, ApiError, AsyncClient, AsyncQuery, BodyError, Client, Query, QueryParams};
/// A trait for providing the necessary information for a single REST API endpoint. /// A trait for providing the necessary information for a single REST API endpoint.
pub trait Endpoint { pub trait Endpoint {
...@@ -61,6 +62,37 @@ where ...@@ -61,6 +62,37 @@ where
} }
} }
#[async_trait]
impl<E, T, C> AsyncQuery<T, C> for E
where
E: Endpoint + Sync,
T: DeserializeOwned + 'static,
C: AsyncClient + Sync,
{
async fn query_async(&self, client: &C) -> Result<T, ApiError<C::Error>> {
let mut url = client.rest_endpoint(&self.endpoint())?;
self.parameters().add_to_url(&mut url);
let req = Request::builder()
.method(self.method())
.uri(query::url_to_http_uri(url));
let (req, data) = if let Some((mime, data)) = self.body()? {
let req = req.header(header::CONTENT_TYPE, mime);
(req, data)
} else {
(req, Vec::new())
};
let rsp = client.rest_async(req, data).await?;
let status = rsp.status();
let v = serde_json::from_slice(rsp.body())?;
if !status.is_success() {
return Err(ApiError::from_gitlab(v));
}
serde_json::from_value::<T>(v).map_err(ApiError::data_type::<T>)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use http::StatusCode; use http::StatusCode;
...@@ -68,7 +100,7 @@ mod tests { ...@@ -68,7 +100,7 @@ mod tests {
use serde_json::json; use serde_json::json;
use crate::api::endpoint_prelude::*; use crate::api::endpoint_prelude::*;
use crate::api::{ApiError, Query}; use crate::api::{ApiError, AsyncQuery, Query};
use crate::test::client::{ExpectedUrl, SingleTestClient}; use crate::test::client::{ExpectedUrl, SingleTestClient};
struct Dummy; struct Dummy;
...@@ -262,4 +294,18 @@ mod tests { ...@@ -262,4 +294,18 @@ mod tests {
let res: DummyResult = Dummy.query(&client).unwrap(); let res: DummyResult = Dummy.query(&client).unwrap();
assert_eq!(res.value, 0); assert_eq!(res.value, 0);
} }
#[tokio::test]
async fn test_good_deserialization_async() {
let endpoint = ExpectedUrl::builder().endpoint("dummy").build().unwrap();
let client = SingleTestClient::new_json(
endpoint,
&json!({
"value": 0,
}),
);
let res: DummyResult = Dummy.query_async(&client).await.unwrap();
assert_eq!(res.value, 0);
}
} }
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use async_trait::async_trait;
use http::{header, Request}; use http::{header, Request};
use crate::api::{query, ApiError, Client, Endpoint, Query}; use crate::api::{query, ApiError, AsyncClient, AsyncQuery, Client, Endpoint, Query};
/// A query modifier that ignores the data returned from an endpoint. /// A query modifier that ignores the data returned from an endpoint.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
...@@ -49,13 +50,42 @@ where ...@@ -49,13 +50,42 @@ where
} }
} }
#[async_trait]
impl<E, C> AsyncQuery<(), C> for Ignore<E>
where
E: Endpoint + Sync,
C: AsyncClient + Sync,
{
async fn query_async(&self, client: &C) -> Result<(), ApiError<C::Error>> {
let mut url = client.rest_endpoint(&self.endpoint.endpoint())?;
self.endpoint.parameters().add_to_url(&mut url);
let req = Request::builder()
.method(self.endpoint.method())
.uri(query::url_to_http_uri(url));
let (req, data) = if let Some((mime, data)) = self.endpoint.body()? {
let req = req.header(header::CONTENT_TYPE, mime);
(req, data)
} else {
(req, Vec::new())
};
let rsp = client.rest_async(req, data).await?;
if !rsp.status().is_success() {
let v = serde_json::from_slice(rsp.body())?;
return Err(ApiError::from_gitlab(v));
}
Ok(())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use http::StatusCode; use http::StatusCode;
use serde_json::json; use serde_json::json;
use crate::api::endpoint_prelude::*; use crate::api::endpoint_prelude::*;
use crate::api::{self, ApiError, Query}; use crate::api::{self, ApiError, AsyncQuery, Query};
use crate::test::client::{ExpectedUrl, SingleTestClient}; use crate::test::client::{ExpectedUrl, SingleTestClient};
struct Dummy; struct Dummy;
...@@ -83,6 +113,14 @@ mod tests { ...@@ -83,6 +113,14 @@ mod tests {
api::ignore(Dummy).query(&client).unwrap() api::ignore(Dummy).query(&client).unwrap()
} }
#[tokio::test]
async fn test_gitlab_non_json_response_async() {
let endpoint = ExpectedUrl::builder().endpoint("dummy").build().unwrap();
let client = SingleTestClient::new_raw(endpoint, "not json");
api::ignore(Dummy).query_async(&client).await.unwrap()
}
#[test] #[test]
fn test_gitlab_error_bad_json() { fn test_gitlab_error_bad_json() {
let endpoint = ExpectedUrl::builder() let endpoint = ExpectedUrl::builder()
......
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use http::{header, HeaderMap, Request}; use http::{header, HeaderMap, Request};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use thiserror::Error; use thiserror::Error;
use url::Url; use url::Url;
use crate::api::{query, ApiError, Client, Endpoint, Query}; use crate::api::{query, ApiError, AsyncClient, AsyncQuery, Client, Endpoint, Query};
/// Errors which may occur with pagination. /// Errors which may occur with pagination.
#[non_exhaustive] #[non_exhaustive]
...@@ -262,6 +265,104 @@ where ...@@ -262,6 +265,104 @@ where
} }
} }
#[async_trait]
impl<E, T, C> AsyncQuery<Vec<T>, C> for Paged<E>
where
E: Endpoint + Sync,
E: Pageable,
T: DeserializeOwned + Send + 'static,
C: AsyncClient + Sync,
{
async fn query_async(&self, client: &C) -> Result<Vec<T>, ApiError<C::Error>> {
let url = {
let mut url = client.rest_endpoint(&self.endpoint.endpoint())?;
self.endpoint.parameters().add_to_url(&mut url);
url
};
let mut page_num = 1;
let per_page = self.pagination.page_limit();
let per_page_str = format!("{}", per_page);
let results = Arc::new(Mutex::new(Vec::new()));
let mut next_url = None;
let use_keyset_pagination = self.endpoint.use_keyset_pagination();
let body = self.endpoint.body()?;
loop {
let page_url = if let Some(url) = next_url.take() {
url
} else {
let page_str = format!("{}", page_num);
let mut page_url = url.clone();
{
let mut pairs = page_url.query_pairs_mut();
pairs.append_pair("per_page", &per_page_str);
if use_keyset_pagination {
pairs.append_pair("pagination", "keyset");
} else {
pairs.append_pair("page", &page_str);
}
}
page_url
};
let req = Request::builder()
.method(self.endpoint.method())
.uri(query::url_to_http_uri(page_url));
let (req, data) = if let Some((mime, data)) = body.as_ref() {
let req = req.header(header::CONTENT_TYPE, *mime);
(req, data.clone())
} else {
(req, Vec::new())
};
let rsp = client.rest_async(req, data).await?;
let status = rsp.status();
if use_keyset_pagination {
next_url = next_page_from_headers(rsp.headers())?;
}
let v = serde_json::from_slice(rsp.body())?;
if !status.is_success() {
return Err(ApiError::from_gitlab(v));
}
let page =
serde_json::from_value::<Vec<T>>(v).map_err(ApiError::data_type::<Vec<T>>)?;
let page_len = page.len();
// Gitlab used to have issues returning paginated results; these have been fixed since,
// but if it is needed, the bug manifests as Gitlab returning *all* results instead of
// just the requested results. This can cause an infinite loop here if the number of
// total results is exactly equal to `per_page`.
let is_last_page = {
let mut locked_results = results.lock().expect("poisoned results");
locked_results.extend(page);
self.pagination.is_last_page(page_len, &locked_results)
};
if is_last_page {
break;
}
if use_keyset_pagination {
if next_url.is_none() {
break;
}
} else {
page_num += 1;
}
}
let mut locked_results = results.lock().expect("poisoned results");
Ok(std::mem::take(&mut locked_results))
}
}
fn next_page_from_headers(headers: &HeaderMap) -> Result<Option<Url>, PaginationError> { fn next_page_from_headers(headers: &HeaderMap) -> Result<Option<Url>, PaginationError> {
let link_headers = headers.get_all(reqwest::header::LINK).iter(); let link_headers = headers.get_all(reqwest::header::LINK).iter();
// GitLab 14.0 will deprecate this header in preference for the W3C spec's `Link` header. Make // GitLab 14.0 will deprecate this header in preference for the W3C spec's `Link` header. Make
...@@ -277,7 +378,7 @@ fn next_page_from_headers(headers: &HeaderMap) -> Result<Option<Url>, Pagination ...@@ -277,7 +378,7 @@ fn next_page_from_headers(headers: &HeaderMap) -> Result<Option<Url>, Pagination
}) })
.collect::<Result<Vec<_>, PaginationError>>()? .collect::<Result<Vec<_>, PaginationError>>()?
.into_iter() .into_iter()
.filter_map(|header| { .find_map(|header| {
let is_next_link = header let is_next_link = header
.params .params
.into_iter() .into_iter()
...@@ -289,7 +390,6 @@ fn next_page_from_headers(headers: &HeaderMap) -> Result<Option<Url>, Pagination ...@@ -289,7 +390,6 @@ fn next_page_from_headers(headers: &HeaderMap) -> Result<Option<Url>, Pagination
None None
} }
}) })
.next()
.transpose() .transpose()
} }
...@@ -301,7 +401,7 @@ mod tests { ...@@ -301,7 +401,7 @@ mod tests {
use crate::api::endpoint_prelude::*; use crate::api::endpoint_prelude::*;
use crate::api::paged::LinkHeader; use crate::api::paged::LinkHeader;
use crate::api::{self, ApiError, LinkHeaderParseError, Pagination, Query}; use crate::api::{self, ApiError, AsyncQuery, LinkHeaderParseError, Pagination, Query};
use crate::test::client::{ExpectedUrl, PagedTestClient, SingleTestClient}; use crate::test::client::{ExpectedUrl, PagedTestClient, SingleTestClient};
#[test] #[test]
...@@ -542,6 +642,35 @@ mod tests { ...@@ -542,6 +642,35 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_pagination_limit_async() {
let endpoint = ExpectedUrl::builder()
.endpoint("paged_dummy")
.paginated(true)
.build()
.unwrap();
let client = PagedTestClient::new_raw(
endpoint,
(0..=255).map(|value| {
DummyResult {
value,
}
}),
);
let query = Dummy {
with_keyset: false,
};
let res: Vec<DummyResult> = api::paged(query, Pagination::Limit(25))
.query_async(&client)
.await
.unwrap();
assert_eq!(res.len(), 25);
for (i, value) in res.iter().enumerate() {
assert_eq!(value.value, i as u8);
}
}
#[test] #[test]
fn test_pagination_all() { fn test_pagination_all() {
let endpoint = ExpectedUrl::builder() let endpoint = ExpectedUrl::builder()
...@@ -566,6 +695,33 @@ mod tests { ...@@ -566,6 +695,33 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_pagination_all_async() {
let endpoint = ExpectedUrl::builder()
.endpoint("paged_dummy")
.paginated(true)
.build()
.unwrap();
let client = PagedTestClient::new_raw(
endpoint,
(0..=255).map(|value| {
DummyResult {
value,
}
}),
);
let query = Dummy::default();
let res: Vec<DummyResult> = api::paged(query, Pagination::All)
.query_async(&client)
.await
.unwrap();
assert_eq!(res.len(), 256);
for (i, value) in res.iter().enumerate() {
assert_eq!(value.value, i as u8);
}
}
#[test] #[test]
fn test_keyset_pagination_limit() { fn test_keyset_pagination_limit() {
let endpoint = ExpectedUrl::builder() let endpoint = ExpectedUrl::builder()
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use async_trait::async_trait;
use http::Uri; use http::Uri;
use url::Url; use url::Url;
use crate::api::{ApiError, Client}; use crate::api::{ApiError, AsyncClient, Client};
pub fn url_to_http_uri(url: Url) -> Uri { pub fn url_to_http_uri(url: Url) -> Uri {
url.as_str() url.as_str()
...@@ -23,3 +24,13 @@ where ...@@ -23,3 +24,13 @@ where
/// Perform the query against the client. /// Perform the query against the client.
fn query(&self, client: &C) -> Result<T, ApiError<C::Error>>; fn query(&self, client: &C) -> Result<T, ApiError<C::Error>>;
} }
/// A trait which represents an asynchronous query which may be made to a GitLab client.
#[async_trait]
pub trait AsyncQuery<T, C>