Skip to content

Commit

Permalink
avoid allocations for TLS handshake (#87874)
Browse files Browse the repository at this point in the history
* avoid allocations for TLS handshake

* cleanup

* feedback

* update

* fixes

* cleanup

* feedback
  • Loading branch information
wfurt authored Dec 6, 2023
1 parent 9e4efd4 commit 71f3bf1
Show file tree
Hide file tree
Showing 18 changed files with 401 additions and 413 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ internal static unsafe int BioGets(SafeBioHandle b, Span<byte> buf)
[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioRead")]
internal static partial int BioRead(SafeBioHandle b, byte[] data, int len);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioRead")]
private static partial int BioRead(SafeBioHandle b, Span<byte> data, int len);
internal static int BioRead(SafeBioHandle b, Span<byte> data) => BioRead(b, data, data.Length);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioWrite")]
internal static partial int BioWrite(SafeBioHandle b, byte[] data, int len);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,9 @@ internal static SecurityStatusPal SslRenegotiate(SafeSslHandle sslContext, out b
return new SecurityStatusPal(SecurityStatusPalErrorCode.OK);
}

internal static SecurityStatusPalErrorCode DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, out byte[]? sendBuf, out int sendCount)
internal static SecurityStatusPalErrorCode DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, ref ProtocolToken token)
{
sendBuf = null;
sendCount = 0;
token.Size = 0;
Exception? handshakeException = null;

if (input.Length > 0)
Expand Down Expand Up @@ -524,14 +523,13 @@ internal static SecurityStatusPalErrorCode DoSslHandshake(SafeSslHandle context,
}
}

sendCount = Crypto.BioCtrlPending(context.OutputBio!);
int sendCount = Crypto.BioCtrlPending(context.OutputBio!);
if (sendCount > 0)
{
sendBuf = new byte[sendCount];

token.EnsureAvailableSpace(sendCount);
try
{
sendCount = BioRead(context.OutputBio!, sendBuf, sendCount);
sendCount = BioRead(context.OutputBio!, token.AvailableSpan, sendCount);
}
catch (Exception) when (handshakeException != null)
{
Expand All @@ -543,12 +541,13 @@ internal static SecurityStatusPalErrorCode DoSslHandshake(SafeSslHandle context,
{
// Make sure we clear out the error that is stored in the queue
Crypto.ErrClearError();
sendBuf = null;
sendCount = 0;
}
}
}

token.Size = sendCount;

if (handshakeException != null)
{
throw handshakeException;
Expand All @@ -563,14 +562,13 @@ internal static SecurityStatusPalErrorCode DoSslHandshake(SafeSslHandle context,
return stateOk ? SecurityStatusPalErrorCode.OK : SecurityStatusPalErrorCode.ContinueNeeded;
}

internal static int Encrypt(SafeSslHandle context, ReadOnlySpan<byte> input, ref byte[] output, out Ssl.SslErrorCode errorCode)
internal static Ssl.SslErrorCode Encrypt(SafeSslHandle context, ReadOnlySpan<byte> input, ref ProtocolToken outToken)
{
int retVal = Ssl.SslWrite(context, ref MemoryMarshal.GetReference(input), input.Length, out errorCode);
int retVal = Ssl.SslWrite(context, ref MemoryMarshal.GetReference(input), input.Length, out Ssl.SslErrorCode errorCode);

if (retVal != input.Length)
{
retVal = 0;

outToken.Size = 0;
switch (errorCode)
{
// indicate end-of-file
Expand All @@ -585,22 +583,22 @@ internal static int Encrypt(SafeSslHandle context, ReadOnlySpan<byte> input, ref
else
{
int capacityNeeded = Crypto.BioCtrlPending(context.OutputBio!);

if (output == null || output.Length < capacityNeeded)
{
output = new byte[capacityNeeded];
}

retVal = BioRead(context.OutputBio!, output, capacityNeeded);
outToken.EnsureAvailableSpace(capacityNeeded);
retVal = BioRead(context.OutputBio!, outToken.AvailableSpan, capacityNeeded);

if (retVal <= 0)
{
// Make sure we clear out the error that is stored in the queue
Crypto.ErrClearError();
outToken.Size = 0;
}
else
{
outToken.Size = retVal;
}
}

return retVal;
return errorCode;
}

internal static int Decrypt(SafeSslHandle context, Span<byte> buffer, out Ssl.SslErrorCode errorCode)
Expand Down Expand Up @@ -811,13 +809,12 @@ private static unsafe void KeyLogCallback(IntPtr ssl, char* line)
}
#endif

private static int BioRead(SafeBioHandle bio, byte[] buffer, int count)
private static int BioRead(SafeBioHandle bio, Span<byte> buffer, int count)
{
Debug.Assert(buffer != null);
Debug.Assert(count >= 0);
Debug.Assert(buffer.Length >= count);

int bytes = Crypto.BioRead(bio, buffer, count);
int bytes = Crypto.BioRead(bio, buffer);
if (bytes != count)
{
throw CreateSslException(SR.net_ssl_read_bio_failed_error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ internal interface ISSPIInterface
unsafe int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.CredentialUse usage, Interop.SspiCli.SCHANNEL_CRED* authdata, out SafeFreeCredentials outCredential);
unsafe int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.CredentialUse usage, Interop.SspiCli.SCH_CREDENTIALS* authdata, out SafeFreeCredentials outCredential);
int AcquireDefaultCredential(string moduleName, Interop.SspiCli.CredentialUse usage, out SafeFreeCredentials outCredential);
int AcceptSecurityContext(SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags);
int InitializeSecurityContext(ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags);
int AcceptSecurityContext(SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags);
int InitializeSecurityContext(ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags);
int EncryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint qop);
int DecryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, out uint qop);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ public unsafe int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.Cr
return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, authdata, out outCredential);
}

public int AcceptSecurityContext(SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int AcceptSecurityContext(SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outToken, ref outFlags);
}

public int InitializeSecurityContext(ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int InitializeSecurityContext(ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outToken, ref outFlags);
}

public int EncryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint qop)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ public unsafe int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.Cr
return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, authdata, out outCredential);
}

public int AcceptSecurityContext(SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int AcceptSecurityContext(SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outToken, ref outFlags);
}

public int InitializeSecurityContext(ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int InitializeSecurityContext(ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outToken, ref outFlags);
}

public int EncryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint qop)
Expand Down
12 changes: 6 additions & 6 deletions src/libraries/Common/src/Interop/Windows/SspiCli/SSPIWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,24 @@ public static unsafe SafeFreeCredentials AcquireCredentialsHandle(ISSPIInterface
return outCredential;
}

internal static int InitializeSecurityContext(ISSPIInterface secModule, ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
internal static int InitializeSecurityContext(ISSPIInterface secModule, ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.InitializeSecurityContext(credential, context, targetName, inFlags);

int errorCode = secModule.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, datarep, inputBuffers, ref outputBuffer, ref outFlags);
int errorCode = secModule.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, datarep, inputBuffers, ref outToken, ref outFlags);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SecurityContextInputBuffers(nameof(InitializeSecurityContext), inputBuffers.Count, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SecurityContextInputBuffers(nameof(InitializeSecurityContext), inputBuffers.Count, outToken.Size, (Interop.SECURITY_STATUS)errorCode);

return errorCode;
}

internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref ProtocolToken outToken, ref Interop.SspiCli.ContextFlags outFlags)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.AcceptSecurityContext(credential, context, inFlags);

int errorCode = secModule.AcceptSecurityContext(credential, ref context, inputBuffers, inFlags, datarep, ref outputBuffer, ref outFlags);
int errorCode = secModule.AcceptSecurityContext(credential, ref context, inputBuffers, inFlags, datarep, ref outToken, ref outFlags);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SecurityContextInputBuffers(nameof(AcceptSecurityContext), inputBuffers.Count, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SecurityContextInputBuffers(nameof(AcceptSecurityContext), inputBuffers.Count, outToken.Size, (Interop.SECURITY_STATUS)errorCode);

return errorCode;
}
Expand Down
Loading

0 comments on commit 71f3bf1

Please sign in to comment.