Skip to content

Commit

Permalink
fix broken Select with error list on macOS (#104915)
Browse files Browse the repository at this point in the history
* fix broken Select with error list on macOS

* update

* Apply suggestions from code review

Co-authored-by: Stephen Toub <[email protected]>

* feedback

* feedback

* Apply suggestions from code review

Co-authored-by: Stephen Toub <[email protected]>

* feedback

* test

---------

Co-authored-by: Stephen Toub <[email protected]>
  • Loading branch information
wfurt and stephentoub authored Jul 28, 2024
1 parent f832e66 commit cb5acf5
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Select")]
internal static unsafe partial Error Select(Span<int> readFDs, int readFDsLength, Span<int> writeFDs, int writeFDsLength, Span<int> checkError, int checkErrorLength, int timeout, int maxFd, out int triggered);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@
Link="Common\Interop\Unix\System.Native\Interop.ReceiveMessage.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Send.cs"
Link="Common\Interop\Unix\System.Native\Interop.Send.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Select.cs"
Link="Common\Interop\Unix\System.Native\Interop.Select.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SendMessage.cs"
Link="Common\Interop\Unix\System.Native\Interop.SendMessage.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SetSockOpt.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal static partial class SocketPal
public static readonly int MaximumAddressSize = Interop.Sys.GetMaximumAddressSize();
private static readonly bool SupportsDualModeIPv4PacketInfo = GetPlatformSupportsDualModeIPv4PacketInfo();

private static readonly bool SelectOverPollIsBroken = OperatingSystem.IsMacOS() || OperatingSystem.IsIOS() || OperatingSystem.IsTvOS() || OperatingSystem.IsMacCatalyst();

// IovStackThreshold matches Linux's UIO_FASTIOV, which is the number of 'struct iovec'
// that get stackalloced in the Linux kernel.
private const int IovStackThreshold = 8;
Expand Down Expand Up @@ -1782,6 +1784,11 @@ public static unsafe SocketError Select(IList? checkRead, IList? checkWrite, ILi
// by the system. Since poll then expects an array of entries, we try to allocate the array on the stack,
// only falling back to allocating it on the heap if it's deemed too big.

if (SelectOverPollIsBroken)
{
return SelectViaSelect(checkRead, checkWrite, checkError, microseconds);
}

const int StackThreshold = 80; // arbitrary limit to avoid too much space on stack
if (count < StackThreshold)
{
Expand All @@ -1806,6 +1813,103 @@ public static unsafe SocketError Select(IList? checkRead, IList? checkWrite, ILi
}
}

private static SocketError SelectViaSelect(IList? checkRead, IList? checkWrite, IList? checkError, int microseconds)
{
const int MaxStackAllocCount = 20; // this is just arbitrary limit 3x 20 -> 60 e.g. close to 64 we have in some other places
Span<int> readFDs = checkRead?.Count > MaxStackAllocCount ? new int[checkRead.Count] : stackalloc int[checkRead?.Count ?? 0];
Span<int> writeFDs = checkWrite?.Count > MaxStackAllocCount ? new int[checkWrite.Count] : stackalloc int[checkWrite?.Count ?? 0];
Span<int> errorFDs = checkError?.Count > MaxStackAllocCount ? new int[checkError.Count] : stackalloc int[checkError?.Count ?? 0];

int refsAdded = 0;
int maxFd = 0;
try
{
AddDesriptors(readFDs, checkRead, ref refsAdded, ref maxFd);
AddDesriptors(writeFDs, checkWrite, ref refsAdded, ref maxFd);
AddDesriptors(errorFDs, checkError, ref refsAdded, ref maxFd);

int triggered = 0;
Interop.Error err = Interop.Sys.Select(readFDs, readFDs.Length, writeFDs, writeFDs.Length, errorFDs, errorFDs.Length, microseconds, maxFd, out triggered);
if (err != Interop.Error.SUCCESS)
{
return GetSocketErrorForErrorCode(err);
}

Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);

if (triggered == 0)
{
checkRead?.Clear();
checkWrite?.Clear();
checkError?.Clear();
}
else
{
FilterSelectList(checkRead, readFDs);
FilterSelectList(checkWrite, writeFDs);
FilterSelectList(checkError, errorFDs);
}
}
finally
{
// This order matches with the AddToPollArray calls
// to release only the handles that were ref'd.
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);
Debug.Assert(refsAdded == 0);
}

return (SocketError)0;
}

private static void AddDesriptors(Span<int> buffer, IList? socketList, ref int refsAdded, ref int maxFd)
{
if (socketList == null || socketList.Count == 0 )
{
return;
}

Debug.Assert(buffer.Length == socketList.Count);
for (int i = 0; i < socketList.Count; i++)
{
Socket? socket = socketList[i] as Socket;
if (socket == null)
{
throw new ArgumentException(SR.Format(SR.net_sockets_select, socket?.GetType().FullName ?? "null", typeof(Socket).FullName), nameof(socketList));
}

if (socket.Handle > maxFd)
{
maxFd = (int)socket.Handle;
}

bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
buffer[i] = (int)socket.InternalSafeHandle.DangerousGetHandle();

refsAdded++;
}
}

private static void FilterSelectList(IList? socketList, Span<int> results)
{
if (socketList == null)
return;

// This loop can be O(n^2) in the unexpected and worst case. Some more thoughts are written in FilterPollList that does exactly same operation.

for (int i = socketList.Count - 1; i >= 0; --i)
{
if (results[i] == 0)
{
socketList.RemoveAt(i);
}
}
}

private static unsafe SocketError SelectViaPoll(
IList? checkRead, int checkReadInitialCount,
IList? checkWrite, int checkWriteInitialCount,
Expand Down
107 changes: 103 additions & 4 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/SelectTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.DotNet.XUnitExtensions;
using Xunit;
using Xunit.Abstractions;

Expand All @@ -21,7 +21,7 @@ public SelectTest(ITestOutputHelper output)
}

private const int SmallTimeoutMicroseconds = 10 * 1000;
private const int FailTimeoutMicroseconds = 30 * 1000 * 1000;
internal const int FailTimeoutMicroseconds = 30 * 1000 * 1000;

[SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")]
[Theory]
Expand Down Expand Up @@ -78,6 +78,82 @@ public void Select_ReadWrite_AllReady(int reads, int writes)
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void Select_ReadError_Success(bool dispose)
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

if (dispose)
{
sender.Dispose();
}
else
{
sender.Send(new byte[] { 1 });
}

var readList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, null, errorList, -1);
if (dispose)
{
Assert.True(readList.Count == 1 || errorList.Count == 1);
}
else
{
Assert.Equal(1, readList.Count);
Assert.Equal(0, errorList.Count);
}
}

[Fact]
public void Select_WriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(null, writeList, errorList, -1);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Fact]
public void Select_ReadWriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

sender.Send(new byte[] { 1 });
receiver.Poll(FailTimeoutMicroseconds, SelectMode.SelectRead);
var readList = new List<Socket> { receiver };
var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, writeList, errorList, -1);
Assert.Equal(1, readList.Count);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Theory]
[InlineData(2, 0)]
[InlineData(2, 1)]
Expand Down Expand Up @@ -109,7 +185,6 @@ public void Select_SocketAlreadyClosed_AllSocketsClosableAfterException(int sock
}
}

[SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")]
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/51392", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)]
public void Select_ReadError_NoneReady_ManySockets()
Expand Down Expand Up @@ -245,7 +320,7 @@ public void Poll_ReadReady_LongTimeouts(int microsecondsTimeout)
}
}

private static KeyValuePair<Socket, Socket> CreateConnectedSockets()
internal static KeyValuePair<Socket, Socket> CreateConnectedSockets()
{
using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
Expand Down Expand Up @@ -342,5 +417,29 @@ private static void DoAccept(Socket listenSocket, int connectionsToAccept)
}
}
}

[ConditionalFact]
public void Select_LargeNumber_Succcess()
{
const int MaxSockets = 1025;
KeyValuePair<Socket, Socket>[] socketPairs;
try
{
// we try to shoot for more socket than FD_SETSIZE (that is typically 1024)
socketPairs = Enumerable.Range(0, MaxSockets).Select(_ => SelectTest.CreateConnectedSockets()).ToArray();
}
catch
{
throw new SkipTestException("Unable to open large count number of socket");
}

var readList = new List<Socket>(socketPairs.Select(p => p.Key).ToArray());

// Try to write and read on last sockets
(Socket reader, Socket writer) = socketPairs[MaxSockets - 1];
writer.Send(new byte[1]);
Socket.Select(readList, null, null, SelectTest.FailTimeoutMicroseconds);
Assert.Equal(1, readList.Count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,8 @@ public void FailedConnect_GetSocketOption_SocketOptionNameError(bool simpleGet)
Assert.ThrowsAny<Exception>(() => client.Connect(server.LocalEndPoint));
}

// Verify via Select that there's an error
const int FailedTimeout = 10 * 1000 * 1000; // 10 seconds
var errorList = new List<Socket> { client };
Socket.Select(null, null, errorList, FailedTimeout);
Assert.Equal(1, errorList.Count);
// Verify via Poll that there's an error
Assert.True(client.Poll(10_000_000, SelectMode.SelectError));

// Get the last error and validate it's what's expected
int errorCode;
Expand Down
1 change: 1 addition & 0 deletions src/native/libs/System.Native/entrypoints.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ static const Entry s_sysNative[] =
DllImportEntry(SystemNative_GetGroupName)
DllImportEntry(SystemNative_GetUInt64OSThreadId)
DllImportEntry(SystemNative_TryGetUInt32OSThreadId)
DllImportEntry(SystemNative_Select)
};

EXTERN_C const void* SystemResolveDllImport(const char* name);
Expand Down
Loading

0 comments on commit cb5acf5

Please sign in to comment.