Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/7.0] Fix SslStream.IsMutuallyAuthenticated #95733

Merged
merged 5 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/libraries/Common/src/Interop/Windows/SspiCli/Interop.SSPI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ internal enum ContextAttribute
SECPKG_ATTR_ISSUER_LIST_EX = 0x59, // returns SecPkgContext_IssuerListInfoEx
SECPKG_ATTR_CLIENT_CERT_POLICY = 0x60, // sets SecPkgCred_ClientCertCtlPolicy
SECPKG_ATTR_CONNECTION_INFO = 0x5A, // returns SecPkgContext_ConnectionInfo
SECPKG_ATTR_SESSION_INFO = 0x5D, // sets SecPkgContext_SessionInfo
SECPKG_ATTR_CIPHER_INFO = 0x64, // returns SecPkgContext_CipherInfo
SECPKG_ATTR_REMOTE_CERT_CHAIN = 0x67, // returns PCCERT_CONTEXT
SECPKG_ATTR_UI_INFO = 0x68, // sets SEcPkgContext_UiInfo
Expand Down Expand Up @@ -249,7 +250,7 @@ public enum Flags
SCH_CRED_IGNORE_REVOCATION_OFFLINE = 0x1000,
SCH_CRED_CACHE_ONLY_URL_RETRIEVAL_ON_CREATE = 0x2000,
SCH_SEND_ROOT_CERT = 0x40000,
SCH_SEND_AUX_RECORD = 0x00200000,
SCH_SEND_AUX_RECORD = 0x00200000,
SCH_USE_STRONG_CRYPTO = 0x00400000,
SCH_USE_PRESHAREDKEY_ONLY = 0x800000,
SCH_ALLOW_NULL_ENCRYPTION = 0x02000000,
Expand Down Expand Up @@ -334,6 +335,21 @@ internal unsafe struct SecPkgCred_ClientCertPolicy
public char* pwszSslCtlIdentifier;
}

[StructLayout(LayoutKind.Sequential)]
internal unsafe struct SecPkgContext_SessionInfo
{
public uint dwFlags;
public uint cbSessionId;
public fixed byte rgbSessionId[32];

[Flags]
public enum Flags
{
Zero = 0,
SSL_SESSION_RECONNECT = 0x01,
};
}

[LibraryImport(Interop.Libraries.SspiCli, SetLastError = true)]
internal static partial int EncryptMessage(
ref CredHandle contextHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ private static bool QueryCertContextAttribute(ISSPIInterface secModule, SafeDele
public static bool QueryContextAttributes_SECPKG_ATTR_REMOTE_CERT_CONTEXT(ISSPIInterface secModule, SafeDeleteContext securityContext, out SafeFreeCertContext? certContext)
=> QueryCertContextAttribute(secModule, securityContext, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_REMOTE_CERT_CONTEXT, out certContext);

public static bool QueryContextAttributes_SECPKG_ATTR_LOCAL_CERT_CONTEXT(ISSPIInterface secModule, SafeDeleteContext securityContext, out SafeFreeCertContext? certContext)
=> QueryCertContextAttribute(secModule, securityContext, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_LOCAL_CERT_CONTEXT, out certContext);

public static bool QueryContextAttributes_SECPKG_ATTR_REMOTE_CERT_CHAIN(ISSPIInterface secModule, SafeDeleteContext securityContext, out SafeFreeCertContext? certContext)
=> QueryCertContextAttribute(secModule, securityContext, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_REMOTE_CERT_CHAIN, out certContext);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Globalization;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using System.Security.Authentication.ExtendedProtection;
using Microsoft.Win32.SafeHandles;

Expand Down Expand Up @@ -310,10 +311,15 @@ public static unsafe int AcquireCredentialsHandle(

internal sealed class SafeFreeCredential_SECURITY : SafeFreeCredentials
{
#pragma warning disable 0649
// This is used only by SslStream but it is included elsewhere
public X509Certificate? LocalCertificate;
#pragma warning restore 0649
public SafeFreeCredential_SECURITY() : base() { }

protected override bool ReleaseHandle()
{
LocalCertificate?.Dispose();
rzikm marked this conversation as resolved.
Show resolved Hide resolved
return Interop.SspiCli.FreeCredentialsHandle(ref _handle) == 0;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal static SslPolicyErrors VerifyCertificateProperties(
string? hostName)
{
if (remoteCertificate == null)
return SslPolicyErrors.RemoteCertificateNotAvailable;
return SslPolicyErrors.RemoteCertificateNotAvailable;

SslPolicyErrors errors = chain.Build(remoteCertificate)
? SslPolicyErrors.None
Expand Down Expand Up @@ -91,6 +91,10 @@ internal static SslPolicyErrors VerifyCertificateProperties(
return cert;
}

// This is only called when we selected local client certificate.
// Currently this is only when Java crypto asked for it.
internal static bool IsLocalCertificateUsed(SafeFreeCredentials? _1, SafeDeleteContext? _2) => true;
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

//
// Used only by client SSL code, never returns null.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ internal static SslPolicyErrors VerifyCertificateProperties(
return result;
}

// This is only called when we selected local client certificate.
// Currently this is only when Apple crypto asked for it.
internal static bool IsLocalCertificateUsed(SafeFreeCredentials? _1, SafeDeleteContext? _2) => true;

//
// Used only by client SSL code, never returns null.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ internal static SslPolicyErrors VerifyCertificateProperties(
return result;
}

// This is only called when we selected local client certificate.
// Currently this is only when OpenSSL needs it because peer asked.
internal static bool IsLocalCertificateUsed(SafeFreeCredentials? _1, SafeDeleteContext? _2) => true;

//
// Used only by client SSL code, never returns null.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Security.Principal;
using static Interop.SspiCli;

namespace System.Net
{
Expand Down Expand Up @@ -89,6 +90,44 @@ internal static SslPolicyErrors VerifyCertificateProperties(
return result;
}

// Check that local certificate was used by schannel.
internal static bool IsLocalCertificateUsed(SafeFreeCredentials? credentialsHandle, SafeDeleteContext securityContext)
{
SecPkgContext_SessionInfo info = default;
rzikm marked this conversation as resolved.
Show resolved Hide resolved

// fails on Server 2008 and older. We will fall-back to probing LOCAL_CERT_CONTEXT in that case.
if (SSPIWrapper.QueryBlittableContextAttributes(
GlobalSSPI.SSPISecureChannel,
securityContext,
Interop.SspiCli.ContextAttribute.SECPKG_ATTR_SESSION_INFO,
ref info) &&
((SecPkgContext_SessionInfo.Flags)info.dwFlags).HasFlag(SecPkgContext_SessionInfo.Flags.SSL_SESSION_RECONNECT))
{
// This is TLS Resumed session. Windows can fail to query the local cert bellow.
// Instead, we will determine the usage form used credentials.
SafeFreeCredential_SECURITY creds = (SafeFreeCredential_SECURITY)credentialsHandle!;
return creds.LocalCertificate != null;
}

SafeFreeCertContext? localContext = null;
try
{
if (SSPIWrapper.QueryContextAttributes_SECPKG_ATTR_LOCAL_CERT_CONTEXT(GlobalSSPI.SSPISecureChannel, securityContext, out localContext) &&
localContext != null)
{
return !localContext.IsInvalid;
}
}
finally
{
localContext?.Dispose();
}

// Some older Windows do not support that. This is only called when client certificate was provided
// so assume it was for a reason.
return true;
}

//
// Used only by client SSL code, never returns null.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ public partial class SslStream
private int _trailerSize = 16;
private int _maxDataSize = 16354;

private bool _refreshCredentialNeeded = true;

private static readonly Oid s_serverAuthOid = new Oid("1.3.6.1.5.5.7.3.1", "1.3.6.1.5.5.7.3.1");
private static readonly Oid s_clientAuthOid = new Oid("1.3.6.1.5.5.7.3.2", "1.3.6.1.5.5.7.3.2");

Expand All @@ -56,7 +54,12 @@ internal X509Certificate? LocalClientCertificate
{
get
{
return _selectedClientCertificate;
if (_selectedClientCertificate != null && CertificateValidationPal.IsLocalCertificateUsed(_credentialsHandle, _securityContext!))
{
return _selectedClientCertificate;
}

return null;
}
}

Expand Down Expand Up @@ -104,11 +107,6 @@ internal bool RemoteCertRequired
}
}

internal void SetRefreshCredentialNeeded()
{
_refreshCredentialNeeded = true;
}

internal void CloseContext()
{
if (!_remoteCertificateExposed)
Expand Down Expand Up @@ -510,15 +508,14 @@ This will not restart a session but helps minimizing the number of handles we cr

--*/

private bool AcquireClientCredentials(ref byte[]? thumbPrint)
private bool AcquireClientCredentials(ref byte[]? thumbPrint, bool newCredentialsRequested = false)
{
// Acquire possible Client Certificate information and set it on the handle.

bool sessionRestartAttempt; // If true and no cached creds we will use anonymous creds.
bool cachedCred = false; // this is a return result from this method.

X509Certificate2? selectedCert = SelectClientCertificate(out sessionRestartAttempt);

try
{
// Try to locate cached creds first.
Expand Down Expand Up @@ -576,7 +573,7 @@ private bool AcquireClientCredentials(ref byte[]? thumbPrint)
_sslAuthenticationOptions.CertificateContext = SslStreamCertificateContext.Create(selectedCert!);
}

_credentialsHandle = AcquireCredentialsHandle(_sslAuthenticationOptions);
_credentialsHandle = AcquireCredentialsHandle(_sslAuthenticationOptions, newCredentialsRequested);
thumbPrint = guessedThumbPrint; // Delay until here in case something above threw.
}
}
Expand Down Expand Up @@ -687,9 +684,9 @@ private bool AcquireServerCredentials(ref byte[]? thumbPrint)
return cachedCred;
}

private static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions)
private static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions, bool newCredentialsRequested = false)
{
SafeFreeCredentials? cred = SslStreamPal.AcquireCredentialsHandle(sslAuthenticationOptions);
SafeFreeCredentials? cred = SslStreamPal.AcquireCredentialsHandle(sslAuthenticationOptions, newCredentialsRequested);

if (sslAuthenticationOptions.CertificateContext != null && cred != null)
{
Expand Down Expand Up @@ -749,7 +746,6 @@ internal ProtocolToken NextMessage(ReadOnlySpan<byte> incomingBuffer)
if (NetEventSource.Log.IsEnabled())
NetEventSource.Info(this, "NextMessage() returned SecurityStatusPal.CredentialsNeeded");

SetRefreshCredentialNeeded();
status = GenerateToken(incomingBuffer, ref nextmsg);
}

Expand Down Expand Up @@ -788,6 +784,11 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte
bool sendTrustList = false;
byte[]? thumbPrint = null;

// We need to try get credentials at the beginning.
// _credentialsHandle may be always null on some platforms but
// _securityContext will be allocated on first call.
bool refreshCredentialNeeded = _securityContext == null;

//
// Looping through ASC or ISC with potentially cached credential that could have been
// already disposed from a different thread before ISC or ASC dir increment a cred ref count.
Expand All @@ -797,7 +798,7 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte
do
{
thumbPrint = null;
if (_refreshCredentialNeeded)
if (refreshCredentialNeeded)
{
cachedCreds = _sslAuthenticationOptions.IsServer
? AcquireServerCredentials(ref thumbPrint)
Expand Down Expand Up @@ -826,15 +827,31 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte
_sslAuthenticationOptions,
SelectClientCertificate
);

if (status.ErrorCode == SecurityStatusPalErrorCode.CredentialsNeeded)
{
refreshCredentialNeeded = true;
cachedCreds = AcquireClientCredentials(ref thumbPrint, newCredentialsRequested: true);

if (NetEventSource.Log.IsEnabled())
NetEventSource.Info(this, "InitializeSecurityContext() returned 'CredentialsNeeded'.");

status = SslStreamPal.InitializeSecurityContext(
ref _credentialsHandle!,
ref _securityContext,
_sslAuthenticationOptions.TargetHost,
inputBuffer,
ref result,
_sslAuthenticationOptions,
SelectClientCertificate);
}
}
} while (cachedCreds && _credentialsHandle == null);
}
finally
{
if (_refreshCredentialNeeded)
if (refreshCredentialNeeded)
{
_refreshCredentialNeeded = false;

//
// Assuming the ISC or ASC has referenced the credential,
// we want to call dispose so to decrement the effective ref count.
Expand Down Expand Up @@ -974,7 +991,6 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
}

_remoteCertificate = certificate;

if (_remoteCertificate == null)
{
if (NetEventSource.Log.IsEnabled() && RemoteCertRequired) NetEventSource.Error(this, $"Remote certificate required, but no remote certificate received");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static SecurityStatusPal Renegotiate(
throw new PlatformNotSupportedException();
}

public static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions)
public static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions _1, bool _2)
{
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public static SecurityStatusPal Renegotiate(
throw new PlatformNotSupportedException();
}

public static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions)
public static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions _1, bool _2)
{
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public static SecurityStatusPal InitializeSecurityContext(
return HandshakeInternal(ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions, clientCertificateSelectionCallback);
}

public static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions)
rzikm marked this conversation as resolved.
Show resolved Hide resolved
public static SafeFreeCredentials? AcquireCredentialsHandle(SslAuthenticationOptions _1, bool _2)
{
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public static SecurityStatusPal Renegotiate(
return status;
}

public static SafeFreeCredentials AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions)
public static SafeFreeCredentials AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions, bool newCredentialsRequested)
{
try
{
Expand All @@ -156,6 +156,16 @@ public static SafeFreeCredentials AcquireCredentialsHandle(SslAuthenticationOpti
AttachCertificateStore(cred, certificateContext.Trust._store!);
}

// Windows can fail to get local credentials in case of TLS Resume.
// We will store associated certificate in credentials and use it in case
// of TLS resume. It will be disposed when the credentials are.
if (newCredentialsRequested && sslAuthenticationOptions.CertificateContext != null)
{
SafeFreeCredential_SECURITY handle = (SafeFreeCredential_SECURITY)cred;
// We need to create copy to avoid Disposal issue.
handle.LocalCertificate = new X509Certificate2(sslAuthenticationOptions.CertificateContext.Certificate);
}

return cred;
}
catch (Win32Exception e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,16 @@ public async Task CertificateValidationClientServer_EndToEnd_Ok(bool useClientSe
clientCerts.Add(_clientCertificate);
}

Task clientAuthentication = sslClientStream.AuthenticateAsClientAsync(
serverName,
clientCerts,
SslProtocolSupport.DefaultSslProtocols,
false);
// Connect to GUID to prevent TLS resume
var options = new SslClientAuthenticationOptions()
{
TargetHost = Guid.NewGuid().ToString("N"),
ClientCertificates = clientCerts,
EnabledSslProtocols = SslProtocolSupport.DefaultSslProtocols,
CertificateChainPolicy = new X509ChainPolicy(),
};
options.CertificateChainPolicy.VerificationFlags = X509VerificationFlags.IgnoreInvalidName;
Task clientAuthentication = sslClientStream.AuthenticateAsClientAsync(options, default);

Task serverAuthentication = sslServerStream.AuthenticateAsServerAsync(
_serverCertificate,
Expand Down Expand Up @@ -258,7 +263,6 @@ private bool ClientSideRemoteServerCertificateValidation(object sender, X509Cert

Assert.Equal(expectedSslPolicyErrors, sslPolicyErrors);
Assert.Equal(_serverCertificate, certificate);

return true;
}

Expand Down
Loading
Loading