diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 4ad2fca1f4..5f7b836d5d 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -32,6 +32,7 @@ const MAX_BUF_LIST_BUFFERS: usize = 16; pub(crate) struct Buffered { flush_pipeline: bool, io: T, + partial_len: Option, read_blocked: bool, read_buf: BytesMut, read_buf_strategy: ReadStrategy, @@ -65,6 +66,7 @@ where Buffered { flush_pipeline: false, io, + partial_len: None, read_blocked: false, read_buf: BytesMut::with_capacity(0), read_buf_strategy: ReadStrategy::default(), @@ -176,6 +178,7 @@ where loop { match super::role::parse_headers::( &mut self.read_buf, + self.partial_len, ParseContext { cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, @@ -191,14 +194,19 @@ where )? { Some(msg) => { debug!("parsed {} headers", msg.head.headers.len()); + self.partial_len = None; return Poll::Ready(Ok(msg)); } None => { let max = self.read_buf_strategy.max(); - if self.read_buf.len() >= max { + let curr_len = self.read_buf.len(); + if curr_len >= max { debug!("max_buf_size ({}) reached, closing", max); return Poll::Ready(Err(crate::Error::new_too_large())); } + if curr_len > 0 { + self.partial_len = Some(curr_len); + } } } if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index e5a8872111..4f04acec96 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -66,6 +66,7 @@ macro_rules! maybe_panic { pub(super) fn parse_headers( bytes: &mut BytesMut, + prev_len: Option, ctx: ParseContext<'_>, ) -> ParseResult where @@ -78,9 +79,37 @@ where let _entered = trace_span!("parse_headers"); + if let Some(prev_len) = prev_len { + if !is_complete_fast(bytes, prev_len) { + return Ok(None); + } + } + T::parse(bytes, ctx) } +/// A fast scan for the end of a message. +/// Used when there was a partial read, to skip full parsing on a +/// a slow connection. +fn is_complete_fast(bytes: &[u8], prev_len: usize) -> bool { + let start = if prev_len < 3 { 0 } else { prev_len - 3 }; + let bytes = &bytes[start..]; + + for (i, b) in bytes.iter().copied().enumerate() { + if b == b'\r' { + if bytes[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) { + return true; + } + } else if b == b'\n' { + if bytes.get(i + 1) == Some(&b'\n') { + return true; + } + } + } + + false +} + pub(super) fn encode_headers( enc: Encode<'_, T::Outgoing>, dst: &mut Vec, @@ -2827,6 +2856,28 @@ mod tests { parse(Some(200), 210, false); } + #[test] + fn test_is_complete_fast() { + let s = b"GET / HTTP/1.1\r\na: b\r\n\r\n"; + for n in 0..s.len() { + assert!(is_complete_fast(s, n), "{:?}; {}", s, n); + } + let s = b"GET / HTTP/1.1\na: b\n\n"; + for n in 0..s.len() { + assert!(is_complete_fast(s, n)); + } + + // Not + let s = b"GET / HTTP/1.1\r\na: b\r\n\r"; + for n in 0..s.len() { + assert!(!is_complete_fast(s, n)); + } + let s = b"GET / HTTP/1.1\na: b\n"; + for n in 0..s.len() { + assert!(!is_complete_fast(s, n)); + } + } + #[test] fn test_write_headers_orig_case_empty_value() { let mut headers = HeaderMap::new();