From 94fe48904213f1f9d28e64c23eba181927ec7bd3 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Mon, 19 Aug 2024 12:34:56 -0700 Subject: [PATCH] PoC: req_ds_cnf Refactor req_ds_cnf to delegation_scope_key Move logic to allow token cache to work wip --- msal/application.py | 68 ++++++++++++++++++++++++++++++-- msal/crypto.py | 23 +++++++++++ msal/token_cache.py | 14 ++++++- tests/test_application.py | 83 +++++++++++++++++++++++++++++++++++++++ tests/test_crypto.py | 12 ++++++ tests/test_e2e.py | 29 +++++++++++--- 6 files changed, 220 insertions(+), 9 deletions(-) create mode 100644 msal/crypto.py create mode 100644 tests/test_crypto.py diff --git a/msal/application.py b/msal/application.py index 260d80e0..0d208ea9 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,3 +1,5 @@ +import base64 +import datetime import functools import json import time @@ -165,6 +167,17 @@ def _preferred_browser(): return None +def _build_req_cnf(jwk:dict, remove_padding:bool = False) -> str: + """req_cnf usually requires base64url encoding. + + https://datatracker.ietf.org/doc/html/draft-ietf-oauth-pop-key-distribution-07#section-4.2.1 + https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e967ebeb-9e9f-443e-857a-5208802943c2 + """ + raw = json.dumps(jwk) + encoded = base64.urlsafe_b64encode(raw.encode('utf-8')).decode('utf-8') + return encoded.rstrip('=') if remove_padding else encoded + + class _ClientWithCcsRoutingInfo(Client): def initiate_auth_code_flow(self, **kwargs): @@ -231,6 +244,7 @@ class ClientApplication(object): _TOKEN_SOURCE_IDP = "identity_provider" _TOKEN_SOURCE_CACHE = "cache" _TOKEN_SOURCE_BROKER = "broker" + _XMS_DS_NONCE = "xms_ds_nonce" _enable_broker = False _AUTH_SCHEME_UNSUPPORTED = ( @@ -238,6 +252,14 @@ class ClientApplication(object): "You can enable broker by following these instructions. " "https://msal-python.readthedocs.io/en/latest/#publicclientapplication") + @functools.lru_cache(maxsize=2) + def __get_rsa_key(self, _bucket): # _bucket is used with lru_cache pattern + from .crypto import _generate_rsa_key + return _generate_rsa_key() + + def _get_rsa_key(self, _bucket=None): # Return the same RSA key, cached for a day + return self.__get_rsa_key(_bucket or datetime.date.today()) + def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, @@ -1552,6 +1574,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( "expires_in": int(expires_in), # OAuth2 specs defines it as int self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE, } + if self._XMS_DS_NONCE in entry: # CDT needs this + access_token_from_cache[self._XMS_DS_NONCE] = entry[ + self._XMS_DS_NONCE] if "refresh_on" in entry: access_token_from_cache["refresh_on"] = int(entry["refresh_on"]) if int(entry["refresh_on"]) < now: # aging @@ -2340,7 +2365,16 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app except that ``allow_broker`` parameter shall remain ``None``. """ - def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): + def acquire_token_for_client( + self, + scopes, + claims_challenge=None, + *, + delegation_constraints: Optional[list] = None, + delegation_confirmation_key=None, # A Cyprtography's RSAPrivateKey-like object + # TODO: Support ECC key? https://github.com/pyca/cryptography/issues/4093 + **kwargs + ): """Acquires token for the current confidential client, not for an end user. Since MSAL Python 1.23, it will automatically look for token from cache, @@ -2363,8 +2397,36 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): raise ValueError( # We choose to disallow force_refresh "Historically, this method does not support force_refresh behavior. " ) - return _clean_up(self._acquire_token_silent_with_error( - scopes, None, claims_challenge=claims_challenge, **kwargs)) + if delegation_constraints: + private_key = delegation_confirmation_key or self._get_rsa_key() + from .crypto import _convert_rsa_keys + _, jwk = _convert_rsa_keys(private_key) + result = _clean_up(self._acquire_token_silent_with_error( + scopes, None, claims_challenge=claims_challenge, data=dict( + kwargs.pop("data", {}), + req_ds_cnf=_build_req_cnf(jwk) # It is part of token cache key + if delegation_constraints else None, + ), + **kwargs)) + if delegation_constraints and not result.get("error"): + if not result.get(self._XMS_DS_NONCE): # Available in cached token, too + raise ValueError( + "The resource did not opt in to xms_ds_cnf claim. " + "After its opt-in, call this function again with " + "a new app object or a new delegation_confirmation_key" + # in order to invalidate the token in cache + ) + import jwt # Lazy loading + cdt_envelope = jwt.encode({ + "constraints": delegation_constraints, + self._XMS_DS_NONCE: result[self._XMS_DS_NONCE], + }, private_key, algorithm="PS256") + result["access_token"] = jwt.encode({ + "t": result["access_token"], + "c": cdt_envelope, + }, None, algorithm=None, headers={"typ": "cdt+jwt"}) + del result[self._XMS_DS_NONCE] # Caller shouldn't need to know that + return result def _acquire_token_for_client( self, diff --git a/msal/crypto.py b/msal/crypto.py new file mode 100644 index 00000000..dedb78b6 --- /dev/null +++ b/msal/crypto.py @@ -0,0 +1,23 @@ +from base64 import urlsafe_b64encode + +from cryptography.hazmat.primitives.asymmetric import rsa + + +def _urlsafe_b64encode(n:int, bit_size:int) -> str: + return urlsafe_b64encode(n.to_bytes(length=int(bit_size/8))).decode("utf-8") + + +def _to_jwk(public_key: rsa.RSAPublicKey) -> dict: + numbers = public_key.public_numbers() + return { + "kty": "RSA", + "n": _urlsafe_b64encode(numbers.n, public_key.key_size), + "e": _urlsafe_b64encode(numbers.e, 24), # TODO: TBD. PyJWT/jwt/algorithms.py RSAAlgorithm.to_jwk() + } + +def _convert_rsa_keys(private_key: rsa.RSAPrivateKey): + return "pairs.private_bytes()", _to_jwk(private_key.public_key()) + +def _generate_rsa_key() -> rsa.RSAPrivateKey: + return rsa.generate_private_key(public_exponent=65537, key_size=2048) + diff --git a/msal/token_cache.py b/msal/token_cache.py index 66be5c9f..0d93da21 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -1,4 +1,5 @@ -import json +import hashlib +import json import threading import time import logging @@ -61,6 +62,7 @@ def __init__(self): realm=None, target=None, # Note: New field(s) can be added here #key_id=None, + req_ds_cnf=None, **ignored_payload_from_a_real_token: "-".join([ # Note: Could use a hash here to shorten key length home_account_id or "", @@ -70,6 +72,13 @@ def __init__(self): realm or "", target or "", #key_id or "", # So ATs of different key_id can coexist + hashlib.sha256(req_ds_cnf.encode()).hexdigest() + # TODO: Could hash the entire key eventually. + # But before that project, we better first + # change the scope to use input scope + # instead of response scope, + # so that a search() can probably have O(1) hit. + if req_ds_cnf else "", # CDT ]).lower(), self.CredentialType.ID_TOKEN: lambda home_account_id=None, environment=None, client_id=None, @@ -267,10 +276,13 @@ def __add(self, event, now=None): "expires_on": str(now + expires_in), # Same here "extended_expires_on": str(now + ext_expires_in) # Same here } + if response.get("xms_ds_nonce"): # Available for CDT + at["xms_ds_nonce"] = response["xms_ds_nonce"] at.update({k: data[k] for k in data if k in { # Also store extra data which we explicitly allow # So that we won't accidentally store a user's password etc. "key_id", # It happens in SSH-cert or POP scenario + "req_ds_cnf", # Used in CDT }}) if "refresh_in" in response: refresh_in = response["refresh_in"] # It is an integer diff --git a/tests/test_application.py b/tests/test_application.py index e565e105..70806bbe 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -11,6 +11,7 @@ ClientApplication, PublicClientApplication, ConfidentialClientApplication, _str2bytes, _merge_claims_challenge_and_capabilities, ) +from msal.oauth2cli.oidc import decode_part from tests import unittest from tests.test_token_cache import build_id_token, build_response from tests.http_client import MinimalHttpClient, MinimalResponse @@ -856,3 +857,85 @@ def test_app_did_not_register_redirect_uri_should_error_out(self): ) self.assertEqual(result.get("error"), "broker_error") + +class CdtTestCase(unittest.TestCase): + + def createConstraint(self, typ: str, action: str, targets: list[str]) -> dict: + return {"ver": "1.0", "typ": typ, "a": action, "target": [ + {"val": t} for t in targets + ]} + + def test_constraint_format(self): + self.assertEqual([ + self.createConstraint("ns:usr", "create", ["guid1", "guid2"]), + self.createConstraint("ns:app", "update", ["guid3", "guid4"]), + self.createConstraint("ns:subscription", "read", ["guid5", "guid6"]), + ], [ # Format defined in https://microsoft-my.sharepoint-df.com/:w:/p/rohitshende/EZgP9niwOvhKn-CUbj1NgG4BTZ6FSD9_16vXvsaXTiUzkg?e=j5DcQu&nav=eyJoIjoiODU5NDAyNjI4In0 + {"ver": "1.0", "typ": "ns:usr", "a": "create", "target": [ + {"val": "guid1"}, {"val": "guid2"}, + ], + }, + {"ver": "1.0", "typ": "ns:app", "a": "update", "target": [ + {"val": "guid3"}, {"val": "guid4"}, + ], + }, + {"ver": "1.0", "typ": "ns:subscription", "a": "read", "target": [ + {"val": "guid5"}, {"val": "guid6"}, + ], + }, + ], "Constraint format is correct") # MSAL actually accepts arbitrary JSON blob + + def assertCdt(self, result: dict, constraints: list[dict]) -> None: + self.assertIsNotNone( + result.get("access_token"), "Encountered {}: {}".format( + result.get("error"), result.get("error_description"))) + _expectancy = "The return value should look like a Bearer response" + self.assertEqual(result["token_type"], "Bearer", _expectancy) + self.assertNotIn("xms_ds_nonce", result, _expectancy) + headers = json.loads(decode_part(result["access_token"].split(".")[0])) + self.assertEqual(headers.get("typ"), "cdt+jwt", "typ should be cdt+jwt") + payload = json.loads(decode_part(result["access_token"].split(".")[1])) + self.assertIsNotNone(payload.get("t") and payload.get("c")) + cdt_envelope = json.loads(decode_part(payload["c"].split(".")[1])) + self.assertIn("xms_ds_nonce", cdt_envelope) + self.assertEqual(cdt_envelope["constraints"], constraints) + + def assertAppObtainsCdt(self, client_app, scopes) -> None: + constraints1 = [self.createConstraint("ns:usr", "create", ["guid1"])] + result = client_app.acquire_token_for_client( + scopes, delegation_constraints=constraints1, + ) + self.assertCdt(result, constraints1) + + constraints2 = [self.createConstraint("ns:app", "update", ["guid2"])] + result = client_app.acquire_token_for_client( + scopes, delegation_constraints=constraints2, + ) + self.assertEqual(result["token_source"], "cache", "App token Should hit cache") + self.assertCdt(result, constraints2) + + result = client_app.acquire_token_for_client( + scopes, delegation_constraints=constraints2, + delegation_confirmation_key=client_app._get_rsa_key("new"), + ) + self.assertEqual( + result["token_source"], "identity_provider", + "Different key should result in a new app token") + self.assertCdt(result, constraints2) + + @patch("msal.authority.tenant_discovery", new=Mock(return_value={ + "authorization_endpoint": "https://contoso.com/placeholder", + "token_endpoint": "https://contoso.com/placeholder", + })) + def test_acquire_token_for_client_should_return_a_cdt(self): + app = msal.ConfidentialClientApplication("id", client_credential="secret") + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=200, text=json.dumps({ + "token_type": "Bearer", + "access_token": "app token", + "expires_in": 3600, + "xms_ds_nonce": "nonce", + }))) as mocked_post: + self.assertAppObtainsCdt(app, ["scope1", "scope2"]) + mocked_post.assert_called_once() + diff --git a/tests/test_crypto.py b/tests/test_crypto.py new file mode 100644 index 00000000..b3b589be --- /dev/null +++ b/tests/test_crypto.py @@ -0,0 +1,12 @@ +from unittest import TestCase + +from msal.crypto import _generate_rsa_key, _convert_rsa_keys + + +class CryptoTestCase(TestCase): + def test_key_generation(self): + key = _generate_rsa_key() + _, jwk = _convert_rsa_keys(key) + self.assertEqual(jwk.get("kty"), "RSA") + self.assertIsNotNone(jwk.get("n") and jwk.get("e")) + diff --git a/tests/test_e2e.py b/tests/test_e2e.py index a0796547..753559bf 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -27,8 +27,10 @@ import msal from tests.http_client import MinimalHttpClient, MinimalResponse +from tests.test_application import CdtTestCase from msal.oauth2cli import AuthCodeReceiver from msal.oauth2cli.oidc import decode_part +from msal.application import _build_req_cnf try: import pymsalruntime @@ -533,7 +535,7 @@ def tearDownClass(cls): cls.session.close() @classmethod - def get_lab_app_object(cls, client_id=None, **query): # https://msidlab.com/swagger/index.html + def get_lab_app_object(cls, client_id=None, **query) -> dict: # https://msidlab.com/swagger/index.html url = "https://msidlab.com/api/app/{}".format(client_id or "") resp = cls.session.get(url, params=query) result = resp.json()[0] @@ -791,12 +793,12 @@ def test_user_account(self): self._test_user_account() -def _data_for_pop(key): - raw_req_cnf = json.dumps({"kid": key, "xms_ksl": "sw"}) +def _data_for_pop(key_id): return { # Sampled from Azure CLI's plugin connectedk8s 'token_type': 'pop', - 'key_id': key, - "req_cnf": base64.urlsafe_b64encode(raw_req_cnf.encode('utf-8')).decode('utf-8').rstrip('='), + 'key_id': key_id, + "req_cnf": _build_req_cnf( + {"kid": key_id, "xms_ksl": "sw"}, remove_padding=True), # Note: Sending raw_req_cnf without base64 encoding would result in an http 500 error } # See also https://github.com/Azure/azure-cli-extensions/blob/main/src/connectedk8s/azext_connectedk8s/_clientproxyutils.py#L86-L92 @@ -817,6 +819,23 @@ def test_user_account(self): self._test_user_account() +class CdtTestCase(LabBasedTestCase, CdtTestCase): + def test_acquire_token_for_client_should_return_a_cdt(self): + resource = self.get_lab_app_object( # This resource has opted in to CDT + publicClient="no", signinAudience="AzureAdMyOrg") + client_app = msal.ConfidentialClientApplication( + # Any CCA can use a CDT, as long as the resource opted in for a CDT + # Here we use the OBO app which is in same tenant as the resource. + os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"), + client_credential=os.getenv("LAB_OBO_CLIENT_SECRET"), + authority="{}{}.onmicrosoft.com".format( + resource["authority"], + resource["labName"].lower().rstrip(".com"), + ), + ) + self.assertAppObtainsCdt(client_app, [f"{resource['appId']}/.default"]) + + class WorldWideTestCase(LabBasedTestCase): def test_aad_managed_user(self): # Pure cloud