Skip to content

Commit

Permalink
tlsconfig: Add support for configuring use of PQ KEMs during TLS hand…
Browse files Browse the repository at this point in the history
…shake

Signed-off-by: Hugo Landau <[email protected]>
  • Loading branch information
hlandau committed Aug 7, 2024
1 parent 51299b0 commit b813997
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 9 deletions.
83 changes: 82 additions & 1 deletion v2/spiffetls/tlsconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func HookTLSClientConfig(config *tls.Config, bundle x509bundle.Source, authorize
resetAuthFields(config)
config.InsecureSkipVerify = true
config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...)
applyOptions(config, opts)
}

// A Option changes the defaults used to by mTLS ClientConfig functions.
Expand All @@ -36,7 +37,8 @@ type option func(*options)
func (fn option) apply(o *options) { fn(o) }

type options struct {
trace Trace
trace Trace
pqKemMode PQKEMMode
}

func newOptions(opts []Option) *options {
Expand All @@ -55,6 +57,38 @@ func WithTrace(trace Trace) Option {
})
}

// Post-quantum TLS KEM mode. Determines whether a post-quantum safe KEM should
// be used when establishing a TLS connection.
type PQKEMMode int

const (
// Do not require use of a post-quantum KEM when establishing a TLS
// connection. Whether a post-quantum KEM is attempted depends on
// environmental configuration (e.g. GODEBUG setting tlskyber) and the target
// Go version at build time.
PQKEMModeDefault PQKEMMode = iota

// Attempt use of a post-quantum KEM as the most preferred key exchange
// method when establishing a TLS connection.
// Support for this requires Go 1.23 or later.
// Configuring this will cause connections to fail if support is not available.
PQKEMModeAttempt

// Require use of a post-quantum KEM when establishing a TLS connection.
// Attempts to initiate a connection with a key exchange method which is not
// post-quantum safe will fail. Support for this requires Go 1.23 or later.
// Configuring this will cause connections to fail if support is not available.
PQKEMModeRequire
)

// WithPQKEMMode configures whether a post-quantum safe KEM should be used when
// establishing a TLS connection.
func WithPQKEMMode(mode PQKEMMode) Option {
return option(func(opts *options) {
opts.pqKemMode = mode
})
}

// MTLSClientConfig returns a TLS configuration which presents an X509-SVID
// to the server and verifies and authorizes the server X509-SVID.
func MTLSClientConfig(svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) *tls.Config {
Expand All @@ -72,6 +106,7 @@ func HookMTLSClientConfig(config *tls.Config, svid x509svid.Source, bundle x509b
config.GetClientCertificate = GetClientCertificate(svid, opts...)
config.InsecureSkipVerify = true
config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...)
applyOptions(config, opts)
}

// MTLSWebClientConfig returns a TLS configuration which presents an X509-SVID
Expand All @@ -90,6 +125,7 @@ func HookMTLSWebClientConfig(config *tls.Config, svid x509svid.Source, roots *x5
resetAuthFields(config)
config.GetClientCertificate = GetClientCertificate(svid, opts...)
config.RootCAs = roots
applyOptions(config, opts)
}

// TLSServerConfig returns a TLS configuration which presents an X509-SVID
Expand All @@ -105,6 +141,7 @@ func TLSServerConfig(svid x509svid.Source, opts ...Option) *tls.Config {
func HookTLSServerConfig(config *tls.Config, svid x509svid.Source, opts ...Option) {
resetAuthFields(config)
config.GetCertificate = GetCertificate(svid, opts...)
applyOptions(config, opts)
}

// MTLSServerConfig returns a TLS configuration which presents an X509-SVID
Expand All @@ -125,6 +162,7 @@ func HookMTLSServerConfig(config *tls.Config, svid x509svid.Source, bundle x509b
config.ClientAuth = tls.RequireAnyClientCert
config.GetCertificate = GetCertificate(svid, opts...)
config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...)
applyOptions(config, opts)
}

// MTLSWebServerConfig returns a TLS configuration which presents a web
Expand All @@ -146,6 +184,7 @@ func HookMTLSWebServerConfig(config *tls.Config, cert *tls.Certificate, bundle x
config.ClientAuth = tls.RequireAnyClientCert
config.Certificates = []tls.Certificate{*cert}
config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...)
applyOptions(config, opts)
}

// GetCertificate returns a GetCertificate callback for tls.Config. It uses the
Expand Down Expand Up @@ -252,3 +291,45 @@ func resetAuthFields(config *tls.Config) {
config.NameToCertificate = nil //nolint:staticcheck // setting to nil is OK
config.RootCAs = nil
}

// Not exported by crypto/tls, so we define it here from the I-D.
const x25519Kyber768Draft00 tls.CurveID = 0x6399

func applyOptions(config *tls.Config, opts []Option) {
o := newOptions(opts)

// Apply post-quantum KEM mode option.
switch o.pqKemMode {
case PQKEMModeDefault:
// Nothing to do - allow default curve preferences.

case PQKEMModeAttempt:
if len(config.CurvePreferences) == 0 {
// This is copied from the crypto/tls default curve list.
config.CurvePreferences = []tls.CurveID{
x25519Kyber768Draft00,
tls.X25519,
tls.CurveP256,
tls.CurveP384,
tls.CurveP521,
}
} else if config.CurvePreferences[0] != x25519Kyber768Draft00 {
// Prepend X25519Kyber768Draft00 to the list, making it most preferred.
curves := make([]tls.CurveID, 0, len(config.CurvePreferences)+1)
curves = append(curves, x25519Kyber768Draft00)
curves = append(curves, config.CurvePreferences...)
config.CurvePreferences = curves
}

case PQKEMModeRequire:
// List only known PQ-safe KEMs as valid curves.
config.CurvePreferences = []tls.CurveID{
x25519Kyber768Draft00,
}

// Require TLS 1.3, as all PQ-safe KEMs require it anyway.
if config.MinVersion < tls.VersionTLS13 {
config.MinVersion = tls.VersionTLS13
}
}
}
134 changes: 128 additions & 6 deletions v2/spiffetls/tlsconfig/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func TestTLSClientConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Nil(t, config.CurvePreferences)

config = tlsconfig.TLSClientConfig(bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire))
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestHookTLSClientConfig(t *testing.T) {
Expand All @@ -62,7 +67,15 @@ func TestHookTLSClientConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Equal(t, []tls.CurveID{tls.CurveP256}, config.CurvePreferences)
assertUnrelatedFieldsUntouched(t, base, config)

base = createBaseTLSConfig()
config = createTestTLSConfig(base)

tlsconfig.HookTLSClientConfig(config, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt))
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestMTLSClientConfig(t *testing.T) {
Expand All @@ -82,6 +95,12 @@ func TestMTLSClientConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Nil(t, config.CurvePreferences)

config = tlsconfig.MTLSClientConfig(svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt))
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestHookMTLSClientConfig(t *testing.T) {
Expand All @@ -103,7 +122,16 @@ func TestHookMTLSClientConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Equal(t, []tls.CurveID{tls.CurveP256}, config.CurvePreferences)
assertUnrelatedFieldsUntouched(t, base, config)

base = createBaseTLSConfig()
config = createTestTLSConfig(base)

tlsconfig.HookMTLSClientConfig(config, svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt))
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestMTLSWebClientConfig(t *testing.T) {
Expand All @@ -122,6 +150,12 @@ func TestMTLSWebClientConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Equal(t, roots, config.RootCAs)
assert.Nil(t, config.VerifyPeerCertificate)
assert.Nil(t, config.CurvePreferences)

config = tlsconfig.MTLSWebClientConfig(svid, roots,
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt))
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestHookMTLSWebClientConfig(t *testing.T) {
Expand All @@ -143,7 +177,17 @@ func TestHookMTLSWebClientConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Equal(t, roots, config.RootCAs)
assert.Nil(t, config.VerifyPeerCertificate)
assert.Equal(t, []tls.CurveID{tls.CurveP256}, config.CurvePreferences)
assertUnrelatedFieldsUntouched(t, base, config)

base = createBaseTLSConfig()
config = createTestTLSConfig(base)
tlsconfig.HookMTLSWebClientConfig(config, svid, roots,
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt),
)

assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestTLSServerConfig(t *testing.T) {
Expand All @@ -161,6 +205,13 @@ func TestTLSServerConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.Nil(t, config.VerifyPeerCertificate)
assert.Nil(t, config.CurvePreferences)

config = tlsconfig.TLSServerConfig(svid,
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire),
)
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestHookTLSServerConfig(t *testing.T) {
Expand All @@ -181,6 +232,16 @@ func TestHookTLSServerConfig(t *testing.T) {
assert.Nil(t, config.RootCAs)
assert.Nil(t, config.VerifyPeerCertificate)
assertUnrelatedFieldsUntouched(t, base, config)
assert.Equal(t, []tls.CurveID{tls.CurveP256}, config.CurvePreferences)

base = createBaseTLSConfig()
config = createTestTLSConfig(base)

tlsconfig.HookTLSServerConfig(config, svid,
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt),
)
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestMTLSServerConfig(t *testing.T) {
Expand All @@ -200,6 +261,13 @@ func TestMTLSServerConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Nil(t, config.CurvePreferences)

config = tlsconfig.MTLSServerConfig(svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire),
)
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestHookMTLSServerConfig(t *testing.T) {
Expand All @@ -221,7 +289,17 @@ func TestHookMTLSServerConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Equal(t, []tls.CurveID{tls.CurveP256}, config.CurvePreferences)
assertUnrelatedFieldsUntouched(t, base, config)

base = createBaseTLSConfig()
config = createTestTLSConfig(base)

tlsconfig.HookMTLSServerConfig(config, svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(localTrace),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire),
)
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestMTLSWebServerConfig(t *testing.T) {
Expand All @@ -239,6 +317,11 @@ func TestMTLSWebServerConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Nil(t, config.CurvePreferences)

config = tlsconfig.MTLSWebServerConfig(tlsCert, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire))
assert.Greater(t, len(config.CurvePreferences), 0)
}

func TestHookMTLSWebServerConfig(t *testing.T) {
Expand All @@ -258,7 +341,15 @@ func TestHookMTLSWebServerConfig(t *testing.T) {
assert.Nil(t, config.NameToCertificate) //nolint:staticcheck // setting to nil is OK
assert.Nil(t, config.RootCAs)
assert.NotNil(t, config.VerifyPeerCertificate)
assert.Equal(t, []tls.CurveID{tls.CurveP256}, config.CurvePreferences)
assertUnrelatedFieldsUntouched(t, base, config)

base = createBaseTLSConfig()
config = createTestTLSConfig(base)

tlsconfig.HookMTLSWebServerConfig(config, tlsCert, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire))
assert.Greater(t, len(config.CurvePreferences), 0)
}

func hookedTracer(onGetCertificate, onGotCertificate func()) tlsconfig.Trace {
Expand Down Expand Up @@ -516,6 +607,10 @@ func TestWrapVerifyPeerCertificate(t *testing.T) {
}
}

func isPQKEMSupported(t *testing.T) bool {
return supportsPQKEM
}

func TestTLSHandshake(t *testing.T) {
td := spiffeid.RequireTrustDomainFromString("domain1.test")
ca1 := test.NewCA(t, td)
Expand All @@ -532,17 +627,42 @@ func TestTLSHandshake(t *testing.T) {
bundle3 := ca3.Bundle()

testCases := []struct {
name string
serverConfig *tls.Config
clientConfig *tls.Config
clientErr string
serverErr string
name string
serverConfig *tls.Config
clientConfig *tls.Config
clientErr string
serverErr string
shouldRunFunc func(t *testing.T) bool
}{
{
name: "success",
serverConfig: tlsconfig.TLSServerConfig(serverSVID),
clientConfig: tlsconfig.TLSClientConfig(bundle1, tlsconfig.AuthorizeAny()),
},
{
name: "success (PQ KEM attempted)",
serverConfig: tlsconfig.TLSServerConfig(serverSVID, tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt)),
clientConfig: tlsconfig.TLSClientConfig(bundle1, tlsconfig.AuthorizeAny(), tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt)),
shouldRunFunc: isPQKEMSupported,
},
{
name: "success (PQ KEM required by client)",
serverConfig: tlsconfig.TLSServerConfig(serverSVID, tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt)),
clientConfig: tlsconfig.TLSClientConfig(bundle1, tlsconfig.AuthorizeAny(), tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire)),
shouldRunFunc: isPQKEMSupported,
},
{
name: "success (PQ KEM required by server)",
serverConfig: tlsconfig.TLSServerConfig(serverSVID, tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire)),
clientConfig: tlsconfig.TLSClientConfig(bundle1, tlsconfig.AuthorizeAny(), tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeAttempt)),
shouldRunFunc: isPQKEMSupported,
},
{
name: "success (PQ KEM mutually required)",
serverConfig: tlsconfig.TLSServerConfig(serverSVID, tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire)),
clientConfig: tlsconfig.TLSClientConfig(bundle1, tlsconfig.AuthorizeAny(), tlsconfig.WithPQKEMMode(tlsconfig.PQKEMModeRequire)),
shouldRunFunc: isPQKEMSupported,
},
{
name: "authentication fails",
serverConfig: tlsconfig.TLSServerConfig(serverSVID),
Expand All @@ -569,7 +689,9 @@ func TestTLSHandshake(t *testing.T) {
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
testConnection(t, testCase.serverConfig, testCase.clientConfig, testCase.serverErr, testCase.clientErr)
if testCase.shouldRunFunc == nil || testCase.shouldRunFunc(t) {
testConnection(t, testCase.serverConfig, testCase.clientConfig, testCase.serverErr, testCase.clientErr)
}
})
}
}
Expand Down
Loading

0 comments on commit b813997

Please sign in to comment.