From a085b49404bcc11ffe651b4aa5a1e54b72bc8a7d Mon Sep 17 00:00:00 2001 From: Ryan Block Date: Wed, 9 Aug 2023 19:09:31 -0700 Subject: [PATCH] Fix issue with clearing sessions via `arc.http.res` Expand unit tests a bit --- arc/_lib/__init__.py | 6 +++ arc/http/helpers/res_fmt.py | 26 +++++++---- arc/http/session/__init__.py | 17 ++------ tests/test_http_res.py | 84 ++++++++++++++++++++++++++++++++++++ tests/test_http_sessions.py | 22 ++++++++++ 5 files changed, 134 insertions(+), 21 deletions(-) diff --git a/arc/_lib/__init__.py b/arc/_lib/__init__.py index 1a53a64..06e31cd 100644 --- a/arc/_lib/__init__.py +++ b/arc/_lib/__init__.py @@ -50,3 +50,9 @@ def use_aws(): # Assumed to be AWS return True + + +def get_session_table(): + return os.environ.get( + "ARC_SESSION_TABLE_NAME", os.environ.get("SESSION_TABLE_NAME") + ) diff --git a/arc/http/helpers/res_fmt.py b/arc/http/helpers/res_fmt.py index 6387fab..c1c77be 100644 --- a/arc/http/helpers/res_fmt.py +++ b/arc/http/helpers/res_fmt.py @@ -1,9 +1,11 @@ import gzip +import os from base64 import b64encode import simplejson as _json from arc.http.session import session_read, session_write from .binary_types import binary_types +from ..._lib import get_session_table def res(req, params): @@ -222,13 +224,21 @@ def compress(body): res["isBase64Encoded"] = True # Save the passed session - if params.get("session"): - sesh = params["session"] - if not sesh.get("_idx"): - session = session_read(req) - session.update(params["session"]) - sesh = session - - res["headers"]["set-cookie"] = session_write(sesh) + if params.get("session") is not None: + # In JWE any passed session payload is the new session + session = params["session"] + if get_session_table() != "jwe": + # In Dynamo, we have to figure out which session we're using + read = session_read(req) + # Set up session object prioritizing passed session payload over db + meta = { + "_idx": session.get("_idx", read.get("_idx")), + "_secret": session.get("_secret", read.get("_secret")), + "_ttl": session.get("_ttl", read.get("_ttl")), + } + # Then merge passed session payload data + session.update(meta) + + res["headers"]["set-cookie"] = session_write(session) return res diff --git a/arc/http/session/__init__.py b/arc/http/session/__init__.py index 68bac45..d42f638 100644 --- a/arc/http/session/__init__.py +++ b/arc/http/session/__init__.py @@ -3,6 +3,7 @@ from .cookies import _write_cookie, _read_cookie from .jwe import jwe_read, jwe_write from .ddb import ddb_read, ddb_write +from ..._lib import get_session_table _session_table_cache = None @@ -14,9 +15,7 @@ def _get_session_table(): if _session_table_cache and not testing: return _session_table_cache - table_name = os.environ.get( - "ARC_SESSION_TABLE_NAME", os.environ.get("SESSION_TABLE_NAME") - ) + table_name = get_session_table() if not table_name: raise TypeError( "To use sessions, ensure the session table name is specified in your ARC_SESSION_TABLE_NAME env var" @@ -37,13 +36,9 @@ def _get_session_table(): def session_read(req): - is_jwe = ( - os.environ.get("ARC_SESSION_TABLE_NAME", os.environ.get("SESSION_TABLE_NAME")) - == "jwe" - ) try: cookie = _read_cookie(req) - if is_jwe: + if get_session_table() == "jwe": return jwe_read(cookie) else: _get_session_table() @@ -53,11 +48,7 @@ def session_read(req): def session_write(payload): - is_jwe = ( - os.environ.get("ARC_SESSION_TABLE_NAME", os.environ.get("SESSION_TABLE_NAME")) - == "jwe" - ) - if is_jwe: + if get_session_table() == "jwe": cookie = jwe_write(payload) else: _get_session_table() diff --git a/tests/test_http_res.py b/tests/test_http_res.py index f353d13..9b7247c 100644 --- a/tests/test_http_res.py +++ b/tests/test_http_res.py @@ -1,8 +1,21 @@ # -*- coding: utf-8 -*- +import copy import json import arc # TODO: implement tests when Arc's req/res fixtures can be ported to py; see: https://github.com/architect/req-res-fixtures +simple_req_mock = { + "version": "2.0", + "headers": {}, +} +session_mock = {"foo": {"bar": 123}, "yak": None} + + +def make_req(res): + new = copy.deepcopy(simple_req_mock) + res_cookie = res["headers"]["set-cookie"] + new["headers"]["cookie"] = res_cookie.split(";")[0] + return new def test_method_exists(): @@ -12,3 +25,74 @@ def test_method_exists(): assert res.get("headers") assert res.get("statusCode") == 200 assert res.get("body") == json.dumps(ok) + + +def test_jwe_session(monkeypatch): + monkeypatch.setenv("ARC_SESSION_TABLE_NAME", "jwe") + + # Create + write a session + payload = {"session": copy.deepcopy(session_mock)} + response = arc.http.res(simple_req_mock, payload) + req = make_req(response) + session = arc.http.session_read(req) + del session["iat"] # delete issued at timestamp + assert session == session_mock + + # Mutate / destroy a session + payload = {"session": {}} + response = arc.http.res(simple_req_mock, payload) + req = make_req(response) + session = arc.http.session_read(req) + del session["iat"] # delete issued at timestamp + assert session == {} + + +def test_ddb_session(monkeypatch, arc_services, ddb_client): + tablename = "sessions" + monkeypatch.setenv("ARC_SESSION_TABLE_NAME", tablename) + ddb_client.create_table( + TableName=tablename, + KeySchema=[{"AttributeName": "_idx", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "_idx", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + tables = ddb_client.list_tables() + arc_services(params={f"tables/{tablename}": tables["TableNames"][0]}) + + # Create + write a session + payload = {"session": copy.deepcopy(session_mock)} + response = arc.http.res(simple_req_mock, payload) + # In the case of DynamoDB sessions, we can keep reusing the same request + req = make_req(response) + session = arc.http.session_read(req) + assert bool(session["_idx"]) + assert bool(session["_secret"]) + assert bool(session["_ttl"]) + session_match = copy.deepcopy(session) + del session_match["_idx"] + del session_match["_secret"] + del session_match["_ttl"] + assert session_match == session_mock + + # Mutate a session + payload = {"session": copy.deepcopy(session)} + payload["session"]["count"] = 0 + arc.http.res(req, payload) + mutated_session = arc.http.session_read(req) + # Ensure a new session wasn't created + assert mutated_session["_idx"] == session["_idx"] + assert mutated_session["_secret"] == session["_secret"] + assert mutated_session["_ttl"] == session["_ttl"] + assert "count" in mutated_session + assert mutated_session["count"] == 0 + + # Destroy session contents + payload = {"session": {}} + arc.http.res(req, payload) + destroyed_session = arc.http.session_read(req) + assert destroyed_session["_idx"] == session["_idx"] + assert destroyed_session["_secret"] == session["_secret"] + assert destroyed_session["_ttl"] == session["_ttl"] + assert "count" not in destroyed_session diff --git a/tests/test_http_sessions.py b/tests/test_http_sessions.py index d06832b..cfd9248 100644 --- a/tests/test_http_sessions.py +++ b/tests/test_http_sessions.py @@ -18,6 +18,7 @@ def test_jwe_read_write(): def test_jwe_session(monkeypatch): monkeypatch.setenv("ARC_SESSION_TABLE_NAME", "jwe") + # Write a session cookie = arc.http.session_write({"count": 0}) mock = { "headers": { @@ -27,6 +28,15 @@ def test_jwe_session(monkeypatch): session = arc.http.session_read(mock) assert "count" in session assert session["count"] == 0 + # Destroy a session + cookie = arc.http.session_write({}) + mock = { + "headers": { + "cookie": cookie, + }, + } + session = arc.http.session_read(mock) + assert "count" not in session def test_custom_key(monkeypatch): @@ -68,6 +78,8 @@ def test_ddb_session(monkeypatch, arc_services, ddb_client): ) tables = ddb_client.list_tables() arc_services(params={f"tables/{tablename}": tables["TableNames"][0]}) + + # Write a session payload = {"_idx": "abc", "count": 0} cookie = arc.http.session_write(payload) mock = { @@ -91,6 +103,16 @@ def test_ddb_session(monkeypatch, arc_services, ddb_client): assert "count" in session assert session["count"] == 0 + # Destroy a session + cookie = arc.http.session_write({"_idx": "abc"}) + mock = { + "headers": { + "cookie": cookie, + } + } + session = arc.http.session_read(mock) + assert "count" not in session + def test_ddb_sign_unsign(): original = "123456"