From aab9a79d367363fecf739626ebd973e0d9140132 Mon Sep 17 00:00:00 2001 From: Mat Trudel Date: Mon, 15 Jan 2024 11:37:17 -0500 Subject: [PATCH] Move stream send window handling into stream process --- lib/bandit/http2/adapter.ex | 56 +++++++++++++++++++++------ lib/bandit/http2/connection.ex | 34 ++++++---------- lib/bandit/http2/handler.ex | 11 ++++-- lib/bandit/http2/stream.ex | 54 +++++++++++--------------- lib/bandit/http2/stream_collection.ex | 30 ++------------ lib/bandit/http2/stream_process.ex | 5 +++ test/bandit/http2/protocol_test.exs | 6 +-- 7 files changed, 99 insertions(+), 97 deletions(-) diff --git a/lib/bandit/http2/adapter.ex b/lib/bandit/http2/adapter.ex index 80112813..75e71a45 100644 --- a/lib/bandit/http2/adapter.ex +++ b/lib/bandit/http2/adapter.ex @@ -12,6 +12,7 @@ defmodule Bandit.HTTP2.Adapter do stream_id: nil, end_stream: false, recv_window_size: 65_535, + send_window_size: nil, method: nil, content_encoding: nil, pending_content_length: nil, @@ -26,6 +27,7 @@ defmodule Bandit.HTTP2.Adapter do stream_id: Bandit.HTTP2.Stream.stream_id(), end_stream: boolean(), recv_window_size: non_neg_integer(), + send_window_size: non_neg_integer(), method: Plug.Conn.method() | nil, content_encoding: String.t() | nil, pending_content_length: non_neg_integer() | nil, @@ -33,11 +35,12 @@ defmodule Bandit.HTTP2.Adapter do opts: keyword() } - def init(connection, transport_info, stream_id, opts) do + def init(connection, transport_info, stream_id, send_window_size, opts) do %__MODULE__{ connection: connection, transport_info: transport_info, stream_id: stream_id, + send_window_size: send_window_size, opts: opts } end @@ -280,8 +283,7 @@ defmodule Bandit.HTTP2.Adapter do # stream will have been closed in send_chunked/3 above, and so this call will return an # `{:error, :not_owner}` error here (which we ignore, but it's still kinda odd) validate_calling_process!(adapter) - _ = send_data(adapter, chunk, IO.iodata_length(chunk) == 0) - :ok + {:ok, nil, send_data(adapter, chunk, IO.iodata_length(chunk) == 0)} end @impl Plug.Conn.Adapter @@ -334,17 +336,47 @@ defmodule Bandit.HTTP2.Adapter do end defp send_data(adapter, data, end_stream) do - GenServer.call( - adapter.connection, - {:send_data, adapter.stream_id, data, end_stream}, - :infinity - ) + max_bytes_to_send = max(adapter.send_window_size, 0) + {data_to_send, bytes_to_send, rest} = split_data(data, max_bytes_to_send) - metrics = - adapter.metrics - |> Map.update(:resp_body_bytes, IO.iodata_length(data), &(&1 + IO.iodata_length(data))) + adapter = + if end_stream || bytes_to_send > 0 do + GenServer.call( + adapter.connection, + {:send_data, adapter.stream_id, data_to_send, end_stream && byte_size(rest) == 0}, + :infinity + ) - %{adapter | metrics: metrics} + metrics = + adapter.metrics |> Map.update(:resp_body_bytes, bytes_to_send, &(&1 + bytes_to_send)) + + %{adapter | metrics: metrics, send_window_size: adapter.send_window_size - bytes_to_send} + else + adapter + end + + if byte_size(rest) == 0 do + adapter + else + adapter = + receive do + {:send_window_update, delta} -> + %{adapter | send_window_size: adapter.send_window_size + delta} + end + + send_data(adapter, rest, end_stream) + end + end + + defp split_data(data, desired_length) do + data_length = IO.iodata_length(data) + + if data_length <= desired_length do + {data, data_length, <<>>} + else + <> = IO.iodata_to_binary(data) + {to_send, desired_length, rest} + end end defp split_cookies(headers) do diff --git a/lib/bandit/http2/connection.ex b/lib/bandit/http2/connection.ex index d344179d..37dbae57 100644 --- a/lib/bandit/http2/connection.ex +++ b/lib/bandit/http2/connection.ex @@ -134,16 +134,16 @@ defmodule Bandit.HTTP2.Connection do def handle_frame(%Frame.Settings{ack: false} = frame, socket, connection) do _ = %Frame.Settings{ack: true} |> send_frame(socket, connection) - streams = - connection.streams - |> StreamCollection.update_initial_send_window_size(frame.settings.initial_window_size) - send_hpack_state = HPAX.resize(connection.send_hpack_state, frame.settings.header_table_size) + delta = frame.settings.initial_window_size - connection.remote_settings.initial_window_size + + StreamCollection.get_streams(connection.streams) + |> Enum.map(&Stream.recv_send_window_update(elem(&1, 1), delta)) + do_pending_sends(socket, %{ connection | remote_settings: frame.settings, - streams: streams, send_hpack_state: send_hpack_state }) end @@ -175,20 +175,10 @@ defmodule Bandit.HTTP2.Connection do # Stream-level receiving # - def handle_frame(%Frame.WindowUpdate{} = frame, socket, connection) do + def handle_frame(%Frame.WindowUpdate{} = frame, _socket, connection) do with {:ok, stream} <- StreamCollection.get_stream(connection.streams, frame.stream_id), - {:ok, stream} <- Stream.recv_window_update(stream, frame.size_increment), - {:ok, streams} <- StreamCollection.put_stream(connection.streams, stream) do - do_pending_sends(socket, %{connection | streams: streams}) - else - {:error, {:connection, error_code, error_message}} -> - shutdown_connection(error_code, error_message, socket, connection) - - {:error, {:stream, stream_id, error_code, error_message}} -> - handle_stream_error(stream_id, error_code, error_message, socket, connection) - - {:error, error} -> - shutdown_connection(Errors.internal_error(), error, socket, connection) + {:ok, _stream} <- Stream.recv_send_window_update(stream, frame.size_increment) do + {:continue, connection} end end @@ -287,6 +277,7 @@ defmodule Bandit.HTTP2.Connection do stream, connection.transport_info, connection.telemetry_span, + connection.remote_settings.initial_window_size, headers, end_stream, connection.plug, @@ -418,17 +409,16 @@ defmodule Bandit.HTTP2.Connection do {:ok, t()} | {:error, term()} def send_data(stream_id, data, end_stream, on_unblock, socket, connection) do with {:ok, stream} <- StreamCollection.get_stream(connection.streams, stream_id), - stream_window_size <- Stream.get_send_window_size(stream), connection_window_size <- connection.send_window_size, - max_bytes_to_send <- max(min(stream_window_size, connection_window_size), 0), + max_bytes_to_send <- max(connection_window_size, 0), {data_to_send, bytes_to_send, rest} <- split_data(data, max_bytes_to_send), - {:ok, stream} <- Stream.send_data(stream, bytes_to_send), + {:ok, stream} <- Stream.send_data(stream), connection <- %{connection | send_window_size: connection_window_size - bytes_to_send}, end_stream_to_send <- end_stream && byte_size(rest) == 0, {:ok, stream} <- Stream.send_end_of_stream(stream, end_stream_to_send), {:ok, streams} <- StreamCollection.put_stream(connection.streams, stream) do _ = - if end_stream_to_send || IO.iodata_length(data_to_send) > 0 do + if end_stream_to_send || bytes_to_send > 0 do %Frame.Data{stream_id: stream_id, end_stream: end_stream_to_send, data: data_to_send} |> send_frame(socket, connection) end diff --git a/lib/bandit/http2/handler.ex b/lib/bandit/http2/handler.ex index 3edaf803..e8734988 100644 --- a/lib/bandit/http2/handler.ex +++ b/lib/bandit/http2/handler.ex @@ -78,11 +78,16 @@ defmodule Bandit.HTTP2.Handler do def handle_call({:send_data, stream_id, data, end_stream}, from, {socket, state}) do # In 'normal' cases where there is sufficient space in the send windows for this message to be # sent, Connection will call `unblock` synchronously in the `Connection.send_data` call below. - # In cases where there is not enough space in either / both windows, Connection will call - # `unblock` at some point in the future once space opens up in the relevant window(s). This + # In cases where there is not enough space in the connection window, Connection will call + # `unblock` at some point in the future once space opens up in the window. This # keeps this code simple in that we can blindly send noreply here and let Connection handle # the separate cases. This ensures that we have backpressure all the way back to the - # stream's handler process in the event of window overruns + # stream's handler process in the event of window overruns. + # + # Note that the above only applies to the connection-level send window; stream-level windows + # are managed internally by the stream and are not considered here at all. If the stream has + # managed to send this message, it is because there was enough room in the stream's send + # window to do so. unblock = fn -> GenServer.reply(from, :ok) end case Connection.send_data(stream_id, data, end_stream, unblock, socket, state.connection) do diff --git a/lib/bandit/http2/stream.ex b/lib/bandit/http2/stream.ex index 3a23f048..7699705b 100644 --- a/lib/bandit/http2/stream.ex +++ b/lib/bandit/http2/stream.ex @@ -11,12 +11,11 @@ defmodule Bandit.HTTP2.Stream do require Integer require Logger - alias Bandit.HTTP2.{Connection, Errors, FlowControl, StreamProcess} + alias Bandit.HTTP2.{Connection, Errors, StreamProcess} defstruct stream_id: nil, state: nil, - pid: nil, - send_window_size: nil + pid: nil defmodule StreamError, do: defexception([:message, :method, :request_target, :status]) @@ -33,14 +32,14 @@ defmodule Bandit.HTTP2.Stream do @type t :: %__MODULE__{ stream_id: stream_id(), state: state(), - pid: pid() | nil, - send_window_size: non_neg_integer() + pid: pid() | nil } @spec recv_headers( t(), Bandit.TransportInfo.t(), ThousandIsland.Telemetry.t(), + non_neg_integer(), Plug.Conn.headers(), boolean, Bandit.Pipeline.plug_def(), @@ -50,6 +49,7 @@ defmodule Bandit.HTTP2.Stream do %__MODULE__{state: state} = stream, _transport_info, _connection_span, + _initial_send_window_size, trailers, true, _plug, @@ -68,13 +68,21 @@ defmodule Bandit.HTTP2.Stream do %__MODULE__{state: :idle} = stream, transport_info, connection_span, + initial_send_window_size, headers, _end_stream, plug, opts ) do with :ok <- stream_id_is_valid_client(stream.stream_id), - req <- Bandit.HTTP2.Adapter.init(self(), transport_info, stream.stream_id, opts), + req <- + Bandit.HTTP2.Adapter.init( + self(), + transport_info, + stream.stream_id, + initial_send_window_size, + opts + ), {:ok, pid} <- StreamProcess.start_link(req, transport_info, headers, plug, connection_span) do {:ok, %{stream | state: :open, pid: pid}} @@ -88,6 +96,7 @@ defmodule Bandit.HTTP2.Stream do %__MODULE__{}, _transport_info, _connection_span, + _initial_send_window_size, _headers, _end_stream, _plug, @@ -125,20 +134,15 @@ defmodule Bandit.HTTP2.Stream do {:error, {:connection, Errors.protocol_error(), "Received DATA when in #{stream.state}"}} end - @spec recv_window_update(t(), non_neg_integer()) :: + @spec recv_send_window_update(t(), non_neg_integer()) :: {:ok, t()} | {:error, Connection.error()} | {:error, error()} - def recv_window_update(%__MODULE__{state: :idle}, _increment) do + def recv_send_window_update(%__MODULE__{state: :idle}, _increment) do {:error, {:connection, Errors.protocol_error(), "Received WINDOW_UPDATE when in idle"}} end - def recv_window_update(%__MODULE__{} = stream, increment) do - case FlowControl.update_send_window(stream.send_window_size, increment) do - {:ok, new_window} -> - {:ok, %{stream | send_window_size: new_window}} - - {:error, error} -> - {:error, {:stream, stream.stream_id, Errors.flow_control_error(), error}} - end + def recv_send_window_update(%__MODULE__{} = stream, increment) do + if is_pid(stream.pid), do: StreamProcess.recv_send_window_update(stream.pid, increment) + {:ok, stream} end @spec recv_rst_stream(t(), Errors.error_code()) :: @@ -172,9 +176,6 @@ defmodule Bandit.HTTP2.Stream do {:ok, stream} end - @spec get_send_window_size(t()) :: non_neg_integer() - def get_send_window_size(%__MODULE__{} = stream), do: stream.send_window_size - @spec send_headers(t()) :: {:ok, t()} | {:error, :invalid_state} def send_headers(%__MODULE__{state: state} = stream) when state in [:open, :remote_closed] do {:ok, stream} @@ -184,21 +185,12 @@ defmodule Bandit.HTTP2.Stream do {:error, :invalid_state} end - @spec send_data(t(), non_neg_integer()) :: - {:ok, t()} | {:error, :insufficient_window_size} | {:error, :invalid_state} - def send_data(%__MODULE__{state: state} = stream, 0) when state in [:open, :remote_closed] do + @spec send_data(t()) :: {:ok, t()} | {:error, :invalid_state} + def send_data(%__MODULE__{state: state} = stream) when state in [:open, :remote_closed] do {:ok, stream} end - def send_data(%__MODULE__{state: state} = stream, len) when state in [:open, :remote_closed] do - if len <= stream.send_window_size do - {:ok, %{stream | send_window_size: stream.send_window_size - len}} - else - {:error, :insufficient_window_size} - end - end - - def send_data(%__MODULE__{}, _len) do + def send_data(%__MODULE__{}) do {:error, :invalid_state} end diff --git a/lib/bandit/http2/stream_collection.ex b/lib/bandit/http2/stream_collection.ex index d8cd2da8..be7cbe05 100644 --- a/lib/bandit/http2/stream_collection.ex +++ b/lib/bandit/http2/stream_collection.ex @@ -8,37 +8,20 @@ defmodule Bandit.HTTP2.StreamCollection do alias Bandit.HTTP2.Stream - defstruct initial_send_window_size: 65_535, - last_local_stream_id: 0, + defstruct last_local_stream_id: 0, last_remote_stream_id: 0, stream_count: 0, streams: %{} @typedoc "A collection of Stream structs, accessible by id or pid" @type t :: %__MODULE__{ - initial_send_window_size: non_neg_integer(), last_remote_stream_id: Stream.stream_id(), last_local_stream_id: Stream.stream_id(), streams: %{Stream.stream_id() => Stream.t()} } - @spec update_initial_send_window_size(t(), non_neg_integer()) :: t() - def update_initial_send_window_size(collection, initial_send_window_size) do - delta = initial_send_window_size - collection.initial_send_window_size - - streams = - collection.streams - |> Enum.map(fn - {id, %Stream{state: state} = stream} when state in [:open, :remote_closed] -> - {id, %{stream | send_window_size: stream.send_window_size + delta}} - - {id, stream} -> - {id, stream} - end) - |> Map.new() - - %{collection | streams: streams, initial_send_window_size: initial_send_window_size} - end + @spec get_streams(t()) :: [Stream.t()] + def get_streams(collection), do: collection.streams @spec get_stream(t(), Stream.stream_id()) :: {:ok, Stream.t()} def get_stream(collection, stream_id) do @@ -55,12 +38,7 @@ defmodule Bandit.HTTP2.StreamCollection do :idle end - {:ok, - %Stream{ - stream_id: stream_id, - state: state, - send_window_size: collection.initial_send_window_size - }} + {:ok, %Stream{stream_id: stream_id, state: state}} end end diff --git a/lib/bandit/http2/stream_process.ex b/lib/bandit/http2/stream_process.ex index 5db35d54..c0268d12 100644 --- a/lib/bandit/http2/stream_process.ex +++ b/lib/bandit/http2/stream_process.ex @@ -41,6 +41,11 @@ defmodule Bandit.HTTP2.StreamProcess do @spec recv_data(pid(), iodata()) :: :ok | :noconnect | :nosuspend def recv_data(pid, data), do: send(pid, {:data, data}) + # Let the stream process know that the stream's send window has changed. The other half of this + # flow can be found in `Bandit.HTTP2.Adapter.send_resp/4` and friends + @spec recv_send_window_update(pid(), non_neg_integer()) :: :ok | :noconnect | :nosuspend + def recv_send_window_update(pid, delta), do: send(pid, {:send_window_update, delta}) + # Let the stream process know that the client has set the end of stream flag. The other half of # this flow can be found in `Bandit.HTTP2.Adapter.read_req_body/2` @spec recv_end_of_stream(pid()) :: :ok | :noconnect | :nosuspend diff --git a/test/bandit/http2/protocol_test.exs b/test/bandit/http2/protocol_test.exs index 436b557a..a027493f 100644 --- a/test/bandit/http2/protocol_test.exs +++ b/test/bandit/http2/protocol_test.exs @@ -1918,12 +1918,12 @@ defmodule HTTP2ProtocolTest do assert SimpleH2Client.recv_body(socket) == {:ok, 3, false, String.duplicate("D", 16_383)} # Grow the stream windows such that we expect to see 100 bytes from 1 and 50 bytes from - # 3 (note that 1 is queued at a higher priority than 3 due to FIFO ordering) Also note that - # we receive end_of_stream on stream 1 here - SimpleH2Client.send_window_update(socket, 3, 100) + # 3. Also note that we receive end_of_stream on stream 1 here SimpleH2Client.send_window_update(socket, 1, 100) SimpleH2Client.send_window_update(socket, 0, 150) assert SimpleH2Client.recv_body(socket) == {:ok, 1, true, "d" <> String.duplicate("e", 99)} + + SimpleH2Client.send_window_update(socket, 3, 100) assert SimpleH2Client.recv_body(socket) == {:ok, 3, false, "D" <> String.duplicate("E", 49)} # Finally grow our connection window and verify we get the last of stream 3