Skip to content

Commit

Permalink
Merge pull request #10 from mblackgeo/fix-key-error
Browse files Browse the repository at this point in the history
Add app context to decorators
  • Loading branch information
mblackgeo authored Jul 6, 2022
2 parents 366c862 + bb8bd38 commit 401ce6b
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 104 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.6.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
Expand Down
213 changes: 110 additions & 103 deletions src/flask_cognito_lib/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,44 @@

def remove_from_session(keys: Iterable[str]):
"""Remove an entry from the session"""
for key in keys:
if key in session:
session.pop(key)
with app.app_context():
for key in keys:
if key in session:
session.pop(key)


def cognito_login(fn):
"""A decorator that redirects to the Cognito hosted UI"""

@wraps(fn)
def wrapper(*args, **kwargs):
# store parameters in the session that are passed to Cognito
# and required for JWT verification
code_verifier = generate_code_verifier()
cognito_session = {
"code_verifier": code_verifier,
"code_challenge": generate_code_challenge(code_verifier),
"nonce": secure_random(),
}
session.update(cognito_session)

# Add suport for custom state values which are appended to a secure
# random value for additional CRSF protection
state = secure_random()
custom_state = session.get("state", default=None)
if custom_state:
state += f"__{custom_state}"

session.update({"state": state})

login_url = cognito_auth.cognito_service.get_sign_in_url(
code_challenge=session["code_challenge"],
state=session["state"],
nonce=session["nonce"],
scopes=cfg.cognito_scopes,
)
with app.app_context():
# store parameters in the session that are passed to Cognito
# and required for JWT verification
code_verifier = generate_code_verifier()
cognito_session = {
"code_verifier": code_verifier,
"code_challenge": generate_code_challenge(code_verifier),
"nonce": secure_random(),
}
session.update(cognito_session)

# Add suport for custom state values which are appended to a secure
# random value for additional CRSF protection
state = secure_random()
custom_state = session.get("state", default=None)
if custom_state:
state += f"__{custom_state}"

session.update({"state": state})

login_url = cognito_auth.cognito_service.get_sign_in_url(
code_challenge=session["code_challenge"],
state=session["state"],
nonce=session["nonce"],
scopes=cfg.cognito_scopes,
)

return redirect(login_url)

return wrapper
Expand All @@ -74,54 +77,56 @@ def cognito_login_callback(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
# Get the access token return after auth flow with Cognito
code_verifier = session["code_verifier"]
state = session["state"]
nonce = session["nonce"]

# exchange the code for an access token
# also confirms the returned state is correct
tokens = cognito_auth.get_tokens(
request_args=request.args,
expected_state=state,
code_verifier=code_verifier,
)

# validate the JWT and get the claims
claims = cognito_auth.verify_access_token(
token=tokens.access_token,
leeway=cfg.cognito_expiration_leeway,
)
session.update({"claims": claims})

# Grab the user info from the user endpoint and store in the session
if tokens.id_token is not None:
user_info = cognito_auth.verify_id_token(
token=tokens.id_token,
nonce=nonce,
with app.app_context():
# Get the access token return after auth flow with Cognito
code_verifier = session["code_verifier"]
state = session["state"]
nonce = session["nonce"]

# exchange the code for an access token
# also confirms the returned state is correct
tokens = cognito_auth.get_tokens(
request_args=request.args,
expected_state=state,
code_verifier=code_verifier,
)

# validate the JWT and get the claims
claims = cognito_auth.verify_access_token(
token=tokens.access_token,
leeway=cfg.cognito_expiration_leeway,
)
session.update({"user_info": user_info})

# Remove one-time use variables now we have completed the auth flow
remove_from_session(("code_challenge", "code_verifier", "nonce"))

# split out the random part of the state value (in case the user
# specified their own custom state value)
state = session.get("state").split("__")[-1]
session.update({"state": state})

# return and set the JWT as a http only cookie
resp = fn(*args, **kwargs)

# Store the access token in a HTTP only secure cookie
resp.set_cookie(
key=cfg.COOKIE_NAME,
value=tokens.access_token,
max_age=cfg.max_cookie_age_seconds,
httponly=True,
secure=True,
)
session.update({"claims": claims})

# Grab the user info from the user endpoint and store in the session
if tokens.id_token is not None:
user_info = cognito_auth.verify_id_token(
token=tokens.id_token,
nonce=nonce,
leeway=cfg.cognito_expiration_leeway,
)
session.update({"user_info": user_info})

# Remove one-time use variables now we have completed the auth flow
remove_from_session(("code_challenge", "code_verifier", "nonce"))

# split out the random part of the state value (in case the user
# specified their own custom state value)
state = session.get("state").split("__")[-1]
session.update({"state": state})

# return and set the JWT as a http only cookie
resp = fn(*args, **kwargs)

# Store the access token in a HTTP only secure cookie
resp.set_cookie(
key=cfg.COOKIE_NAME,
value=tokens.access_token,
max_age=cfg.max_cookie_age_seconds,
httponly=True,
secure=True,
)

return resp

return wrapper
Expand All @@ -132,9 +137,10 @@ def cognito_logout(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
# logout at cognito and remove the cookies
resp = redirect(cfg.logout_endpoint)
resp.delete_cookie(key=cfg.COOKIE_NAME)
with app.app_context():
# logout at cognito and remove the cookies
resp = redirect(cfg.logout_endpoint)
resp.delete_cookie(key=cfg.COOKIE_NAME)

# Cognito will redirect to the sign-out URL (if set) or else use
# the callback URL
Expand All @@ -149,32 +155,33 @@ def auth_required(groups: Optional[Iterable[str]] = None):
def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
# return early if the extension is disabled
if cfg.disabled:
return fn(*args, **kwargs)

# Try and validate the access token stored in the cookie
try:
access_token = request.cookies.get(cfg.COOKIE_NAME)
claims = cognito_auth.verify_access_token(
token=access_token,
leeway=cfg.cognito_expiration_leeway,
)
valid = True

# Check for required group membership
if groups:
valid = all(g in claims["cognito:groups"] for g in groups)
if not valid:
raise CognitoGroupRequiredError

except (TokenVerifyError, KeyError):
valid = False

if valid:
return fn(*args, **kwargs)

raise AuthorisationRequiredError
with app.app_context():
# return early if the extension is disabled
if cfg.disabled:
return fn(*args, **kwargs)

# Try and validate the access token stored in the cookie
try:
access_token = request.cookies.get(cfg.COOKIE_NAME)
claims = cognito_auth.verify_access_token(
token=access_token,
leeway=cfg.cognito_expiration_leeway,
)
valid = True

# Check for required group membership
if groups:
valid = all(g in claims["cognito:groups"] for g in groups)
if not valid:
raise CognitoGroupRequiredError

except (TokenVerifyError, KeyError):
valid = False

if valid:
return fn(*args, **kwargs)

raise AuthorisationRequiredError

return decorator

Expand Down

0 comments on commit 401ce6b

Please sign in to comment.