Skip to content

Commit

Permalink
Simplify DynamicWinsockMethods (#43190)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephen Toub <[email protected]>
  • Loading branch information
jkotas and stephentoub authored Oct 9, 2020
1 parent bae8b42 commit c5b6881
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 226 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using System.Threading;

Expand Down Expand Up @@ -38,86 +39,24 @@ public static DynamicWinsockMethods GetMethods(AddressFamily addressFamily, Sock
private readonly AddressFamily _addressFamily;
private readonly SocketType _socketType;
private readonly ProtocolType _protocolType;
private readonly object _lockObject;

private AcceptExDelegate? _acceptEx;
private GetAcceptExSockaddrsDelegate? _getAcceptExSockaddrs;
private ConnectExDelegate? _connectEx;
private TransmitPacketsDelegate? _transmitPackets;

private DisconnectExDelegate? _disconnectEx;
private DisconnectExDelegateBlocking? _disconnectExBlocking;

private WSARecvMsgDelegate? _recvMsg;
private WSARecvMsgDelegateBlocking? _recvMsgBlocking;

private DynamicWinsockMethods(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
_addressFamily = addressFamily;
_socketType = socketType;
_protocolType = protocolType;
_lockObject = new object();
}

public T GetDelegate<T>(SafeSocketHandle socketHandle)
where T : class
{
if (typeof(T) == typeof(AcceptExDelegate))
{
EnsureAcceptEx(socketHandle);
Debug.Assert(_acceptEx != null);
return (T)(object)_acceptEx;
}
else if (typeof(T) == typeof(GetAcceptExSockaddrsDelegate))
{
EnsureGetAcceptExSockaddrs(socketHandle);
Debug.Assert(_getAcceptExSockaddrs != null);
return (T)(object)_getAcceptExSockaddrs;
}
else if (typeof(T) == typeof(ConnectExDelegate))
{
EnsureConnectEx(socketHandle);
Debug.Assert(_connectEx != null);
return (T)(object)_connectEx;
}
else if (typeof(T) == typeof(DisconnectExDelegate))
{
EnsureDisconnectEx(socketHandle);
Debug.Assert(_disconnectEx != null);
return (T)(object)_disconnectEx;
}
else if (typeof(T) == typeof(DisconnectExDelegateBlocking))
{
EnsureDisconnectEx(socketHandle);
Debug.Assert(_disconnectExBlocking != null);
return (T)(object)_disconnectExBlocking;
}
else if (typeof(T) == typeof(WSARecvMsgDelegate))
{
EnsureWSARecvMsg(socketHandle);
Debug.Assert(_recvMsg != null);
return (T)(object)_recvMsg;
}
else if (typeof(T) == typeof(WSARecvMsgDelegateBlocking))
{
EnsureWSARecvMsgBlocking(socketHandle);
Debug.Assert(_recvMsgBlocking != null);
return (T)(object)_recvMsgBlocking;
}
else if (typeof(T) == typeof(TransmitPacketsDelegate))
{
EnsureTransmitPackets(socketHandle);
Debug.Assert(_transmitPackets != null);
return (T)(object)_transmitPackets;
}

Debug.Fail("Invalid type passed to DynamicWinsockMethods.GetDelegate");
return null;
}

// Private methods that actually load the function pointers.
private IntPtr LoadDynamicFunctionPointer(SafeSocketHandle socketHandle, ref Guid guid)
private static T CreateDelegate<T>([NotNull] ref T? cache, SafeSocketHandle socketHandle, string guidString) where T: Delegate
{
Guid guid = new Guid(guidString);
IntPtr ptr = IntPtr.Zero;
int length;
SocketError errorCode;
Expand All @@ -141,125 +80,27 @@ private IntPtr LoadDynamicFunctionPointer(SafeSocketHandle socketHandle, ref Gui
throw new SocketException();
}

return ptr;
Interlocked.CompareExchange(ref cache, Marshal.GetDelegateForFunctionPointer<T>(ptr), null);
return cache;
}

// NOTE: the volatile writes in the functions below are necessary to ensure that all writes
// to the fields of the delegate instances are visible before the write to the field
// that holds the reference to the delegate instance.
internal AcceptExDelegate GetAcceptExDelegate(SafeSocketHandle socketHandle)
=> _acceptEx ?? CreateDelegate(ref _acceptEx, socketHandle, "b5367df1cbac11cf95ca00805f48a192");

private void EnsureAcceptEx(SafeSocketHandle socketHandle)
{
if (_acceptEx == null)
{
lock (_lockObject)
{
if (_acceptEx == null)
{
Guid guid = new Guid("{0xb5367df1,0xcbac,0x11cf,{0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}");
IntPtr ptrAcceptEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _acceptEx, Marshal.GetDelegateForFunctionPointer<AcceptExDelegate>(ptrAcceptEx));
}
}
}
}

private void EnsureGetAcceptExSockaddrs(SafeSocketHandle socketHandle)
{
if (_getAcceptExSockaddrs == null)
{
lock (_lockObject)
{
if (_getAcceptExSockaddrs == null)
{
Guid guid = new Guid("{0xb5367df2,0xcbac,0x11cf,{0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}");
IntPtr ptrGetAcceptExSockaddrs = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _getAcceptExSockaddrs, Marshal.GetDelegateForFunctionPointer<GetAcceptExSockaddrsDelegate>(ptrGetAcceptExSockaddrs));
}
}
}
}
internal GetAcceptExSockaddrsDelegate GetGetAcceptExSockaddrsDelegate(SafeSocketHandle socketHandle)
=> _getAcceptExSockaddrs ?? CreateDelegate(ref _getAcceptExSockaddrs, socketHandle, "b5367df2cbac11cf95ca00805f48a192");

private void EnsureConnectEx(SafeSocketHandle socketHandle)
{
if (_connectEx == null)
{
lock (_lockObject)
{
if (_connectEx == null)
{
Guid guid = new Guid("{0x25a207b9,0x0ddf3,0x4660,{0x8e,0xe9,0x76,0xe5,0x8c,0x74,0x06,0x3e}}");
IntPtr ptrConnectEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _connectEx, Marshal.GetDelegateForFunctionPointer<ConnectExDelegate>(ptrConnectEx));
}
}
}
}
internal ConnectExDelegate GetConnectExDelegate(SafeSocketHandle socketHandle)
=> _connectEx ?? CreateDelegate(ref _connectEx, socketHandle, "25a207b9ddf346608ee976e58c74063e");

private void EnsureDisconnectEx(SafeSocketHandle socketHandle)
{
if (_disconnectEx == null)
{
lock (_lockObject)
{
if (_disconnectEx == null)
{
Guid guid = new Guid("{0x7fda2e11,0x8630,0x436f,{0xa0, 0x31, 0xf5, 0x36, 0xa6, 0xee, 0xc1, 0x57}}");
IntPtr ptrDisconnectEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
_disconnectExBlocking = Marshal.GetDelegateForFunctionPointer<DisconnectExDelegateBlocking>(ptrDisconnectEx);
Volatile.Write(ref _disconnectEx, Marshal.GetDelegateForFunctionPointer<DisconnectExDelegate>(ptrDisconnectEx));
}
}
}
}
private void EnsureWSARecvMsg(SafeSocketHandle socketHandle)
{
if (_recvMsg == null)
{
lock (_lockObject)
{
if (_recvMsg == null)
{
Guid guid = new Guid("{0xf689d7c8,0x6f1f,0x436b,{0x8a,0x53,0xe5,0x4f,0xe3,0x51,0xc3,0x22}}");
IntPtr ptrWSARecvMsg = LoadDynamicFunctionPointer(socketHandle, ref guid);
_recvMsgBlocking = Marshal.GetDelegateForFunctionPointer<WSARecvMsgDelegateBlocking>(ptrWSARecvMsg);
Volatile.Write(ref _recvMsg, Marshal.GetDelegateForFunctionPointer<WSARecvMsgDelegate>(ptrWSARecvMsg));
}
}
}
}
internal DisconnectExDelegate GetDisconnectExDelegate(SafeSocketHandle socketHandle)
=> _disconnectEx ?? CreateDelegate(ref _disconnectEx, socketHandle, "7fda2e118630436fa031f536a6eec157");

private void EnsureWSARecvMsgBlocking(SafeSocketHandle socketHandle)
{
if (_recvMsgBlocking == null)
{
lock (_lockObject)
{
if (_recvMsgBlocking == null)
{
Guid guid = new Guid("{0xf689d7c8,0x6f1f,0x436b,{0x8a,0x53,0xe5,0x4f,0xe3,0x51,0xc3,0x22}}");
IntPtr ptrWSARecvMsg = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _recvMsgBlocking, Marshal.GetDelegateForFunctionPointer<WSARecvMsgDelegateBlocking>(ptrWSARecvMsg));
}
}
}
}
internal WSARecvMsgDelegate GetWSARecvMsgDelegate(SafeSocketHandle socketHandle)
=> _recvMsg ?? CreateDelegate(ref _recvMsg, socketHandle, "f689d7c86f1f436b8a53e54fe351c322");

private void EnsureTransmitPackets(SafeSocketHandle socketHandle)
{
if (_transmitPackets == null)
{
lock (_lockObject)
{
if (_transmitPackets == null)
{
Guid guid = new Guid("{0xd9689da0,0x1f90,0x11d3,{0x99,0x71,0x00,0xc0,0x4f,0x68,0xc8,0x76}}");
IntPtr ptrTransmitPackets = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _transmitPackets, Marshal.GetDelegateForFunctionPointer<TransmitPacketsDelegate>(ptrTransmitPackets));
}
}
}
}
internal TransmitPacketsDelegate GetTransmitPacketsDelegate(SafeSocketHandle socketHandle)
=> _transmitPackets ?? CreateDelegate(ref _transmitPackets, socketHandle, "d9689da01f9011d3997100c04f68c876");
}

[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
Expand Down Expand Up @@ -302,13 +143,6 @@ internal unsafe delegate bool DisconnectExDelegate(
int flags,
int reserved);

[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal delegate bool DisconnectExDelegateBlocking(
SafeSocketHandle socketHandle,
IntPtr overlapped,
int flags,
int reserved);

[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal unsafe delegate SocketError WSARecvMsgDelegate(
SafeSocketHandle socketHandle,
Expand All @@ -317,14 +151,6 @@ internal unsafe delegate SocketError WSARecvMsgDelegate(
NativeOverlapped* overlapped,
IntPtr completionRoutine);

[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal delegate SocketError WSARecvMsgDelegateBlocking(
SafeSocketHandle socketHandle,
IntPtr msg,
out int bytesTransferred,
IntPtr overlapped,
IntPtr completionRoutine);

[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal unsafe delegate bool TransmitPacketsDelegate(
SafeSocketHandle socketHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,9 @@ public Socket EndAccept(out byte[] buffer, out int bytesTransferred, IAsyncResul
return EndAcceptCommon(out buffer!, out bytesTransferred, asyncResult);
}

private void EnsureDynamicWinsockMethods()
private DynamicWinsockMethods GetDynamicWinsockMethods()
{
if (_dynamicWinsockMethods == null)
{
_dynamicWinsockMethods = DynamicWinsockMethods.GetMethods(_addressFamily, _socketType, _protocolType);
}
return _dynamicWinsockMethods ??= DynamicWinsockMethods.GetMethods(_addressFamily, _socketType, _protocolType);
}

internal unsafe bool AcceptEx(SafeSocketHandle listenSocketHandle,
Expand All @@ -168,8 +165,7 @@ internal unsafe bool AcceptEx(SafeSocketHandle listenSocketHandle,
out int bytesReceived,
NativeOverlapped* overlapped)
{
EnsureDynamicWinsockMethods();
AcceptExDelegate acceptEx = _dynamicWinsockMethods!.GetDelegate<AcceptExDelegate>(listenSocketHandle);
AcceptExDelegate acceptEx = GetDynamicWinsockMethods().GetAcceptExDelegate(listenSocketHandle);

return acceptEx(listenSocketHandle,
acceptSocketHandle,
Expand All @@ -190,8 +186,7 @@ internal void GetAcceptExSockaddrs(IntPtr buffer,
out IntPtr remoteSocketAddress,
out int remoteSocketAddressLength)
{
EnsureDynamicWinsockMethods();
GetAcceptExSockaddrsDelegate getAcceptExSockaddrs = _dynamicWinsockMethods!.GetDelegate<GetAcceptExSockaddrsDelegate>(_handle);
GetAcceptExSockaddrsDelegate getAcceptExSockaddrs = GetDynamicWinsockMethods().GetGetAcceptExSockaddrsDelegate(_handle);

getAcceptExSockaddrs(buffer,
receiveDataLength,
Expand All @@ -205,18 +200,16 @@ internal void GetAcceptExSockaddrs(IntPtr buffer,

internal unsafe bool DisconnectEx(SafeSocketHandle socketHandle, NativeOverlapped* overlapped, int flags, int reserved)
{
EnsureDynamicWinsockMethods();
DisconnectExDelegate disconnectEx = _dynamicWinsockMethods!.GetDelegate<DisconnectExDelegate>(socketHandle);
DisconnectExDelegate disconnectEx = GetDynamicWinsockMethods().GetDisconnectExDelegate(socketHandle);

return disconnectEx(socketHandle, overlapped, flags, reserved);
}

internal bool DisconnectExBlocking(SafeSocketHandle socketHandle, IntPtr overlapped, int flags, int reserved)
internal unsafe bool DisconnectExBlocking(SafeSocketHandle socketHandle, int flags, int reserved)
{
EnsureDynamicWinsockMethods();
DisconnectExDelegateBlocking disconnectEx_Blocking = _dynamicWinsockMethods!.GetDelegate<DisconnectExDelegateBlocking>(socketHandle);
DisconnectExDelegate disconnectEx = GetDynamicWinsockMethods().GetDisconnectExDelegate(socketHandle);

return disconnectEx_Blocking(socketHandle, overlapped, flags, reserved);
return disconnectEx(socketHandle, null, flags, reserved);
}

partial void WildcardBindForConnectIfNecessary(AddressFamily addressFamily)
Expand Down Expand Up @@ -257,32 +250,28 @@ internal unsafe bool ConnectEx(SafeSocketHandle socketHandle,
out int bytesSent,
NativeOverlapped* overlapped)
{
EnsureDynamicWinsockMethods();
ConnectExDelegate connectEx = _dynamicWinsockMethods!.GetDelegate<ConnectExDelegate>(socketHandle);
ConnectExDelegate connectEx = GetDynamicWinsockMethods().GetConnectExDelegate(socketHandle);

return connectEx(socketHandle, socketAddress, socketAddressSize, buffer, dataLength, out bytesSent, overlapped);
}

internal unsafe SocketError WSARecvMsg(SafeSocketHandle socketHandle, IntPtr msg, out int bytesTransferred, NativeOverlapped* overlapped, IntPtr completionRoutine)
{
EnsureDynamicWinsockMethods();
WSARecvMsgDelegate recvMsg = _dynamicWinsockMethods!.GetDelegate<WSARecvMsgDelegate>(socketHandle);
WSARecvMsgDelegate recvMsg = GetDynamicWinsockMethods().GetWSARecvMsgDelegate(socketHandle);

return recvMsg(socketHandle, msg, out bytesTransferred, overlapped, completionRoutine);
}

internal SocketError WSARecvMsgBlocking(SafeSocketHandle socketHandle, IntPtr msg, out int bytesTransferred, IntPtr overlapped, IntPtr completionRoutine)
internal unsafe SocketError WSARecvMsgBlocking(SafeSocketHandle socketHandle, IntPtr msg, out int bytesTransferred)
{
EnsureDynamicWinsockMethods();
WSARecvMsgDelegateBlocking recvMsg_Blocking = _dynamicWinsockMethods!.GetDelegate<WSARecvMsgDelegateBlocking>(_handle);
WSARecvMsgDelegate recvMsg = GetDynamicWinsockMethods().GetWSARecvMsgDelegate(_handle);

return recvMsg_Blocking(socketHandle, msg, out bytesTransferred, overlapped, completionRoutine);
return recvMsg(socketHandle, msg, out bytesTransferred, null, IntPtr.Zero);
}

internal unsafe bool TransmitPackets(SafeSocketHandle socketHandle, IntPtr packetArray, int elementCount, int sendSize, NativeOverlapped* overlapped, TransmitFileOptions flags)
{
EnsureDynamicWinsockMethods();
TransmitPacketsDelegate transmitPackets = _dynamicWinsockMethods!.GetDelegate<TransmitPacketsDelegate>(socketHandle);
TransmitPacketsDelegate transmitPackets = GetDynamicWinsockMethods().GetTransmitPacketsDelegate(socketHandle);

return transmitPackets(socketHandle, packetArray, elementCount, sendSize, overlapped, flags);
}
Expand Down
Loading

0 comments on commit c5b6881

Please sign in to comment.