From d27d690848e487b8996948e3fadafefdf5a19fa3 Mon Sep 17 00:00:00 2001 From: Mat Trudel Date: Sat, 13 Jan 2024 14:37:06 -0500 Subject: [PATCH] Move reset handling into stream process --- lib/bandit/http2/adapter.ex | 4 ++++ lib/bandit/http2/connection.ex | 21 ++++++++++-------- lib/bandit/http2/handler.ex | 7 +++++- lib/bandit/http2/stream.ex | 12 +++++----- lib/bandit/http2/stream_process.ex | 34 ++++++++++++++++++++++++++--- test/bandit/http2/plug_test.exs | 10 +++++++++ test/bandit/http2/protocol_test.exs | 13 ----------- 7 files changed, 69 insertions(+), 32 deletions(-) diff --git a/lib/bandit/http2/adapter.ex b/lib/bandit/http2/adapter.ex index d782b883..fe397399 100644 --- a/lib/bandit/http2/adapter.ex +++ b/lib/bandit/http2/adapter.ex @@ -349,4 +349,8 @@ defmodule Bandit.HTTP2.Adapter do raise "Adapter functions may only be called by the stream owner" end end + + def send_rst_stream(adapter, error_code) do + GenServer.call(adapter.connection, {:send_rst_stream, adapter.stream_id, error_code}) + end end diff --git a/lib/bandit/http2/connection.ex b/lib/bandit/http2/connection.ex index fd71076f..5fc5cc1b 100644 --- a/lib/bandit/http2/connection.ex +++ b/lib/bandit/http2/connection.ex @@ -460,17 +460,20 @@ defmodule Bandit.HTTP2.Connection do end end - @spec stream_terminated(pid(), term(), Socket.t(), t()) :: {:ok, t()} | {:error, term()} - def stream_terminated(pid, reason, socket, connection) do + @spec send_rst_stream(Stream.stream_id(), Errors.error_code(), Socket.t(), t()) :: :ok + def send_rst_stream(stream_id, error_code, socket, connection) do + _ = + %Frame.RstStream{stream_id: stream_id, error_code: error_code} + |> send_frame(socket, connection) + + :ok + end + + @spec stream_terminated(pid(), term(), t()) :: {:ok, t()} | {:error, term()} + def stream_terminated(pid, reason, connection) do with {:ok, stream} <- StreamCollection.get_active_stream_by_pid(connection.streams, pid), - {:ok, stream, error_code} <- Stream.stream_terminated(stream, reason), + {:ok, stream} <- Stream.stream_terminated(stream, reason), {:ok, streams} <- StreamCollection.put_stream(connection.streams, stream) do - _ = - if !is_nil(error_code) do - %Frame.RstStream{stream_id: stream.stream_id, error_code: error_code} - |> send_frame(socket, connection) - end - {:ok, %{connection | streams: streams}} end end diff --git a/lib/bandit/http2/handler.ex b/lib/bandit/http2/handler.ex index 0c7379a5..a897ca60 100644 --- a/lib/bandit/http2/handler.ex +++ b/lib/bandit/http2/handler.ex @@ -94,8 +94,13 @@ defmodule Bandit.HTTP2.Handler do end end + def handle_call({:send_rst_stream, stream_id, error_code}, _from, {socket, state}) do + Connection.send_rst_stream(stream_id, error_code, socket, state.connection) + {:reply, :ok, {socket, state}, socket.read_timeout} + end + def handle_info({:EXIT, pid, reason}, {socket, state}) do - case Connection.stream_terminated(pid, reason, socket, state.connection) do + case Connection.stream_terminated(pid, reason, state.connection) do {:ok, connection} -> {:noreply, {socket, %{state | connection: connection}}, socket.read_timeout} diff --git a/lib/bandit/http2/stream.ex b/lib/bandit/http2/stream.ex index 94ed31ed..b3121d74 100644 --- a/lib/bandit/http2/stream.ex +++ b/lib/bandit/http2/stream.ex @@ -246,18 +246,18 @@ defmodule Bandit.HTTP2.Stream do :ok end - @spec stream_terminated(t(), term()) :: {:ok, t(), Errors.error_code() | nil} + @spec stream_terminated(t(), term()) :: {:ok, t()} def stream_terminated(%__MODULE__{state: :closed} = stream, :normal) do # In the normal case, stop telemetry is emitted by the stream process to keep the main # connection process unblocked. In error cases we send from here, however, since there are # many error cases which never involve the stream process at all - {:ok, %{stream | state: :closed, pid: nil}, nil} + {:ok, %{stream | state: :closed, pid: nil}} end def stream_terminated(%__MODULE__{} = stream, {:bandit, reason}) do Bandit.Telemetry.stop_span(stream.span, %{}, %{error: reason}) Logger.warning("Stream #{stream.stream_id} was killed by bandit (#{reason})") - {:ok, %{stream | state: :closed, pid: nil}, nil} + {:ok, %{stream | state: :closed, pid: nil}} end def stream_terminated(%__MODULE__{} = stream, {%StreamError{} = error, _}) do @@ -269,12 +269,12 @@ defmodule Bandit.HTTP2.Stream do }) Logger.warning("Stream #{stream.stream_id} encountered a stream error (#{inspect(error)})") - {:ok, %{stream | state: :closed, pid: nil}, Errors.protocol_error()} + {:ok, %{stream | state: :closed, pid: nil}} end def stream_terminated(%__MODULE__{} = stream, :normal) do Logger.warning("Stream #{stream.stream_id} completed in unexpected state #{stream.state}") - {:ok, %{stream | state: :closed, pid: nil}, Errors.no_error()} + {:ok, %{stream | state: :closed, pid: nil}} end def stream_terminated(%__MODULE__{} = stream, reason) do @@ -288,6 +288,6 @@ defmodule Bandit.HTTP2.Stream do Logger.error("Process for stream #{stream.stream_id} crashed with #{inspect(reason)}") - {:ok, %{stream | state: :closed, pid: nil}, Errors.internal_error()} + {:ok, %{stream | state: :closed, pid: nil}} end end diff --git a/lib/bandit/http2/stream_process.ex b/lib/bandit/http2/stream_process.ex index f128b0a1..35c001d2 100644 --- a/lib/bandit/http2/stream_process.ex +++ b/lib/bandit/http2/stream_process.ex @@ -15,6 +15,8 @@ defmodule Bandit.HTTP2.StreamProcess do use GenServer, restart: :temporary + alias Bandit.HTTP2.{Adapter, Errors, Stream} + # A stream process can be created only once we have an adapter & set of headers. Pass them in # at creation time to ensure this invariant @spec start_link( @@ -25,7 +27,13 @@ defmodule Bandit.HTTP2.StreamProcess do Bandit.Telemetry.t() ) :: GenServer.on_start() def start_link(req, transport_info, headers, plug, span) do - GenServer.start_link(__MODULE__, {req, transport_info, headers, plug, span}) + GenServer.start_link(__MODULE__, %{ + req: req, + transport_info: transport_info, + headers: headers, + plug: plug, + span: span + }) end # Let the stream process know that body data has arrived from the client. The other half of this @@ -43,11 +51,19 @@ defmodule Bandit.HTTP2.StreamProcess do @spec recv_rst_stream(pid(), Bandit.HTTP2.Errors.error_code()) :: true def recv_rst_stream(pid, error_code), do: Process.exit(pid, {:recv_rst_stream, error_code}) + @impl GenServer def init(state) do {:ok, state, {:continue, :run}} end - def handle_continue(:run, {req, transport_info, all_headers, plug, span}) do + @impl GenServer + def handle_continue(:run, %{ + req: req, + transport_info: transport_info, + headers: all_headers, + plug: plug, + span: span + }) do req = %{req | owner_pid: self()} with {:ok, request_target} <- build_request_target(all_headers), @@ -77,7 +93,8 @@ defmodule Bandit.HTTP2.StreamProcess do status: conn.status }) - {:stop, :normal, {req, transport_info, all_headers, plug, span}} + {:stop, :normal, + %{req: req, transport_info: transport_info, headers: all_headers, plug: plug, span: span}} else {:error, reason} -> raise Bandit.HTTP2.Stream.StreamError, @@ -194,4 +211,15 @@ defmodule Bandit.HTTP2.StreamProcess do combined_cookie = Enum.map_join(crumbs, "; ", fn {"cookie", crumb} -> crumb end) [{"cookie", combined_cookie} | other_headers] end + + @impl GenServer + def terminate(:normal, _state), do: :ok + + def terminate({%Stream.StreamError{}, _stacktrace}, state) do + Adapter.send_rst_stream(state.req, Errors.protocol_error()) + end + + def terminate(_reason, state) do + Adapter.send_rst_stream(state.req, Errors.internal_error()) + end end diff --git a/test/bandit/http2/plug_test.exs b/test/bandit/http2/plug_test.exs index b12980b3..a601c0bb 100644 --- a/test/bandit/http2/plug_test.exs +++ b/test/bandit/http2/plug_test.exs @@ -321,6 +321,8 @@ defmodule HTTP2PlugTest do assert {:error, %Mint.HTTPError{reason: {:server_closed_request, :internal_error}}} = response + + Process.sleep(100) end) assert errors =~ @@ -341,6 +343,8 @@ defmodule HTTP2PlugTest do assert {:error, %Mint.HTTPError{reason: {:server_closed_request, :internal_error}}} = response + + Process.sleep(100) end) assert errors =~ @@ -356,6 +360,8 @@ defmodule HTTP2PlugTest do assert {:error, %Mint.HTTPError{reason: {:server_closed_request, :internal_error}}} = response + + Process.sleep(100) end) assert errors =~ @@ -405,6 +411,8 @@ defmodule HTTP2PlugTest do assert {:error, %Mint.HTTPError{reason: {:server_closed_request, :internal_error}}} = response + + Process.sleep(100) end) assert errors =~ @@ -894,6 +902,8 @@ defmodule HTTP2PlugTest do Req.get(context.req, url: "/raise_error") + Process.sleep(100) + assert Bandit.TelemetryCollector.get_events(collector_pid) ~> [ {[:bandit, :request, :exception], %{monotonic_time: integer()}, diff --git a/test/bandit/http2/protocol_test.exs b/test/bandit/http2/protocol_test.exs index a1d1f320..e4d49589 100644 --- a/test/bandit/http2/protocol_test.exs +++ b/test/bandit/http2/protocol_test.exs @@ -1565,19 +1565,6 @@ defmodule HTTP2ProtocolTest do end describe "RST_STREAM frames" do - @tag capture_log: true - test "sends RST_FRAME with no error if stream task ends without closed stream", context do - socket = SimpleH2Client.setup_connection(context) - - # Send headers with end_stream bit cleared - SimpleH2Client.send_simple_headers(socket, 1, :post, "/body_response", context.port) - SimpleH2Client.recv_headers(socket) - SimpleH2Client.recv_body(socket) - - assert SimpleH2Client.recv_rst_stream(socket) == {:ok, 1, 0} - assert SimpleH2Client.connection_alive?(socket) - end - @tag capture_log: true test "sends RST_FRAME with error if stream task crashes", context do socket = SimpleH2Client.setup_connection(context)