Skip to content

Commit

Permalink
feat: add callback handler for OIDC
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Oct 16, 2024
1 parent c1c9df9 commit c98f5be
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 3 deletions.
57 changes: 56 additions & 1 deletion cmd/api/src/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type Authenticator interface {
ValidateSecret(ctx context.Context, secret string, authSecret model.AuthSecret) error
ValidateRequestSignature(tokenID uuid.UUID, request *http.Request, serverTime time.Time) (auth.Context, int, error)
CreateSession(ctx context.Context, user model.User, authProvider any) (string, error)
CreateSSOSession(request *http.Request, response http.ResponseWriter, principalNameOrEmail string, authProvider any)
ValidateSession(ctx context.Context, jwtTokenString string) (auth.Context, error)
}

Expand Down Expand Up @@ -329,6 +330,56 @@ func SetSecureBrowserCookie(request *http.Request, response http.ResponseWriter,
})
}

func (s authenticator) CreateSSOSession(request *http.Request, response http.ResponseWriter, principalNameOrEmail string, authProvider any) {
var (
hostURL = *ctx.FromRequest(request).Host
requestCtx = request.Context()
)

if user, err := s.db.LookupUser(requestCtx, principalNameOrEmail); err != nil {
if !errors.Is(err, database.ErrNotFound) {
HandleDatabaseError(request, response, err)
} else {
WriteErrorResponse(request.Context(), BuildErrorResponse(http.StatusForbidden, "user is not allowed", request), response)
}
} else {
switch typedAuthProvider := authProvider.(type) {
case model.SAMLProvider:
if !user.SAMLProviderID.Valid || typedAuthProvider.ID != user.SAMLProvider.ID {
WriteErrorResponse(request.Context(), BuildErrorResponse(http.StatusForbidden, "user is not allowed", request), response)
return
}
case model.OIDCProvider:
//todo connect to db provider table
break
case model.AuthSecret:
WriteErrorResponse(request.Context(), BuildErrorResponse(http.StatusBadRequest, "invalid auth provider", request), response)
return
default:
WriteErrorResponse(request.Context(), BuildErrorResponse(http.StatusBadRequest, "invalid auth provider", request), response)
return
}

if sessionJWT, err := s.CreateSession(requestCtx, user, authProvider); err != nil {
if locationURL := URLJoinPath(hostURL, UserDisabledPath); errors.Is(err, ErrUserDisabled) {
response.Header().Add(headers.Location.String(), locationURL.String())
response.WriteHeader(http.StatusFound)
} else {
WriteErrorResponse(request.Context(), BuildErrorResponse(http.StatusInternalServerError, "session creation failure", request), response)
}
} else {
locationURL := URLJoinPath(hostURL, UserInterfacePath)

// Set the token cookie
SetSecureBrowserCookie(request, response, AuthTokenCookieName, sessionJWT, time.Now().UTC().Add(s.cfg.AuthSessionTTL()))

// Redirect back to the UI landing page
response.Header().Add(headers.Location.String(), locationURL.String())
response.WriteHeader(http.StatusFound)
}
}
}

func (s authenticator) CreateSession(ctx context.Context, user model.User, authProvider any) (string, error) {
if user.IsDisabled {
return "", ErrUserDisabled
Expand All @@ -346,10 +397,14 @@ func (s authenticator) CreateSession(ctx context.Context, user model.User, authP
case model.AuthSecret:
userSession.AuthProviderType = model.SessionAuthProviderSecret
userSession.AuthProviderID = typedAuthProvider.ID

case model.SAMLProvider:
userSession.AuthProviderType = model.SessionAuthProviderSAML
userSession.AuthProviderID = typedAuthProvider.ID
case model.OIDCProvider:
userSession.AuthProviderType = model.SessionAuthProviderOIDC
userSession.AuthProviderID = int32(typedAuthProvider.ID)
default:
return "", errors.New("invalid auth provider")
}

if newSession, err := s.db.CreateUserSession(ctx, userSession); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions cmd/api/src/api/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const (
QueryParameterHydrateDomains = "hydrate_domains"
QueryParameterHydrateOUs = "hydrate_ous"
QueryParameterScope = "scope"
QueryParameterState = "state"
QueryParameterCode = "code"

// URI path parameters
URIPathVariableApplicationConfigurationParameter = "parameter"
Expand Down
1 change: 1 addition & 0 deletions cmd/api/src/api/registration/v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func registerV2Auth(cfg config.Configuration, db database.Database, permissions
routerInst.POST("/api/v2/sso-providers/oidc", managementResource.CreateOIDCProvider).CheckFeatureFlag(db, appcfg.FeatureOIDCSupport).RequirePermissions(permissions.AuthManageProviders),
routerInst.GET("/api/v2/sso-providers", managementResource.ListAuthProviders).CheckFeatureFlag(db, appcfg.FeatureOIDCSupport),
routerInst.GET(fmt.Sprintf("/api/v2/sso/{%s}/login", api.URIPathVariableSSOProviderSlug), managementResource.SSOLoginHandler).CheckFeatureFlag(db, appcfg.FeatureOIDCSupport),
routerInst.GET(fmt.Sprintf("/api/v2/sso/{%s}/callback", api.URIPathVariableSSOProviderSlug), managementResource.SSOCallbackHandler).CheckFeatureFlag(db, appcfg.FeatureOIDCSupport),

// Permissions
routerInst.GET("/api/v2/permissions", managementResource.ListPermissions).RequirePermissions(permissions.AuthManageSelf),
Expand Down
53 changes: 53 additions & 0 deletions cmd/api/src/api/v2/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,56 @@ func (s ManagementResource) OIDCLoginHandler(response http.ResponseWriter, reque
response.WriteHeader(http.StatusFound)
}
}

func (s ManagementResource) OIDCCallbackHandler(response http.ResponseWriter, request *http.Request, ssoProvider model.SSOProvider, oidcProvider model.OIDCProvider) {
var (
queryParams = request.URL.Query()
state = queryParams[api.QueryParameterState]
code = queryParams[api.QueryParameterCode]
)

if len(code) == 0 {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "missing code", request), response)
} else if pkceVerifier, err := request.Cookie(api.AuthPKCECookieName); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "missing pkce verifier", request), response)
} else if len(state) == 0 {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "missing state", request), response)
} else if stateCookie, err := request.Cookie(api.AuthStateCookieName); err != nil || stateCookie.Value != state[0] {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "bad state", request), response)
} else if provider, err := oidc.NewProvider(request.Context(), oidcProvider.Issuer); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, err.Error(), request), response)
} else {
var (
oidcVerifier = provider.Verifier(&oidc.Config{ClientID: oidcProvider.ClientID})
oauth2Conf = &oauth2.Config{
ClientID: oidcProvider.ClientID,
Endpoint: provider.Endpoint(),
RedirectURL: getRedirectURL(request, ssoProvider), // Required as verification check
}
)

if token, err := oauth2Conf.Exchange(request.Context(), code[0], oauth2.VerifierOption(pkceVerifier.Value)); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, api.ErrorResponseDetailsForbidden, request), response)
} else if rawIDToken, ok := token.Extra("id_token").(string); !ok { // Extract the ID Token from OAuth2 token
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "missing id token", request), response)
} else if idToken, err := oidcVerifier.Verify(request.Context(), rawIDToken); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "invalid id token", request), response)
} else {
log.Debugf("GOT A TOKEN %+v\n", token)
log.Debugf("ID TOKEN %+v\n", rawIDToken)
// Extract custom claims
var claims struct {
Name string `json:"name"`
DisplayName string `json:"given_name"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, err.Error(), request), response)
} else {
log.Debugf("ID CLAIMS %+v\n", claims)
s.authenticator.CreateSSOSession(request, response, claims.Email, oidcProvider)
}
}
}
}
27 changes: 25 additions & 2 deletions cmd/api/src/api/v2/auth/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,38 @@ func (s ManagementResource) SSOLoginHandler(response http.ResponseWriter, reques
switch ssoProvider.Type {
case model.SessionAuthProviderSAML:
//todo handle saml login
return
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusNotFound, api.ErrorResponseDetailsResourceNotFound, request), response)
case model.SessionAuthProviderOIDC:
if oidcProvider, err := s.db.GetOIDCProviderBySSOProviderID(request.Context(), ssoProvider.ID); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
s.OIDCLoginHandler(response, request, ssoProvider, oidcProvider)
}
default:
return
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusNotFound, api.ErrorResponseDetailsResourceNotFound, request), response)
}
}
}

func (s ManagementResource) SSOCallbackHandler(response http.ResponseWriter, request *http.Request) {
ssoProviderSlug := mux.Vars(request)[api.URIPathVariableSSOProviderSlug]
log.Debugf("HERE I AM IN CALLBACK - provider %s", ssoProviderSlug)

if ssoProvider, err := s.db.GetSSOProviderBySlug(request.Context(), ssoProviderSlug); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
switch ssoProvider.Type {
case model.SessionAuthProviderSAML:
//todo handle saml callback
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusNotFound, api.ErrorResponseDetailsResourceNotFound, request), response)
case model.SessionAuthProviderOIDC:
if oidcProvider, err := s.db.GetOIDCProviderBySSOProviderID(request.Context(), ssoProvider.ID); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
s.OIDCCallbackHandler(response, request, ssoProvider, oidcProvider)
}
default:
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusNotFound, api.ErrorResponseDetailsResourceNotFound, request), response)
}
}
}

0 comments on commit c98f5be

Please sign in to comment.