diff --git a/async-nats/Cargo.toml b/async-nats/Cargo.toml index 6dd6db8e9..a3cbffe76 100644 --- a/async-nats/Cargo.toml +++ b/async-nats/Cargo.toml @@ -27,6 +27,7 @@ serde_repr = "0.1.16" tokio = { version = "1.36", features = ["macros", "rt", "fs", "net", "sync", "time", "io-util"] } url = { version = "2"} tokio-rustls = { version = "0.26", default-features = false } +tokio-util = "0.7" rustls-pemfile = "2" nuid = "0.5" serde_nanos = "0.1.3" diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 47ab987e8..4ff1140c8 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -11,15 +11,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +use core::pin::Pin; +use core::task::{Context, Poll}; + use crate::connection::State; use crate::subject::ToSubject; -use crate::ServerInfo; +use crate::{PublishMessage, ServerInfo}; use super::{header::HeaderMap, status::StatusCode, Command, Message, Subscriber}; use crate::error::Error; use bytes::Bytes; use futures::future::TryFutureExt; -use futures::StreamExt; +use futures::{Sink, SinkExt as _, StreamExt}; use once_cell::sync::Lazy; use portable_atomic::AtomicU64; use regex::Regex; @@ -29,6 +32,7 @@ use std::sync::Arc; use std::time::Duration; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; +use tokio_util::sync::PollSender; use tracing::trace; static VERSION_RE: Lazy = @@ -44,6 +48,12 @@ impl From> for PublishError { } } +impl From> for PublishError { + fn from(err: tokio_util::sync::PollSendError) -> Self { + PublishError::with_source(PublishErrorKind::Send, err) + } +} + #[derive(Copy, Clone, Debug, PartialEq)] pub enum PublishErrorKind { MaxPayloadExceeded, @@ -67,6 +77,7 @@ pub struct Client { info: tokio::sync::watch::Receiver, pub(crate) state: tokio::sync::watch::Receiver, pub(crate) sender: mpsc::Sender, + poll_sender: PollSender, next_subscription_id: Arc, subscription_capacity: usize, inbox_prefix: Arc, @@ -74,6 +85,28 @@ pub struct Client { max_payload: Arc, } +impl Sink for Client { + type Error = PublishError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_sender.poll_ready_unpin(cx).map_err(Into::into) + } + + fn start_send(mut self: Pin<&mut Self>, msg: PublishMessage) -> Result<(), Self::Error> { + self.poll_sender + .start_send_unpin(Command::Publish(msg)) + .map_err(Into::into) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_sender.poll_flush_unpin(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_sender.poll_close_unpin(cx).map_err(Into::into) + } +} + impl Client { pub(crate) fn new( info: tokio::sync::watch::Receiver, @@ -84,10 +117,12 @@ impl Client { request_timeout: Option, max_payload: Arc, ) -> Client { + let poll_sender = PollSender::new(sender.clone()); Client { info, state, sender, + poll_sender, next_subscription_id: Arc::new(AtomicU64::new(1)), subscription_capacity: capacity, inbox_prefix: inbox_prefix.into(), @@ -191,12 +226,12 @@ impl Client { } self.sender - .send(Command::Publish { + .send(Command::Publish(PublishMessage { subject, payload, - respond: None, + reply: None, headers: None, - }) + })) .await?; Ok(()) } @@ -229,12 +264,12 @@ impl Client { let subject = subject.to_subject(); self.sender - .send(Command::Publish { + .send(Command::Publish(PublishMessage { subject, payload, - respond: None, + reply: None, headers: Some(headers), - }) + })) .await?; Ok(()) } @@ -265,12 +300,12 @@ impl Client { let reply = reply.to_subject(); self.sender - .send(Command::Publish { + .send(Command::Publish(PublishMessage { subject, payload, - respond: Some(reply), + reply: Some(reply), headers: None, - }) + })) .await?; Ok(()) } @@ -304,12 +339,12 @@ impl Client { let reply = reply.to_subject(); self.sender - .send(Command::Publish { + .send(Command::Publish(PublishMessage { subject, payload, - respond: Some(reply), + reply: Some(reply), headers: Some(headers), - }) + })) .await?; Ok(()) } diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index a530d63a5..3981f172c 100755 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -342,14 +342,19 @@ pub(crate) enum ServerOp { }, } +/// `PublishMessage` represents a message being published +#[derive(Debug)] +pub struct PublishMessage { + pub subject: Subject, + pub payload: Bytes, + pub reply: Option, + pub headers: Option, +} + +/// `Command` represents all commands that a [`Client`] can handle #[derive(Debug)] pub(crate) enum Command { - Publish { - subject: Subject, - payload: Bytes, - respond: Option, - headers: Option, - }, + Publish(PublishMessage), Request { subject: Subject, payload: Bytes, @@ -822,12 +827,12 @@ impl ConnectionHandler { self.connection.enqueue_write_op(&pub_op); } - Command::Publish { + Command::Publish(PublishMessage { subject, payload, - respond, + reply: respond, headers, - } => { + }) => { self.connection.enqueue_write_op(&ClientOp::Publish { subject, payload,