diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index f64b11717..b7ebdc251 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -22,7 +22,7 @@ use super::{header::HeaderMap, status::StatusCode, Command, Message, Subscriber} use crate::error::Error; use bytes::Bytes; use futures::future::TryFutureExt; -use futures::{Sink, SinkExt as _, StreamExt}; +use futures::{FutureExt, Sink, SinkExt as _, Stream, StreamExt}; use once_cell::sync::Lazy; use portable_atomic::AtomicU64; use regex::Regex; @@ -673,6 +673,10 @@ impl Client { pub fn statistics(&self) -> Arc { self.connection_stats.clone() } + + pub fn request_many(&self) -> RequestMany { + RequestMany::new(self.clone(), self.request_timeout) + } } /// Used for building customized requests. @@ -866,3 +870,141 @@ pub struct Statistics { /// Initial connect will be counted as well, then all successful reconnects. pub connects: AtomicU64, } + +pub struct RequestMany { + client: Client, + sentinel: Option bool + 'static>>, + max_wait: Option, + stall_wait: Option, + max_messags: Option, +} + +impl RequestMany { + pub fn new(client: Client, max_wait: Option) -> Self { + RequestMany { + client, + sentinel: None, + max_wait, + stall_wait: None, + max_messags: None, + } + } + + pub fn sentinel(mut self, sentinel: impl Fn(&crate::Message) -> bool + 'static) -> Self { + self.sentinel = Some(Box::new(sentinel)); + self + } + + pub fn stall_wait(mut self, stall_wait: Duration) -> Self { + self.stall_wait = Some(stall_wait); + self + } + + pub fn max_messages(mut self, max_messages: usize) -> Self { + self.max_messags = Some(max_messages); + self + } + + pub fn max_wait(mut self, max_wait: Option) -> Self { + self.max_wait = max_wait; + self + } + + pub async fn send( + self, + subject: S, + payload: Bytes, + ) -> Result { + let response_subject = self.client.new_inbox(); + let responses = self.client.subscribe(response_subject.clone()).await?; + self.client + .publish_with_reply(subject, response_subject, payload) + .await?; + + let timer = self + .max_wait + .map(|max_wait| Box::pin(tokio::time::sleep(max_wait))); + + Ok(Responses { + timer, + stall: None, + responses, + messages_received: 0, + sentinel: self.sentinel, + max_messages: self.max_messags, + stall_wait: self.stall_wait, + }) + } +} + +pub struct Responses { + responses: Subscriber, + messages_received: usize, + timer: Option>>, + stall: Option>>, + sentinel: Option bool + 'static>>, + max_messages: Option, + stall_wait: Option, +} + +impl Stream for Responses { + type Item = crate::Message; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // max_wait + if let Some(timer) = self.timer.as_mut() { + match timer.poll_unpin(cx) { + Poll::Ready(_) => { + return Poll::Ready(None); + } + Poll::Pending => {} + } + } + + // max_messages + if let Some(max_messages) = self.max_messages { + if self.messages_received >= max_messages { + return Poll::Ready(None); + } + } + + // stall_wait + if let Some(stall) = self.stall_wait { + let stall = self + .stall + .get_or_insert_with(|| Box::pin(tokio::time::sleep(stall))); + + match stall.as_mut().poll_unpin(cx) { + Poll::Ready(_) => { + return Poll::Ready(None); + } + Poll::Pending => {} + } + } + + match self.responses.receiver.poll_recv(cx) { + Poll::Ready(message) => match message { + Some(message) => { + self.messages_received += 1; + + // reset timer + self.stall = None; + + // sentinel + match self.sentinel { + Some(ref sentinel) => { + if sentinel(&message) { + Poll::Ready(None) + } else { + return Poll::Ready(Some(message)); + } + } + None => Poll::Ready(Some(message)), + } + } + None => Poll::Ready(None), + }, + Poll::Pending => Poll::Pending, + } + } +} diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index bdb0cea2c..d6a40d599 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -996,4 +996,109 @@ mod client { assert!(stats.out_bytes.load(Ordering::Relaxed) != 0); assert_eq!(stats.connects.load(Ordering::Relaxed), 2); } + + #[tokio::test] + async fn request_many() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + // request many with sentinel + let mut requests = client.subscribe("test").await.unwrap(); + let responses = client + .request_many() + .sentinel(|msg| msg.payload.is_empty()) + .send("test", "data".into()) + .await + .unwrap(); + + let request = requests.next().await.unwrap(); + + for _ in 0..100 { + client + .publish(request.reply.clone().unwrap(), "data".into()) + .await + .unwrap(); + } + client + .publish(request.reply.unwrap(), "".into()) + .await + .unwrap(); + + assert_eq!(responses.count().await, 100); + requests.unsubscribe().await.unwrap(); + + // request many with max messages + let mut requests = client.subscribe("test").await.unwrap(); + let responses = client + .request_many() + .max_messages(20) + .send("test", "data".into()) + .await + .unwrap(); + + let request = requests.next().await.unwrap(); + + for _ in 1..=100 { + client + .publish(request.reply.clone().unwrap(), "data".into()) + .await + .unwrap(); + } + + assert_eq!(responses.count().await, 20); + requests.unsubscribe().await.unwrap(); + + // request many with stall + let mut requests = client.subscribe("test").await.unwrap(); + let responses = client + .request_many() + .stall_wait(Duration::from_millis(100)) + .send("test", "data".into()) + .await + .unwrap(); + + tokio::task::spawn({ + let client = client.clone(); + async move { + let request = requests.next().await.unwrap(); + for i in 1..=100 { + if i == 51 { + tokio::time::sleep(Duration::from_millis(500)).await; + } + client + .publish(request.reply.clone().unwrap(), "data".into()) + .await + .unwrap(); + } + requests.unsubscribe().await.unwrap(); + } + }); + assert_eq!(responses.count().await, 50); + + // request many with max wait + let mut requests = client.subscribe("test").await.unwrap(); + let responses = client + .request_many() + .max_wait(Some(Duration::from_secs(5))) + .send("test", "data".into()) + .await + .unwrap(); + + tokio::task::spawn({ + let client = client.clone(); + async move { + let request = requests.next().await.unwrap(); + for i in 1..=100 { + if i == 21 { + tokio::time::sleep(Duration::from_secs(10)).await; + } + client + .publish(request.reply.clone().unwrap(), "data".into()) + .await + .unwrap(); + } + } + }); + assert_eq!(responses.count().await, 20); + } }