Skip to content

Commit

Permalink
Implement straightforward ServicePoint(Manager) properties in HttpWeb…
Browse files Browse the repository at this point in the history
…Request (#94664)

* Implement TcpKeepAlive and some properties from ServicePoint

* Delete unicast set socket option

* Review feedback

* Delete else case

* Make TcpKeepAlive reference type

* Delete unnecessary usings

* Compile error fix

* Review Feedback

* Add Expect100Continue Support

* Add ContinueTimeout Tests

* Correct test cases naming

* Review feedback

* Add Expect 100 Continue header tests

* Apply suggestions from code review

Co-authored-by: Anton Firszov <[email protected]>

* Review feedback

* Apply suggestions from code review

Co-authored-by: Miha Zupan <[email protected]>

* Review feedback

---------

Co-authored-by: Anton Firszov <[email protected]>
Co-authored-by: Miha Zupan <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2023
1 parent 1473dea commit f0a6dbd
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
<Compile Include="System\Net\ServicePoint\SecurityProtocolType.cs" />
<Compile Include="System\Net\ServicePoint\ServicePoint.cs" />
<Compile Include="System\Net\ServicePoint\ServicePointManager.cs" />
<Compile Include="System\Net\ServicePoint\TcpKeepAlive.cs" />
<Compile Include="$(CommonPath)System\Obsoletions.cs"
Link="Common\System\Obsoletions.cs" />
<Compile Include="$(CommonPath)System\Net\Http\HttpHandlerDefaults.cs"
Expand Down
34 changes: 30 additions & 4 deletions src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ public class HttpWebRequest : WebRequest, ISerializable
private readonly Uri _requestUri = null!;
private string _originVerb = HttpMethod.Get.Method;

// We allow getting and setting this (to preserve app-compat). But we don't do anything with it
// as the underlying System.Net.Http API doesn't support it.
private int _continueTimeout = DefaultContinueTimeout;

private bool _allowReadStreamBuffering;
Expand Down Expand Up @@ -115,6 +113,8 @@ private sealed class HttpClientParameters
public readonly RemoteCertificateValidationCallback? ServerCertificateValidationCallback;
public readonly X509CertificateCollection? ClientCertificates;
public readonly CookieContainer? CookieContainer;
public readonly ServicePoint? ServicePoint;
public readonly TimeSpan ContinueTimeout;

public HttpClientParameters(HttpWebRequest webRequest, bool async)
{
Expand All @@ -135,6 +135,8 @@ public HttpClientParameters(HttpWebRequest webRequest, bool async)
ServerCertificateValidationCallback = webRequest.ServerCertificateValidationCallback ?? ServicePointManager.ServerCertificateValidationCallback;
ClientCertificates = webRequest._clientCertificates;
CookieContainer = webRequest._cookieContainer;
ServicePoint = webRequest._servicePoint;
ContinueTimeout = TimeSpan.FromMilliseconds(webRequest.ContinueTimeout);
}

public bool Matches(HttpClientParameters requestParameters)
Expand All @@ -149,11 +151,13 @@ public bool Matches(HttpClientParameters requestParameters)
&& Timeout == requestParameters.Timeout
&& SslProtocols == requestParameters.SslProtocols
&& CheckCertificateRevocationList == requestParameters.CheckCertificateRevocationList
&& ContinueTimeout == requestParameters.ContinueTimeout
&& ReferenceEquals(Credentials, requestParameters.Credentials)
&& ReferenceEquals(Proxy, requestParameters.Proxy)
&& ReferenceEquals(ServerCertificateValidationCallback, requestParameters.ServerCertificateValidationCallback)
&& ReferenceEquals(ClientCertificates, requestParameters.ClientCertificates)
&& ReferenceEquals(CookieContainer, requestParameters.CookieContainer);
&& ReferenceEquals(CookieContainer, requestParameters.CookieContainer)
&& ReferenceEquals(ServicePoint, requestParameters.ServicePoint);
}

public bool AreParametersAcceptableForCaching()
Expand All @@ -162,7 +166,8 @@ public bool AreParametersAcceptableForCaching()
&& ReferenceEquals(Proxy, DefaultWebProxy)
&& ServerCertificateValidationCallback == null
&& ClientCertificates == null
&& CookieContainer == null;
&& CookieContainer == null
&& ServicePoint == null;
}
}

Expand Down Expand Up @@ -1178,6 +1183,11 @@ private async Task<WebResponse> SendRequest(bool async)
request.Headers.ConnectionClose = true;
}

if (_servicePoint?.Expect100Continue == true)
{
request.Headers.ExpectContinue = true;
}

request.Version = ProtocolVersion;

_sendRequestTask = async ?
Expand Down Expand Up @@ -1598,6 +1608,7 @@ private static HttpClient CreateHttpClient(HttpClientParameters parameters, Http
handler.MaxAutomaticRedirections = parameters.MaximumAutomaticRedirections;
handler.MaxResponseHeadersLength = parameters.MaximumResponseHeadersLength;
handler.PreAuthenticate = parameters.PreAuthenticate;
handler.Expect100ContinueTimeout = parameters.ContinueTimeout;
client.Timeout = parameters.Timeout;

if (parameters.CookieContainer != null)
Expand Down Expand Up @@ -1660,6 +1671,21 @@ private static HttpClient CreateHttpClient(HttpClientParameters parameters, Http

try
{
if (parameters.ServicePoint is { } servicePoint)
{
if (servicePoint.ReceiveBufferSize != -1)
{
socket.ReceiveBufferSize = servicePoint.ReceiveBufferSize;
}

if (servicePoint.KeepAlive is { } keepAlive)
{
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, keepAlive.Time);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, keepAlive.Interval);
}
}

socket.NoDelay = true;

if (parameters.Async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ public class ServicePoint
private int _receiveBufferSize = -1;
private int _connectionLimit;

internal TcpKeepAlive? KeepAlive { get; set; }

internal ServicePoint(Uri address)
{
Debug.Assert(address != null);
Expand Down Expand Up @@ -87,11 +89,20 @@ public int ConnectionLimit

public void SetTcpKeepAlive(bool enabled, int keepAliveTime, int keepAliveInterval)
{
if (enabled)
if (!enabled)
{
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveTime);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveInterval);
KeepAlive = null;
return;
}

ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveTime);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveInterval);

KeepAlive = new TcpKeepAlive
{
Time = keepAliveTime,
Interval = keepAliveInterval
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ private static void ValidateSecurityProtocol(SecurityProtocolType value)
}
}

internal static TcpKeepAlive? KeepAlive { get; private set; }

public static int MaxServicePoints
{
get { return s_maxServicePoints; }
Expand Down Expand Up @@ -153,7 +155,8 @@ public static ServicePoint FindServicePoint(Uri address, IWebProxy? proxy)
ConnectionLimit = DefaultConnectionLimit,
IdleSince = DateTime.Now,
Expect100Continue = Expect100Continue,
UseNagleAlgorithm = UseNagleAlgorithm
UseNagleAlgorithm = UseNagleAlgorithm,
KeepAlive = KeepAlive
};
s_servicePointTable[tableKey] = new WeakReference<ServicePoint>(sp);

Expand Down Expand Up @@ -208,11 +211,20 @@ private static string MakeQueryString(Uri address, bool isProxy)

public static void SetTcpKeepAlive(bool enabled, int keepAliveTime, int keepAliveInterval)
{
if (enabled)
if (!enabled)
{
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveTime);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveInterval);
KeepAlive = null;
return;
}

ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveTime);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(keepAliveInterval);

KeepAlive = new TcpKeepAlive
{
Time = keepAliveTime,
Interval = keepAliveInterval
};
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace System.Net
{
internal sealed class TcpKeepAlive
{
internal int Time { get; set; }
internal int Interval { get; set; }
}
}
91 changes: 91 additions & 0 deletions src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Linq;
using System.Net.Cache;
using System.Net.Http;
using System.Net.Http.Functional.Tests;
using System.Net.Sockets;
using System.Net.Test.Common;
using System.Runtime.Serialization.Formatters.Binary;
Expand Down Expand Up @@ -2076,6 +2077,96 @@ await server.AcceptConnectionAsync(async connection =>
});
}

[Fact]
public async Task SendHttpPostRequest_WithContinueTimeoutAndBody_BodyIsDelayed()
{
await LoopbackServer.CreateClientAndServerAsync(
async (uri) =>
{
HttpWebRequest request = WebRequest.CreateHttp(uri);
request.Method = "POST";
request.ServicePoint.Expect100Continue = true;
request.ContinueTimeout = 30000;
Stream requestStream = await request.GetRequestStreamAsync();
requestStream.Write("aaaa\r\n\r\n"u8);
await request.GetResponseAsync();
},
async (server) =>
{
await server.AcceptConnectionAsync(async (client) =>
{
await client.ReadRequestHeaderAsync();
// This should time out, because we're expecting the body itself but we'll get it after 30 sec.
await Assert.ThrowsAsync<TimeoutException>(() => client.ReadLineAsync().WaitAsync(TimeSpan.FromMilliseconds(100)));
await client.SendResponseAsync();
});
}
);
}

[Theory]
[InlineData(true, 1)]
[InlineData(false, 30000)]
public async Task SendHttpPostRequest_WithContinueTimeoutAndBody_Success(bool expect100Continue, int continueTimeout)
{
await LoopbackServer.CreateClientAndServerAsync(
async (uri) =>
{
HttpWebRequest request = WebRequest.CreateHttp(uri);
request.Method = "POST";
request.ServicePoint.Expect100Continue = expect100Continue;
request.ContinueTimeout = continueTimeout;
Stream requestStream = await request.GetRequestStreamAsync();
requestStream.Write("aaaa\r\n\r\n"u8);
await request.GetResponseAsync();
},
async (server) =>
{
await server.AcceptConnectionAsync(async (client) =>
{
await client.ReadRequestHeaderAsync();
// This should not time out, because we're expecting the body itself and we should get it after 1 sec.
string data = await client.ReadLineAsync().WaitAsync(TimeSpan.FromSeconds(10));
Assert.StartsWith("aaaa", data);
await client.SendResponseAsync();
});
});
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task SendHttpPostRequest_When100ContinueSet_ReceivedByServer(bool expect100Continue)
{
await LoopbackServer.CreateClientAndServerAsync(
async (uri) =>
{
HttpWebRequest request = WebRequest.CreateHttp(uri);
request.Method = "POST";
request.ServicePoint.Expect100Continue = expect100Continue;
await request.GetResponseAsync();
},
async (server) =>
{
await server.AcceptConnectionAsync(
async (client) =>
{
List<string> headers = await client.ReadRequestHeaderAsync();
if (expect100Continue)
{
Assert.Contains("Expect: 100-continue", headers);
}
else
{
Assert.DoesNotContain("Expect: 100-continue", headers);
}
await client.SendResponseAsync();
}
);
}
);
}

private void RequestStreamCallback(IAsyncResult asynchronousResult)
{
RequestState state = (RequestState)asynchronousResult.AsyncState;
Expand Down

0 comments on commit f0a6dbd

Please sign in to comment.