diff --git a/src/api/mod.rs b/src/api/mod.rs index d6fb6a7..e0549b1 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -19,6 +19,7 @@ pub const DEFAULT_NAME: &str = "POSTGRESQL_DEFAULT_NAME"; #[derive(Debug, Clone, Copy, Default)] pub enum PgWireConnectionState { #[default] + AwaitingSslRequest, AwaitingStartup, AuthenticationInProgress, ReadyForQuery, diff --git a/src/messages/mod.rs b/src/messages/mod.rs index 6af5bd9..f69d9a4 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -74,7 +74,9 @@ pub mod terminate; #[derive(Debug)] pub enum PgWireFrontendMessage { Startup(startup::Startup), - SslRequest(startup::SslRequest), + // when client has no ssl configured, it skip this message. + // our decoder will return a `SslRequest(None)` for this case. + SslRequest(Option), PasswordMessageFamily(startup::PasswordMessageFamily), Query(simplequery::Query), @@ -111,7 +113,13 @@ impl PgWireFrontendMessage { pub fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> { match self { Self::Startup(msg) => msg.encode(buf), - Self::SslRequest(msg) => msg.encode(buf), + Self::SslRequest(msg) => { + if let Some(msg) = msg { + msg.encode(buf) + } else { + Ok(()) + } + } Self::PasswordMessageFamily(msg) => msg.encode(buf), Self::Query(msg) => msg.encode(buf), @@ -135,6 +143,7 @@ impl PgWireFrontendMessage { pub fn decode(buf: &mut BytesMut) -> PgWireResult> { if buf.remaining() > 1 { let first_byte = buf[0]; + match first_byte { // Password, SASLInitialResponse, SASLResponse can only be // decoded under certain context diff --git a/src/tokio.rs b/src/tokio.rs index 5bcc39a..6fae638 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -1,7 +1,7 @@ use std::io::Error as IOError; use std::sync::Arc; -use bytes::BytesMut; +use bytes::Buf; use futures::{SinkExt, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; @@ -34,17 +34,31 @@ impl Decoder for PgWireMessageServerCodec { fn decode(&mut self, src: &mut bytes::BytesMut) -> Result, Self::Error> { match self.client_info.state() { - PgWireConnectionState::AwaitingStartup => { - if let Some(request) = SslRequest::decode(src)? { - return Ok(Some(PgWireFrontendMessage::SslRequest(request))); + PgWireConnectionState::AwaitingSslRequest => { + if src.remaining() >= SslRequest::BODY_SIZE { + self.client_info + .set_state(PgWireConnectionState::AwaitingStartup); + + if let Some(request) = SslRequest::decode(src)? { + return Ok(Some(PgWireFrontendMessage::SslRequest(Some(request)))); + } else { + // this is not a real message, but to indicate that + // client will not init ssl handshake + return Ok(Some(PgWireFrontendMessage::SslRequest(None))); + } } + Ok(None) + } + + PgWireConnectionState::AwaitingStartup => { if let Some(startup) = Startup::decode(src)? { - return Ok(Some(PgWireFrontendMessage::Startup(startup))); + Ok(Some(PgWireFrontendMessage::Startup(startup))) + } else { + Ok(None) } - - Ok(None) } + _ => PgWireFrontendMessage::decode(src), } } @@ -256,54 +270,34 @@ enum SslNegotiationType { None, } -async fn check_ssl_negotiation(tcp_socket: &TcpStream) -> Result { - let mut buf = [0u8; SslRequest::BODY_SIZE]; - loop { - let n = tcp_socket.peek(&mut buf).await?; - - // the tcp_stream has ended - if n == 0 { - return Ok(SslNegotiationType::None); - } - - if n >= SslRequest::BODY_SIZE { - break; - } - } - if buf[0] == 0x16 { - return Ok(SslNegotiationType::Direct); - } +async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result { + let mut buf = [0u8; 1]; + let n = tcp_socket.peek(&mut buf).await?; - let mut buf = BytesMut::from(buf.as_slice()); - if let Ok(Some(_)) = SslRequest::decode(&mut buf) { - return Ok(SslNegotiationType::Postgres); - } - Ok(SslNegotiationType::None) + Ok(n > 0 && buf[0] == 0x16) } async fn peek_for_sslrequest( socket: &mut Framed>, ssl_supported: bool, ) -> Result { - let mut negotiation_type = check_ssl_negotiation(socket.get_ref()).await?; - match negotiation_type { - SslNegotiationType::Postgres => { - // consume request - socket.next().await; - - let response = if ssl_supported { - PgWireBackendMessage::SslResponse(SslResponse::Accept) - } else { - negotiation_type = SslNegotiationType::None; - PgWireBackendMessage::SslResponse(SslResponse::Refuse) - }; - socket.send(response).await?; + if check_ssl_direct_negotiation(socket.get_ref()).await? { + Ok(SslNegotiationType::Direct) + } else if let Some(Ok(PgWireFrontendMessage::SslRequest(Some(_)))) = socket.next().await { + if ssl_supported { + socket + .send(PgWireBackendMessage::SslResponse(SslResponse::Accept)) + .await?; + Ok(SslNegotiationType::Postgres) + } else { + socket + .send(PgWireBackendMessage::SslResponse(SslResponse::Refuse)) + .await?; + Ok(SslNegotiationType::None) } - SslNegotiationType::Direct => {} - SslNegotiationType::None => {} + } else { + Ok(SslNegotiationType::None) } - - Ok(negotiation_type) } async fn do_process_socket(