From 5e0b99d8a6145e564f594ba3111290de713e1e95 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Mon, 3 Apr 2023 22:24:01 +0300 Subject: [PATCH] Add `rtt` to `Client` --- .config/nats.dic | 1 + async-nats/src/client.rs | 92 +++++++++++++++++++++++++++++++- async-nats/src/lib.rs | 36 +++++++++++-- async-nats/tests/client_tests.rs | 11 ++++ 4 files changed, 135 insertions(+), 5 deletions(-) diff --git a/.config/nats.dic b/.config/nats.dic index 1d5c7f61c..715023e5d 100644 --- a/.config/nats.dic +++ b/.config/nats.dic @@ -133,3 +133,4 @@ ConnectError DNS RequestErrorKind rustls +RttError diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 06ef82601..4593d2bd2 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -23,7 +23,7 @@ use regex::Regex; use std::fmt::Display; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use thiserror::Error; use tokio::sync::mpsc; use tracing::trace; @@ -463,6 +463,51 @@ impl Client { Ok(()) } + /// Calculates the round trip time between this client and the server, + /// if the server is currently connected. + /// + /// # Examples + /// + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// let client = async_nats::connect("demo.nats.io").await?; + /// let rtt = client.rtt().await?; + /// println!("server rtt: {:?}", rtt); + /// # Ok(()) + /// # } + /// ``` + pub async fn rtt(&self) -> Result { + let start = Instant::now(); + + let (ping_tx, ping_rx) = tokio::sync::oneshot::channel(); + let (pong_tx, pong_rx) = tokio::sync::oneshot::channel(); + + self.sender + .send(Command::Ping { + ping_result: Some(ping_tx), + pong_result: Some(pong_tx), + }) + .await + .map_err(|err| RttError::with_source(RttErrorKind::PingError, err))?; + + ping_rx + .await + // first handle rx error + .map_err(|err| RttError::with_source(RttErrorKind::PingError, err))? + // second handle the atual ping error + .map_err(|err| RttError::with_source(RttErrorKind::PingError, err))?; + + pong_rx + .await + // first handle rx error + .map_err(|err| RttError::with_source(RttErrorKind::PongError, err))? + // second handle the actual pong error + .map_err(|err| RttError::with_source(RttErrorKind::PongError, err))?; + + Ok(start.elapsed()) + } + /// Returns the current state of the connection. /// /// # Examples @@ -688,3 +733,48 @@ impl From for RequestError { RequestError::with_source(RequestErrorKind::Other, e) } } + +/// Error returned when doing a round-trip time measurement fails. +/// To enumerate over the variants, call [RttError::kind]. +#[derive(Debug, Error)] +pub struct RttError { + kind: RttErrorKind, + source: Option>, +} + +impl Display for RttError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let source_info = self + .source + .as_ref() + .map(|e| e.to_string()) + .unwrap_or_else(|| "no details".into()); + match self.kind { + RttErrorKind::PingError => { + write!(f, "failed to ping server: {}", source_info) + } + RttErrorKind::PongError => write!(f, "pong failed: {}", source_info), + } + } +} + +impl RttError { + fn with_source(kind: RttErrorKind, source: E) -> RttError + where + E: Into>, + { + RttError { + kind, + source: Some(source.into()), + } + } + pub fn kind(&self) -> RttErrorKind { + self.kind + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum RttErrorKind { + PingError, + PongError, +} diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 58ada931b..0b704054b 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -108,7 +108,7 @@ use futures::stream::Stream; use tracing::{debug, error}; use core::fmt; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::fmt::Display; use std::iter; use std::net::{SocketAddr, ToSocketAddrs}; @@ -255,7 +255,10 @@ pub enum Command { sid: u64, max: Option, }, - Ping, + Ping { + ping_result: Option>>, + pong_result: Option>>, + }, Flush { result: oneshot::Sender>, }, @@ -305,6 +308,7 @@ pub(crate) struct ConnectionHandler { info_sender: tokio::sync::watch::Sender, ping_interval: Interval, flush_interval: Interval, + pending_pongs: VecDeque>>>, } impl ConnectionHandler { @@ -330,6 +334,7 @@ impl ConnectionHandler { info_sender, ping_interval, flush_interval, + pending_pongs: VecDeque::new(), } } @@ -398,6 +403,12 @@ impl ConnectionHandler { ServerOp::Pong => { debug!("received PONG"); self.pending_pings = self.pending_pings.saturating_sub(1); + + if let Some(Some(sender)) = self.pending_pongs.pop_front() { + sender.send(Ok(())).map_err(|_| { + io::Error::new(io::ErrorKind::Other, "one shot failed to be received") + })?; + } } ServerOp::Error(error) => { self.connector @@ -508,7 +519,10 @@ impl ConnectionHandler { } } } - Command::Ping => { + Command::Ping { + ping_result, + pong_result, + } => { debug!( "PING command. Pending pings {}, max pings {}", self.pending_pings, self.max_pings @@ -524,8 +538,21 @@ impl ConnectionHandler { self.handle_disconnect().await?; } - if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await { + // awlays push to pending_pongs in queue, will handle them in PONG + self.pending_pongs.push_back(pong_result); + + if let Err(err) = self.connection.write_op(&ClientOp::Ping).await { self.handle_disconnect().await?; + + if let Some(ping_result) = ping_result { + ping_result.send(Err(err)).map_err(|_| { + io::Error::new(io::ErrorKind::Other, "one shot failed to be received") + })?; + } + } else if let Some(ping_result) = ping_result { + ping_result.send(Ok(())).map_err(|_| { + io::Error::new(io::ErrorKind::Other, "one shot failed to be received") + })?; } self.handle_flush().await?; @@ -615,6 +642,7 @@ impl ConnectionHandler { async fn handle_disconnect(&mut self) -> io::Result<()> { self.pending_pings = 0; + self.pending_pongs.clear(); self.connector.events_tx.try_send(Event::Disconnected).ok(); self.connector.state_tx.send(State::Disconnected).ok(); self.handle_reconnect().await?; diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index 538b78d2a..c4e244fb9 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -764,4 +764,15 @@ mod client { drop(servers.remove(0)); rx.recv().await; } + + #[tokio::test] + async fn rtt() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let rtt = client.rtt().await.unwrap(); + + println!("rtt: {:?}", rtt); + assert!(rtt.as_nanos() > 0); + } }