Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oidc: Reimplement userinfo for fine-grained error handling #31

Merged
merged 5 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ e2e: publish
# Run E2E tests
cd e2e/manifests/authservice/base && \
kustomize edit set image gcr.io/arrikto/kubeflow/oidc-authservice=$(IMG):$(TAG)
go test ./e2e -v
# Use -count=1 to skip Go's test cache
go test -v -count=1 ./e2e

publish: docker-build docker-push

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Can be used with:

For E2E tests, we use [K3d](https://k3d.io/), a very lightweight way to run a K8s cluster
locally using Docker. For E2E tests to work, you need the following external tools:
* `go (>=1.13)`
* `kustomize`
* `kubectl`
* `k3d`
Expand Down
22 changes: 13 additions & 9 deletions e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
appsv1 "k8s.io/api/apps/v1"
"k8s.io/apimachinery/pkg/types"
"net/http"
"net/url"
"os"
"os/exec"
controllerruntime "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"path"
"strconv"
"strings"
"testing"
"time"

"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
appsv1 "k8s.io/api/apps/v1"
"k8s.io/apimachinery/pkg/types"
controllerruntime "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
)

type E2ETestSuite struct {
Expand Down Expand Up @@ -250,7 +252,9 @@ func createK3DCluster() error {
if err != nil {
return err
}
return exec.Command("k3d", "get", "kubeconfig", "e2e-test-cluster", "--switch").Run()
kubeconfigPath := path.Join(os.Getenv("HOME"), ".kube/config")
return exec.Command("k3d", "get", "kubeconfig", "e2e-test-cluster",
"--switch", "--output", kubeconfigPath).Run()
}

func deleteK3DCluster() error {
Expand Down
9 changes: 6 additions & 3 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package main

import (
"fmt"
"net/http"
)

var _ error = &requestError{}

type requestError struct {
StatusCode int
Err error
Response *http.Response
Body []byte
Err error
}

func (e *requestError) Error() string {
return fmt.Sprintf("status: %d, err: %v", e.StatusCode, e.Err)
return fmt.Sprintf("status: %d, body: %s, err: %v", e.Response.StatusCode,
e.Body, e.Err)
}
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ require (
github.com/boltdb/bolt v1.3.1
github.com/cenkalti/backoff/v4 v4.0.2
github.com/coreos/go-oidc v2.1.0+incompatible
github.com/golang/protobuf v1.3.2 // indirect
github.com/gorilla/handlers v1.4.2
github.com/gorilla/mux v1.7.3
github.com/gorilla/sessions v1.2.0
github.com/kelseyhightower/envconfig v1.4.0
github.com/pkg/errors v0.8.1
github.com/pkg/errors v0.9.1
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/quasoft/memstore v0.0.0-20180925164028-84a050167438
github.com/sirupsen/logrus v1.4.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.0.0-20171018203845-0dec1b30a021/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
Expand Down
86 changes: 86 additions & 0 deletions oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package main

import (
"context"
"encoding/json"
"io/ioutil"
"net/http"

"github.com/coreos/go-oidc"
"github.com/pkg/errors"
"golang.org/x/oauth2"
)

// UserInfo represents the OpenID Connect userinfo claims.
type UserInfo struct {
Subject string `json:"sub"`
Profile string `json:"profile"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`

RawClaims []byte
}

// Claims unmarshals the raw JSON object claims into the provided object.
func (u *UserInfo) Claims(v interface{}) error {
if u.RawClaims == nil {
return errors.New("oidc: claims not set")
}
return json.Unmarshal(u.RawClaims, v)
}

// GetUserInfo uses the token source to query the provider's user info endpoint.
// We reimplement UserInfo [1] instead of using the go-oidc's library UserInfo, in
// order to include HTTP response information in case of an error during
// contacting the UserInfo endpoint.
//
// [1]: https://github.com/coreos/go-oidc/blob/v2.1.0/oidc.go#L180
func GetUserInfo(ctx context.Context, provider *oidc.Provider, tokenSource oauth2.TokenSource) (*UserInfo, error) {

discoveryClaims := &struct {
UserInfoURL string `json:"userinfo_endpoint"`
}{}
if err := provider.Claims(discoveryClaims); err != nil {
return nil, errors.Errorf("Error unmarshalling OIDC discovery document claims: %v", err)
}

userInfoURL := discoveryClaims.UserInfoURL
if userInfoURL == "" {
return nil, errors.New("oidc: user info endpoint is not supported by this provider")
}

req, err := http.NewRequest("GET", userInfoURL, nil)
if err != nil {
return nil, errors.Errorf("oidc: create GET request: %v", err)
}

token, err := tokenSource.Token()
if err != nil {
return nil, errors.Errorf("oidc: get access token: %v", err)
}
token.SetAuthHeader(req)

resp, err := doRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, &requestError{
Response: resp,
Body: body,
Err: errors.Errorf("oidc: Calling UserInfo endpoint failed. body: %s", body),
}
}

var userInfo UserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, errors.Errorf("oidc: failed to decode userinfo: %v", err)
}
userInfo.RawClaims = body
return &userInfo, nil
}
129 changes: 129 additions & 0 deletions oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package main

import (
"context"
"html/template"
"net/http"
"testing"
"time"

"github.com/coreos/go-oidc"
"github.com/gorilla/mux"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
)

func startFakeOIDCProvider(addr string) {

discoveryDoc := `
{
"issuer": "{{.Address}}",
"authorization_endpoint": "{{.Address}}/auth",
"userinfo_endpoint": "{{.Address}}/userinfo",
"revocation_endpoint": "{{.Address}}/revoke",
"jwks_uri": "{{.Address}}/jwks",
"response_types_supported": [
"code",
"token",
"id_token",
"code token",
"code id_token",
"token id_token",
"code token id_token",
"none"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"email",
"profile"
],
"token_endpoint_auth_methods_supported": [
"client_secret_post",
"client_secret_basic"
],
"claims_supported": [
"aud",
"email",
"email_verified",
"exp",
"family_name",
"given_name",
"iat",
"iss",
"locale",
"name",
"picture",
"sub"
],
"code_challenge_methods_supported": [
"plain",
"S256"
],
"grant_types_supported": [
"authorization_code",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:jwt-bearer"
]
}
`

tmpl, err := template.New("oidc_discovery_doc").Parse(discoveryDoc)
if err != nil {
log.Fatalf("Error parsing discovery doc template: %v", err)
}

discoveryHandler := func(w http.ResponseWriter, r *http.Request) {
err := tmpl.Execute(w, struct{ Address string }{Address: addr})
if err != nil {
log.Errorf("Error executing oidc discovery doc template: %v", err)
}
}

userinfoHandler := func(w http.ResponseWriter, r *http.Request) {
log.Info("Userinfo handler, returning 401...")
w.WriteHeader(http.StatusUnauthorized)
}

router := mux.NewRouter()
router.HandleFunc("/.well-known/openid-configuration", discoveryHandler)
router.HandleFunc("/userinfo", userinfoHandler)
log.Infof("Starting fake OIDC Provider at address: %v", addr)
if err := http.ListenAndServe("localhost:9999", router); err != nil {
log.Fatalf("Error in fake OIDC Provider server: %v", err)
}
}

func TestGetUserInfo_ContextCancelled(t *testing.T) {

// Start fake OIDC provider
oidcProviderAddr := "http://localhost:9999"
go startFakeOIDCProvider(oidcProviderAddr)
time.Sleep(5 * time.Second)
provider, err := oidc.NewProvider(context.Background(), oidcProviderAddr)
if err != nil {
t.Fatalf("Error creating OIDC Provider: %v", err)
}

// Make a UserInfo request
_, err = GetUserInfo(context.Background(), provider,
oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test"}))

// Check that we find a wrapped requestError
var reqErr *requestError
if !errors.As(err, &reqErr) {
log.Fatalf("Returned error is not a requestError. Got: %+v", reqErr)
}

if reqErr.Response.StatusCode != http.StatusUnauthorized {
t.Fatalf("Got wrong status code. Got '%v', expected '%v'.",
reqErr.Response.StatusCode, http.StatusUnauthorized)
}
}
17 changes: 10 additions & 7 deletions revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package main
import (
"context"
"fmt"
"github.com/coreos/go-oidc"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"io/ioutil"
"net/http"
"net/url"
"strings"

"github.com/coreos/go-oidc"
"github.com/pkg/errors"
"golang.org/x/oauth2"
)

// revocationEndpoint parses the OIDC Provider claims from the discovery document
Expand Down Expand Up @@ -77,13 +78,15 @@ func revokeToken(ctx context.Context, revocationEndpoint string, token, tokenTyp
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return &requestError{
StatusCode: resp.StatusCode,
Err: errors.New(fmt.Sprintf("Revocation endpoint returned code %v, failed to read body: %v", code, err)),
Response: resp,
Body: body,
Err: errors.New(fmt.Sprintf("Revocation endpoint returned code %v, failed to read body: %v", code, err)),
}
}
return &requestError{
StatusCode: resp.StatusCode,
Err: errors.New(fmt.Sprintf("Revocation endpoint returned code %v, server returned: %v", code, body)),
Response: resp,
Body: body,
Err: errors.New(fmt.Sprintf("Revocation endpoint returned code %v, server returned: %v", code, body)),
}
}
return nil
Expand Down
Loading