From 0ca6ea64ee230edb25e288e6e409035316ec2247 Mon Sep 17 00:00:00 2001 From: Joe Williams Date: Tue, 28 Jul 2020 16:03:39 -0700 Subject: [PATCH] tlsconfig trace implementation Signed-off-by: Joe Williams add time to function call typo forgotten imports aybabtme's suggestions function name change clean up use two structs Signed-off-by: Antoine Grondin introduce func. opts. pattern and convert GetCertificate tracing to it Signed-off-by: Antoine Grondin adjust trickle down occurences of tlsconfig.Trace Signed-off-by: Antoine Grondin collapse all option types into 1 for whole tlsconfig package Signed-off-by: Antoine Grondin rework option type and trace mechanism Signed-off-by: Antoine Grondin remove weird go-sum addition Signed-off-by: Antoine Grondin test clean up Signed-off-by: Antoine Grondin Update v2 to beta Signed-off-by: Andrew Harding Signed-off-by: Antoine Grondin Print out the expected peer and domains when encountering mismatches Signed-off-by: Kyle Anderson Signed-off-by: Antoine Grondin assert that GetCertificate gets called as expected Signed-off-by: Antoine Grondin pass config.Option to every func Signed-off-by: Antoine Grondin pluralize options funcs Signed-off-by: Antoine Grondin --- README.md | 8 +- spiffe/expect.go | 6 +- spiffe/expect_test.go | 6 +- spiffe/tls_verify_test.go | 2 +- v2/spiffetls/dial.go | 4 +- v2/spiffetls/listen.go | 4 +- v2/spiffetls/option.go | 17 ++++ v2/spiffetls/tlsconfig/config.go | 113 ++++++++++++++++++-------- v2/spiffetls/tlsconfig/config_test.go | 75 ++++++++++++++--- v2/spiffetls/tlsconfig/trace.go | 18 ++++ 10 files changed, 193 insertions(+), 60 deletions(-) create mode 100644 v2/spiffetls/tlsconfig/trace.go diff --git a/README.md b/README.md index 1768b973..be26af10 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ -# go-spiffe (v1) library [![GoDoc](https://godoc.org/github.com/spiffe/go-spiffe?status.svg)](https://godoc.org/github.com/spiffe/go-spiffe) - # Deprecation Warning __NOTE:__ This version of the library will be deprecated soon. -The new [v2](./v2) module is currently in alpha release and published under +The [v2](./v2) module is in **beta** and published under `github.com/spiffe/go-spiffe/v2`, following go module guidelines. -New code should consider using the `v2` module. +**New code should strongly consider using the `v2` module.** See the [v2 README](./v2) for more details. +# go-spiffe (v1) library [![GoDoc](https://godoc.org/github.com/spiffe/go-spiffe?status.svg)](https://godoc.org/github.com/spiffe/go-spiffe) + ## Overview The go-spiffe project provides two components: diff --git a/spiffe/expect.go b/spiffe/expect.go index 8bb5100f..4a33f86e 100644 --- a/spiffe/expect.go +++ b/spiffe/expect.go @@ -22,7 +22,7 @@ func ExpectAnyPeer() ExpectPeerFunc { func ExpectPeer(expectedID string) ExpectPeerFunc { return func(peerID string, _ [][]*x509.Certificate) error { if peerID != expectedID { - return fmt.Errorf("unexpected peer ID %q", peerID) + return fmt.Errorf("unexpected peer ID %q: expected %q", peerID, expectedID) } return nil } @@ -36,7 +36,7 @@ func ExpectPeers(expectedIDs ...string) ExpectPeerFunc { } return func(peerID string, _ [][]*x509.Certificate) error { if _, ok := m[peerID]; !ok { - return fmt.Errorf("unexpected peer ID %q", peerID) + return fmt.Errorf("unexpected peer ID %q: expected one of %q", peerID, expectedIDs) } return nil } @@ -47,7 +47,7 @@ func ExpectPeers(expectedIDs ...string) ExpectPeerFunc { func ExpectPeerInDomain(expectedDomain string) ExpectPeerFunc { return func(peerID string, _ [][]*x509.Certificate) error { if domain := getPeerTrustDomain(peerID); domain != expectedDomain { - return fmt.Errorf("unexpected peer trust domain %q", domain) + return fmt.Errorf("unexpected trust domain %q for peer ID %q: expected trust domain %q", domain, peerID, expectedDomain) } return nil } diff --git a/spiffe/expect_test.go b/spiffe/expect_test.go index c0ded68f..d44f3c2a 100644 --- a/spiffe/expect_test.go +++ b/spiffe/expect_test.go @@ -18,7 +18,7 @@ func TestExpectPeer(t *testing.T) { expect := ExpectPeer("spiffe://domain.test/workload1") assert.NoError(t, expect("spiffe://domain.test/workload1", nil)) assert.EqualError(t, expect("spiffe://domain.test/workload2", nil), - `unexpected peer ID "spiffe://domain.test/workload2"`) + `unexpected peer ID "spiffe://domain.test/workload2": expected "spiffe://domain.test/workload1"`) } func TestExpectPeers(t *testing.T) { @@ -26,12 +26,12 @@ func TestExpectPeers(t *testing.T) { assert.NoError(t, expect("spiffe://domain.test/workload1", nil)) assert.NoError(t, expect("spiffe://domain.test/workload2", nil)) assert.EqualError(t, expect("spiffe://domain.test/workload3", nil), - `unexpected peer ID "spiffe://domain.test/workload3"`) + `unexpected peer ID "spiffe://domain.test/workload3": expected one of ["spiffe://domain.test/workload1" "spiffe://domain.test/workload2"]`) } func TestExpectPeerInDomain(t *testing.T) { expect := ExpectPeerInDomain("domain1.test") assert.NoError(t, expect("spiffe://domain1.test/workload", nil)) assert.EqualError(t, expect("spiffe://domain2.test/workload", nil), - `unexpected peer trust domain "domain2.test"`) + `unexpected trust domain "domain2.test" for peer ID "spiffe://domain2.test/workload": expected trust domain "domain1.test"`) } diff --git a/spiffe/tls_verify_test.go b/spiffe/tls_verify_test.go index 60d37c6e..92aebc71 100644 --- a/spiffe/tls_verify_test.go +++ b/spiffe/tls_verify_test.go @@ -64,7 +64,7 @@ func TestVerifyPeerCertificate(t *testing.T) { chain: peer1, roots: roots1, expect: ExpectPeer("spiffe://domain2.test/workload"), - err: `unexpected peer ID "spiffe://domain1.test/workload"`, + err: `unexpected peer ID "spiffe://domain1.test/workload": expected "spiffe://domain2.test/workload"`, }, { name: "bad peer id", diff --git a/v2/spiffetls/dial.go b/v2/spiffetls/dial.go index 6068ed73..5c80feaa 100644 --- a/v2/spiffetls/dial.go +++ b/v2/spiffetls/dial.go @@ -59,9 +59,9 @@ func DialWithMode(ctx context.Context, network, addr string, mode DialMode, opti case tlsClientMode: tlsconfig.HookTLSClientConfig(tlsConfig, m.bundle, m.authorizer) case mtlsClientMode: - tlsconfig.HookMTLSClientConfig(tlsConfig, m.svid, m.bundle, m.authorizer) + tlsconfig.HookMTLSClientConfig(tlsConfig, m.svid, m.bundle, m.authorizer, opt.tlsoptions...) case mtlsWebClientMode: - tlsconfig.HookMTLSWebClientConfig(tlsConfig, m.svid, m.roots) + tlsconfig.HookMTLSWebClientConfig(tlsConfig, m.svid, m.roots, opt.tlsoptions...) default: return nil, spiffetlsErr.New("unknown client mode: %v", m.mode) } diff --git a/v2/spiffetls/listen.go b/v2/spiffetls/listen.go index 5e0a0add..95076067 100644 --- a/v2/spiffetls/listen.go +++ b/v2/spiffetls/listen.go @@ -89,9 +89,9 @@ func NewListenerWithMode(ctx context.Context, inner net.Listener, mode ListenMod switch m.mode { case tlsServerMode: - tlsconfig.HookTLSServerConfig(tlsConfig, m.svid) + tlsconfig.HookTLSServerConfig(tlsConfig, m.svid, opt.tlsoptions...) case mtlsServerMode: - tlsconfig.HookMTLSServerConfig(tlsConfig, m.svid, m.bundle, m.authorizer) + tlsconfig.HookMTLSServerConfig(tlsConfig, m.svid, m.bundle, m.authorizer, opt.tlsoptions...) case mtlsWebServerMode: tlsconfig.HookMTLSWebServerConfig(tlsConfig, m.cert, m.bundle, m.authorizer) default: diff --git a/v2/spiffetls/option.go b/v2/spiffetls/option.go index a7dcf8d3..6d51d120 100644 --- a/v2/spiffetls/option.go +++ b/v2/spiffetls/option.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net" + "github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig" "github.com/zeebo/errs" ) @@ -23,12 +24,14 @@ func (fn dialOption) apply(c *dialConfig) { type dialConfig struct { baseTLSConf *tls.Config dialer *net.Dialer + tlsoptions []tlsconfig.Option } type listenOption func(*listenConfig) type listenConfig struct { baseTLSConf *tls.Config + tlsoptions []tlsconfig.Option } func (fn listenOption) apply(c *listenConfig) { @@ -44,6 +47,13 @@ func WithDialTLSConfigBase(base *tls.Config) DialOption { }) } +// WithDialTLSOptions provides options to use for the TLS config. +func WithDialTLSOptions(opts ...tlsconfig.Option) DialOption { + return dialOption(func(c *dialConfig) { + c.tlsoptions = opts + }) +} + // WithDialer provides a net dialer to use. If unset, the standard net dialer // will be used. func WithDialer(dialer *net.Dialer) DialOption { @@ -65,3 +75,10 @@ func WithListenTLSConfigBase(base *tls.Config) ListenOption { c.baseTLSConf = base }) } + +// WithListenTLSOptions provides options to use when doing Server mTLS. +func WithListenTLSOptions(opts ...tlsconfig.Option) ListenOption { + return listenOption(func(c *listenConfig) { + c.tlsoptions = opts + }) +} diff --git a/v2/spiffetls/tlsconfig/config.go b/v2/spiffetls/tlsconfig/config.go index 17a2b1a3..f9194b39 100644 --- a/v2/spiffetls/tlsconfig/config.go +++ b/v2/spiffetls/tlsconfig/config.go @@ -10,9 +10,9 @@ import ( // TLSClientConfig returns a TLS configuration which verifies and authorizes // the server X509-SVID. -func TLSClientConfig(bundle x509bundle.Source, authorizer Authorizer) *tls.Config { +func TLSClientConfig(bundle x509bundle.Source, authorizer Authorizer, opts ...Option) *tls.Config { config := new(tls.Config) - HookTLSClientConfig(config, bundle, authorizer) + HookTLSClientConfig(config, bundle, authorizer, opts...) return config } @@ -20,17 +20,46 @@ func TLSClientConfig(bundle x509bundle.Source, authorizer Authorizer) *tls.Confi // the server X509-SVID. If there is an existing callback set for // VerifyPeerCertificate it will be wrapped by by this package and invoked // after SPIFFE authentication has completed. -func HookTLSClientConfig(config *tls.Config, bundle x509bundle.Source, authorizer Authorizer) { +func HookTLSClientConfig(config *tls.Config, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) { resetAuthFields(config) config.InsecureSkipVerify = true - config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer) + config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...) +} + +// A Option changes the defaults used to by mTLS ClientConfig functions. +type Option interface { + apply(*options) +} + +type option func(*options) + +func (fn option) apply(o *options) { fn(o) } + +type options struct { + trace Trace +} + +func newOptions(opts []Option) *options { + out := &options{} + for _, opt := range opts { + opt.apply(out) + } + return out +} + +// WithTrace will use the provided tracing callbacks +// when various TLS config functions gets invoked. +func WithTrace(trace Trace) Option { + return option(func(opts *options) { + opts.trace = trace + }) } // MTLSClientConfig returns a TLS configuration which presents an X509-SVID // to the server and verifies and authorizes the server X509-SVID. -func MTLSClientConfig(svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer) *tls.Config { +func MTLSClientConfig(svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) *tls.Config { config := new(tls.Config) - HookMTLSClientConfig(config, svid, bundle, authorizer) + HookMTLSClientConfig(config, svid, bundle, authorizer, opts...) return config } @@ -38,51 +67,51 @@ func MTLSClientConfig(svid x509svid.Source, bundle x509bundle.Source, authorizer // to the server and verify and authorize the server X509-SVID. If there is an // existing callback set for VerifyPeerCertificate it will be wrapped by by // this package and invoked after SPIFFE authentication has completed. -func HookMTLSClientConfig(config *tls.Config, svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer) { +func HookMTLSClientConfig(config *tls.Config, svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) { resetAuthFields(config) - config.GetClientCertificate = GetClientCertificate(svid) + config.GetClientCertificate = GetClientCertificate(svid, opts...) config.InsecureSkipVerify = true - config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer) + config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...) } // MTLSWebClientConfig returns a TLS configuration which presents an X509-SVID // to the server and verifies the server certificate using provided roots (or // the system roots if nil). -func MTLSWebClientConfig(svid x509svid.Source, roots *x509.CertPool) *tls.Config { +func MTLSWebClientConfig(svid x509svid.Source, roots *x509.CertPool, opts ...Option) *tls.Config { config := new(tls.Config) - HookMTLSWebClientConfig(config, svid, roots) + HookMTLSWebClientConfig(config, svid, roots, opts...) return config } // HookMTLSWebClientConfig sets up the TLS configuration to present an // X509-SVID to the server and verifies the server certificate using the // provided roots (or the system roots if nil). -func HookMTLSWebClientConfig(config *tls.Config, svid x509svid.Source, roots *x509.CertPool) { +func HookMTLSWebClientConfig(config *tls.Config, svid x509svid.Source, roots *x509.CertPool, opts ...Option) { resetAuthFields(config) - config.GetClientCertificate = GetClientCertificate(svid) + config.GetClientCertificate = GetClientCertificate(svid, opts...) config.RootCAs = roots } // TLSServerConfig returns a TLS configuration which presents an X509-SVID // to the client and does not require or verify client certificates. -func TLSServerConfig(svid x509svid.Source) *tls.Config { +func TLSServerConfig(svid x509svid.Source, opts ...Option) *tls.Config { config := new(tls.Config) - HookTLSServerConfig(config, svid) + HookTLSServerConfig(config, svid, opts...) return config } // HookTLSServerConfig sets up the TLS configuration to present an X509-SVID // to the client and to not require or verify client certificates. -func HookTLSServerConfig(config *tls.Config, svid x509svid.Source) { +func HookTLSServerConfig(config *tls.Config, svid x509svid.Source, opts ...Option) { resetAuthFields(config) - config.GetCertificate = GetCertificate(svid) + config.GetCertificate = GetCertificate(svid, opts...) } // MTLSServerConfig returns a TLS configuration which presents an X509-SVID // to the client and requires, verifies, and authorizes client X509-SVIDs. -func MTLSServerConfig(svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer) *tls.Config { +func MTLSServerConfig(svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) *tls.Config { config := new(tls.Config) - HookMTLSServerConfig(config, svid, bundle, authorizer) + HookMTLSServerConfig(config, svid, bundle, authorizer, opts...) return config } @@ -91,19 +120,19 @@ func MTLSServerConfig(svid x509svid.Source, bundle x509bundle.Source, authorizer // there is an existing callback set for VerifyPeerCertificate it will be // wrapped by by this package and invoked after SPIFFE authentication has // completed. -func HookMTLSServerConfig(config *tls.Config, svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer) { +func HookMTLSServerConfig(config *tls.Config, svid x509svid.Source, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) { resetAuthFields(config) config.ClientAuth = tls.RequireAnyClientCert - config.GetCertificate = GetCertificate(svid) - config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer) + config.GetCertificate = GetCertificate(svid, opts...) + config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...) } // MTLSWebServerConfig returns a TLS configuration which presents a web // server certificate to the client and requires, verifies, and authorizes // client X509-SVIDs. -func MTLSWebServerConfig(cert *tls.Certificate, bundle x509bundle.Source, authorizer Authorizer) *tls.Config { +func MTLSWebServerConfig(cert *tls.Certificate, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) *tls.Config { config := new(tls.Config) - HookMTLSWebServerConfig(config, cert, bundle, authorizer) + HookMTLSWebServerConfig(config, cert, bundle, authorizer, opts...) return config } @@ -112,34 +141,36 @@ func MTLSWebServerConfig(cert *tls.Certificate, bundle x509bundle.Source, author // X509-SVIDs. If there is an existing callback set for VerifyPeerCertificate // it will be wrapped by by this package and invoked after SPIFFE // authentication has completed. -func HookMTLSWebServerConfig(config *tls.Config, cert *tls.Certificate, bundle x509bundle.Source, authorizer Authorizer) { +func HookMTLSWebServerConfig(config *tls.Config, cert *tls.Certificate, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) { resetAuthFields(config) config.ClientAuth = tls.RequireAnyClientCert config.Certificates = []tls.Certificate{*cert} - config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer) + config.VerifyPeerCertificate = WrapVerifyPeerCertificate(config.VerifyPeerCertificate, bundle, authorizer, opts...) } // GetCertificate returns a GetCertificate callback for tls.Config. It uses the // given X509-SVID getter to obtain a server X509-SVID for the TLS handshake. -func GetCertificate(svid x509svid.Source) func(*tls.ClientHelloInfo) (*tls.Certificate, error) { +func GetCertificate(svid x509svid.Source, opts ...Option) func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + opt := newOptions(opts) return func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - return getTLSCertificate(svid) + return getTLSCertificate(svid, opt.trace) } } // GetClientCertificate returns a GetClientCertificate callback for tls.Config. // It uses the given X509-SVID getter to obtain a client X509-SVID for the TLS // handshake. -func GetClientCertificate(svid x509svid.Source) func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { +func GetClientCertificate(svid x509svid.Source, opts ...Option) func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + opt := newOptions(opts) return func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - return getTLSCertificate(svid) + return getTLSCertificate(svid, opt.trace) } } // VerifyPeerCertificate returns a VerifyPeerCertificate callback for // tls.Config. It uses the given bundle source and authorizer to verify and // authorize X509-SVIDs provided by peers during the TLS handshake. -func VerifyPeerCertificate(bundle x509bundle.Source, authorizer Authorizer) func([][]byte, [][]*x509.Certificate) error { +func VerifyPeerCertificate(bundle x509bundle.Source, authorizer Authorizer, opts ...Option) func([][]byte, [][]*x509.Certificate) error { return func(raw [][]byte, _ [][]*x509.Certificate) error { id, certs, err := x509svid.ParseAndVerify(raw, bundle) if err != nil { @@ -154,9 +185,9 @@ func VerifyPeerCertificate(bundle x509bundle.Source, authorizer Authorizer) func // SPIFFE authentication against the peer certificates using the given bundle and // authorizer. The wrapped callback will be passed the verified chains. // Note: TLS clients must set `InsecureSkipVerify` when doing SPIFFE authentication to disable hostname verification. -func WrapVerifyPeerCertificate(wrapped func([][]byte, [][]*x509.Certificate) error, bundle x509bundle.Source, authorizer Authorizer) func([][]byte, [][]*x509.Certificate) error { +func WrapVerifyPeerCertificate(wrapped func([][]byte, [][]*x509.Certificate) error, bundle x509bundle.Source, authorizer Authorizer, opts ...Option) func([][]byte, [][]*x509.Certificate) error { if wrapped == nil { - return VerifyPeerCertificate(bundle, authorizer) + return VerifyPeerCertificate(bundle, authorizer, opts...) } return func(raw [][]byte, _ [][]*x509.Certificate) error { @@ -173,10 +204,18 @@ func WrapVerifyPeerCertificate(wrapped func([][]byte, [][]*x509.Certificate) err } } -func getTLSCertificate(svid x509svid.Source) (*tls.Certificate, error) { +func getTLSCertificate(svid x509svid.Source, trace Trace) (*tls.Certificate, error) { + var traceVal interface{} + if trace.GetCertificate != nil { + traceVal = trace.GetCertificate() + } + s, err := svid.GetX509SVID() if err != nil { - return nil, err + if trace.GotCertificate != nil { + trace.GotCertificate(traceVal, GotCertificateInfo{Err: err}) + return nil, err + } } cert := &tls.Certificate{ @@ -188,6 +227,10 @@ func getTLSCertificate(svid x509svid.Source) (*tls.Certificate, error) { cert.Certificate = append(cert.Certificate, svidCert.Raw) } + if trace.GotCertificate != nil { + trace.GotCertificate(traceVal, GotCertificateInfo{Cert: cert}) + } + return cert, nil } diff --git a/v2/spiffetls/tlsconfig/config_test.go b/v2/spiffetls/tlsconfig/config_test.go index 98ef6b3e..978c3ae2 100644 --- a/v2/spiffetls/tlsconfig/config_test.go +++ b/v2/spiffetls/tlsconfig/config_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" "strings" "testing" "time" @@ -19,6 +20,16 @@ import ( "github.com/stretchr/testify/require" ) +var localTrace = tlsconfig.Trace{ + GetCertificate: func() interface{} { + fmt.Printf("got start of GetTLSCertificate\n") + return nil + }, + GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) { + fmt.Printf("got end of GetTLSCertificate\n") + }, +} + func TestTLSClientConfig(t *testing.T) { trustDomain := spiffeid.RequireTrustDomainFromString("test.domain") bundle := x509bundle.New(trustDomain) @@ -59,7 +70,9 @@ func TestMTLSClientConfig(t *testing.T) { bundle := x509bundle.New(trustDomain) svid := &x509svid.SVID{} - config := tlsconfig.MTLSClientConfig(svid, bundle, tlsconfig.AuthorizeAny()) + config := tlsconfig.MTLSClientConfig(svid, bundle, tlsconfig.AuthorizeAny(), + tlsconfig.WithTrace(localTrace), + ) assert.Nil(t, config.Certificates) assert.Equal(t, tls.NoClientCert, config.ClientAuth) @@ -78,7 +91,9 @@ func TestHookMTLSClientConfig(t *testing.T) { base := createBaseTLSConfig() config := createTestTLSConfig(base) - tlsconfig.HookMTLSClientConfig(config, svid, bundle, tlsconfig.AuthorizeAny()) + tlsconfig.HookMTLSClientConfig(config, svid, bundle, tlsconfig.AuthorizeAny(), + tlsconfig.WithTrace(localTrace), + ) assert.Nil(t, config.Certificates) assert.Equal(t, tls.NoClientCert, config.ClientAuth) @@ -95,7 +110,9 @@ func TestMTLSWebClientConfig(t *testing.T) { svid := &x509svid.SVID{} roots := x509.NewCertPool() - config := tlsconfig.MTLSWebClientConfig(svid, roots) + config := tlsconfig.MTLSWebClientConfig(svid, roots, + tlsconfig.WithTrace(localTrace), + ) assert.Nil(t, config.Certificates) assert.Equal(t, tls.NoClientCert, config.ClientAuth) @@ -113,7 +130,9 @@ func TestHookMTLSWebClientConfig(t *testing.T) { config := createTestTLSConfig(base) roots := x509.NewCertPool() - tlsconfig.HookMTLSWebClientConfig(config, svid, roots) + tlsconfig.HookMTLSWebClientConfig(config, svid, roots, + tlsconfig.WithTrace(localTrace), + ) // Expected AuthFields assert.Nil(t, config.Certificates) @@ -130,7 +149,9 @@ func TestHookMTLSWebClientConfig(t *testing.T) { func TestTLSServerConfig(t *testing.T) { svid := &x509svid.SVID{} - config := tlsconfig.TLSServerConfig(svid) + config := tlsconfig.TLSServerConfig(svid, + tlsconfig.WithTrace(localTrace), + ) assert.Nil(t, config.Certificates) assert.Equal(t, tls.NoClientCert, config.ClientAuth) @@ -147,7 +168,9 @@ func TestHookTLSServerConfig(t *testing.T) { base := createBaseTLSConfig() config := createTestTLSConfig(base) - tlsconfig.HookTLSServerConfig(config, svid) + tlsconfig.HookTLSServerConfig(config, svid, + tlsconfig.WithTrace(localTrace), + ) assert.Nil(t, config.Certificates) assert.Equal(t, tls.NoClientCert, config.ClientAuth) @@ -165,7 +188,9 @@ func TestMTLSServerConfig(t *testing.T) { bundle := x509bundle.New(trustDomain) svid := &x509svid.SVID{} - config := tlsconfig.MTLSServerConfig(svid, bundle, tlsconfig.AuthorizeAny()) + config := tlsconfig.MTLSServerConfig(svid, bundle, tlsconfig.AuthorizeAny(), + tlsconfig.WithTrace(localTrace), + ) assert.Nil(t, config.Certificates) assert.Equal(t, tls.RequireAnyClientCert, config.ClientAuth) @@ -184,7 +209,9 @@ func TestHookMTLSServerConfig(t *testing.T) { base := createBaseTLSConfig() config := createTestTLSConfig(base) - tlsconfig.HookMTLSServerConfig(config, svid, bundle, tlsconfig.AuthorizeAny()) + tlsconfig.HookMTLSServerConfig(config, svid, bundle, tlsconfig.AuthorizeAny(), + tlsconfig.WithTrace(localTrace), + ) assert.Nil(t, config.Certificates) assert.Equal(t, tls.RequireAnyClientCert, config.ClientAuth) @@ -234,6 +261,22 @@ func TestHookMTLSWebServerConfig(t *testing.T) { assertUnrelatedFieldsUntouched(t, base, config) } +func hookedTracer(onGetCertificate, onGotCertificate func()) tlsconfig.Trace { + return tlsconfig.Trace{ + GetCertificate: func() interface{} { + if onGetCertificate != nil { + onGetCertificate() + } + return nil + }, + GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) { + if onGotCertificate != nil { + onGotCertificate() + } + }, + } +} + func TestGetCertificate(t *testing.T) { testCases := []struct { name string @@ -266,7 +309,12 @@ func TestGetCertificate(t *testing.T) { for _, testCase := range testCases { testCase := testCase t.Run(testCase.name, func(t *testing.T) { - getCertificate := tlsconfig.GetCertificate(testCase.source) + getCertificateCalls := 0 + tracer := hookedTracer( + func() { getCertificateCalls++ }, + nil, + ) + getCertificate := tlsconfig.GetCertificate(testCase.source, tlsconfig.WithTrace(tracer)) require.NotNil(t, getCertificate) tlsCert, err := getCertificate(&tls.ClientHelloInfo{}) @@ -278,6 +326,7 @@ func TestGetCertificate(t *testing.T) { require.NoError(t, err) require.Equal(t, testCase.expectedCerts, tlsCert.Certificate) + require.Equal(t, 1, getCertificateCalls) }) } } @@ -314,7 +363,12 @@ func TestGetClientCertificate(t *testing.T) { for _, testCase := range testCases { testCase := testCase t.Run(testCase.name, func(t *testing.T) { - getClientCertificate := tlsconfig.GetClientCertificate(testCase.source) + getCertificateCalls := 0 + tracer := hookedTracer( + func() { getCertificateCalls++ }, + nil, + ) + getClientCertificate := tlsconfig.GetClientCertificate(testCase.source, tlsconfig.WithTrace(tracer)) require.NotNil(t, getClientCertificate) tlsCert, err := getClientCertificate(&tls.CertificateRequestInfo{}) @@ -326,6 +380,7 @@ func TestGetClientCertificate(t *testing.T) { require.NoError(t, err) require.Equal(t, testCase.expectedCerts, tlsCert.Certificate) + require.Equal(t, 1, getCertificateCalls) }) } } diff --git a/v2/spiffetls/tlsconfig/trace.go b/v2/spiffetls/tlsconfig/trace.go new file mode 100644 index 00000000..0fdd1269 --- /dev/null +++ b/v2/spiffetls/tlsconfig/trace.go @@ -0,0 +1,18 @@ +package tlsconfig + +import ( + "crypto/tls" +) + +// GotCertificateInfo provides err and TLS certificate info to Trace +type GotCertificateInfo struct { + Cert *tls.Certificate + Err error +} + +// Trace is the interface to define what functions are triggered when functions +// in tlsconfig are called +type Trace struct { + GetCertificate func() interface{} + GotCertificate func(interface{}, GotCertificateInfo) +}