From 4a6f49a481629482550b8bf75bf93d187bd970eb Mon Sep 17 00:00:00 2001 From: Arne Bahlo Date: Thu, 24 Aug 2023 12:32:40 +0200 Subject: [PATCH 1/4] Implement blocking client This adds the sync HTTP client [ureq](https://lib.rs/ureq) as well as [maybe-async](https://lib.rs/maybe-async) for conditional compilation. You can now use a blocking client by setting the `blocking` feature flag. --- Cargo.toml | 22 ++- src/client.rs | 40 +++-- src/datasets/client.rs | 8 + src/error.rs | 29 +++- src/http.rs | 382 +++++++---------------------------------- src/http_async.rs | 305 ++++++++++++++++++++++++++++++++ src/http_blocking.rs | 181 +++++++++++++++++++ src/lib.rs | 9 + src/limits.rs | 70 ++++++++ src/users/client.rs | 2 + 10 files changed, 704 insertions(+), 344 deletions(-) create mode 100644 src/http_async.rs create mode 100644 src/http_blocking.rs diff --git a/Cargo.toml b/Cargo.toml index 50b66e3..27e55cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ include = ["src/**/*.rs", "README.md", "LICENSE-APACHE", "LICENSE-MIT"] resolver = "2" [dependencies] -reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "gzip", "blocking"] } serde = { version = "1", features = ["derive"] } serde_json = "1" chrono = { version = "0.4", features = ["serde"] } @@ -25,13 +24,21 @@ bytes = "1" flate2 = "1" http = "0.2" backoff = { version = "0.4", features = ["futures"] } -futures = "0.3" -tokio = { version = "1", optional = true, features = ["rt", "sync"] } -async-std = { version = "1", optional = true, features = ["tokio1"] } url = "2" tracing = { version = "0.1" } -tokio-stream = "0.1" bitflags = "2" +maybe-async = "0.2.7" + +# sync packages +ureq = { version = "2.7.1", optional = true, features = ["json"] } + +# async packages +futures = { version = "0.3", optional = true } +async-trait = { version = "0.1", optional = true } +tokio = { version = "1", optional = true, features = ["rt", "sync"] } +tokio-stream = { version = "0.1", optional = true } +async-std = { version = "1", optional = true, features = ["tokio1"] } +reqwest = { version = "0.11", optional = true, default-features = false, features = ["json", "stream", "gzip", "blocking"] } [dev-dependencies] tokio = { version = "1", features = ["full"] } @@ -46,8 +53,9 @@ tracing-subscriber = { version = "0.3", features = ["ansi", "env-filter"] } [features] default = ["tokio", "default-tls"] -tokio = ["backoff/tokio", "dep:tokio"] -async-std = ["backoff/async-std", "dep:async-std"] +tokio = ["backoff/tokio", "dep:tokio", "futures", "async-trait", "tokio-stream", "reqwest"] +async-std = ["backoff/async-std", "dep:async-std", "futures", "async-trait", "tokio-stream", "reqwest"] default-tls = ["reqwest/default-tls"] native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] +blocking = ["ureq", "maybe-async/is_sync"] diff --git a/src/client.rs b/src/client.rs index adb41a8..8881292 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,7 +3,9 @@ use async_std::task::spawn_blocking; use bytes::Bytes; use flate2::{write::GzEncoder, Compression}; +#[cfg(not(feature = "blocking"))] use futures::Stream; +use maybe_async::{async_impl, maybe_async}; use reqwest::header; use serde::Serialize; use std::{ @@ -12,6 +14,7 @@ use std::{ }; #[cfg(feature = "tokio")] use tokio::task::spawn_blocking; +#[cfg(not(feature = "blocking"))] use tokio_stream::StreamExt; use tracing::instrument; @@ -78,12 +81,13 @@ impl Client { } /// Get client version. - pub async fn version(&self) -> String { + pub fn version(&self) -> String { env!("CARGO_PKG_VERSION").to_string() } /// Executes the given query specified using the Axiom Processing Language (APL). /// To learn more about APL, see the APL documentation at https://www.axiom.co/docs/apl/introduction. + #[maybe_async] #[instrument(skip(self, opts))] pub async fn query(&self, apl: S, opts: O) -> Result where @@ -122,20 +126,17 @@ impl Client { let res = self.http_client.post(path, &req).await?; let saved_query_id = res - .headers() - .get("X-Axiom-History-Query-Id") - .map(|s| s.to_str()) - .transpose() - .map_err(|_e| Error::InvalidQueryId)? + .get_header("X-Axiom-History-Query-Id") .map(|s| s.to_string()); - let mut result = res.json::().await?; + let mut result: QueryResult = res.json::().await?; result.saved_query_id = saved_query_id; Ok(result) } /// Execute the given query on the dataset identified by its id. + #[maybe_async] #[instrument(skip(self, opts))] #[deprecated( since = "0.6.0", @@ -162,13 +163,9 @@ impl Client { let res = self.http_client.post(path, &query).await?; let saved_query_id = res - .headers() - .get("X-Axiom-History-Query-Id") - .map(|s| s.to_str()) - .transpose() - .map_err(|_e| Error::InvalidQueryId)? + .get_header("X-Axiom-History-Query-Id") .map(|s| s.to_string()); - let mut result = res.json::().await?; + let mut result: LegacyQueryResult = res.json::().await?; result.saved_query_id = saved_query_id; Ok(result) @@ -177,6 +174,7 @@ impl Client { /// Ingest events into the dataset identified by its id. /// Restrictions for field names (JSON object keys) can be reviewed here: /// . + #[maybe_async] #[instrument(skip(self, events))] pub async fn ingest(&self, dataset_name: N, events: I) -> Result where @@ -189,13 +187,24 @@ impl Client { .map(|event| serde_json::to_vec(&event).map_err(Error::Serialize)) .collect(); let json_payload = json_lines?.join(&b"\n"[..]); + + #[cfg(not(feature = "blocking"))] let payload = spawn_blocking(move || { let mut gzip_payload = GzEncoder::new(Vec::new(), Compression::default()); gzip_payload.write_all(&json_payload)?; gzip_payload.finish() }) .await; - #[cfg(feature = "tokio")] + #[cfg(feature = "blocking")] + let payload = { + let mut gzip_payload = GzEncoder::new(Vec::new(), Compression::default()); + gzip_payload + .write_all(&json_payload) + .map_err(Error::Encoding)?; + gzip_payload.finish() + }; + + #[cfg(all(feature = "tokio", not(feature = "blocking")))] let payload = payload.map_err(Error::JoinError)?; let payload = payload.map_err(Error::Encoding)?; @@ -211,6 +220,7 @@ impl Client { /// Ingest data into the dataset identified by its id. /// Restrictions for field names (JSON object keys) can be reviewed here: /// . + #[maybe_async] #[instrument(skip(self, payload))] pub async fn ingest_bytes( &self, @@ -243,6 +253,7 @@ impl Client { /// with a backoff. /// Restrictions for field names (JSON object keys) can be reviewed here: /// . + #[async_impl] #[instrument(skip(self, stream))] pub async fn ingest_stream(&self, dataset_name: N, stream: S) -> Result where @@ -261,6 +272,7 @@ impl Client { } /// Like [`Client::ingest_stream`], but takes a stream that contains results. + #[async_impl] #[instrument(skip(self, stream))] pub async fn try_ingest_stream( &self, diff --git a/src/datasets/client.rs b/src/datasets/client.rs index b7f9abe..4289193 100644 --- a/src/datasets/client.rs +++ b/src/datasets/client.rs @@ -1,3 +1,4 @@ +use maybe_async::maybe_async; use std::{ convert::{TryFrom, TryInto}, fmt::Debug as FmtDebug, @@ -27,6 +28,7 @@ impl Client { } /// Create a dataset with the given name and description. + #[maybe_async] #[instrument(skip(self))] pub async fn create(&self, dataset_name: N, description: D) -> Result where @@ -45,6 +47,7 @@ impl Client { } /// Delete the dataset with the given ID. + #[maybe_async] #[instrument(skip(self))] pub async fn delete(&self, dataset_name: N) -> Result<()> where @@ -56,6 +59,7 @@ impl Client { } /// Get a dataset by its id. + #[maybe_async] #[instrument(skip(self))] pub async fn get(&self, dataset_name: N) -> Result where @@ -69,6 +73,7 @@ impl Client { } /// Retrieve the information of the dataset identified by its id. + #[maybe_async] #[instrument(skip(self))] #[deprecated( since = "0.8.0", @@ -86,6 +91,7 @@ impl Client { } /// List all available datasets. + #[maybe_async] #[instrument(skip(self))] pub async fn list(&self) -> Result> { self.http_client.get("/v1/datasets").await?.json().await @@ -96,6 +102,7 @@ impl Client { /// Older ones will be deleted from the dataset. /// The duration can either be a [`std::time::Duration`] or a /// [`chrono::Duration`]. + #[maybe_async] #[instrument(skip(self))] #[allow(deprecated)] pub async fn trim(&self, dataset_name: N, duration: D) -> Result @@ -113,6 +120,7 @@ impl Client { } /// Update a dataset. + #[maybe_async] #[instrument(skip(self))] pub async fn update(&self, dataset_name: N, new_description: D) -> Result where diff --git a/src/error.rs b/src/error.rs index 8f7e3a2..79602c3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,7 @@ //! Error type definitions. use serde::Deserialize; -use std::fmt; +use std::{fmt, io}; use thiserror::Error; use crate::limits::Limits; @@ -21,12 +21,21 @@ pub enum Error { InvalidToken, #[error("Invalid Org ID (make sure there are no invalid characters)")] InvalidOrgId, + #[cfg(not(feature = "blocking"))] #[error("Failed to setup HTTP client: {0}")] HttpClientSetup(reqwest::Error), + #[cfg(not(feature = "blocking"))] #[error("Failed to deserialize response: {0}")] Deserialize(reqwest::Error), + #[cfg(feature = "blocking")] + #[error("Failed to deserialize response: {0}")] + Deserialize(io::Error), + #[cfg(not(feature = "blocking"))] #[error("Http error: {0}")] Http(reqwest::Error), + #[cfg(feature = "blocking")] + #[error("Http error: {0}")] + Http(ureq::Error), #[error(transparent)] Axiom(AxiomError), #[error("Query ID contains invisible characters (this is a server error)")] @@ -36,10 +45,10 @@ pub enum Error { #[error(transparent)] Serialize(#[from] serde_json::Error), #[error("Failed to encode payload: {0}")] - Encoding(std::io::Error), + Encoding(io::Error), #[error("Duration is out of range (can't be larger than i64::MAX milliseconds)")] DurationOutOfRange, - #[cfg(feature = "tokio")] + #[cfg(all(feature = "tokio", not(feature = "blocking")))] #[error("Failed to join thread: {0}")] JoinError(tokio::task::JoinError), #[error("Rate limit exceeded for the {scope} scope: {limits}")] @@ -60,6 +69,7 @@ pub enum Error { /// This is the manual implementation. We don't really care if the error is /// permanent or transient at this stage so we just return Error::Http. +#[cfg(not(feature = "blocking"))] impl From> for Error { fn from(err: backoff::Error) -> Self { match err { @@ -72,6 +82,19 @@ impl From> for Error { } } +#[cfg(feature = "blocking")] +impl From> for Error { + fn from(err: backoff::Error) -> Self { + match err { + backoff::Error::Permanent(err) => Error::Http(err), + backoff::Error::Transient { + err, + retry_after: _, + } => Error::Http(err), + } + } +} + /// An error returned by the Axiom API. #[derive(Deserialize, Debug)] pub struct AxiomError { diff --git a/src/http.rs b/src/http.rs index 6404732..7dc1b7e 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,25 +1,18 @@ -use backoff::{future::retry, ExponentialBackoffBuilder}; +use backoff::{ExponentialBackoff, ExponentialBackoffBuilder}; use bytes::Bytes; -use http::header; pub(crate) use http::HeaderMap; -use serde::{de::DeserializeOwned, Serialize}; -use std::{env, time::Duration}; -use url::Url; +use maybe_async::maybe_async; +use serde::Serialize; +use std::time::Duration; -use crate::{ - error::{AxiomError, Error, Result}, - limits::Limit, -}; +use crate::error::{Error, Result}; +#[cfg(not(feature = "blocking"))] +use crate::http_async::{Client as ClientImpl, Response as ResponseImpl}; +#[cfg(feature = "blocking")] +use crate::http_blocking::{Client as ClientImpl, Response as ResponseImpl}; -static USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); - -/// Client is a wrapper around reqwest::Client which provides automatically -/// prepending the base url. -#[derive(Debug, Clone)] -pub(crate) struct Client { - base_url: Url, - inner: reqwest::Client, -} +pub(crate) static USER_AGENT: &str = + concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); #[derive(Clone)] pub(crate) enum Body { @@ -28,354 +21,103 @@ pub(crate) enum Body { Bytes(Bytes), } +pub(crate) fn build_backoff() -> ExponentialBackoff { + ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_millis(500)) // first retry after 500ms + .with_multiplier(2.0) // all following retries are twice as long as the previous one + .with_max_elapsed_time(Some(Duration::from_secs(30))) // try up to 30s + .build() +} + +#[derive(Debug, Clone)] +pub(crate) struct Client { + inner: ClientImpl, +} + impl Client { - /// Creates a new client. pub(crate) fn new(base_url: U, token: T, org_id: O) -> Result where U: AsRef, T: Into, O: Into>, { - let base_url = Url::parse(base_url.as_ref()).map_err(Error::InvalidUrl)?; - let token = token.into(); - - let mut default_headers = header::HeaderMap::new(); - let token_header_value = header::HeaderValue::from_str(&format!("Bearer {}", token)) - .map_err(|_e| Error::InvalidToken)?; - default_headers.insert(header::AUTHORIZATION, token_header_value); - if let Some(org_id) = org_id.into() { - let org_id_header_value = - header::HeaderValue::from_str(&org_id).map_err(|_e| Error::InvalidOrgId)?; - default_headers.insert("X-Axiom-Org-Id", org_id_header_value); - } - - let http_client = reqwest::Client::builder() - .user_agent(USER_AGENT) - .default_headers(default_headers) - .timeout(Duration::from_secs(10)) - .build() - .map_err(Error::HttpClientSetup)?; - Ok(Self { - base_url, - inner: http_client, + inner: ClientImpl::new(base_url, token, org_id)?, }) } - async fn execute( - &self, - method: http::Method, - path: P, - body: Body, - headers: H, - ) -> Result - where - P: AsRef, - H: Into>, - { - let url = self - .base_url - .join(path.as_ref().trim_start_matches('/')) - .map_err(Error::InvalidUrl)?; - - let headers = headers.into(); - let backoff = ExponentialBackoffBuilder::new() - .with_initial_interval(Duration::from_millis(500)) // first retry after 500ms - .with_multiplier(2.0) // all following retries are twice as long as the previous one - .with_max_elapsed_time(Some(Duration::from_secs(30))) // try up to 30s - .build(); - - let res = retry(backoff, || async { - let mut req = self.inner.request(method.clone(), url.clone()); - if let Some(headers) = headers.clone() { - req = req.headers(headers); - } - match body.clone() { - Body::Empty => {} - Body::Json(value) => req = req.json(&value), - Body::Bytes(bytes) => req = req.body(bytes), - } - self.inner.execute(req.build()?).await.map_err(|e| { - if let Some(status) = e.status() { - if status.is_client_error() { - // Don't retry 4XX - return backoff::Error::permanent(e); - } - } - - backoff::Error::transient(e) - }) - }) - .await - .map(|res| Response::new(res, method, path.as_ref().to_string())) - .map_err(Error::Http)?; - - Ok(res) - } - - pub(crate) async fn get(&self, path: S) -> Result + #[maybe_async] + pub(crate) async fn get(&self, path: S) -> Result where S: AsRef, { - self.execute(http::Method::GET, path.as_ref(), Body::Empty, None) + self.inner + .execute(http::Method::GET, path.as_ref(), Body::Empty, None) .await } - pub(crate) async fn post(&self, path: S, payload: P) -> Result + #[maybe_async] + pub(crate) async fn post(&self, path: S, payload: P) -> Result where S: AsRef, P: Serialize, { - self.execute( - http::Method::POST, - path, - Body::Json(serde_json::to_value(payload).map_err(Error::Serialize)?), - None, - ) - .await + self.inner + .execute( + http::Method::POST, + path, + Body::Json(serde_json::to_value(payload).map_err(Error::Serialize)?), + None, + ) + .await } + #[maybe_async] pub(crate) async fn post_bytes( &self, path: S, payload: P, headers: H, - ) -> Result + ) -> Result where S: AsRef, P: Into, H: Into>, { - self.execute( - http::Method::POST, - path, - Body::Bytes(payload.into()), - headers, - ) - .await + self.inner + .execute( + http::Method::POST, + path, + Body::Bytes(payload.into()), + headers, + ) + .await } - pub(crate) async fn put(&self, path: S, payload: P) -> Result + #[maybe_async] + pub(crate) async fn put(&self, path: S, payload: P) -> Result where S: AsRef, P: Serialize, { - self.execute( - http::Method::PUT, - path, - Body::Json(serde_json::to_value(payload).map_err(Error::Serialize)?), - None, - ) - .await + self.inner + .execute( + http::Method::PUT, + path, + Body::Json(serde_json::to_value(payload).map_err(Error::Serialize)?), + None, + ) + .await } + #[maybe_async] pub(crate) async fn delete(&self, path: S) -> Result<()> where S: AsRef, { - self.execute(http::Method::DELETE, path, Body::Empty, None) + self.inner + .execute(http::Method::DELETE, path, Body::Empty, None) .await?; Ok(()) } } - -pub(crate) struct Response { - inner: reqwest::Response, - method: http::Method, - path: String, - limits: Option, -} - -impl Response { - pub(crate) fn new(inner: reqwest::Response, method: http::Method, path: String) -> Self { - let limits = Limit::try_from(&inner); - Self { - inner, - method, - path, - limits, - } - } - - pub(crate) async fn json(self) -> Result { - self.check_error() - .await? - .inner - .json::() - .await - .map_err(Error::Deserialize) - } - - pub(crate) async fn check_error(self) -> Result { - let status = self.inner.status(); - if !status.is_success() { - // Check if we hit some limits - match self.limits { - Some(Limit::Rate(scope, limits)) => { - return Err(Error::RateLimitExceeded { scope, limits }); - } - Some(Limit::Query(limit)) => { - return Err(Error::QueryLimitExceeded(limit)); - } - Some(Limit::Ingest(limit)) => { - return Err(Error::IngestLimitExceeded(limit)); - } - None => {} - } - - // Try to decode the error - let e = match self.inner.json::().await { - Ok(mut e) => { - e.status = status.as_u16(); - e.method = self.method; - e.path = self.path; - Error::Axiom(e) - } - Err(_e) => { - // Decoding failed, we still want an AxiomError - Error::Axiom(AxiomError::new( - status.as_u16(), - self.method, - self.path, - None, - )) - } - }; - return Err(e); - } - - Ok(self) - } - - pub(crate) fn headers(&self) -> &header::HeaderMap { - self.inner.headers() - } -} - -impl From for reqwest::Response { - fn from(res: Response) -> Self { - res.inner - } -} - -#[cfg(test)] -mod test { - use chrono::{Duration, Utc}; - use httpmock::prelude::*; - use serde_json::json; - - use crate::{limits, Client, Error}; - - #[tokio::test] - async fn test_ingest_limit_exceeded() -> Result<(), Box> { - let expires_after = Duration::seconds(1); - let tomorrow = Utc::now() + expires_after; - - let server = MockServer::start(); - let rate_mock = server.mock(|when, then| { - when.method(POST).path("/v1/datasets/test/ingest"); - then.status(430) - .json_body(json!({ "message": "ingest limit exceeded" })) - .header(limits::HEADER_INGEST_LIMIT, "42") - .header(limits::HEADER_INGEST_REMAINING, "0") - .header( - limits::HEADER_INGEST_RESET, - format!("{}", tomorrow.timestamp()), - ); - }); - - let client = Client::builder() - .no_env() - .with_url(server.base_url()) - .with_token("xapt-nope") - .build()?; - - match client.ingest("test", vec![json!({"foo": "bar"})]).await { - Err(Error::IngestLimitExceeded(limits)) => { - assert_eq!(limits.limit, 42); - assert_eq!(limits.remaining, 0); - assert_eq!(limits.reset.timestamp(), tomorrow.timestamp()); - } - res => panic!("Expected ingest limit error, got {:?}", res), - }; - - rate_mock.assert_hits_async(1).await; - Ok(()) - } - - #[tokio::test] - async fn test_query_limit_exceeded() -> Result<(), Box> { - let expires_after = Duration::seconds(1); - let tomorrow = Utc::now() + expires_after; - - let server = MockServer::start(); - let rate_mock = server.mock(|when, then| { - when.method(POST).path("/v1/datasets/_apl"); - then.status(430) - .json_body(json!({ "message": "query limit exceeded" })) - .header(limits::HEADER_QUERY_LIMIT, "42") - .header(limits::HEADER_QUERY_REMAINING, "0") - .header( - limits::HEADER_QUERY_RESET, - format!("{}", tomorrow.timestamp()), - ); - }); - - let client = Client::builder() - .no_env() - .with_url(server.base_url()) - .with_token("xapt-nope") - .build()?; - - match client.query("test | count", None).await { - Err(Error::QueryLimitExceeded(limits)) => { - assert_eq!(limits.limit, 42); - assert_eq!(limits.remaining, 0); - assert_eq!(limits.reset.timestamp(), tomorrow.timestamp()); - } - res => panic!("Expected ingest limit error, got {:?}", res), - }; - - rate_mock.assert_hits_async(1).await; - Ok(()) - } - - #[tokio::test] - async fn test_rate_limit_exceeded() -> Result<(), Box> { - let expires_after = Duration::seconds(1); - let tomorrow = Utc::now() + expires_after; - - let server = MockServer::start(); - let rate_mock = server.mock(|when, then| { - when.method(GET).path("/v1/datasets"); - then.status(429) - .json_body(json!({ "message": "rate limit exceeded" })) - .header(limits::HEADER_RATE_SCOPE, "user") - .header(limits::HEADER_RATE_LIMIT, "42") - .header(limits::HEADER_RATE_REMAINING, "0") - .header( - limits::HEADER_RATE_RESET, - format!("{}", tomorrow.timestamp()), - ); - }); - - let client = Client::builder() - .no_env() - .with_url(server.base_url()) - .with_token("xapt-nope") - .build()?; - - match client.datasets.list().await { - Err(Error::RateLimitExceeded { scope, limits }) => { - assert_eq!(scope, "user"); - assert_eq!(limits.limit, 42); - assert_eq!(limits.remaining, 0); - assert_eq!(limits.reset.timestamp(), tomorrow.timestamp()); - } - res => panic!("Expected ingest limit error, got {:?}", res), - }; - - rate_mock.assert_hits_async(1).await; - Ok(()) - } -} diff --git a/src/http_async.rs b/src/http_async.rs new file mode 100644 index 0000000..8158401 --- /dev/null +++ b/src/http_async.rs @@ -0,0 +1,305 @@ +use backoff::future::retry; +use http::header; +use http::HeaderMap; +use serde::de::DeserializeOwned; +use std::time::Duration; +use url::Url; + +use crate::{ + error::{AxiomError, Error, Result}, + http::{build_backoff, Body, USER_AGENT}, + limits::Limit, +}; + +/// Client is a wrapper around reqwest::Client which provides automatically +/// prepending the base url. +#[derive(Debug, Clone)] +pub(crate) struct Client { + base_url: Url, + inner: reqwest::Client, +} + +impl Client { + /// Creates a new client. + pub(crate) fn new(base_url: U, token: T, org_id: O) -> Result + where + U: AsRef, + T: Into, + O: Into>, + { + let base_url = Url::parse(base_url.as_ref()).map_err(Error::InvalidUrl)?; + let token = token.into(); + + let mut default_headers = header::HeaderMap::new(); + let token_header_value = header::HeaderValue::from_str(&format!("Bearer {}", token)) + .map_err(|_e| Error::InvalidToken)?; + default_headers.insert(header::AUTHORIZATION, token_header_value); + if let Some(org_id) = org_id.into() { + let org_id_header_value = + header::HeaderValue::from_str(&org_id).map_err(|_e| Error::InvalidOrgId)?; + default_headers.insert("X-Axiom-Org-Id", org_id_header_value); + } + + let http_client = reqwest::Client::builder() + .user_agent(USER_AGENT) + .default_headers(default_headers) + .timeout(Duration::from_secs(10)) + .build() + .map_err(Error::HttpClientSetup)?; + + Ok(Self { + base_url, + inner: http_client, + }) + } + + pub(crate) async fn execute( + &self, + method: http::Method, + path: P, + body: Body, + headers: H, + ) -> Result + where + P: AsRef, + H: Into>, + { + let url = self + .base_url + .join(path.as_ref().trim_start_matches('/')) + .map_err(Error::InvalidUrl)?; + + let headers = headers.into(); + + let res = retry(build_backoff(), || async { + let mut req = self.inner.request(method.clone(), url.clone()); + if let Some(headers) = headers.clone() { + req = req.headers(headers); + } + match body.clone() { + Body::Empty => {} + Body::Json(value) => req = req.json(&value), + Body::Bytes(bytes) => req = req.body(bytes), + } + self.inner.execute(req.build()?).await.map_err(|e| { + if let Some(status) = e.status() { + if status.is_client_error() { + // Don't retry 4XX + return backoff::Error::permanent(e); + } + } + + backoff::Error::transient(e) + }) + }) + .await + .map(|res| Response::new(res, method, path.as_ref().to_string())) + .map_err(Error::Http)?; + + Ok(res) + } +} + +pub(crate) struct Response { + inner: reqwest::Response, + method: http::Method, + path: String, + limits: Option, +} + +impl Response { + pub(crate) fn new(inner: reqwest::Response, method: http::Method, path: String) -> Self { + let limits = Limit::try_from(&inner); + Self { + inner, + method, + path, + limits, + } + } + + pub(crate) async fn json(self) -> Result { + self.check_error() + .await? + .inner + .json::() + .await + .map_err(Error::Deserialize) + } + + pub(crate) async fn check_error(self) -> Result { + let status = self.inner.status(); + if !status.is_success() { + // Check if we hit some limits + match self.limits { + Some(Limit::Rate(scope, limits)) => { + return Err(Error::RateLimitExceeded { scope, limits }); + } + Some(Limit::Query(limit)) => { + return Err(Error::QueryLimitExceeded(limit)); + } + Some(Limit::Ingest(limit)) => { + return Err(Error::IngestLimitExceeded(limit)); + } + None => {} + } + + // Try to decode the error + let e = match self.inner.json::().await { + Ok(mut e) => { + e.status = status.as_u16(); + e.method = self.method; + e.path = self.path; + Error::Axiom(e) + } + Err(_e) => { + // Decoding failed, we still want an AxiomError + Error::Axiom(AxiomError::new( + status.as_u16(), + self.method, + self.path, + None, + )) + } + }; + return Err(e); + } + + Ok(self) + } + + pub(crate) fn get_header(&self, name: impl AsRef) -> Option<&str> { + self.inner + .headers() + .get(name.as_ref()) + .and_then(|name| name.to_str().ok()) + } +} + +impl From for reqwest::Response { + fn from(res: Response) -> Self { + res.inner + } +} + +#[cfg(test)] +mod test { + use chrono::{Duration, Utc}; + use httpmock::prelude::*; + use serde_json::json; + + use crate::{limits, Client, Error}; + + #[tokio::test] + async fn test_ingest_limit_exceeded() -> Result<(), Box> { + let expires_after = Duration::seconds(1); + let tomorrow = Utc::now() + expires_after; + + let server = MockServer::start(); + let rate_mock = server.mock(|when, then| { + when.method(POST).path("/v1/datasets/test/ingest"); + then.status(430) + .json_body(json!({ "message": "ingest limit exceeded" })) + .header(limits::HEADER_INGEST_LIMIT, "42") + .header(limits::HEADER_INGEST_REMAINING, "0") + .header( + limits::HEADER_INGEST_RESET, + format!("{}", tomorrow.timestamp()), + ); + }); + + let client = Client::builder() + .no_env() + .with_url(server.base_url()) + .with_token("xapt-nope") + .build()?; + + match client.ingest("test", vec![json!({"foo": "bar"})]).await { + Err(Error::IngestLimitExceeded(limits)) => { + assert_eq!(limits.limit, 42); + assert_eq!(limits.remaining, 0); + assert_eq!(limits.reset.timestamp(), tomorrow.timestamp()); + } + res => panic!("Expected ingest limit error, got {:?}", res), + }; + + rate_mock.assert_hits_async(1).await; + Ok(()) + } + + #[tokio::test] + async fn test_query_limit_exceeded() -> Result<(), Box> { + let expires_after = Duration::seconds(1); + let tomorrow = Utc::now() + expires_after; + + let server = MockServer::start(); + let rate_mock = server.mock(|when, then| { + when.method(POST).path("/v1/datasets/_apl"); + then.status(430) + .json_body(json!({ "message": "query limit exceeded" })) + .header(limits::HEADER_QUERY_LIMIT, "42") + .header(limits::HEADER_QUERY_REMAINING, "0") + .header( + limits::HEADER_QUERY_RESET, + format!("{}", tomorrow.timestamp()), + ); + }); + + let client = Client::builder() + .no_env() + .with_url(server.base_url()) + .with_token("xapt-nope") + .build()?; + + match client.query("test | count", None).await { + Err(Error::QueryLimitExceeded(limits)) => { + assert_eq!(limits.limit, 42); + assert_eq!(limits.remaining, 0); + assert_eq!(limits.reset.timestamp(), tomorrow.timestamp()); + } + res => panic!("Expected ingest limit error, got {:?}", res), + }; + + rate_mock.assert_hits_async(1).await; + Ok(()) + } + + #[tokio::test] + async fn test_rate_limit_exceeded() -> Result<(), Box> { + let expires_after = Duration::seconds(1); + let tomorrow = Utc::now() + expires_after; + + let server = MockServer::start(); + let rate_mock = server.mock(|when, then| { + when.method(GET).path("/v1/datasets"); + then.status(429) + .json_body(json!({ "message": "rate limit exceeded" })) + .header(limits::HEADER_RATE_SCOPE, "user") + .header(limits::HEADER_RATE_LIMIT, "42") + .header(limits::HEADER_RATE_REMAINING, "0") + .header( + limits::HEADER_RATE_RESET, + format!("{}", tomorrow.timestamp()), + ); + }); + + let client = Client::builder() + .no_env() + .with_url(server.base_url()) + .with_token("xapt-nope") + .build()?; + + match client.datasets.list().await { + Err(Error::RateLimitExceeded { scope, limits }) => { + assert_eq!(scope, "user"); + assert_eq!(limits.limit, 42); + assert_eq!(limits.remaining, 0); + assert_eq!(limits.reset.timestamp(), tomorrow.timestamp()); + } + res => panic!("Expected ingest limit error, got {:?}", res), + }; + + rate_mock.assert_hits_async(1).await; + Ok(()) + } +} diff --git a/src/http_blocking.rs b/src/http_blocking.rs new file mode 100644 index 0000000..708e8ab --- /dev/null +++ b/src/http_blocking.rs @@ -0,0 +1,181 @@ +use backoff::retry; +use http::HeaderMap; +use serde::de::DeserializeOwned; +use std::time::Duration; +use ureq::{Agent, Middleware, MiddlewareNext, Request}; +use url::Url; + +use crate::{ + error::{AxiomError, Error}, + http::{build_backoff, Body, USER_AGENT}, + limits::Limit, +}; + +#[derive(Debug, Clone)] +pub(crate) struct Client { + agent: Agent, + base_url: Url, +} + +impl Client { + pub(crate) fn new( + base_url: impl AsRef, + token: impl Into, + org_id: impl Into>, + ) -> Result { + let base_url = Url::parse(base_url.as_ref()).map_err(Error::InvalidUrl)?; + Ok(Self { + agent: ureq::AgentBuilder::new() + .user_agent(USER_AGENT) + .middleware(TokenMiddleware::new(token, org_id)) + .timeout(Duration::from_secs(10)) + .build(), + base_url, + }) + } + + pub(crate) fn execute( + &self, + method: http::Method, + path: P, + body: Body, + headers: H, + ) -> Result + where + P: AsRef, + H: Into>, + { + let url = self + .base_url + .join(path.as_ref()) + .map_err(Error::InvalidUrl)?; + + let mut req = ureq::request_url(method.as_str(), &url); + if let Some(headers) = headers.into() { + for (key, value) in headers { + if let Some(name) = key { + if let Some(value) = value.to_str().ok() { + req = req.set(name.as_str(), value); + } + } + } + } + + let res = retry(build_backoff(), || { + match &body { + Body::Empty => req.clone().call(), + Body::Json(json) => req.clone().send_json(json), + Body::Bytes(bytes) => req.clone().send_bytes(&bytes), + } + .map_err(|e| match e { + ureq::Error::Status(status, _) => { + if status >= 400 && status < 500 { + // Don't retry 4XX + backoff::Error::permanent(e) + } else { + backoff::Error::transient(e) + } + } + ureq::Error::Transport(_) => backoff::Error::transient(e), + }) + })?; + + Ok(Response::new(res, method, path.as_ref().to_string())) + } +} + +struct TokenMiddleware { + token: String, + org_id: Option, +} + +impl TokenMiddleware { + fn new(token: impl Into, org_id: impl Into>) -> Self { + Self { + token: token.into(), + org_id: org_id.into(), + } + } +} + +impl Middleware for TokenMiddleware { + fn handle( + &self, + request: Request, + next: MiddlewareNext, + ) -> Result { + let req = request.set("Authorization", &format!("Bearer {}", self.token)); + let req = if let Some(org_id) = &self.org_id { + req.set("X-Axiom-Org-Id", org_id) + } else { + req + }; + next.handle(req) + } +} + +pub(crate) struct Response { + inner: ureq::Response, + method: http::Method, + path: String, + limits: Option, +} + +impl Response { + pub(crate) fn new(inner: ureq::Response, method: http::Method, path: String) -> Self { + let limits = Limit::try_from(&inner); + Self { + inner, + method, + path, + limits, + } + } + + pub(crate) fn json(self) -> Result { + self.check_error()? + .inner + .into_json::() + .map_err(Error::Deserialize) + } + + pub(crate) fn check_error(self) -> Result { + let status = self.inner.status(); + if status < 200 || status > 299 { + // Check if we hit some limits + match self.limits { + Some(Limit::Rate(scope, limits)) => { + return Err(Error::RateLimitExceeded { scope, limits }); + } + Some(Limit::Query(limit)) => { + return Err(Error::QueryLimitExceeded(limit)); + } + Some(Limit::Ingest(limit)) => { + return Err(Error::IngestLimitExceeded(limit)); + } + None => {} + } + + // Try to decode the error + let e = match self.inner.into_json::() { + Ok(mut e) => { + e.status = status; + e.method = self.method; + e.path = self.path; + Error::Axiom(e) + } + Err(_e) => { + // Decoding failed, we still want an AxiomError + Error::Axiom(AxiomError::new(status, self.method, self.path, None)) + } + }; + return Err(e); + } + + Ok(self) + } + + pub(crate) fn get_header(&self, name: impl AsRef) -> Option<&str> { + self.inner.header(name.as_ref()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 83b724c..7b2afc1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,6 +33,10 @@ pub mod client; pub mod error; mod http; +#[cfg(not(feature = "blocking"))] +mod http_async; +#[cfg(feature = "blocking")] +mod http_blocking; pub mod limits; mod serde; @@ -58,6 +62,11 @@ compile_error!("Feature \"native-tls\" and \"rustls-tls\" cannot be enabled at t #[cfg(all(feature = "rustls-tls", feature = "default-tls"))] compile_error!("Feature \"rustls-tls\" and \"default-tls\" cannot be enabled at the same time"); +#[cfg(all(feature = "blocking", any(feature = "tokio", feature = "async-std")))] +compile_error!( + "Feature \"blocking\" cannot be enabled at the same time as \"tokio\" or \"async-std\"" +); + /// Returns true if the given acces token is a personal token. fn is_personal_token>(token: S) -> bool { token.into().starts_with("xapt-") diff --git a/src/limits.rs b/src/limits.rs index 6c46261..2d1c5fb 100644 --- a/src/limits.rs +++ b/src/limits.rs @@ -36,6 +36,7 @@ pub(crate) enum Limit { } impl Limit { + #[cfg(not(feature = "blocking"))] pub(crate) fn try_from(response: &reqwest::Response) -> Option { match response.status().as_u16() { 429 => { @@ -81,6 +82,50 @@ impl Limit { _ => None, } } + + #[cfg(feature = "blocking")] + pub(crate) fn try_from(response: &ureq::Response) -> Option { + match response.status() { + 429 => { + // Rate limit + let scope = response.header(HEADER_RATE_SCOPE); + let limits = Limits::from_headers( + response, + HEADER_RATE_LIMIT, + HEADER_RATE_REMAINING, + HEADER_RATE_RESET, + ) + .ok(); + + scope + .zip(limits) + .map(|(scope, limits)| Limit::Rate(scope.to_string(), limits)) + } + 430 => { + // Query or ingest limit + let query_limit = Limits::from_headers( + response, + HEADER_QUERY_LIMIT, + HEADER_QUERY_REMAINING, + HEADER_QUERY_RESET, + ) + .map(Limit::Query) + .ok(); + let ingest_limit = Limits::from_headers( + response, + HEADER_INGEST_LIMIT, + HEADER_INGEST_REMAINING, + HEADER_INGEST_RESET, + ) + .map(Limit::Ingest) + .ok(); + + // Can't have both + query_limit.or(ingest_limit) + } + _ => None, + } + } } /// Rate-limit information. @@ -110,6 +155,7 @@ impl Limits { self.remaining == 0 && self.reset > Utc::now() } + #[cfg(not(feature = "blocking"))] fn from_headers( headers: &header::HeaderMap, header_limit: &str, @@ -135,4 +181,28 @@ impl Limits { .ok_or(InvalidHeaderError::Reset)?, }) } + + #[cfg(feature = "blocking")] + fn from_headers( + response: &ureq::Response, + header_limit: &str, + header_remaining: &str, + header_reset: &str, + ) -> Result { + Ok(Limits { + limit: response + .header(header_limit) + .and_then(|limit| limit.parse::().ok()) + .ok_or(InvalidHeaderError::Limit)?, + remaining: response + .header(header_remaining) + .and_then(|limit| limit.parse::().ok()) + .ok_or(InvalidHeaderError::Remaining)?, + reset: response + .header(header_reset) + .and_then(|limit| limit.parse::().ok()) + .and_then(|limit| Utc.timestamp_opt(limit, 0).single()) + .ok_or(InvalidHeaderError::Reset)?, + }) + } } diff --git a/src/users/client.rs b/src/users/client.rs index 190784a..589f07a 100644 --- a/src/users/client.rs +++ b/src/users/client.rs @@ -1,4 +1,5 @@ use crate::{error::Result, http, users::model::*}; +use maybe_async::maybe_async; use tracing::instrument; /// Provides methods to work with Axiom datasets. @@ -13,6 +14,7 @@ impl Client { } /// Retrieve the authenticated user. + #[maybe_async] #[instrument(skip(self))] pub async fn current(&self) -> Result { self.http_client.get("/v1/user").await?.json().await From a9b49670cce9c6ae6feb18c60509fd9dc2353561 Mon Sep 17 00:00:00 2001 From: Arne Bahlo Date: Thu, 24 Aug 2023 12:52:32 +0200 Subject: [PATCH 2/4] Document required features for async ingest methods --- Cargo.toml | 4 ++++ src/client.rs | 2 ++ src/lib.rs | 1 + 3 files changed, 7 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 27e55cd..acb0970 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,3 +59,7 @@ default-tls = ["reqwest/default-tls"] native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] blocking = ["ureq", "maybe-async/is_sync"] + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/src/client.rs b/src/client.rs index 8881292..e85fcfd 100644 --- a/src/client.rs +++ b/src/client.rs @@ -254,6 +254,7 @@ impl Client { /// Restrictions for field names (JSON object keys) can be reviewed here: /// . #[async_impl] + #[cfg_attr(docsrs, doc(cfg(any(feature = "tokio", feature = "async-std"))))] #[instrument(skip(self, stream))] pub async fn ingest_stream(&self, dataset_name: N, stream: S) -> Result where @@ -273,6 +274,7 @@ impl Client { /// Like [`Client::ingest_stream`], but takes a stream that contains results. #[async_impl] + #[cfg_attr(docsrs, doc(cfg(any(feature = "tokio", feature = "async-std"))))] #[instrument(skip(self, stream))] pub async fn try_ingest_stream( &self, diff --git a/src/lib.rs b/src/lib.rs index 7b2afc1..35fc06e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ //! Ok(()) //! } //! ``` +#![cfg_attr(docsrs, feature(doc_cfg))] pub mod client; pub mod error; mod http; From 89f03ba8a067a5556c2b8730a564c0256c810944 Mon Sep 17 00:00:00 2001 From: Arne Bahlo Date: Thu, 24 Aug 2023 14:52:13 +0200 Subject: [PATCH 3/4] Fix tests to work in blocking contexts --- .github/workflows/ci.yaml | 4 +- Cargo.toml | 7 +- src/client.rs | 19 ++-- src/http_blocking.rs | 2 +- src/lib.rs | 13 ++- src/limits.rs | 1 + tests/cursor.rs | 92 ++++++------------ tests/datasets.rs | 196 ++++++++++++++++++-------------------- 8 files changed, 146 insertions(+), 188 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7a7f4e4..e2e9634 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -61,11 +61,13 @@ jobs: strategy: max-parallel: 1 matrix: - runtime: [Tokio, async-std] + runtime: [Tokio, async-std, blocking] environment: [development, staging] include: - runtime: async-std flags: --no-default-features --features async-std,default-tls + - runtime: blocking + flags: --no-default-features --features blocking,default-tls --tests - environment: development url: TESTING_DEV_API_URL token: TESTING_DEV_TOKEN diff --git a/Cargo.toml b/Cargo.toml index acb0970..5d36256 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,6 @@ reqwest = { version = "0.11", optional = true, default-features = false, feature tokio = { version = "1", features = ["full"] } async-std = { version = "1", features = ["attributes"] } serde_test = "1" -test-context = "0.1" async-trait = "0.1" futures-util = "0.3" httpmock = "0.6" @@ -55,9 +54,9 @@ tracing-subscriber = { version = "0.3", features = ["ansi", "env-filter"] } default = ["tokio", "default-tls"] tokio = ["backoff/tokio", "dep:tokio", "futures", "async-trait", "tokio-stream", "reqwest"] async-std = ["backoff/async-std", "dep:async-std", "futures", "async-trait", "tokio-stream", "reqwest"] -default-tls = ["reqwest/default-tls"] -native-tls = ["reqwest/native-tls"] -rustls-tls = ["reqwest/rustls-tls"] +default-tls = ["reqwest/default-tls", "ureq/tls"] +native-tls = ["reqwest/native-tls", "ureq/native-tls"] +rustls-tls = ["reqwest/rustls-tls", "ureq/rustls"] blocking = ["ureq", "maybe-async/is_sync"] [package.metadata.docs.rs] diff --git a/src/client.rs b/src/client.rs index e85fcfd..e72a6cd 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,13 +5,12 @@ use bytes::Bytes; use flate2::{write::GzEncoder, Compression}; #[cfg(not(feature = "blocking"))] use futures::Stream; +use http::header; use maybe_async::{async_impl, maybe_async}; -use reqwest::header; use serde::Serialize; -use std::{ - env, fmt::Debug as FmtDebug, io::Write, result::Result as StdResult, - time::Duration as StdDuration, -}; +#[cfg(not(feature = "blocking"))] +use std::time::Duration as StdDuration; +use std::{env, fmt::Debug as FmtDebug, io::Write}; #[cfg(feature = "tokio")] use tokio::task::spawn_blocking; #[cfg(not(feature = "blocking"))] @@ -24,7 +23,7 @@ use crate::{ LegacyQueryResult, Query, QueryOptions, QueryParams, QueryResult, }, error::{Error, Result}, - http::{self, HeaderMap}, + http::{Client as HttpClient, HeaderMap}, is_personal_token, users, }; @@ -56,7 +55,7 @@ static API_URL: &str = "https://api.axiom.co"; /// ``` #[derive(Debug, Clone)] pub struct Client { - http_client: http::Client, + http_client: HttpClient, url: String, pub datasets: datasets::Client, @@ -283,7 +282,7 @@ impl Client { ) -> Result where N: Into + FmtDebug, - S: Stream> + Send + Sync + 'static, + S: Stream> + Send + Sync + 'static, I: Serialize, E: std::error::Error + Send + Sync + 'static, { @@ -291,7 +290,7 @@ impl Client { let mut chunks = Box::pin(stream.chunks_timeout(1000, StdDuration::from_secs(1))); let mut ingest_status = IngestStatus::default(); while let Some(events) = chunks.next().await { - let events: StdResult, E> = events.into_iter().collect(); + let events: std::result::Result, E> = events.into_iter().collect(); match events { Ok(events) => { let new_ingest_status = self.ingest(dataset_name.clone(), events).await?; @@ -381,7 +380,7 @@ impl Builder { return Err(Error::MissingOrgId); } - let http_client = http::Client::new(url.clone(), token, org_id)?; + let http_client = HttpClient::new(url.clone(), token, org_id)?; Ok(Client { http_client: http_client.clone(), diff --git a/src/http_blocking.rs b/src/http_blocking.rs index 708e8ab..bf6ee00 100644 --- a/src/http_blocking.rs +++ b/src/http_blocking.rs @@ -50,7 +50,7 @@ impl Client { .join(path.as_ref()) .map_err(Error::InvalidUrl)?; - let mut req = ureq::request_url(method.as_str(), &url); + let mut req = self.agent.request_url(method.as_str(), &url); if let Some(headers) = headers.into() { for (key, value) in headers { if let Some(name) = key { diff --git a/src/lib.rs b/src/lib.rs index 35fc06e..6f9b8ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,6 +54,14 @@ pub struct ReadmeDoctests; #[cfg(all(feature = "tokio", feature = "async-std"))] compile_error!("Feature \"tokio\" and \"async-std\" cannot be enabled at the same time"); +#[cfg(all(feature = "blocking", any(feature = "tokio", feature = "async-std")))] +compile_error!( + "Feature \"blocking\" cannot be enabled at the same time as \"tokio\" or \"async-std\"" +); + +#[cfg(not(any(feature = "blocking", feature = "tokio", feature = "async-std")))] +compile_error!("Needs at least one of \"blocking\", \"tokio\" or \"async-std\" features"); + #[cfg(all(feature = "default-tls", feature = "native-tls"))] compile_error!("Feature \"default-tls\" and \"native-tls\" cannot be enabled at the same time"); @@ -63,11 +71,6 @@ compile_error!("Feature \"native-tls\" and \"rustls-tls\" cannot be enabled at t #[cfg(all(feature = "rustls-tls", feature = "default-tls"))] compile_error!("Feature \"rustls-tls\" and \"default-tls\" cannot be enabled at the same time"); -#[cfg(all(feature = "blocking", any(feature = "tokio", feature = "async-std")))] -compile_error!( - "Feature \"blocking\" cannot be enabled at the same time as \"tokio\" or \"async-std\"" -); - /// Returns true if the given acces token is a personal token. fn is_personal_token>(token: S) -> bool { token.into().starts_with("xapt-") diff --git a/src/limits.rs b/src/limits.rs index 2d1c5fb..51b0f96 100644 --- a/src/limits.rs +++ b/src/limits.rs @@ -1,6 +1,7 @@ //! Rate-limit type definitions. use chrono::{DateTime, TimeZone, Utc}; +#[cfg(not(feature = "blocking"))] use http::header; use std::fmt::Display; use thiserror::Error; diff --git a/tests/cursor.rs b/tests/cursor.rs index 3d494de..e4c3ada 100644 --- a/tests/cursor.rs +++ b/tests/cursor.rs @@ -1,79 +1,47 @@ -use async_trait::async_trait; use axiom_rs::{datasets::*, Client}; use chrono::{Duration, Utc}; +use maybe_async::maybe_async; use serde_json::json; use std::env; -use test_context::{test_context, AsyncTestContext}; -struct Context { - client: Client, - dataset: Dataset, -} - -#[async_trait] -impl AsyncTestContext for Context { - async fn setup() -> Context { - let client = Client::new().unwrap(); - - let dataset_name = format!( - "test-axiom-rs-{}", - env::var("AXIOM_DATASET_SUFFIX").expect("AXIOM_DATASET_SUFFIX is not set"), - ); - - // Delete dataset in case we have a zombie - client.datasets.delete(&dataset_name).await.ok(); - - let dataset = client.datasets.create(&dataset_name, "bar").await.unwrap(); - assert_eq!(dataset_name.clone(), dataset.name); - assert_eq!("bar".to_string(), dataset.description); - - Context { client, dataset } - } +#[maybe_async] +#[cfg_attr(feature = "blocking", test)] +#[cfg_attr(feature = "tokio", tokio::test)] +#[cfg_attr(feature = "async-std", async_std::test)] +async fn test_cursor() { + let client = Client::new().unwrap(); - async fn teardown(self) { - self.client.datasets.delete(self.dataset.name).await.ok(); - } -} + let dataset_name = format!( + "test-axiom-rs-{}", + env::var("AXIOM_DATASET_SUFFIX").expect("AXIOM_DATASET_SUFFIX is not set"), + ); -#[cfg(feature = "tokio")] -#[test_context(Context)] -#[tokio::test] -async fn test_cursor(ctx: &mut Context) { - test_cursor_impl(ctx).await; -} + // Delete dataset in case we have a zombie + client.datasets.delete(&dataset_name).await.ok(); -#[cfg(feature = "async-std")] -#[test_context(Context)] -#[async_std::test] -async fn test_cursor(ctx: &mut Context) { - test_cursor_impl(ctx).await; -} + let dataset = client.datasets.create(&dataset_name, "bar").await.unwrap(); + assert_eq!(dataset_name.clone(), dataset.name); + assert_eq!("bar".to_string(), dataset.description); -async fn test_cursor_impl(ctx: &mut Context) { // Let's update the dataset. - let dataset = ctx - .client + let dataset = client .datasets - .update( - &ctx.dataset.name, - "This is a soon to be filled test dataset", - ) + .update(&dataset.name, "This is a soon to be filled test dataset") .await .unwrap(); - ctx.dataset = dataset; // Get the dataset and make sure it matches what we have updated it to. - let dataset = ctx.client.datasets.get(&ctx.dataset.name).await.unwrap(); - assert_eq!(ctx.dataset.name, dataset.name); - assert_eq!(ctx.dataset.name, dataset.name); - assert_eq!(ctx.dataset.description, dataset.description); + let dataset = client.datasets.get(&dataset.name).await.unwrap(); + assert_eq!(dataset.name, dataset.name); + assert_eq!(dataset.name, dataset.name); + assert_eq!(dataset.description, dataset.description); // List all datasets and make sure the created dataset is part of that // list. - let datasets = ctx.client.datasets.list().await.unwrap(); + let datasets = client.datasets.list().await.unwrap(); datasets .iter() - .find(|dataset| dataset.name == ctx.dataset.name) + .find(|dataset| dataset.name == dataset.name) .expect("Expected dataset to be in the list"); let mut events = Vec::new(); @@ -93,7 +61,7 @@ async fn test_cursor_impl(ctx: &mut Context) { })); } - let ingest_status = ctx.client.ingest(&ctx.dataset.name, &events).await.unwrap(); + let ingest_status = client.ingest(&dataset.name, &events).await.unwrap(); assert_eq!(ingest_status.ingested, 1000); assert_eq!(ingest_status.failed, 0); assert_eq!(ingest_status.failures.len(), 0); @@ -101,10 +69,9 @@ async fn test_cursor_impl(ctx: &mut Context) { let start_time = Utc::now() - Duration::minutes(1); let end_time = Utc::now() + Duration::minutes(1); - let apl_query_result = ctx - .client + let apl_query_result = client .query( - format!("['{}'] | sort by _time desc", ctx.dataset.name), + format!("['{}'] | sort by _time desc", dataset.name), QueryOptions { start_time: Some(start_time), end_time: Some(end_time), @@ -119,10 +86,9 @@ async fn test_cursor_impl(ctx: &mut Context) { let mid_row_id = &apl_query_result.matches[500].row_id; - let apl_query_result = ctx - .client + let apl_query_result = client .query( - format!("['{}'] | sort by _time desc", ctx.dataset.name), + format!("['{}'] | sort by _time desc", dataset.name), QueryOptions { start_time: Some(start_time), end_time: Some(end_time), diff --git a/tests/datasets.rs b/tests/datasets.rs index ec20577..0087fdb 100644 --- a/tests/datasets.rs +++ b/tests/datasets.rs @@ -1,80 +1,53 @@ -use async_trait::async_trait; use axiom_rs::{datasets::*, Client}; use chrono::{Duration, Utc}; +#[cfg(not(feature = "blocking"))] use futures::StreamExt; +use maybe_async::maybe_async; use serde_json::json; +#[cfg(feature = "blocking")] +use std::thread::sleep; use std::{env, time::Duration as StdDuration}; -use test_context::{test_context, AsyncTestContext}; +#[cfg(not(feature = "blocking"))] +use tokio::time::sleep; -struct Context { - client: Client, - dataset: Dataset, -} - -#[async_trait] -impl AsyncTestContext for Context { - async fn setup() -> Context { - let client = Client::new().unwrap(); - - let dataset_name = format!( - "test-axiom-rs-{}", - env::var("AXIOM_DATASET_SUFFIX").expect("AXIOM_DATASET_SUFFIX is not set"), - ); - - // Delete dataset in case we have a zombie - client.datasets.delete(&dataset_name).await.ok(); - - let dataset = client.datasets.create(&dataset_name, "bar").await.unwrap(); - assert_eq!(dataset_name.clone(), dataset.name); - assert_eq!("bar".to_string(), dataset.description); - - Context { client, dataset } - } +#[maybe_async] +#[cfg_attr(feature = "blocking", test)] +#[cfg_attr(feature = "tokio", tokio::test)] +#[cfg_attr(feature = "async-std", async_std::test)] +async fn test_datasets_impl() { + let client = Client::new().unwrap(); - async fn teardown(self) { - self.client.datasets.delete(self.dataset.name).await.ok(); - } -} + let dataset_name = format!( + "test-axiom-rs-{}", + env::var("AXIOM_DATASET_SUFFIX").expect("AXIOM_DATASET_SUFFIX is not set"), + ); -#[cfg(feature = "tokio")] -#[test_context(Context)] -#[tokio::test] -async fn test_datasets(ctx: &mut Context) { - test_datasets_impl(ctx).await; -} + // Delete dataset in case we have a zombie + client.datasets.delete(&dataset_name).await.ok(); -#[cfg(feature = "async-std")] -#[test_context(Context)] -#[async_std::test] -async fn test_datasets(ctx: &mut Context) { - test_datasets_impl(ctx).await; -} + let dataset = client.datasets.create(&dataset_name, "bar").await.unwrap(); + assert_eq!(dataset_name.clone(), dataset.name); + assert_eq!("bar".to_string(), dataset.description); -async fn test_datasets_impl(ctx: &mut Context) { // Let's update the dataset. - let dataset = ctx - .client + let dataset = client .datasets - .update( - &ctx.dataset.name, - "This is a soon to be filled test dataset", - ) + .update(&dataset.name, "This is a soon to be filled test dataset") .await .unwrap(); - ctx.dataset = dataset; // Get the dataset and make sure it matches what we have updated it to. - let dataset = ctx.client.datasets.get(&ctx.dataset.name).await.unwrap(); - assert_eq!(ctx.dataset.name, dataset.name); - assert_eq!(ctx.dataset.name, dataset.name); - assert_eq!(ctx.dataset.description, dataset.description); + let dataset = client.datasets.get(&dataset.name).await.unwrap(); + assert_eq!(dataset.name, dataset.name); + assert_eq!(dataset.name, dataset.name); + assert_eq!(dataset.description, dataset.description); // List all datasets and make sure the created dataset is part of that // list. - let datasets = ctx.client.datasets.list().await.unwrap(); + let datasets = client.datasets.list().await.unwrap(); datasets .iter() - .find(|dataset| dataset.name == ctx.dataset.name) + .find(|dataset| dataset.name == dataset.name) .expect("Expected dataset to be in the list"); // Let's ingest some data @@ -100,10 +73,9 @@ async fn test_datasets_impl(ctx: &mut Context) { "agent": "Debian APT-HTTP/1.3 (0.8.16~exp12ubuntu10.21)" } ]"#; - let ingest_status = ctx - .client + let ingest_status = client .ingest_bytes( - &ctx.dataset.name, + &dataset.name, PAYLOAD, ContentType::Json, ContentEncoding::Identity, @@ -138,48 +110,48 @@ async fn test_datasets_impl(ctx: &mut Context) { "agent": "Debian APT-HTTP/1.3 (0.8.16~exp12ubuntu10.21)" }), ]; - let ingest_status = ctx.client.ingest(&ctx.dataset.name, &events).await.unwrap(); + let ingest_status = client.ingest(&dataset.name, &events).await.unwrap(); assert_eq!(ingest_status.ingested, 2); assert_eq!(ingest_status.failed, 0); assert_eq!(ingest_status.failures.len(), 0); - // ... a small stream - let stream = futures_util::stream::iter(events.clone()); - let ingest_status = ctx - .client - .ingest_stream(&ctx.dataset.name, stream) - .await - .unwrap(); - assert_eq!(ingest_status.ingested, 2); - assert_eq!(ingest_status.failed, 0); - assert_eq!(ingest_status.failures.len(), 0); + #[cfg(not(feature = "blocking"))] + { + // ... a small stream + let stream = futures_util::stream::iter(events.clone()); + let ingest_status = client.ingest_stream(&dataset.name, stream).await.unwrap(); + assert_eq!(ingest_status.ingested, 2); + assert_eq!(ingest_status.failed, 0); + assert_eq!(ingest_status.failures.len(), 0); - // ... and a big stream (4321 items) - let stream = futures_util::stream::iter(events).cycle().take(4321); - let ingest_status = ctx - .client - .ingest_stream(&ctx.dataset.name, stream) - .await - .unwrap(); - assert_eq!(ingest_status.ingested, 4321); - assert_eq!(ingest_status.failed, 0); - assert_eq!(ingest_status.failures.len(), 0); + // ... and a big stream (4321 items) + let stream = futures_util::stream::iter(events).cycle().take(4321); + let ingest_status = client.ingest_stream(&dataset.name, stream).await.unwrap(); + assert_eq!(ingest_status.ingested, 4321); + assert_eq!(ingest_status.failed, 0); + assert_eq!(ingest_status.failures.len(), 0); + } // Give the db some time to write the data. - tokio::time::sleep(StdDuration::from_secs(15)).await; + sleep(StdDuration::from_secs(15)).await; + + let expected_event_count = if cfg!(feature = "blocking") { + 4 + } else { + 4327 // From async stream tests + }; // Get the dataset info and make sure four events have been ingested. - let info = ctx.client.datasets.info(&ctx.dataset.name).await.unwrap(); - assert_eq!(ctx.dataset.name, info.stat.name); - assert_eq!(4327, info.stat.num_events); + let info = client.datasets.info(&dataset.name).await.unwrap(); + assert_eq!(dataset.name, info.stat.name); + assert_eq!(expected_event_count, info.stat.num_events); assert!(info.fields.len() > 0); // Run a query and make sure we see some results. #[allow(deprecated)] - let simple_query_result = ctx - .client + let simple_query_result = client .query_legacy( - &ctx.dataset.name, + &dataset.name, LegacyQuery { start_time: Some(Utc::now() - Duration::minutes(1)), end_time: Some(Utc::now()), @@ -194,15 +166,24 @@ async fn test_datasets_impl(ctx: &mut Context) { .unwrap(); assert!(simple_query_result.saved_query_id.is_some()); // assert_eq!(1, simple_query_result.status.blocks_examined); - assert_eq!(4327, simple_query_result.status.rows_examined); - assert_eq!(4327, simple_query_result.status.rows_matched); - assert_eq!(1000, simple_query_result.matches.len()); + assert_eq!( + expected_event_count, + simple_query_result.status.rows_examined + ); + assert_eq!( + expected_event_count, + simple_query_result.status.rows_matched + ); + if cfg!(feature = "blocking") { + assert_eq!(4, simple_query_result.matches.len()); + } else { + assert_eq!(1000, simple_query_result.matches.len()); + } // Run another query but using APL. - let apl_query_result = ctx - .client + let apl_query_result = client .query( - format!("['{}']", ctx.dataset.name), + format!("['{}']", dataset.name), QueryOptions { save: true, ..Default::default() @@ -212,9 +193,13 @@ async fn test_datasets_impl(ctx: &mut Context) { .unwrap(); assert!(apl_query_result.saved_query_id.is_some()); // assert_eq!(1, apl_query_result.status.blocks_examined); - assert_eq!(4327, apl_query_result.status.rows_examined); - assert_eq!(4327, apl_query_result.status.rows_matched); - assert_eq!(1000, apl_query_result.matches.len()); + assert_eq!(expected_event_count, apl_query_result.status.rows_examined); + assert_eq!(expected_event_count, apl_query_result.status.rows_matched); + if cfg!(feature = "blocking") { + assert_eq!(4, apl_query_result.matches.len()); + } else { + assert_eq!(1000, apl_query_result.matches.len()); + } // Run a more complex query. let query = LegacyQuery { @@ -254,10 +239,9 @@ async fn test_datasets_impl(ctx: &mut Context) { ..Default::default() }; #[allow(deprecated)] - let query_result = ctx - .client + let query_result = client .query_legacy( - &ctx.dataset.name, + &dataset.name, query, LegacyQueryOptions { save_as_kind: QueryKind::Analytics, @@ -266,8 +250,8 @@ async fn test_datasets_impl(ctx: &mut Context) { ) .await .unwrap(); - assert_eq!(4327, query_result.status.rows_examined); - assert_eq!(4327, query_result.status.rows_matched); + assert_eq!(expected_event_count, query_result.status.rows_examined); + assert_eq!(expected_event_count, query_result.status.rows_matched); assert!(query_result.buckets.totals.len() == 2); let agg = query_result .buckets @@ -278,12 +262,16 @@ async fn test_datasets_impl(ctx: &mut Context) { .get(0) .unwrap(); assert_eq!("event_count", agg.alias); - assert_eq!(2164, agg.value); + if cfg!(feature = "blocking") { + assert_eq!(2, agg.value); + } else { + assert_eq!(2164, agg.value); + } // Trim the dataset down to a minimum. - ctx.client + client .datasets - .trim(&ctx.dataset.name, Duration::seconds(1)) + .trim(&dataset.name, Duration::seconds(1)) .await .unwrap(); } From 9d1d71824c9306c62d2403c8a07425e87d884d18 Mon Sep 17 00:00:00 2001 From: Arne Bahlo Date: Thu, 24 Aug 2023 15:01:16 +0200 Subject: [PATCH 4/4] Udpate README.md to add blocking feature --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b4939c0..77dfd6a 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,8 @@ that can be enabled or disabled: - **native-tls**: Enables TLS functionality provided by `native-tls`. - **rustls-tls**: Enables TLS functionality provided by `rustls`. - **tokio** _(enabled by default)_: Enables the usage with the `tokio` runtime. -- **async-std** : Enables the usage with the `async-std` runtime. +- **async-std**: Enables the usage with the `async-std` runtime. +- **blocking**: Provides a sync client. ## License