Skip to content

Commit

Permalink
assert that GetCertificate gets called as expected
Browse files Browse the repository at this point in the history
Signed-off-by: Antoine Grondin <[email protected]>
  • Loading branch information
aybabtme committed Sep 14, 2020
1 parent cefd677 commit c6b9a8f
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions v2/spiffetls/tlsconfig/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/require"
)

var LocalTrace = tlsconfig.Trace{
var localTrace = tlsconfig.Trace{
GetCertificate: func() interface{} {
fmt.Printf("got start of GetTLSCertificate\n")
return nil
Expand Down Expand Up @@ -71,7 +71,7 @@ func TestMTLSClientConfig(t *testing.T) {
svid := &x509svid.SVID{}

config := tlsconfig.MTLSClientConfig(svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

assert.Nil(t, config.Certificates)
Expand All @@ -92,7 +92,7 @@ func TestHookMTLSClientConfig(t *testing.T) {
config := createTestTLSConfig(base)

tlsconfig.HookMTLSClientConfig(config, svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

assert.Nil(t, config.Certificates)
Expand All @@ -111,7 +111,7 @@ func TestMTLSWebClientConfig(t *testing.T) {
roots := x509.NewCertPool()

config := tlsconfig.MTLSWebClientConfig(svid, roots,
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

assert.Nil(t, config.Certificates)
Expand All @@ -131,7 +131,7 @@ func TestHookMTLSWebClientConfig(t *testing.T) {
roots := x509.NewCertPool()

tlsconfig.HookMTLSWebClientConfig(config, svid, roots,
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

// Expected AuthFields
Expand All @@ -150,7 +150,7 @@ func TestTLSServerConfig(t *testing.T) {
svid := &x509svid.SVID{}

config := tlsconfig.TLSServerConfig(svid,
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

assert.Nil(t, config.Certificates)
Expand All @@ -169,7 +169,7 @@ func TestHookTLSServerConfig(t *testing.T) {
config := createTestTLSConfig(base)

tlsconfig.HookTLSServerConfig(config, svid,
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

assert.Nil(t, config.Certificates)
Expand All @@ -189,7 +189,7 @@ func TestMTLSServerConfig(t *testing.T) {
svid := &x509svid.SVID{}

config := tlsconfig.MTLSServerConfig(svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

assert.Nil(t, config.Certificates)
Expand All @@ -210,7 +210,7 @@ func TestHookMTLSServerConfig(t *testing.T) {
config := createTestTLSConfig(base)

tlsconfig.HookMTLSServerConfig(config, svid, bundle, tlsconfig.AuthorizeAny(),
tlsconfig.WithTrace(LocalTrace),
tlsconfig.WithTrace(localTrace),
)

assert.Nil(t, config.Certificates)
Expand Down Expand Up @@ -261,6 +261,22 @@ func TestHookMTLSWebServerConfig(t *testing.T) {
assertUnrelatedFieldsUntouched(t, base, config)
}

func hookedTracer(onGetCertificate, onGotCertificate func()) tlsconfig.Trace {
return tlsconfig.Trace{
GetCertificate: func() interface{} {
if onGetCertificate != nil {
onGetCertificate()
}
return nil
},
GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) {
if onGotCertificate != nil {
onGotCertificate()
}
},
}
}

func TestGetCertificate(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -293,7 +309,12 @@ func TestGetCertificate(t *testing.T) {
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
getCertificate := tlsconfig.GetCertificate(testCase.source, tlsconfig.WithTrace(LocalTrace))
getCertificateCalls := 0
tracer := hookedTracer(
func() { getCertificateCalls++ },
nil,
)
getCertificate := tlsconfig.GetCertificate(testCase.source, tlsconfig.WithTrace(tracer))
require.NotNil(t, getCertificate)

tlsCert, err := getCertificate(&tls.ClientHelloInfo{})
Expand All @@ -305,6 +326,7 @@ func TestGetCertificate(t *testing.T) {

require.NoError(t, err)
require.Equal(t, testCase.expectedCerts, tlsCert.Certificate)
require.Equal(t, 1, getCertificateCalls)
})
}
}
Expand Down Expand Up @@ -341,7 +363,12 @@ func TestGetClientCertificate(t *testing.T) {
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
getClientCertificate := tlsconfig.GetClientCertificate(testCase.source, tlsconfig.WithTrace(LocalTrace))
getCertificateCalls := 0
tracer := hookedTracer(
func() { getCertificateCalls++ },
nil,
)
getClientCertificate := tlsconfig.GetClientCertificate(testCase.source, tlsconfig.WithTrace(tracer))
require.NotNil(t, getClientCertificate)

tlsCert, err := getClientCertificate(&tls.CertificateRequestInfo{})
Expand All @@ -353,6 +380,7 @@ func TestGetClientCertificate(t *testing.T) {

require.NoError(t, err)
require.Equal(t, testCase.expectedCerts, tlsCert.Certificate)
require.Equal(t, 1, getCertificateCalls)
})
}
}
Expand Down

0 comments on commit c6b9a8f

Please sign in to comment.