diff --git a/provider/provider_aws_adfs.go b/provider/provider_aws_adfs.go index d88c05e..6dd66d1 100644 --- a/provider/provider_aws_adfs.go +++ b/provider/provider_aws_adfs.go @@ -30,6 +30,15 @@ type AWSRole struct { Principal string `json:"principal"` } +type AWSCreds struct { + AccessKey string `json:"accesskey"` + SecretKey string `json:"secretkey"` + SessionToken string `json:"sessiontoken"` + SecurityToken string `json:"securitytoken"` + Principal string `json:"principal"` + Expires time.Time `json:"expires"` +} + func NewProviderAwsAdfs(loginURL string, clusterID string) (*ProviderAwsAdfs, error) { // Check if login URL is valid @@ -129,26 +138,26 @@ func (p *ProviderAwsAdfs) Login(user, password string) (string, map[string]strin return samlAssertion, roles, nil } -// Token returns a Kubernetes token based on SAML assertion and role -func (p *ProviderAwsAdfs) Token(SAMLAssertion, role string) (string, error) { +// AssumeRole will take a SAML assertion and role and return AWS credentials +func (p *ProviderAwsAdfs) AssumeRole(SAMLAssertion, role string) (AWSCreds, error) { // Decode base64 role decodedRole, err := base64.StdEncoding.DecodeString(role) if err != nil { - return "", fmt.Errorf("error decoding role: %w", err) + return AWSCreds{}, fmt.Errorf("error decoding role: %w", err) } // Unmarshal the role var awsRole AWSRole err = json.Unmarshal(decodedRole, &awsRole) if err != nil { - return "", fmt.Errorf("error unmarshalling role: %w", err) + return AWSCreds{}, fmt.Errorf("error unmarshalling role: %w", err) } // Create a new STS session sess, err := session.NewSession(&aws.Config{}) if err != nil { - return "", fmt.Errorf("error creating aws session: %w", err) + return AWSCreds{}, fmt.Errorf("error creating aws session: %w", err) } svc := sts.New(sess) @@ -161,25 +170,42 @@ func (p *ProviderAwsAdfs) Token(SAMLAssertion, role string) (string, error) { } resp, err := svc.AssumeRoleWithSAML(assumeInput) if err != nil { - return "", fmt.Errorf("assuming role failed: %w", err) + return AWSCreds{}, fmt.Errorf("assuming role failed: %w", err) } + + creds := AWSCreds{ + AccessKey: aws.StringValue(resp.Credentials.AccessKeyId), + SecretKey: aws.StringValue(resp.Credentials.SecretAccessKey), + SessionToken: aws.StringValue(resp.Credentials.SessionToken), + SecurityToken: aws.StringValue(resp.Credentials.SessionToken), + Principal: aws.StringValue(resp.AssumedRoleUser.Arn), + Expires: resp.Credentials.Expiration.Local(), + } + + return creds, nil +} + +// Token returns a Kubernetes token based on temporary AWS credentials +func (p *ProviderAwsAdfs) Token(tmpCreds AWSCreds) (string, error) { + + // Set temporary credentials creds := &awsconfig.AWSCredentials{ - AWSAccessKey: aws.StringValue(resp.Credentials.AccessKeyId), - AWSSecretKey: aws.StringValue(resp.Credentials.SecretAccessKey), - AWSSessionToken: aws.StringValue(resp.Credentials.SessionToken), - AWSSecurityToken: aws.StringValue(resp.Credentials.SessionToken), - PrincipalARN: aws.StringValue(resp.AssumedRoleUser.Arn), - Expires: resp.Credentials.Expiration.Local(), + AWSAccessKey: tmpCreds.AccessKey, + AWSSecretKey: tmpCreds.SecretKey, + AWSSessionToken: tmpCreds.SessionToken, + AWSSecurityToken: tmpCreds.SecurityToken, + PrincipalARN: tmpCreds.Principal, + Expires: tmpCreds.Expires, } // Create a new AWS STS session with the temporary credentials - sess, err = session.NewSession(&aws.Config{ + sess, err := session.NewSession(&aws.Config{ Credentials: credentials.NewStaticCredentials(creds.AWSAccessKey, creds.AWSSecretKey, creds.AWSSessionToken), }) if err != nil { return "", fmt.Errorf("error creating session: %w", err) } - svc = sts.New(sess) + svc := sts.New(sess) // Generate Kubernetes token with the STS session // it generates a token with a presigned URL to the STS session diff --git a/proxy/loginHandler.go b/proxy/loginHandler.go index 0dcb7e3..c229e2a 100644 --- a/proxy/loginHandler.go +++ b/proxy/loginHandler.go @@ -2,6 +2,8 @@ package proxy import ( _ "embed" //embed web resources for login page + "encoding/base64" + "encoding/json" "fmt" "html/template" "log" @@ -112,13 +114,31 @@ func loginPostHandler(authProvider provider.Provider) func(w http.ResponseWriter } // Assume role with SAML assertion - token, err := adfsProvider.Token(assertion, role) + creds, err := adfsProvider.AssumeRole(assertion, role) if err != nil { log.Printf("failed to assume role: %s", err) http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to assume role")), http.StatusFound) return } + // Marshal credentials to base64 encoded JSON and store them in proxy_aws_creds cookie + jsonCreds, err := json.Marshal(creds) + if err != nil { + log.Printf("failed to marshal credentials: %s", err) + http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to marshal credentials")), http.StatusFound) + return + } + b64JsonCreds := base64.StdEncoding.EncodeToString(jsonCreds) + http.SetCookie(w, &http.Cookie{Name: "proxy_aws_creds", Value: b64JsonCreds}) + + // Get token from credentials + token, err := adfsProvider.Token(creds) + if err != nil { + log.Printf("failed to get token: %s", err) + http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to get token")), http.StatusFound) + return + } + // As stated by RFC, cookie size limit must be at least 4096 bytes // so we split the token below this size to be compatible with all // browsers https://stackoverflow.com/a/52492934 diff --git a/proxy/proxyHandler.go b/proxy/proxyHandler.go index a24bb4e..cabb10f 100644 --- a/proxy/proxyHandler.go +++ b/proxy/proxyHandler.go @@ -1,6 +1,8 @@ package proxy import ( + "encoding/base64" + "encoding/json" "fmt" "log" "net/http" @@ -21,16 +23,74 @@ func proxyHandler(target string, authProvider provider.Provider) func(w http.Res return } - // Check if token is valid - err = authProvider.Valid(token) - if err != nil { + // Switch on auth provider + switch authProvider := authProvider.(type) { + case *provider.ProviderAwsAdfs: + + // Check if token is valid + err = authProvider.Valid(token) + if err == nil { + break + } + + // Get cookier proxy_aws_creds + b64Creds, err := r.Cookie("proxy_aws_creds") + if err != nil || b64Creds.Value == "" { + log.Printf("failed to get cookie proxy_aws_creds: %s", err) + http.Redirect(w, r, "/login", http.StatusFound) + return + } + + // Extract creds from cookie + decodedCreds, err := base64.StdEncoding.DecodeString(b64Creds.Value) + if err != nil { + log.Printf("failed to decode cookie proxy_aws_creds: %s", err) + http.Redirect(w, r, "/login", http.StatusFound) + return + } + var creds provider.AWSCreds + err = json.Unmarshal(decodedCreds, &creds) + if err != nil { + log.Printf("failed to unmarshal cookie proxy_aws_creds: %s", err) + http.Redirect(w, r, "/login", http.StatusFound) + return + } + + // Try to refresh token with existing creds + newToken, err := authProvider.Token(creds) + if err != nil { + log.Printf("failed to refresh token: %s", err) + http.Redirect(w, r, "/login", http.StatusFound) + return + } + token = newToken + + // As stated by RFC, cookie size limit must be at least 4096 bytes + // so we split the token below this size to be compatible with all + // browsers https://stackoverflow.com/a/52492934 + setTokenCookie(w, token, 4000) + + case *provider.ProviderTanzu: + + // Check if token is valid + err = authProvider.Valid(token) + if err == nil { + break + } + log.Printf("failed to check if token is valid: %s", err) http.Redirect(w, r, "/login", http.StatusFound) return + } // create the reverse proxy - url, _ := url.Parse(target) + url, err := url.Parse(target) + if err != nil { + log.Printf("failed to parse target URL: %s", err) + http.Redirect(w, r, "/login", http.StatusFound) + return + } proxy := httputil.NewSingleHostReverseProxy(url) // add token as authorization header