From dbbd4732dc652578151edd87b3d9d2b5681a89d9 Mon Sep 17 00:00:00 2001 From: Cesar <174354243+cesar-startale@users.noreply.github.com> Date: Mon, 15 Jul 2024 18:23:46 -0400 Subject: [PATCH 1/2] Http Upstream client This pull request introduces the HTTP Upstream client. The main goal is to be backward compatible and introduce an HTTP upstream client. If the config has any HTTP/s server to connect to, the HttpClient struct will handle them. Any WebSocket will use the existing upstream client code. If one or more HTTP clients are configured, they route all requests, and Websocket clients will be used for subscriptions. If no HTTP clients are configured, the fallback behavior is used, and all requests and subscriptions are routed through the WebSocket upstream client. If no WebSocket upstream clients are configured, then subscriptions are not enabled, only method requests. --- src/extensions/client/http.rs | 61 +++++++++++++++++ src/extensions/client/mod.rs | 124 ++++++++++++++++++++++------------ 2 files changed, 140 insertions(+), 45 deletions(-) create mode 100644 src/extensions/client/http.rs diff --git a/src/extensions/client/http.rs b/src/extensions/client/http.rs new file mode 100644 index 0000000..f8497b2 --- /dev/null +++ b/src/extensions/client/http.rs @@ -0,0 +1,61 @@ +use crate::middlewares::CallResult; +use jsonrpsee::{ + core::{ + client::{ClientT, Error}, + JsonValue, + }, + http_client::HttpClient as RpcClient, + types::{error::INTERNAL_ERROR_CODE, ErrorObject}, +}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Very simple struct to have a set of JsonRpsee HTTP clients and send requests to them +pub struct HttpClient { + clients: Vec, + last_sent: AtomicUsize, +} + +impl HttpClient { + pub fn new(endpoints: Vec) -> Result<(Option, Vec), Error> { + let mut other_urls = vec![]; + let clients = endpoints + .into_iter() + .filter_map(|url| { + let t_url = url.to_lowercase(); + if t_url.starts_with("http://") || t_url.starts_with("https://") { + Some(RpcClient::builder().build(url)) + } else { + other_urls.push(url); + None + } + }) + .collect::, _>>()?; + + if clients.is_empty() { + Ok((None, other_urls)) + } else { + Ok(( + Some(Self { + clients, + last_sent: AtomicUsize::new(0), + }), + other_urls, + )) + } + } + + /// Sends a request to one of the clients + /// + /// The client is selected in a round-robin fashion as fair as possible + pub async fn request(&self, method: &str, params: Vec) -> CallResult { + let client_id = self.last_sent.fetch_add(1, Ordering::Relaxed) % self.clients.len(); + + self.clients[client_id] + .request(method, params) + .await + .map_err(|e| match e { + jsonrpsee::core::client::Error::Call(e) => e, + e => ErrorObject::owned(INTERNAL_ERROR_CODE, e.to_string(), None::), + }) + } +} diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 2e03ea7..063214c 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -29,6 +29,7 @@ use crate::{ utils::{self, errors}, }; +mod http; #[cfg(test)] pub mod mock; #[cfg(test)] @@ -38,15 +39,18 @@ const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client") pub struct Client { endpoints: Vec, - sender: tokio::sync::mpsc::Sender, - rotation_notify: Arc, + http_client: Option, + sender: Option>, + rotation_notify: Option>, retries: u32, - background_task: tokio::task::JoinHandle<()>, + background_task: Option>, } impl Drop for Client { fn drop(&mut self) { - self.background_task.abort(); + if let Some(background_task) = self.background_task.take() { + background_task.abort(); + } } } @@ -152,12 +156,26 @@ impl Client { retries: Option, ) -> Result { let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); + let endpoints_ = endpoints.clone(); if endpoints.is_empty() { return Err(anyhow!("No endpoints provided")); } - tracing::debug!("New client with endpoints: {:?}", endpoints); + let (http_client, ws_endpoints) = http::HttpClient::new(endpoints)?; + + if ws_endpoints.is_empty() { + return Ok(Self { + http_client, + endpoints: endpoints_, + sender: None, // No websocket + rotation_notify: None, + retries: retries.unwrap_or(3), + background_task: None, + }); + } + + tracing::debug!("New client with endpoints: {:?}", ws_endpoints); let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::(100); @@ -165,7 +183,6 @@ impl Client { let rotation_notify = Arc::new(Notify::new()); let rotation_notify_bg = rotation_notify.clone(); - let endpoints_ = endpoints.clone(); let background_task = tokio::spawn(async move { let connect_backoff_counter = Arc::new(AtomicU32::new(0)); @@ -177,7 +194,7 @@ impl Client { let build_ws = || async { let build = || { let current_endpoint = current_endpoint.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let url = &endpoints[current_endpoint % endpoints.len()]; + let url = &ws_endpoints[current_endpoint % ws_endpoints.len()]; tracing::info!("Connecting to endpoint: {}", url); @@ -414,11 +431,12 @@ impl Client { } Ok(Self { + http_client, endpoints: endpoints_, - sender: message_tx, - rotation_notify, + sender: Some(message_tx), + rotation_notify: Some(rotation_notify), retries: retries.unwrap_or(3), - background_task, + background_task: Some(background_task), }) } @@ -431,22 +449,30 @@ impl Client { } pub async fn request(&self, method: &str, params: Vec) -> CallResult { - async move { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.sender - .send(Message::Request { - method: method.into(), - params, - response: tx, - retries: self.retries, - }) - .await - .map_err(errors::internal_error)?; + if let Some(http_client) = &self.http_client { + return http_client.request(method, params).await; + } - rx.await.map_err(errors::internal_error)?.map_err(errors::map_error) + if let Some(sender) = self.sender.as_ref() { + async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + sender + .send(Message::Request { + method: method.into(), + params, + response: tx, + retries: self.retries, + }) + .await + .map_err(errors::internal_error)?; + + rx.await.map_err(errors::internal_error)?.map_err(errors::map_error) + } + .with_context(TRACER.context(method.to_string())) + .await + } else { + Err(errors::internal_error("No sender")) } - .with_context(TRACER.context(method.to_string())) - .await } pub async fn subscribe( @@ -455,35 +481,43 @@ impl Client { params: Vec, unsubscribe: &str, ) -> Result, Error> { - async move { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.sender - .send(Message::Subscribe { - subscribe: subscribe.into(), - params, - unsubscribe: unsubscribe.into(), - response: tx, - retries: self.retries, - }) - .await - .map_err(errors::internal_error)?; - - rx.await.map_err(errors::internal_error)? + if let Some(sender) = self.sender.as_ref() { + async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + sender + .send(Message::Subscribe { + subscribe: subscribe.into(), + params, + unsubscribe: unsubscribe.into(), + response: tx, + retries: self.retries, + }) + .await + .map_err(errors::internal_error)?; + + rx.await.map_err(errors::internal_error)? + } + .with_context(TRACER.context(subscribe.to_string())) + .await + } else { + Err(Error::Call(errors::internal_error("No websocket connection"))) } - .with_context(TRACER.context(subscribe.to_string())) - .await } pub async fn rotate_endpoint(&self) { - self.sender - .send(Message::RotateEndpoint) - .await - .expect("Failed to rotate endpoint"); + if let Some(sender) = self.sender.as_ref() { + sender + .send(Message::RotateEndpoint) + .await + .expect("Failed to rotate endpoint"); + } } /// Returns a future that resolves when the endpoint is rotated. pub async fn on_rotation(&self) { - self.rotation_notify.notified().await + if let Some(rotation_notify) = self.rotation_notify.as_ref() { + rotation_notify.notified().await + } } } From bd8c52a3c25e93e0fe46fdb13a33e219136fff89 Mon Sep 17 00:00:00 2001 From: Cesar <174354243+cesar-startale@users.noreply.github.com> Date: Fri, 19 Jul 2024 11:50:27 -0400 Subject: [PATCH 2/2] Move the original `mod.rs` to `ws.rs` Introduce the meta Client that will instantiate the original Ws client or the new Http upstream client --- src/extensions/client/http.rs | 30 +- src/extensions/client/mod.rs | 517 +++++----------------------------- src/extensions/client/ws.rs | 515 +++++++++++++++++++++++++++++++++ 3 files changed, 592 insertions(+), 470 deletions(-) create mode 100644 src/extensions/client/ws.rs diff --git a/src/extensions/client/http.rs b/src/extensions/client/http.rs index f8497b2..582b7d8 100644 --- a/src/extensions/client/http.rs +++ b/src/extensions/client/http.rs @@ -16,32 +16,16 @@ pub struct HttpClient { } impl HttpClient { - pub fn new(endpoints: Vec) -> Result<(Option, Vec), Error> { - let mut other_urls = vec![]; + pub fn new(endpoints: &[String]) -> Result { let clients = endpoints - .into_iter() - .filter_map(|url| { - let t_url = url.to_lowercase(); - if t_url.starts_with("http://") || t_url.starts_with("https://") { - Some(RpcClient::builder().build(url)) - } else { - other_urls.push(url); - None - } - }) + .iter() + .map(|url| RpcClient::builder().build(url)) .collect::, _>>()?; - if clients.is_empty() { - Ok((None, other_urls)) - } else { - Ok(( - Some(Self { - clients, - last_sent: AtomicUsize::new(0), - }), - other_urls, - )) - } + Ok(Self { + clients, + last_sent: AtomicUsize::new(0), + }) } /// Sends a request to one of the clients diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 063214c..a696a3b 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -1,57 +1,30 @@ -use std::{ - sync::{ - atomic::{AtomicU32, AtomicUsize}, - Arc, - }, - time::Duration, -}; +use std::time::Duration; use anyhow::anyhow; use async_trait::async_trait; -use futures::TryFutureExt; use garde::Validate; -use jsonrpsee::{ - core::{ - client::{ClientT, Error, Subscription, SubscriptionClientT}, - JsonValue, - }, - ws_client::{WsClient, WsClientBuilder}, +use jsonrpsee::core::{ + client::{Error, Subscription}, + JsonValue, }; -use opentelemetry::trace::FutureExt; use rand::{seq::SliceRandom, thread_rng}; use serde::Deserialize; -use tokio::sync::Notify; use super::ExtensionRegistry; -use crate::{ - extensions::Extension, - middlewares::CallResult, - utils::{self, errors}, -}; +use crate::{extensions::Extension, middlewares::CallResult, utils::errors}; mod http; #[cfg(test)] pub mod mock; #[cfg(test)] mod tests; - -const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client"); +#[allow(dead_code)] +mod ws; pub struct Client { endpoints: Vec, http_client: Option, - sender: Option>, - rotation_notify: Option>, - retries: u32, - background_task: Option>, -} - -impl Drop for Client { - fn drop(&mut self) { - if let Some(background_task) = self.background_task.take() { - background_task.abort(); - } - } + ws_client: Option, } #[derive(Deserialize, Validate, Debug)] @@ -59,7 +32,7 @@ impl Drop for Client { pub struct ClientConfig { #[garde(inner(custom(validate_endpoint)))] pub endpoints: Vec, - #[serde(default = "bool_true")] + #[serde(default = "ws::bool_true")] pub shuffle_endpoints: bool, } @@ -73,64 +46,19 @@ fn validate_endpoint(endpoint: &str, _context: &()) -> garde::Result { impl ClientConfig { pub async fn all_endpoints_can_be_connected(&self) -> bool { - let join_handles: Vec<_> = self - .endpoints - .iter() - .map(|endpoint| { - let endpoint = endpoint.clone(); - tokio::spawn(async move { - match check_endpoint_connection(&endpoint).await { - Ok(_) => { - tracing::info!("Connected to endpoint: {endpoint}"); - true - } - Err(err) => { - tracing::error!("Failed to connect to endpoint: {endpoint}, error: {err:?}",); - false - } - } - }) - }) - .collect(); - let mut ok_all = true; - for join_handle in join_handles { - let ok = join_handle.await.unwrap_or_else(|e| { - tracing::error!("Failed to join: {e:?}"); - false - }); - if !ok { - ok_all = false - } - } - ok_all - } -} -// simple connection check with default client params and no retries -async fn check_endpoint_connection(endpoint: &str) -> Result<(), anyhow::Error> { - let _ = WsClientBuilder::default().build(&endpoint).await?; - Ok(()) -} + let (ws_clients, _) = Client::get_urls(&self.endpoints); -pub fn bool_true() -> bool { - true -} + if ws_clients.is_empty() { + return true; + } -#[derive(Debug)] -enum Message { - Request { - method: String, - params: Vec, - response: tokio::sync::oneshot::Sender>, - retries: u32, - }, - Subscribe { - subscribe: String, - params: Vec, - unsubscribe: String, - response: tokio::sync::oneshot::Sender, Error>>, - retries: u32, - }, - RotateEndpoint, + ws::ClientConfig { + endpoints: ws_clients, + shuffle_endpoints: self.shuffle_endpoints, + } + .all_endpoints_can_be_connected() + .await + } } #[async_trait] @@ -155,291 +83,53 @@ impl Client { connection_timeout: Option, retries: Option, ) -> Result { - let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); - let endpoints_ = endpoints.clone(); - - if endpoints.is_empty() { + let endpoints = endpoints + .into_iter() + .map(|e| e.as_ref().to_string()) + .collect::>(); + let (ws_endpoints, http_endpoints) = Self::get_urls(&endpoints); + if ws_endpoints.is_empty() && http_endpoints.is_empty() { return Err(anyhow!("No endpoints provided")); } - let (http_client, ws_endpoints) = http::HttpClient::new(endpoints)?; - - if ws_endpoints.is_empty() { - return Ok(Self { - http_client, - endpoints: endpoints_, - sender: None, // No websocket - rotation_notify: None, - retries: retries.unwrap_or(3), - background_task: None, - }); - } - - tracing::debug!("New client with endpoints: {:?}", ws_endpoints); - - let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::(100); - - let message_tx_bg = message_tx.clone(); - - let rotation_notify = Arc::new(Notify::new()); - let rotation_notify_bg = rotation_notify.clone(); - - let background_task = tokio::spawn(async move { - let connect_backoff_counter = Arc::new(AtomicU32::new(0)); - let request_backoff_counter = Arc::new(AtomicU32::new(0)); - - let current_endpoint = AtomicUsize::new(0); - - let connect_backoff_counter2 = connect_backoff_counter.clone(); - let build_ws = || async { - let build = || { - let current_endpoint = current_endpoint.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let url = &ws_endpoints[current_endpoint % ws_endpoints.len()]; - - tracing::info!("Connecting to endpoint: {}", url); - - // TODO: make those configurable - WsClientBuilder::default() - .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) - .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) - .max_buffer_capacity_per_subscription(2048) - .max_concurrent_requests(2048) - .max_response_size(20 * 1024 * 1024) - .build(url) - .map_err(|e| (e, url.to_string())) - }; - - loop { - match build().await { - Ok(ws) => { - let ws = Arc::new(ws); - tracing::info!("Endpoint connected"); - connect_backoff_counter2.store(0, std::sync::atomic::Ordering::Relaxed); - break ws; - } - Err((e, url)) => { - tracing::warn!("Unable to connect to endpoint: '{url}' error: {e}"); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter2)).await; - } - } - } - }; - - let mut ws = build_ws().await; - - let handle_message = |message: Message, ws: Arc| { - let tx = message_tx_bg.clone(); - let request_backoff_counter = request_backoff_counter.clone(); - - // total timeout for a request - let task_timeout = request_timeout - .unwrap_or(Duration::from_secs(30)) - // buffer 5 seconds for the request to be processed - .saturating_add(Duration::from_secs(5)); - - tokio::spawn(async move { - match message { - Message::Request { - method, - params, - response, - mut retries, - } => { - retries = retries.saturating_sub(1); - - // make sure it's still connected - if response.is_closed() { - return; - } - - if let Ok(result) = - tokio::time::timeout(task_timeout, ws.request(&method, params.clone())).await - { - match result { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(result); - } - Err(err) => { - tracing::debug!("Request failed: {:?}", err); - match err { - Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; - } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; - } - - if matches!(err, Error::RequestTimeout) { - tx.send(Message::RotateEndpoint) - .await - .expect("Failed to send rotate message"); - } - - tx.send(Message::Request { - method, - params, - response, - retries, - }) - .await - .expect("Failed to send request message"); - } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); - } - } - } - } - } else { - tracing::error!("request timed out method: {} params: {:?}", method, params); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(Err(Error::RequestTimeout)); - } - } - Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - mut retries, - } => { - retries = retries.saturating_sub(1); - - if let Ok(result) = tokio::time::timeout( - task_timeout, - ws.subscribe(&subscribe, params.clone(), &unsubscribe), - ) - .await - { - match result { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(result); - } - Err(err) => { - tracing::debug!("Subscribe failed: {:?}", err); - match err { - Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; - } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; - } - - if matches!(err, Error::RequestTimeout) { - tx.send(Message::RotateEndpoint) - .await - .expect("Failed to send rotate message"); - } - - tx.send(Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - retries, - }) - .await - .expect("Failed to send subscribe message") - } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); - } - } - } - } - } else { - tracing::error!("subscribe timed out subscribe: {} params: {:?}", subscribe, params); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(Err(Error::RequestTimeout)); - } - } - Message::RotateEndpoint => { - unreachable!() - } - } - }); - }; - - loop { - tokio::select! { - _ = ws.on_disconnect() => { - tracing::info!("Endpoint disconnected"); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; - ws = build_ws().await; - } - message = message_rx.recv() => { - tracing::trace!("Received message {message:?}"); - match message { - Some(Message::RotateEndpoint) => { - rotation_notify_bg.notify_waiters(); - tracing::info!("Rotate endpoint"); - ws = build_ws().await; - } - Some(message) => handle_message(message, ws.clone()), - None => { - tracing::debug!("Client dropped"); - break; - } - } - }, - }; - } - }); - - if let Some(0) = retries { - return Err(anyhow!("Retries need to be at least 1")); - } - Ok(Self { - http_client, - endpoints: endpoints_, - sender: Some(message_tx), - rotation_notify: Some(rotation_notify), - retries: retries.unwrap_or(3), - background_task: Some(background_task), + endpoints, + http_client: if http_endpoints.is_empty() { + None + } else { + Some(http::HttpClient::new(&http_endpoints)?) + }, + ws_client: if ws_endpoints.is_empty() { + None + } else { + Some(ws::Client::new( + &ws_endpoints, + request_timeout, + connection_timeout, + retries, + )?) + }, }) } + pub fn get_urls(endpoints: impl IntoIterator>) -> (Vec, Vec) { + let endpoints = endpoints + .into_iter() + .map(|e| e.as_ref().to_string()) + .collect::>(); + ( + endpoints + .iter() + .filter(|e| e.starts_with("ws://") || e.starts_with("wss://")) + .map(|c| c.to_string()) + .collect::>(), + endpoints + .into_iter() + .filter(|e| e.starts_with("http://") || e.starts_with("https://")) + .collect::>(), + ) + } + pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { Self::new(endpoints, None, None, None) } @@ -450,28 +140,11 @@ impl Client { pub async fn request(&self, method: &str, params: Vec) -> CallResult { if let Some(http_client) = &self.http_client { - return http_client.request(method, params).await; - } - - if let Some(sender) = self.sender.as_ref() { - async move { - let (tx, rx) = tokio::sync::oneshot::channel(); - sender - .send(Message::Request { - method: method.into(), - params, - response: tx, - retries: self.retries, - }) - .await - .map_err(errors::internal_error)?; - - rx.await.map_err(errors::internal_error)?.map_err(errors::map_error) - } - .with_context(TRACER.context(method.to_string())) - .await + http_client.request(method, params).await + } else if let Some(ws_client) = &self.ws_client { + ws_client.request(method, params).await } else { - Err(errors::internal_error("No sender")) + Err(errors::internal_error("No upstream client")) } } @@ -481,73 +154,23 @@ impl Client { params: Vec, unsubscribe: &str, ) -> Result, Error> { - if let Some(sender) = self.sender.as_ref() { - async move { - let (tx, rx) = tokio::sync::oneshot::channel(); - sender - .send(Message::Subscribe { - subscribe: subscribe.into(), - params, - unsubscribe: unsubscribe.into(), - response: tx, - retries: self.retries, - }) - .await - .map_err(errors::internal_error)?; - - rx.await.map_err(errors::internal_error)? - } - .with_context(TRACER.context(subscribe.to_string())) - .await + if let Some(ws_client) = &self.ws_client { + ws_client.subscribe(subscribe, params, unsubscribe).await } else { Err(Error::Call(errors::internal_error("No websocket connection"))) } } pub async fn rotate_endpoint(&self) { - if let Some(sender) = self.sender.as_ref() { - sender - .send(Message::RotateEndpoint) - .await - .expect("Failed to rotate endpoint"); + if let Some(ws_client) = &self.ws_client { + ws_client.rotate_endpoint().await; } } /// Returns a future that resolves when the endpoint is rotated. pub async fn on_rotation(&self) { - if let Some(rotation_notify) = self.rotation_notify.as_ref() { - rotation_notify.notified().await + if let Some(ws_client) = &self.ws_client { + ws_client.on_rotation().await; } } } - -fn get_backoff_time(counter: &Arc) -> Duration { - let min_time = 100u64; - let step = 100u64; - let max_count = 10u32; - - let backoff_count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - let backoff_count = backoff_count.min(max_count) as u64; - let backoff_time = backoff_count * backoff_count * step; - - Duration::from_millis(backoff_time + min_time) -} - -#[test] -fn test_get_backoff_time() { - let counter = Arc::new(AtomicU32::new(0)); - - let mut times = Vec::new(); - - for _ in 0..12 { - times.push(get_backoff_time(&counter)); - } - - let times = times.into_iter().map(|t| t.as_millis()).collect::>(); - - assert_eq!( - times, - vec![100, 200, 500, 1000, 1700, 2600, 3700, 5000, 6500, 8200, 10100, 10100] - ); -} diff --git a/src/extensions/client/ws.rs b/src/extensions/client/ws.rs new file mode 100644 index 0000000..b3eef7d --- /dev/null +++ b/src/extensions/client/ws.rs @@ -0,0 +1,515 @@ +use std::{ + sync::{ + atomic::{AtomicU32, AtomicUsize}, + Arc, + }, + time::Duration, +}; + +use anyhow::anyhow; +use async_trait::async_trait; +use futures::TryFutureExt; +use garde::Validate; +use jsonrpsee::{ + core::{ + client::{ClientT, Error, Subscription, SubscriptionClientT}, + JsonValue, + }, + ws_client::{WsClient, WsClientBuilder}, +}; +use opentelemetry::trace::FutureExt; +use rand::{seq::SliceRandom, thread_rng}; +use serde::Deserialize; +use tokio::sync::Notify; + +use super::ExtensionRegistry; +use crate::{ + extensions::Extension, + middlewares::CallResult, + utils::{self, errors}, +}; + +const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client"); + +#[derive(Debug)] +pub struct Client { + endpoints: Vec, + sender: tokio::sync::mpsc::Sender, + rotation_notify: Arc, + retries: u32, + background_task: tokio::task::JoinHandle<()>, +} + +impl Drop for Client { + fn drop(&mut self) { + self.background_task.abort(); + } +} + +#[derive(Deserialize, Validate, Debug)] +#[garde(allow_unvalidated)] +pub struct ClientConfig { + #[garde(inner(custom(validate_endpoint)))] + pub endpoints: Vec, + #[serde(default = "bool_true")] + pub shuffle_endpoints: bool, +} + +fn validate_endpoint(endpoint: &str, _context: &()) -> garde::Result { + endpoint + .parse::() + .map_err(|_| garde::Error::new(format!("Invalid endpoint format: {}", endpoint)))?; + + Ok(()) +} + +impl ClientConfig { + pub async fn all_endpoints_can_be_connected(&self) -> bool { + let join_handles: Vec<_> = self + .endpoints + .iter() + .map(|endpoint| { + let endpoint = endpoint.clone(); + tokio::spawn(async move { + match check_endpoint_connection(&endpoint).await { + Ok(_) => { + tracing::info!("Connected to endpoint: {endpoint}"); + true + } + Err(err) => { + tracing::error!("Failed to connect to endpoint: {endpoint}, error: {err:?}",); + false + } + } + }) + }) + .collect(); + let mut ok_all = true; + for join_handle in join_handles { + let ok = join_handle.await.unwrap_or_else(|e| { + tracing::error!("Failed to join: {e:?}"); + false + }); + if !ok { + ok_all = false + } + } + ok_all + } +} +// simple connection check with default client params and no retries +async fn check_endpoint_connection(endpoint: &str) -> Result<(), anyhow::Error> { + let _ = WsClientBuilder::default().build(&endpoint).await?; + Ok(()) +} + +pub fn bool_true() -> bool { + true +} + +#[derive(Debug)] +enum Message { + Request { + method: String, + params: Vec, + response: tokio::sync::oneshot::Sender>, + retries: u32, + }, + Subscribe { + subscribe: String, + params: Vec, + unsubscribe: String, + response: tokio::sync::oneshot::Sender, Error>>, + retries: u32, + }, + RotateEndpoint, +} + +#[async_trait] +impl Extension for Client { + type Config = ClientConfig; + + async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { + if config.shuffle_endpoints { + let mut endpoints = config.endpoints.clone(); + endpoints.shuffle(&mut thread_rng()); + Ok(Self::new(endpoints, None, None, None)?) + } else { + Ok(Self::new(config.endpoints.clone(), None, None, None)?) + } + } +} + +impl Client { + pub fn new( + endpoints: impl IntoIterator>, + request_timeout: Option, + connection_timeout: Option, + retries: Option, + ) -> Result { + let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); + + if endpoints.is_empty() { + return Err(anyhow!("No endpoints provided")); + } + + tracing::debug!("New client with endpoints: {:?}", endpoints); + + let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::(100); + + let message_tx_bg = message_tx.clone(); + + let rotation_notify = Arc::new(Notify::new()); + let rotation_notify_bg = rotation_notify.clone(); + let endpoints_ = endpoints.clone(); + + let background_task = tokio::spawn(async move { + let connect_backoff_counter = Arc::new(AtomicU32::new(0)); + let request_backoff_counter = Arc::new(AtomicU32::new(0)); + + let current_endpoint = AtomicUsize::new(0); + + let connect_backoff_counter2 = connect_backoff_counter.clone(); + let build_ws = || async { + let build = || { + let current_endpoint = current_endpoint.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let url = &endpoints[current_endpoint % endpoints.len()]; + + tracing::info!("Connecting to endpoint: {}", url); + + // TODO: make those configurable + WsClientBuilder::default() + .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) + .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) + .max_buffer_capacity_per_subscription(2048) + .max_concurrent_requests(2048) + .max_response_size(20 * 1024 * 1024) + .build(url) + .map_err(|e| (e, url.to_string())) + }; + + loop { + match build().await { + Ok(ws) => { + let ws = Arc::new(ws); + tracing::info!("Endpoint connected"); + connect_backoff_counter2.store(0, std::sync::atomic::Ordering::Relaxed); + break ws; + } + Err((e, url)) => { + tracing::warn!("Unable to connect to endpoint: '{url}' error: {e}"); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter2)).await; + } + } + } + }; + + let mut ws = build_ws().await; + + let handle_message = |message: Message, ws: Arc| { + let tx = message_tx_bg.clone(); + let request_backoff_counter = request_backoff_counter.clone(); + + // total timeout for a request + let task_timeout = request_timeout + .unwrap_or(Duration::from_secs(30)) + // buffer 5 seconds for the request to be processed + .saturating_add(Duration::from_secs(5)); + + tokio::spawn(async move { + match message { + Message::Request { + method, + params, + response, + mut retries, + } => { + retries = retries.saturating_sub(1); + + // make sure it's still connected + if response.is_closed() { + return; + } + + if let Ok(result) = + tokio::time::timeout(task_timeout, ws.request(&method, params.clone())).await + { + match result { + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(result); + } + Err(err) => { + tracing::debug!("Request failed: {:?}", err); + match err { + Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; + } + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; + } + + if matches!(err, Error::RequestTimeout) { + tx.send(Message::RotateEndpoint) + .await + .expect("Failed to send rotate message"); + } + + tx.send(Message::Request { + method, + params, + response, + retries, + }) + .await + .expect("Failed to send request message"); + } + err => { + // make sure it's still connected + if response.is_closed() { + return; + } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); + } + } + } + } + } else { + tracing::error!("request timed out method: {} params: {:?}", method, params); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(Err(Error::RequestTimeout)); + } + } + Message::Subscribe { + subscribe, + params, + unsubscribe, + response, + mut retries, + } => { + retries = retries.saturating_sub(1); + + if let Ok(result) = tokio::time::timeout( + task_timeout, + ws.subscribe(&subscribe, params.clone(), &unsubscribe), + ) + .await + { + match result { + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(result); + } + Err(err) => { + tracing::debug!("Subscribe failed: {:?}", err); + match err { + Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; + } + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; + } + + if matches!(err, Error::RequestTimeout) { + tx.send(Message::RotateEndpoint) + .await + .expect("Failed to send rotate message"); + } + + tx.send(Message::Subscribe { + subscribe, + params, + unsubscribe, + response, + retries, + }) + .await + .expect("Failed to send subscribe message") + } + err => { + // make sure it's still connected + if response.is_closed() { + return; + } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); + } + } + } + } + } else { + tracing::error!("subscribe timed out subscribe: {} params: {:?}", subscribe, params); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(Err(Error::RequestTimeout)); + } + } + Message::RotateEndpoint => { + unreachable!() + } + } + }); + }; + + loop { + tokio::select! { + _ = ws.on_disconnect() => { + tracing::info!("Endpoint disconnected"); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; + ws = build_ws().await; + } + message = message_rx.recv() => { + tracing::trace!("Received message {message:?}"); + match message { + Some(Message::RotateEndpoint) => { + rotation_notify_bg.notify_waiters(); + tracing::info!("Rotate endpoint"); + ws = build_ws().await; + } + Some(message) => handle_message(message, ws.clone()), + None => { + tracing::debug!("Client dropped"); + break; + } + } + }, + }; + } + }); + + if let Some(0) = retries { + return Err(anyhow!("Retries need to be at least 1")); + } + + Ok(Self { + endpoints: endpoints_, + sender: message_tx, + rotation_notify, + retries: retries.unwrap_or(3), + background_task, + }) + } + + pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { + Self::new(endpoints, None, None, None) + } + + pub fn endpoints(&self) -> &Vec { + &self.endpoints + } + + pub async fn request(&self, method: &str, params: Vec) -> CallResult { + async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.sender + .send(Message::Request { + method: method.into(), + params, + response: tx, + retries: self.retries, + }) + .await + .map_err(errors::internal_error)?; + + rx.await.map_err(errors::internal_error)?.map_err(errors::map_error) + } + .with_context(TRACER.context(method.to_string())) + .await + } + + pub async fn subscribe( + &self, + subscribe: &str, + params: Vec, + unsubscribe: &str, + ) -> Result, Error> { + async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.sender + .send(Message::Subscribe { + subscribe: subscribe.into(), + params, + unsubscribe: unsubscribe.into(), + response: tx, + retries: self.retries, + }) + .await + .map_err(errors::internal_error)?; + + rx.await.map_err(errors::internal_error)? + } + .with_context(TRACER.context(subscribe.to_string())) + .await + } + + pub async fn rotate_endpoint(&self) { + self.sender + .send(Message::RotateEndpoint) + .await + .expect("Failed to rotate endpoint"); + } + + /// Returns a future that resolves when the endpoint is rotated. + pub async fn on_rotation(&self) { + self.rotation_notify.notified().await + } +} + +fn get_backoff_time(counter: &Arc) -> Duration { + let min_time = 100u64; + let step = 100u64; + let max_count = 10u32; + + let backoff_count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let backoff_count = backoff_count.min(max_count) as u64; + let backoff_time = backoff_count * backoff_count * step; + + Duration::from_millis(backoff_time + min_time) +} + +#[test] +fn test_get_backoff_time() { + let counter = Arc::new(AtomicU32::new(0)); + + let mut times = Vec::new(); + + for _ in 0..12 { + times.push(get_backoff_time(&counter)); + } + + let times = times.into_iter().map(|t| t.as_millis()).collect::>(); + + assert_eq!( + times, + vec![100, 200, 500, 1000, 1700, 2600, 3700, 5000, 6500, 8200, 10100, 10100] + ); +}