From 1e68efdd716d8ca43576fc7f1be0935b8d5a0405 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Fri, 21 Apr 2023 22:27:50 -0700 Subject: [PATCH 1/9] Implement zstd compression --- tests/compression/Cargo.toml | 2 +- tonic/Cargo.toml | 2 ++ tonic/src/codec/compression.rs | 51 ++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 4c1e7132d..bca447765 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -15,7 +15,7 @@ hyper = "0.14.3" pin-project = "1.0" prost = "0.11" tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} -tonic = {path = "../../tonic", features = ["gzip"]} +tonic = {path = "../../tonic", features = ["gzip", "zstd"]} tower = {version = "0.4", features = []} tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]} diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 5ce7c542a..4749b31d6 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -25,6 +25,7 @@ version = "0.9.2" [features] codegen = ["dep:async-trait"] gzip = ["dep:flate2"] +zstd = ["dep:zstd"] default = ["transport", "codegen", "prost"] prost = ["dep:prost"] tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:async-stream"] @@ -85,6 +86,7 @@ webpki-roots = { version = "0.23.0", optional = true } # compression flate2 = {version = "1.0", optional = true} +zstd = { version = "0.12.3", optional = true } [dev-dependencies] bencher = "0.1.5" diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 7063bd865..9f4184af5 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -4,6 +4,9 @@ use bytes::{Buf, BytesMut}; #[cfg(feature = "gzip")] use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; +use std::io::Write; +#[cfg(feature = "zstd")] +use zstd::{Decoder, Encoder}; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; @@ -13,6 +16,8 @@ pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; pub struct EnabledCompressionEncodings { #[cfg(feature = "gzip")] pub(crate) gzip: bool, + #[cfg(feature = "zstd")] + pub(crate) zstd: bool, } impl EnabledCompressionEncodings { @@ -21,6 +26,8 @@ impl EnabledCompressionEncodings { match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => self.gzip, + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => self.zstd, } } @@ -29,12 +36,16 @@ impl EnabledCompressionEncodings { match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => self.gzip = true, + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => self.zstd = true, } } pub(crate) fn into_accept_encoding_header_value(self) -> Option { if self.is_gzip_enabled() { Some(http::HeaderValue::from_static("gzip,identity")) + } else if self.is_zstd_enabled() { + Some(http::HeaderValue::from_static("zstd,identity")) } else { None } @@ -49,6 +60,16 @@ impl EnabledCompressionEncodings { const fn is_gzip_enabled(&self) -> bool { false } + + #[cfg(feature = "zstd")] + const fn is_zstd_enabled(&self) -> bool { + self.zstd + } + + #[cfg(not(feature = "zstd"))] + const fn is_gzip_enabled(&self) -> bool { + false + } } /// The compression encodings Tonic supports. @@ -59,6 +80,10 @@ pub enum CompressionEncoding { #[cfg(feature = "gzip")] #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] Gzip, + #[allow(missing_docs)] + #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + Zstd, } impl CompressionEncoding { @@ -77,6 +102,8 @@ impl CompressionEncoding { split_by_comma(header_value_str).find_map(|value| match value { #[cfg(feature = "gzip")] "gzip" => Some(CompressionEncoding::Gzip), + #[cfg(feature = "zstd")] + "zstd" => Some(CompressionEncoding::Zstd), _ => None, }) } @@ -103,6 +130,9 @@ impl CompressionEncoding { "gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => { Ok(Some(CompressionEncoding::Gzip)) } + "zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => { + Ok(Some(CompressionEncoding::Zstd)) + } "identity" => Ok(None), other => { let mut status = Status::unimplemented(format!( @@ -127,6 +157,8 @@ impl CompressionEncoding { match self { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => http::HeaderValue::from_static("gzip"), + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => http::HeaderValue::from_static("zstd"), } } @@ -134,6 +166,8 @@ impl CompressionEncoding { &[ #[cfg(feature = "gzip")] CompressionEncoding::Gzip, + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd, ] } } @@ -144,6 +178,8 @@ impl fmt::Display for CompressionEncoding { match *self { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => write!(f, "gzip"), + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => write!(f, "zstd"), } } } @@ -175,6 +211,14 @@ pub(crate) fn compress( std::io::copy(&mut gzip_encoder, &mut out_writer)?; } + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => { + let out_writer = bytes::BufMut::writer(out_buf); + let mut zstd_encoder = Encoder::new(out_writer, 0)?; + + zstd_encoder.write_all(&decompressed_buf[0..len])?; + zstd_encoder.finish()?; + } } decompressed_buf.advance(len); @@ -202,6 +246,13 @@ pub(crate) fn decompress( std::io::copy(&mut gzip_decoder, &mut out_writer)?; } + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => { + let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; + let mut out_writer = bytes::BufMut::writer(out_buf); + + std::io::copy(&mut zstd_decoder, &mut out_writer)?; + } } compressed_buf.advance(len); From da89eef8495e57fa2a29d05d7f521af124727ca4 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Sat, 22 Apr 2023 20:05:12 -0700 Subject: [PATCH 2/9] Parametrize compression tests --- tests/compression/Cargo.toml | 1 + tests/compression/src/bidirectional_stream.rs | 50 ++++-- tests/compression/src/client_stream.rs | 99 ++++++++---- tests/compression/src/compressing_request.rs | 76 ++++++--- tests/compression/src/compressing_response.rs | 153 ++++++++++++------ tests/compression/src/server_stream.rs | 49 ++++-- tests/compression/src/util.rs | 38 +++++ tonic/src/codec/compression.rs | 13 +- 8 files changed, 357 insertions(+), 122 deletions(-) diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index bca447765..95e04c4c1 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -12,6 +12,7 @@ futures = "0.3" http = "0.2" http-body = "0.4" hyper = "0.14.3" +paste = "1.0.12" pin-project = "1.0" prost = "0.11" tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index a92783c34..eca3a8492 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -1,20 +1,43 @@ use super::*; +use http_body::Body; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()) - .accept_compressed(CompressionEncoding::Gzip) - .send_compressed(CompressionEncoding::Gzip); + .accept_compressed(encoding) + .send_compressed(encoding); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); - fn assert_right_encoding(req: http::Request) -> http::Request { - assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); - req + #[derive(Clone)] + pub struct AssertRightEncoding { + encoding: CompressionEncoding, + } + + #[allow(dead_code)] + impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + assert_eq!( + req.headers().get("grpc-encoding").unwrap(), + self.encoding.as_str() + ); + + req + } } tokio::spawn({ @@ -24,7 +47,9 @@ async fn client_enabled_server_enabled() { Server::builder() .layer( ServiceBuilder::new() - .map_request(assert_right_encoding) + .map_request(move |req| { + AssertRightEncoding::new(encoding).clone().call(req) + }) .layer(measure_request_body_size_layer( request_bytes_counter.clone(), )) @@ -44,8 +69,8 @@ async fn client_enabled_server_enabled() { }); let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip) - .accept_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding) + .accept_compressed(encoding); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -56,7 +81,10 @@ async fn client_enabled_server_enabled() { .await .unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!( + res.metadata().get("grpc-encoding").unwrap(), + encoding.as_str() + ); let mut stream: Streaming = res.into_inner(); diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index a749c2b58..de96d7bbc 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -1,19 +1,40 @@ use super::*; -use http_body::Body as _; +use http_body::Body; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); - fn assert_right_encoding(req: http::Request) -> http::Request { - assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); - req + #[derive(Clone)] + pub struct AssertRightEncoding { + encoding: CompressionEncoding, + } + + #[allow(dead_code)] + impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + assert_eq!( + req.headers().get("grpc-encoding").unwrap(), + self.encoding.as_str() + ); + + req + } } tokio::spawn({ @@ -22,7 +43,9 @@ async fn client_enabled_server_enabled() { Server::builder() .layer( ServiceBuilder::new() - .map_request(assert_right_encoding) + .map_request(move |req| { + AssertRightEncoding::new(encoding).clone().call(req) + }) .layer(measure_request_body_size_layer( request_bytes_counter.clone(), )) @@ -35,8 +58,8 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -48,12 +71,17 @@ async fn client_enabled_server_enabled() { assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_disabled_server_enabled() { +util::parametrized_tests! { + client_disabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_disabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -93,8 +121,14 @@ async fn client_disabled_server_enabled() { assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +util::parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -107,8 +141,8 @@ async fn client_enabled_server_disabled() { .unwrap(); }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -119,16 +153,24 @@ async fn client_enabled_server_disabled() { assert_eq!(status.code(), tonic::Code::Unimplemented); assert_eq!( status.message(), - "Content is compressed with `gzip` which isn't supported" + format!( + "Content is compressed with `{}` which isn't supported", + encoding.as_str() + ) ); } -#[tokio::test(flavor = "multi_thread")] -async fn compressing_response_from_client_stream() { +util::parametrized_tests! { + compressing_response_from_client_stream, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn compressing_response_from_client_stream(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -153,14 +195,17 @@ async fn compressing_response_from_client_stream() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let stream = futures::stream::iter(vec![]); let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!( + res.metadata().get("grpc-encoding").unwrap(), + encoding.as_str() + ); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index dd0536091..5d5bb86fe 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -1,19 +1,40 @@ use super::*; -use http_body::Body as _; +use http_body::Body; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); - fn assert_right_encoding(req: http::Request) -> http::Request { - assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); - req + #[derive(Clone)] + pub struct AssertRightEncoding { + encoding: CompressionEncoding, + } + + #[allow(dead_code)] + impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + assert_eq!( + req.headers().get("grpc-encoding").unwrap(), + self.encoding.as_str() + ); + + req + } } tokio::spawn({ @@ -24,7 +45,9 @@ async fn client_enabled_server_enabled() { ServiceBuilder::new() .layer( ServiceBuilder::new() - .map_request(assert_right_encoding) + .map_request(move |req| { + AssertRightEncoding::new(encoding).clone().call(req) + }) .layer(measure_request_body_size_layer(request_bytes_counter)) .into_inner(), ) @@ -37,8 +60,8 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); for _ in 0..3 { client @@ -52,8 +75,14 @@ async fn client_enabled_server_enabled() { } } -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -66,8 +95,8 @@ async fn client_enabled_server_disabled() { .unwrap(); }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); let status = client .compress_input_unary(SomeData { @@ -79,7 +108,10 @@ async fn client_enabled_server_disabled() { assert_eq!(status.code(), tonic::Code::Unimplemented); assert_eq!( status.message(), - "Content is compressed with `gzip` which isn't supported" + format!( + "Content is compressed with `{}` which isn't supported", + encoding.as_str() + ) ); assert_eq!( @@ -87,13 +119,17 @@ async fn client_enabled_server_disabled() { "identity" ); } +parametrized_tests! { + client_mark_compressed_without_header_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} -#[tokio::test(flavor = "multi_thread")] -async fn client_mark_compressed_without_header_server_enabled() { +#[allow(dead_code)] +async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); tokio::spawn({ async move { diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index 5c1cb9fa9..b5a43f284 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -1,12 +1,21 @@ use super::*; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); #[derive(Clone, Copy)] - struct AssertCorrectAcceptEncoding(S); + struct AssertCorrectAcceptEncoding { + service: S, + encoding: CompressionEncoding, + } impl Service> for AssertCorrectAcceptEncoding where @@ -20,20 +29,23 @@ async fn client_enabled_server_enabled() { &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.0.poll_ready(cx) + self.service.poll_ready(cx) } fn call(&mut self, req: http::Request) -> Self::Future { assert_eq!( - req.headers().get("grpc-accept-encoding").unwrap(), - "gzip,identity" + req.headers() + .get("grpc-accept-encoding") + .unwrap() + .to_str() + .unwrap(), + format!("{},identity", self.encoding.as_str()), ); - self.0.call(req) + self.service.call(req) } } - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -43,7 +55,10 @@ async fn client_enabled_server_enabled() { Server::builder() .layer( ServiceBuilder::new() - .layer(layer_fn(AssertCorrectAcceptEncoding)) + .layer(layer_fn(|service| AssertCorrectAcceptEncoding { + service, + encoding, + })) .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, @@ -59,19 +74,28 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); for _ in 0..3 { let res = client.compress_output_unary(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!( + res.metadata().get("grpc-encoding").unwrap(), + encoding.as_str() + ); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } } -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +util::parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -100,8 +124,8 @@ async fn client_enabled_server_disabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_unary(()).await.unwrap(); @@ -111,8 +135,14 @@ async fn client_enabled_server_disabled() { assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_disabled() { +util::parametrized_tests! { + client_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); #[derive(Clone, Copy)] @@ -139,8 +169,7 @@ async fn client_disabled() { } } - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -176,12 +205,17 @@ async fn client_disabled() { assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn server_replying_with_unsupported_encoding() { +util::parametrized_tests! { + server_replying_with_unsupported_encoding, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn server_replying_with_unsupported_encoding(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); fn add_weird_content_encoding(mut response: http::Response) -> http::Response { response @@ -203,8 +237,8 @@ async fn server_replying_with_unsupported_encoding() { .unwrap(); }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let status: Status = client.compress_output_unary(()).await.unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); @@ -214,14 +248,20 @@ async fn server_replying_with_unsupported_encoding() { ); } -#[tokio::test(flavor = "multi_thread")] -async fn disabling_compression_on_single_response() { +util::parametrized_tests! { + disabling_compression_on_single_response, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn disabling_compression_on_single_response(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) - .send_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -246,23 +286,34 @@ async fn disabling_compression_on_single_response() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_unary(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!( + res.metadata().get("grpc-encoding").unwrap(), + encoding.as_str() + ); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn disabling_compression_on_response_but_keeping_compression_on_stream() { +util::parametrized_tests! { + disabling_compression_on_response_but_keeping_compression_on_stream, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn disabling_compression_on_response_but_keeping_compression_on_stream( + encoding: CompressionEncoding, +) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) - .send_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -287,12 +338,15 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_server_stream(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!( + res.metadata().get("grpc-encoding").unwrap(), + encoding.as_str() + ); let mut stream: Streaming = res.into_inner(); @@ -311,14 +365,20 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn disabling_compression_on_response_from_client_stream() { +util::parametrized_tests! { + disabling_compression_on_response_from_client_stream, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn disabling_compression_on_response_from_client_stream(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) - .send_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -343,14 +403,17 @@ async fn disabling_compression_on_response_from_client_stream() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let stream = futures::stream::iter(vec![]); let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!( + res.metadata().get("grpc-encoding").unwrap(), + encoding.as_str() + ); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index 2ec52bb08..6f2b3c593 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -2,12 +2,17 @@ use super::*; use tonic::codec::CompressionEncoding; use tonic::Streaming; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -32,12 +37,15 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_server_stream(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!( + res.metadata().get("grpc-encoding").unwrap(), + encoding.as_str() + ); let mut stream: Streaming = res.into_inner(); @@ -56,12 +64,17 @@ async fn client_enabled_server_enabled() { assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_disabled_server_enabled() { +util::parametrized_tests! { + client_disabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_disabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -102,8 +115,14 @@ async fn client_disabled_server_enabled() { assert!(response_bytes_counter.load(SeqCst) > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +util::parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -131,8 +150,8 @@ async fn client_enabled_server_disabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_server_stream(()).await.unwrap(); diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 34fdc0f3a..8a03a7c0b 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -12,9 +12,26 @@ use std::{ task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tonic::codec::CompressionEncoding; use tonic::transport::{server::Connected, Channel}; use tower_http::map_request_body::MapRequestBodyLayer; +macro_rules! parametrized_tests { + ($fn_name:ident, $($test_name:ident: $input:expr),+ $(,)?) => { + paste::paste! { + $( + #[tokio::test(flavor = "multi_thread")] + async fn [<$fn_name _ $test_name>]() { + let input = $input; + $fn_name(input).await; + } + )+ + } + } +} + +pub(crate) use parametrized_tests; + /// A body that tracks how many bytes passes through it #[pin_project] pub struct CountBytesBody { @@ -100,3 +117,24 @@ pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { .await .unwrap() } + +#[derive(Clone)] +pub struct AssertRightEncoding { + encoding: CompressionEncoding, +} + +#[allow(dead_code)] +impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + assert_eq!( + req.headers().get("grpc-encoding").unwrap(), + self.encoding.as_str() + ); + + req + } +} diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 9f4184af5..aead80b79 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -92,7 +92,7 @@ impl CompressionEncoding { map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Option { - if !enabled_encodings.is_gzip_enabled() { + if !enabled_encodings.is_gzip_enabled() && !enabled_encodings.is_zstd_enabled() { return None; } @@ -153,15 +153,20 @@ impl CompressionEncoding { } } - pub(crate) fn into_header_value(self) -> http::HeaderValue { + #[allow(missing_docs)] + pub fn as_str(&self) -> &'static str { match self { #[cfg(feature = "gzip")] - CompressionEncoding::Gzip => http::HeaderValue::from_static("gzip"), + CompressionEncoding::Gzip => "gzip", #[cfg(feature = "zstd")] - CompressionEncoding::Zstd => http::HeaderValue::from_static("zstd"), + CompressionEncoding::Zstd => "zstd", } } + pub(crate) fn into_header_value(self) -> http::HeaderValue { + http::HeaderValue::from_static(self.as_str()) + } + pub(crate) fn encodings() -> &'static [Self] { &[ #[cfg(feature = "gzip")] From 9ea070e4cd951482fce0f4a91b90bfe4d949dc96 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Sun, 23 Apr 2023 20:00:29 -0700 Subject: [PATCH 3/9] add tests for accepting multiple encodings --- tests/compression/src/compressing_request.rs | 60 +++++++++++++++++++ tests/compression/src/compressing_response.rs | 42 +++++++++++++ tonic/src/codec/compression.rs | 19 +++--- 3 files changed, 110 insertions(+), 11 deletions(-) diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 5d5bb86fe..194403810 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -75,6 +75,66 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { } } +util::parametrized_tests! { + client_enabled_server_enabled_multi_encoding, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + + let svc = test_server::TestServer::new(Svc::default()) + .accept_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Zstd); + + let request_bytes_counter = Arc::new(AtomicUsize::new(0)); + + fn assert_right_encoding(req: http::Request) -> http::Request { + let supported_encodings = ["gzip", "zstd"]; + let req_encoding = req.headers().get("grpc-encoding").unwrap(); + assert!(supported_encodings.iter().any(|e| e == req_encoding)); + + req + } + + tokio::spawn({ + let request_bytes_counter = request_bytes_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer( + ServiceBuilder::new() + .map_request(assert_right_encoding) + .layer(measure_request_body_size_layer(request_bytes_counter)) + .into_inner(), + ) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)])) + .await + .unwrap(); + } + }); + + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); + + for _ in 0..3 { + client + .compress_input_unary(SomeData { + data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), + }) + .await + .unwrap(); + let bytes_sent = request_bytes_counter.load(SeqCst); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); + } +} + parametrized_tests! { client_enabled_server_disabled, zstd: CompressionEncoding::Zstd, diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index b5a43f284..43fa7c7fc 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -135,6 +135,48 @@ async fn client_enabled_server_disabled(encoding: CompressionEncoding) { assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_disabled_multi_encoding() { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + + let svc = test_server::TestServer::new(Svc::default()); + + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let response_bytes_counter = response_bytes_counter.clone(); + async move { + Server::builder() + // no compression enable on the server so responses should not be compressed + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: response_bytes_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)])) + .await + .unwrap(); + } + }); + + let mut client = test_client::TestClient::new(mock_io_channel(client).await) + .accept_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Zstd); + + let res = client.compress_output_unary(()).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); + + let bytes_sent = response_bytes_counter.load(SeqCst); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); +} + util::parametrized_tests! { client_disabled, zstd: CompressionEncoding::Zstd, diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index aead80b79..8e9cfbc1d 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -4,9 +4,6 @@ use bytes::{Buf, BytesMut}; #[cfg(feature = "gzip")] use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; -use std::io::Write; -#[cfg(feature = "zstd")] -use zstd::{Decoder, Encoder}; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; @@ -219,10 +216,12 @@ pub(crate) fn compress( #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { let out_writer = bytes::BufMut::writer(out_buf); - let mut zstd_encoder = Encoder::new(out_writer, 0)?; - - zstd_encoder.write_all(&decompressed_buf[0..len])?; - zstd_encoder.finish()?; + zstd::stream::copy_encode( + &decompressed_buf[0..len], + out_writer, + // Use zstd's default level + 0, + )?; } } @@ -253,10 +252,8 @@ pub(crate) fn decompress( } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { - let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; - let mut out_writer = bytes::BufMut::writer(out_buf); - - std::io::copy(&mut zstd_decoder, &mut out_writer)?; + let out_writer = bytes::BufMut::writer(out_buf); + zstd::stream::copy_decode(&compressed_buf[0..len], out_writer)?; } } From 6a71130940373443aa414fc05892839b82563068 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Sun, 23 Apr 2023 20:32:13 -0700 Subject: [PATCH 4/9] add some missing feature cfg for zstd --- tonic/src/codec/compression.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 8e9cfbc1d..9d5bb259a 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -64,7 +64,7 @@ impl EnabledCompressionEncodings { } #[cfg(not(feature = "zstd"))] - const fn is_gzip_enabled(&self) -> bool { + const fn is_zstd_enabled(&self) -> bool { false } } @@ -127,6 +127,7 @@ impl CompressionEncoding { "gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => { Ok(Some(CompressionEncoding::Gzip)) } + #[cfg(feature = "zstd")] "zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => { Ok(Some(CompressionEncoding::Zstd)) } From 9fb2b4759d5d42245fe800c84c1963e50204f091 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Mon, 24 Apr 2023 17:02:04 -0700 Subject: [PATCH 5/9] make as_str only crate public --- tests/compression/src/bidirectional_stream.rs | 19 ++++--- tests/compression/src/client_stream.rs | 26 ++++++---- tests/compression/src/compressing_request.rs | 16 ++++-- tests/compression/src/compressing_response.rs | 50 ++++++++++++------- tests/compression/src/server_stream.rs | 10 ++-- tonic/src/codec/compression.rs | 2 +- 6 files changed, 79 insertions(+), 44 deletions(-) diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index eca3a8492..e1c1c6896 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -31,10 +31,11 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { } pub fn call(self, req: http::Request) -> http::Request { - assert_eq!( - req.headers().get("grpc-encoding").unwrap(), - self.encoding.as_str() - ); + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); req } @@ -81,10 +82,12 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { .await .unwrap(); - assert_eq!( - res.metadata().get("grpc-encoding").unwrap(), - encoding.as_str() - ); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let mut stream: Streaming = res.into_inner(); diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index de96d7bbc..265ef5944 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -28,10 +28,11 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { } pub fn call(self, req: http::Request) -> http::Request { - assert_eq!( - req.headers().get("grpc-encoding").unwrap(), - self.encoding.as_str() - ); + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); req } @@ -151,11 +152,16 @@ async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let status = client.compress_input_client_stream(req).await.unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; assert_eq!( status.message(), format!( "Content is compressed with `{}` which isn't supported", - encoding.as_str() + expected ) ); } @@ -202,10 +208,12 @@ async fn compressing_response_from_client_stream(encoding: CompressionEncoding) let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); - assert_eq!( - res.metadata().get("grpc-encoding").unwrap(), - encoding.as_str() - ); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 194403810..a95b117a0 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -28,10 +28,11 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { } pub fn call(self, req: http::Request) -> http::Request { - assert_eq!( - req.headers().get("grpc-encoding").unwrap(), - self.encoding.as_str() - ); + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); req } @@ -166,11 +167,16 @@ async fn client_enabled_server_disabled(encoding: CompressionEncoding) { .unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; assert_eq!( status.message(), format!( "Content is compressed with `{}` which isn't supported", - encoding.as_str() + expected ) ); diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index 43fa7c7fc..40638d92a 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -33,13 +33,17 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { } fn call(&mut self, req: http::Request) -> Self::Future { + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + }; assert_eq!( req.headers() .get("grpc-accept-encoding") .unwrap() .to_str() .unwrap(), - format!("{},identity", self.encoding.as_str()), + format!("{},identity", expected) ); self.service.call(req) } @@ -77,12 +81,15 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + for _ in 0..3 { let res = client.compress_output_unary(()).await.unwrap(); - assert_eq!( - res.metadata().get("grpc-encoding").unwrap(), - encoding.as_str() - ); + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } @@ -332,10 +339,14 @@ async fn disabling_compression_on_single_response(encoding: CompressionEncoding) test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_unary(()).await.unwrap(); - assert_eq!( - res.metadata().get("grpc-encoding").unwrap(), - encoding.as_str() - ); + + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } @@ -385,10 +396,12 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream( let res = client.compress_output_server_stream(()).await.unwrap(); - assert_eq!( - res.metadata().get("grpc-encoding").unwrap(), - encoding.as_str() - ); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let mut stream: Streaming = res.into_inner(); @@ -452,10 +465,13 @@ async fn disabling_compression_on_response_from_client_stream(encoding: Compress let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); - assert_eq!( - res.metadata().get("grpc-encoding").unwrap(), - encoding.as_str() - ); + + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index 6f2b3c593..d14372d88 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -42,10 +42,12 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let res = client.compress_output_server_stream(()).await.unwrap(); - assert_eq!( - res.metadata().get("grpc-encoding").unwrap(), - encoding.as_str() - ); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let mut stream: Streaming = res.into_inner(); diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 9d5bb259a..2df286146 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -152,7 +152,7 @@ impl CompressionEncoding { } #[allow(missing_docs)] - pub fn as_str(&self) -> &'static str { + pub(crate) fn as_str(&self) -> &'static str { match self { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => "gzip", From 6ce81c511831f0ae08484b1c0da07ba4e4e35929 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Mon, 19 Jun 2023 19:02:25 -0700 Subject: [PATCH 6/9] make into_accept_encoding_header_value handle all combinations --- tests/compression/src/bidirectional_stream.rs | 1 + tests/compression/src/client_stream.rs | 1 + tests/compression/src/compressing_request.rs | 1 + tests/compression/src/compressing_response.rs | 1 + tests/compression/src/util.rs | 10 ++++++---- tonic/src/codec/compression.rs | 11 +++++------ 6 files changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index e1c1c6896..d612ecf6c 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -34,6 +34,7 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let expected = match self.encoding { CompressionEncoding::Gzip => "gzip", CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), }; assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index 265ef5944..344747b9b 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -31,6 +31,7 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let expected = match self.encoding { CompressionEncoding::Gzip => "gzip", CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), }; assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index a95b117a0..ef06d6986 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -31,6 +31,7 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let expected = match self.encoding { CompressionEncoding::Gzip => "gzip", CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), }; assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index 40638d92a..49808f530 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -36,6 +36,7 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let expected = match self.encoding { CompressionEncoding::Gzip => "gzip", CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), }; assert_eq!( req.headers() diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 8a03a7c0b..ca9fd22d9 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -130,10 +130,12 @@ impl AssertRightEncoding { } pub fn call(self, req: http::Request) -> http::Request { - assert_eq!( - req.headers().get("grpc-encoding").unwrap(), - self.encoding.as_str() - ); + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); req } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 2df286146..fcdc3ade5 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -39,12 +39,11 @@ impl EnabledCompressionEncodings { } pub(crate) fn into_accept_encoding_header_value(self) -> Option { - if self.is_gzip_enabled() { - Some(http::HeaderValue::from_static("gzip,identity")) - } else if self.is_zstd_enabled() { - Some(http::HeaderValue::from_static("zstd,identity")) - } else { - None + match (self.is_gzip_enabled(), self.is_zstd_enabled()) { + (true, false) => Some(http::HeaderValue::from_static("gzip,identity")), + (false, true) => Some(http::HeaderValue::from_static("zstd,identity")), + (true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")), + (false, false) => None, } } From 384e95ec35c0cd5d1a1a8e7c8411db9ee29c2973 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Mon, 19 Jun 2023 19:23:29 -0700 Subject: [PATCH 7/9] make decompress implementation consistent --- tonic/src/codec/compression.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index fcdc3ade5..b73be7843 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -252,8 +252,10 @@ pub(crate) fn decompress( } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { - let out_writer = bytes::BufMut::writer(out_buf); - zstd::stream::copy_decode(&compressed_buf[0..len], out_writer)?; + let mut zstd_decoder = zstd::Decoder::new(&compressed_buf[0..len])?; + let mut out_writer = bytes::BufMut::writer(out_buf); + + std::io::copy(&mut zstd_decoder, &mut out_writer)?; } } From 08393243a2f1e7dc42dd53bf509ee915cb6f43aa Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Fri, 23 Jun 2023 20:24:20 -0700 Subject: [PATCH 8/9] use zstd::stream::read::Encoder --- tonic/src/codec/compression.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index b73be7843..168b9847b 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -4,6 +4,8 @@ use bytes::{Buf, BytesMut}; #[cfg(feature = "gzip")] use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; +#[cfg(feature = "zstd")] +use zstd::stream::read::{Decoder, Encoder}; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; @@ -215,13 +217,14 @@ pub(crate) fn compress( } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { - let out_writer = bytes::BufMut::writer(out_buf); - zstd::stream::copy_encode( + let mut zstd_encoder = Encoder::new( &decompressed_buf[0..len], - out_writer, - // Use zstd's default level + // FIXME: support customizing the compression level 0, )?; + let mut out_writer = bytes::BufMut::writer(out_buf); + + std::io::copy(&mut zstd_encoder, &mut out_writer)?; } } @@ -252,7 +255,7 @@ pub(crate) fn decompress( } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { - let mut zstd_decoder = zstd::Decoder::new(&compressed_buf[0..len])?; + let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; let mut out_writer = bytes::BufMut::writer(out_buf); std::io::copy(&mut zstd_decoder, &mut out_writer)?; From 12653d0131766aa1e977e9df93fd6853f2e6fb34 Mon Sep 17 00:00:00 2001 From: martinabeleda Date: Mon, 26 Jun 2023 19:40:27 -0300 Subject: [PATCH 9/9] use default compression level for zstd --- tonic/src/codec/compression.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 168b9847b..bf94ca3fd 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -202,6 +202,7 @@ pub(crate) fn compress( ) -> Result<(), std::io::Error> { let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buf.reserve(capacity); + let mut out_writer = bytes::BufMut::writer(out_buf); match encoding { #[cfg(feature = "gzip")] @@ -211,8 +212,6 @@ pub(crate) fn compress( // FIXME: support customizing the compression level flate2::Compression::new(6), ); - let mut out_writer = bytes::BufMut::writer(out_buf); - std::io::copy(&mut gzip_encoder, &mut out_writer)?; } #[cfg(feature = "zstd")] @@ -220,10 +219,8 @@ pub(crate) fn compress( let mut zstd_encoder = Encoder::new( &decompressed_buf[0..len], // FIXME: support customizing the compression level - 0, + zstd::DEFAULT_COMPRESSION_LEVEL, )?; - let mut out_writer = bytes::BufMut::writer(out_buf); - std::io::copy(&mut zstd_encoder, &mut out_writer)?; } } @@ -244,20 +241,17 @@ pub(crate) fn decompress( let estimate_decompressed_len = len * 2; let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buf.reserve(capacity); + let mut out_writer = bytes::BufMut::writer(out_buf); match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); - let mut out_writer = bytes::BufMut::writer(out_buf); - std::io::copy(&mut gzip_decoder, &mut out_writer)?; } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; - let mut out_writer = bytes::BufMut::writer(out_buf); - std::io::copy(&mut zstd_decoder, &mut out_writer)?; } }