From e5e6d4b4ecb9c6dce7e027d1f9e070c05834370b Mon Sep 17 00:00:00 2001 From: gracewilcox Date: Thu, 3 Oct 2024 16:07:07 -0700 Subject: [PATCH] cae-tests --- .../internal/challenge_policy_test.go | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/sdk/security/keyvault/internal/challenge_policy_test.go b/sdk/security/keyvault/internal/challenge_policy_test.go index 0529b0a73c8e..e8a685a79942 100644 --- a/sdk/security/keyvault/internal/challenge_policy_test.go +++ b/sdk/security/keyvault/internal/challenge_policy_test.go @@ -100,6 +100,150 @@ func TestChallengePolicy(t *testing.T) { } } +var ( + accessTk = "***" + kvChallenge = `Bearer authorization="https://login.microsoftonline.com/tenant", resource="https://vault.azure.net"` + caeChallenge1 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzE="` + caeChallenge2 = `Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="dGVzdGluZzI="` +) + +func TestChallengePolicy_CAE(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", kvChallenge), + mock.WithStatusCode(401), + ) + srv.AppendResponse() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge1), + mock.WithStatusCode(401), + ) + srv.AppendResponse() + + tkReqs := 0 + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + require.True(t, tro.EnableCAE) + tkReqs += 1 + switch tkReqs { + case 1: + require.Empty(t, tro.Claims) + case 2: + // second call should include challenge claims + require.Equal(t, "testing1", tro.Claims) + default: + t.Fatal("unexpected token request") + } + return azcore.AccessToken{Token: accessTk, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewKeyVaultChallengePolicy(cred, nil) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + + // req 1 kv then regular + req, err := runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net") + require.NoError(t, err) + res, err := pl.Do(req) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + require.Equal(t, tkReqs, 1) + + // req 2 cae + req, err = runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net") + require.NoError(t, err) + res, err = pl.Do(req) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + require.Equal(t, tkReqs, 2) +} + +func TestChallengePolicy_KVThenCAE(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", kvChallenge), + mock.WithStatusCode(401), + ) + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge1), + mock.WithStatusCode(401), + ) + srv.AppendResponse() + + tkReqs := 0 + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + require.True(t, tro.EnableCAE) + tkReqs += 1 + switch tkReqs { + case 1: + require.Empty(t, tro.Claims) + case 2: + // second call should include challenge claims + require.Equal(t, "testing1", tro.Claims) + default: + t.Fatal("unexpected token request") + } + return azcore.AccessToken{Token: accessTk, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewKeyVaultChallengePolicy(cred, nil) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + req, err := runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net") + require.NoError(t, err) + res, err := pl.Do(req) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + require.Equal(t, tkReqs, 2) +} + +func TestChallengePolicy_TwoCAEChallenges(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", kvChallenge), + mock.WithStatusCode(401), + ) + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge1), + mock.WithStatusCode(401), + ) + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", caeChallenge2), + mock.WithStatusCode(401), + ) + tkReqs := 0 + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + require.True(t, tro.EnableCAE) + tkReqs += 1 + switch tkReqs { + case 1: + require.Empty(t, tro.Claims) + case 2: + // second call should include challenge claims + require.Equal(t, "testing1", tro.Claims) + default: + t.Fatal("unexpected token request") + } + return azcore.AccessToken{Token: accessTk, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewKeyVaultChallengePolicy(cred, nil) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + req, err := runtime.NewRequest(context.Background(), "GET", "https://42.vault.azure.net") + require.NoError(t, err) + res, err := pl.Do(req) + require.NoError(t, err) + require.Equal(t, 401, res.StatusCode) + require.Equal(t, caeChallenge2, res.Header.Get("WWW-Authenticate")) + require.Equal(t, tkReqs, 2) +} + func TestParseTenant(t *testing.T) { actual := parseTenant("") require.Empty(t, actual)