Skip to content

Commit

Permalink
Add AWS SAML token auto refresh (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablo-ruth authored Apr 27, 2023
1 parent 4d62a3f commit bd9faae
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 19 deletions.
54 changes: 40 additions & 14 deletions provider/provider_aws_adfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ type AWSRole struct {
Principal string `json:"principal"`
}

type AWSCreds struct {
AccessKey string `json:"accesskey"`
SecretKey string `json:"secretkey"`
SessionToken string `json:"sessiontoken"`
SecurityToken string `json:"securitytoken"`
Principal string `json:"principal"`
Expires time.Time `json:"expires"`
}

func NewProviderAwsAdfs(loginURL string, clusterID string) (*ProviderAwsAdfs, error) {

// Check if login URL is valid
Expand Down Expand Up @@ -129,26 +138,26 @@ func (p *ProviderAwsAdfs) Login(user, password string) (string, map[string]strin
return samlAssertion, roles, nil
}

// Token returns a Kubernetes token based on SAML assertion and role
func (p *ProviderAwsAdfs) Token(SAMLAssertion, role string) (string, error) {
// AssumeRole will take a SAML assertion and role and return AWS credentials
func (p *ProviderAwsAdfs) AssumeRole(SAMLAssertion, role string) (AWSCreds, error) {

// Decode base64 role
decodedRole, err := base64.StdEncoding.DecodeString(role)
if err != nil {
return "", fmt.Errorf("error decoding role: %w", err)
return AWSCreds{}, fmt.Errorf("error decoding role: %w", err)
}

// Unmarshal the role
var awsRole AWSRole
err = json.Unmarshal(decodedRole, &awsRole)
if err != nil {
return "", fmt.Errorf("error unmarshalling role: %w", err)
return AWSCreds{}, fmt.Errorf("error unmarshalling role: %w", err)
}

// Create a new STS session
sess, err := session.NewSession(&aws.Config{})
if err != nil {
return "", fmt.Errorf("error creating aws session: %w", err)
return AWSCreds{}, fmt.Errorf("error creating aws session: %w", err)
}
svc := sts.New(sess)

Expand All @@ -161,25 +170,42 @@ func (p *ProviderAwsAdfs) Token(SAMLAssertion, role string) (string, error) {
}
resp, err := svc.AssumeRoleWithSAML(assumeInput)
if err != nil {
return "", fmt.Errorf("assuming role failed: %w", err)
return AWSCreds{}, fmt.Errorf("assuming role failed: %w", err)
}

creds := AWSCreds{
AccessKey: aws.StringValue(resp.Credentials.AccessKeyId),
SecretKey: aws.StringValue(resp.Credentials.SecretAccessKey),
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
SecurityToken: aws.StringValue(resp.Credentials.SessionToken),
Principal: aws.StringValue(resp.AssumedRoleUser.Arn),
Expires: resp.Credentials.Expiration.Local(),
}

return creds, nil
}

// Token returns a Kubernetes token based on temporary AWS credentials
func (p *ProviderAwsAdfs) Token(tmpCreds AWSCreds) (string, error) {

// Set temporary credentials
creds := &awsconfig.AWSCredentials{
AWSAccessKey: aws.StringValue(resp.Credentials.AccessKeyId),
AWSSecretKey: aws.StringValue(resp.Credentials.SecretAccessKey),
AWSSessionToken: aws.StringValue(resp.Credentials.SessionToken),
AWSSecurityToken: aws.StringValue(resp.Credentials.SessionToken),
PrincipalARN: aws.StringValue(resp.AssumedRoleUser.Arn),
Expires: resp.Credentials.Expiration.Local(),
AWSAccessKey: tmpCreds.AccessKey,
AWSSecretKey: tmpCreds.SecretKey,
AWSSessionToken: tmpCreds.SessionToken,
AWSSecurityToken: tmpCreds.SecurityToken,
PrincipalARN: tmpCreds.Principal,
Expires: tmpCreds.Expires,
}

// Create a new AWS STS session with the temporary credentials
sess, err = session.NewSession(&aws.Config{
sess, err := session.NewSession(&aws.Config{
Credentials: credentials.NewStaticCredentials(creds.AWSAccessKey, creds.AWSSecretKey, creds.AWSSessionToken),
})
if err != nil {
return "", fmt.Errorf("error creating session: %w", err)
}
svc = sts.New(sess)
svc := sts.New(sess)

// Generate Kubernetes token with the STS session
// it generates a token with a presigned URL to the STS session
Expand Down
22 changes: 21 additions & 1 deletion proxy/loginHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package proxy

import (
_ "embed" //embed web resources for login page
"encoding/base64"
"encoding/json"
"fmt"
"html/template"
"log"
Expand Down Expand Up @@ -112,13 +114,31 @@ func loginPostHandler(authProvider provider.Provider) func(w http.ResponseWriter
}

// Assume role with SAML assertion
token, err := adfsProvider.Token(assertion, role)
creds, err := adfsProvider.AssumeRole(assertion, role)
if err != nil {
log.Printf("failed to assume role: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to assume role")), http.StatusFound)
return
}

// Marshal credentials to base64 encoded JSON and store them in proxy_aws_creds cookie
jsonCreds, err := json.Marshal(creds)
if err != nil {
log.Printf("failed to marshal credentials: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to marshal credentials")), http.StatusFound)
return
}
b64JsonCreds := base64.StdEncoding.EncodeToString(jsonCreds)
http.SetCookie(w, &http.Cookie{Name: "proxy_aws_creds", Value: b64JsonCreds})

// Get token from credentials
token, err := adfsProvider.Token(creds)
if err != nil {
log.Printf("failed to get token: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to get token")), http.StatusFound)
return
}

// As stated by RFC, cookie size limit must be at least 4096 bytes
// so we split the token below this size to be compatible with all
// browsers https://stackoverflow.com/a/52492934
Expand Down
68 changes: 64 additions & 4 deletions proxy/proxyHandler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package proxy

import (
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
Expand All @@ -21,16 +23,74 @@ func proxyHandler(target string, authProvider provider.Provider) func(w http.Res
return
}

// Check if token is valid
err = authProvider.Valid(token)
if err != nil {
// Switch on auth provider
switch authProvider := authProvider.(type) {
case *provider.ProviderAwsAdfs:

// Check if token is valid
err = authProvider.Valid(token)
if err == nil {
break
}

// Get cookier proxy_aws_creds
b64Creds, err := r.Cookie("proxy_aws_creds")
if err != nil || b64Creds.Value == "" {
log.Printf("failed to get cookie proxy_aws_creds: %s", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}

// Extract creds from cookie
decodedCreds, err := base64.StdEncoding.DecodeString(b64Creds.Value)
if err != nil {
log.Printf("failed to decode cookie proxy_aws_creds: %s", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
var creds provider.AWSCreds
err = json.Unmarshal(decodedCreds, &creds)
if err != nil {
log.Printf("failed to unmarshal cookie proxy_aws_creds: %s", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}

// Try to refresh token with existing creds
newToken, err := authProvider.Token(creds)
if err != nil {
log.Printf("failed to refresh token: %s", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
token = newToken

// As stated by RFC, cookie size limit must be at least 4096 bytes
// so we split the token below this size to be compatible with all
// browsers https://stackoverflow.com/a/52492934
setTokenCookie(w, token, 4000)

case *provider.ProviderTanzu:

// Check if token is valid
err = authProvider.Valid(token)
if err == nil {
break
}

log.Printf("failed to check if token is valid: %s", err)
http.Redirect(w, r, "/login", http.StatusFound)
return

}

// create the reverse proxy
url, _ := url.Parse(target)
url, err := url.Parse(target)
if err != nil {
log.Printf("failed to parse target URL: %s", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
proxy := httputil.NewSingleHostReverseProxy(url)

// add token as authorization header
Expand Down

0 comments on commit bd9faae

Please sign in to comment.