Skip to content

Commit

Permalink
refactor: use abstracted sso create session in SAML
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Oct 16, 2024
1 parent c98f5be commit 29905a9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 99 deletions.
98 changes: 16 additions & 82 deletions cmd/api/src/api/saml/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ import (
"encoding/xml"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
Expand All @@ -36,9 +33,7 @@ import (
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/auth/bhsaml"
"github.com/specterops/bloodhound/src/config"
"github.com/specterops/bloodhound/src/ctx"
"github.com/specterops/bloodhound/src/database"
"github.com/specterops/bloodhound/src/model"
)

const (
Expand All @@ -56,7 +51,6 @@ const (
)

const (
ErrorUserDisabled = errors.Error("user disabled")
ErrorUserNotFound = errors.Error("User not found")
ErrorSAMLAssertion = errors.Error("SAML assertion error")
ErrorUserNotAuthorizedForProvider = errors.Error("User not authorized for this provider")
Expand Down Expand Up @@ -220,7 +214,7 @@ func (s ProviderResource) emailAttributeNames() []string {
return []string{bhsaml.ObjectIDEmail, bhsaml.XMLSOAPClaimsEmailAddress}
}

func (s ProviderResource) lookupSAMLUser(ctx context.Context, assertion *saml.Assertion) (model.User, error) {
func (s ProviderResource) getSAMLUserPrincipalNameFromAssertion(assertion *saml.Assertion) (string, error) {
for _, attrStmt := range assertion.AttributeStatements {
for _, attr := range attrStmt.Attributes {
for _, value := range attr.Values {
Expand All @@ -232,80 +226,9 @@ func (s ProviderResource) lookupSAMLUser(ctx context.Context, assertion *saml.As
// All SAML assertions must contain a eduPersonPrincipalName attribute. While this is not expected to be an email
// address, it may be formatted as such.
if principalName, err := assertionFindString(assertion, s.emailAttributeNames()...); err != nil {
return model.User{}, ErrorSAMLAssertion
return "", ErrorSAMLAssertion
} else {
if user, err := s.db.LookupUser(ctx, principalName); err != nil {
if !errors.Is(err, database.ErrNotFound) {
return model.User{}, api.FormatDatabaseError(err)
} else {
return model.User{}, ErrorUserNotFound
}
} else if !user.SAMLProviderID.Valid || s.serviceProvider.Config.ID != user.SAMLProvider.ID {
return model.User{}, ErrorUserNotAuthorizedForProvider
} else {
return user, nil
}
}
}

func sameSiteValue(host url.URL) http.SameSite {
if host.Scheme == "https" {
return http.SameSiteStrictMode
} else {
return http.SameSiteDefaultMode
}
}

// NOTE: Set-Cookie should generally have the Domain field blank to ensure the cookie is only included with requests against the host, excluding subdomains; however,
// most browsers will ignore Set-Cookie headers from localhost responses if the Domain field is not set explicitly.
func domainValue(host url.URL) string {
if strings.Contains(host.Hostname(), "localhost") {
return host.Hostname()
} else {
return ""
}
}

func (s ProviderResource) createSessionFromAssertion(request *http.Request, response http.ResponseWriter, expires time.Time, assertion *saml.Assertion) {
hostURL := *ctx.FromRequest(request).Host

if user, err := s.lookupSAMLUser(request.Context(), assertion); err != nil {
log.Errorf("[SAML] Failed to lookup user for SAML provider %s: %v", s.serviceProvider.Config.Name, err)

switch {
case errors.Is(err, ErrorSAMLAssertion):
s.writeAPIErrorResponse(request, response, http.StatusBadRequest, "session assertion does not meet the requirements for user lookup")
case errors.Is(err, ErrorUserNotFound), errors.Is(err, ErrorUserNotAuthorizedForProvider):
// This is a tiny bit more descriptive for the end user without leaking any sensitive information
s.writeAPIErrorResponse(request, response, http.StatusForbidden, "user is not allowed")
default:
s.writeAPIErrorResponse(request, response, http.StatusInternalServerError, "session creation failure")
}
} else if sessionJWT, err := s.authenticator.CreateSession(request.Context(), user, s.serviceProvider.Config); err != nil {
if locationURL := api.URLJoinPath(hostURL, api.UserDisabledPath); errors.Is(err, ErrorUserDisabled) {
response.Header().Add(headers.Location.String(), locationURL.String())
response.WriteHeader(http.StatusFound)
} else {
log.Errorf("[SAML] Failed to create user session for SAML provider %s: %v", s.serviceProvider.Config.Name, err)
s.writeAPIErrorResponse(request, response, http.StatusInternalServerError, "session creation failure")
}
} else {
locationURL := api.URLJoinPath(hostURL, api.UserInterfacePath)

// Set the token cookie
http.SetCookie(response, &http.Cookie{
Name: api.AuthTokenCookieName,
Value: sessionJWT,
Expires: expires,
Secure: hostURL.Scheme == "https",
SameSite: sameSiteValue(hostURL),
Path: "/",
Domain: domainValue(hostURL),
})

// Redirect back to the UI landing page
response.Header().Add(headers.Location.String(), locationURL.String())
response.WriteHeader(http.StatusFound)
return principalName, nil
}
}

Expand Down Expand Up @@ -353,7 +276,6 @@ func (s ProviderResource) bindingTypeAndLocation() (string, string) {
// HandleStartAuthFlow is called to start the SAML authentication process.
func (s ProviderResource) serveStartAuthFlow(response http.ResponseWriter, request *http.Request) {
binding, bindingLocation := s.bindingTypeAndLocation()

// relayState is limited to 80 bytes but also must be integrity protected.
// this means that we cannot use a JWT because it is way too long. Instead,
// we set a signed cookie that encodes the original URL which we'll check
Expand Down Expand Up @@ -424,8 +346,20 @@ func (s ProviderResource) serveAssertionConsumerService(response http.ResponseWr
}

s.writeAPIErrorResponse(request, response, http.StatusUnauthorized, api.ErrorResponseDetailsAuthenticationInvalid)
} else if principalName, err := s.getSAMLUserPrincipalNameFromAssertion(assertion); err != nil {
log.Errorf("[SAML] Failed to lookup user for SAML provider %s: %v", s.serviceProvider.Config.Name, err)

switch {
case errors.Is(err, ErrorSAMLAssertion):
s.writeAPIErrorResponse(request, response, http.StatusBadRequest, "session assertion does not meet the requirements for user lookup")
case errors.Is(err, ErrorUserNotFound), errors.Is(err, ErrorUserNotAuthorizedForProvider):
// This is a tiny bit more descriptive for the end user without leaking any sensitive information
s.writeAPIErrorResponse(request, response, http.StatusForbidden, "user is not allowed")
default:
s.writeAPIErrorResponse(request, response, http.StatusInternalServerError, "session creation failure")
}
} else {
s.createSessionFromAssertion(request, response, time.Now().UTC().Add(s.cfg.AuthSessionTTL()), assertion)
s.authenticator.CreateSSOSession(request, response, principalName, s.serviceProvider.Config)
}
}
}
17 changes: 0 additions & 17 deletions cmd/api/src/database/asset.go

This file was deleted.

3 changes: 3 additions & 0 deletions packages/go/crypto/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ const (

func tryParsePrivateKey(key string) (*rsa.PrivateKey, error) {
keyBlock, _ := pem.Decode([]byte(key))
if keyBlock == nil {
return nil, fmt.Errorf("unsupported key type")
}

if pkcs8PrivateKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes); err == nil {
if rsaPrivateKey, ok := pkcs8PrivateKey.(*rsa.PrivateKey); !ok {
Expand Down

0 comments on commit 29905a9

Please sign in to comment.