diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs index ec677f9f4e58d..e99f5fda640e1 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs @@ -27,13 +27,27 @@ internal static QuicException GetOperationAbortedException(string? message = nul return new QuicException(QuicError.OperationAborted, null, message ?? SR.net_quic_operationaborted); } - internal static bool TryGetStreamExceptionForMsQuicStatus(int status, [NotNullWhen(true)] out Exception? exception) + internal static bool TryGetStreamExceptionForMsQuicStatus(int status, [NotNullWhen(true)] out Exception? exception, bool streamWasSuccessfullyStarted = true, string? message = null) { if (status == QUIC_STATUS_ABORTED) { - // If status == QUIC_STATUS_ABORTED, we will receive an event later, which will complete the task source. - exception = null; - return false; + // Connection has been closed by the peer (either at transport or application level), + if (streamWasSuccessfullyStarted) + { + // we will receive an event later, which will complete the stream with concrete + // information why the connection was aborted. + exception = null; + return false; + } + else + { + // we won't be receiving any event callback for shutdown on this stream, so we don't + // necessarily know which error to report. So we throw an exception which we can distinguish + // at the caller (ConnectionAborted normally has App error code) and throw the correct + // exception from there. + exception = new QuicException(QuicError.ConnectionAborted, null, ""); + return true; + } } else if (status == QUIC_STATUS_INVALID_STATE) { @@ -43,13 +57,16 @@ internal static bool TryGetStreamExceptionForMsQuicStatus(int status, [NotNullWh } else if (StatusFailed(status)) { - exception = GetExceptionForMsQuicStatus(status); + exception = GetExceptionForMsQuicStatus(status, message: message); return true; } exception = null; return false; } + // see TryGetStreamExceptionForMsQuicStatus for explanation + internal static bool IsConnectionAbortedWhenStartingStreamException(Exception ex) => ex is QuicException qe && qe.QuicError == QuicError.ConnectionAborted && qe.ApplicationErrorCode is null; + internal static Exception GetExceptionForMsQuicStatus(int status, long? errorCode = default, string? message = null) { Exception ex = GetExceptionInternal(status, errorCode, message); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index a2ade033afe59..13351faaa20c0 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -98,6 +98,11 @@ static async ValueTask StartConnectAsync(QuicClientConnectionOpt /// private int _disposed; + /// + /// Completed when connection shutdown is initiated. + /// + private TaskCompletionSource _connectionCloseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly ValueTaskSource _connectedTcs = new ValueTaskSource(); private readonly ValueTaskSource _shutdownTcs = new ValueTaskSource(); @@ -376,16 +381,22 @@ public async ValueTask OpenOutboundStreamAsync(QuicStreamType type, stream = new QuicStream(_handle, type, _defaultStreamErrorCode); await stream.StartAsync(cancellationToken).ConfigureAwait(false); } - catch + catch (Exception ex) { if (stream is not null) { await stream.DisposeAsync().ConfigureAwait(false); } + + // In case of an incoming race when the connection is closed by the peer just before we open the stream, + // we receive QUIC_STATUS_ABORTED from MsQuic, but we don't know how the connection was closed. We throw + // special exception and handle it here where we can determine the shutdown reason. + bool connectionAbortedByPeer = ThrowHelper.IsConnectionAbortedWhenStartingStreamException(ex); + // Propagate connection error if present. - if (_acceptQueue.Reader.Completion.IsFaulted) + if (_connectionCloseTcs.Task.IsFaulted || connectionAbortedByPeer) { - await _acceptQueue.Reader.Completion.ConfigureAwait(false); + await _connectionCloseTcs.Task.ConfigureAwait(false); } throw; } @@ -475,17 +486,21 @@ private unsafe int HandleEventShutdownInitiatedByTransport(ref SHUTDOWN_INITIATE { Exception exception = ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetExceptionForMsQuicStatus(data.Status, (long)data.ErrorCode)); _connectedTcs.TrySetException(exception); + _connectionCloseTcs.TrySetException(exception); _acceptQueue.Writer.TryComplete(exception); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventShutdownInitiatedByPeer(ref SHUTDOWN_INITIATED_BY_PEER_DATA data) { - _acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetConnectionAbortedException((long)data.ErrorCode))); + Exception exception = ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetConnectionAbortedException((long)data.ErrorCode)); + _connectionCloseTcs.TrySetException(exception); + _acceptQueue.Writer.TryComplete(exception); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventShutdownComplete() { Exception exception = ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException()); + _connectionCloseTcs.TrySetException(exception); _acceptQueue.Writer.TryComplete(exception); _connectedTcs.TrySetException(exception); _shutdownTcs.TrySetResult(); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index 6165f2085cb5f..82ee656dc6bdc 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -161,13 +161,18 @@ internal unsafe QuicStream(MsQuicContextSafeHandle connectionHandle, QuicStreamT try { QUIC_HANDLE* handle; - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamOpen( + int status = MsQuicApi.Api.StreamOpen( connectionHandle, type == QuicStreamType.Unidirectional ? QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL : QUIC_STREAM_OPEN_FLAGS.NONE, &NativeCallback, (void*)GCHandle.ToIntPtr(context), - &handle), - "StreamOpen failed"); + &handle); + + if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? ex, streamWasSuccessfullyStarted: false, message: "StreamOpen failed")) + { + throw ex; + } + _handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Stream, connectionHandle); } catch @@ -241,7 +246,8 @@ internal ValueTask StartAsync(CancellationToken cancellationToken = default) int status = MsQuicApi.Api.StreamStart( _handle, QUIC_STREAM_START_FLAGS.SHUTDOWN_ON_FAIL | QUIC_STREAM_START_FLAGS.INDICATE_PEER_ACCEPT); - if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) + + if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception, streamWasSuccessfullyStarted: false)) { _startedTcs.TrySetException(exception); }