Skip to content

Commit

Permalink
PoC: req_ds_cnf
Browse files Browse the repository at this point in the history
Refactor req_ds_cnf to delegation_scope_key

Move logic to allow token cache to work

wip
  • Loading branch information
rayluo committed Oct 18, 2024
1 parent 6d80cc5 commit b69dda8
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 10 deletions.
68 changes: 65 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import datetime
import functools
import json
import time
Expand Down Expand Up @@ -165,6 +167,17 @@ def _preferred_browser():
return None


def _build_req_cnf(jwk:dict, remove_padding:bool = False) -> str:
"""req_cnf usually requires base64url encoding.
https://datatracker.ietf.org/doc/html/draft-ietf-oauth-pop-key-distribution-07#section-4.2.1
https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e967ebeb-9e9f-443e-857a-5208802943c2
"""
raw = json.dumps(jwk)
encoded = base64.urlsafe_b64encode(raw.encode('utf-8')).decode('utf-8')
return encoded.rstrip('=') if remove_padding else encoded


class _ClientWithCcsRoutingInfo(Client):

def initiate_auth_code_flow(self, **kwargs):
Expand Down Expand Up @@ -231,13 +244,22 @@ class ClientApplication(object):
_TOKEN_SOURCE_IDP = "identity_provider"
_TOKEN_SOURCE_CACHE = "cache"
_TOKEN_SOURCE_BROKER = "broker"
_XMS_DS_NONCE = "xms_ds_nonce"

_enable_broker = False
_AUTH_SCHEME_UNSUPPORTED = (
"auth_scheme is currently only available from broker. "
"You can enable broker by following these instructions. "
"https://msal-python.readthedocs.io/en/latest/#publicclientapplication")

@functools.lru_cache(maxsize=2)
def __get_rsa_key(self, _bucket): # _bucket is used with lru_cache pattern
from .crypto import _generate_rsa_key
return _generate_rsa_key()

def _get_rsa_key(self, _bucket=None): # Return the same RSA key, cached for a day
return self.__get_rsa_key(_bucket or datetime.date.today())

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
Expand Down Expand Up @@ -1552,6 +1574,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"expires_in": int(expires_in), # OAuth2 specs defines it as int
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
}
if self._XMS_DS_NONCE in entry: # CDT needs this
access_token_from_cache[self._XMS_DS_NONCE] = entry[
self._XMS_DS_NONCE]
if "refresh_on" in entry:
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
if int(entry["refresh_on"]) < now: # aging
Expand Down Expand Up @@ -2340,7 +2365,16 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
except that ``allow_broker`` parameter shall remain ``None``.
"""

def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
def acquire_token_for_client(
self,
scopes,
claims_challenge=None,
*,
delegation_constraints: Optional[list] = None,
delegation_confirmation_key=None, # A Cyprtography's RSAPrivateKey-like object
# TODO: Support ECC key? https://github.com/pyca/cryptography/issues/4093
**kwargs
):
"""Acquires token for the current confidential client, not for an end user.
Since MSAL Python 1.23, it will automatically look for token from cache,
Expand All @@ -2363,8 +2397,36 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
raise ValueError( # We choose to disallow force_refresh
"Historically, this method does not support force_refresh behavior. "
)
return _clean_up(self._acquire_token_silent_with_error(
scopes, None, claims_challenge=claims_challenge, **kwargs))
if delegation_constraints:
private_key = delegation_confirmation_key or self._get_rsa_key()
from .crypto import _convert_rsa_keys
_, jwk = _convert_rsa_keys(private_key)
result = _clean_up(self._acquire_token_silent_with_error(
scopes, None, claims_challenge=claims_challenge, data=dict(
kwargs.pop("data", {}),
req_ds_cnf=_build_req_cnf(jwk) # It is part of token cache key
if delegation_constraints else None,
),
**kwargs))
if delegation_constraints and not result.get("error"):
if not result.get(self._XMS_DS_NONCE): # Available in cached token, too
raise ValueError(
"The resource did not opt in to xms_ds_cnf claim. "
"After its opt-in, call this function again with "
"a new app object or a new delegation_confirmation_key"
# in order to invalidate the token in cache
)
import jwt # Lazy loading
cdt_envelope = jwt.encode({
"constraints": delegation_constraints,
self._XMS_DS_NONCE: result[self._XMS_DS_NONCE],
}, private_key, algorithm="PS256")
result["access_token"] = jwt.encode({
"t": result["access_token"],
"c": cdt_envelope,
}, None, algorithm=None, headers={"typ": "cdt+jwt"})
del result[self._XMS_DS_NONCE] # Caller shouldn't need to know that
return result

def _acquire_token_for_client(
self,
Expand Down
37 changes: 37 additions & 0 deletions msal/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from cryptography.hazmat.primitives.asymmetric import rsa


def _urlsafe_b64encode(n:int, bit_size:int) -> str:
from base64 import urlsafe_b64encode
return urlsafe_b64encode(n.to_bytes(
length=int(bit_size/8),
byteorder="big",
)).decode("utf-8").rstrip("=")


def _to_jwk(public_key: rsa.RSAPublicKey) -> dict:
"""Equivalent to:
numbers = public_key.public_numbers()
result = {
"kty": "RSA",
"n": _urlsafe_b64encode(numbers.n, public_key.key_size),
"e": _urlsafe_b64encode(numbers.e, 24),
}
return result
"""
import jwt
return jwt.get_algorithm_by_name( # PyJWT 2.5.0 https://github.com/jpadilla/pyjwt/releases/tag/2.5.0
"RS256"
).to_jwk(
public_key,
as_dict=True, # PyJWT 2.7.0 https://github.com/jpadilla/pyjwt/releases/tag/2.7.0
)

def _convert_rsa_keys(private_key: rsa.RSAPrivateKey):
return "pairs.private_bytes()", _to_jwk(private_key.public_key())

def _generate_rsa_key() -> rsa.RSAPrivateKey:
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/#cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key
return rsa.generate_private_key(public_exponent=65537, key_size=2048)

14 changes: 13 additions & 1 deletion msal/token_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import hashlib
import json
import threading
import time
import logging
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self):
realm=None, target=None,
# Note: New field(s) can be added here
#key_id=None,
req_ds_cnf=None,
**ignored_payload_from_a_real_token:
"-".join([ # Note: Could use a hash here to shorten key length
home_account_id or "",
Expand All @@ -70,6 +72,13 @@ def __init__(self):
realm or "",
target or "",
#key_id or "", # So ATs of different key_id can coexist
hashlib.sha256(req_ds_cnf.encode()).hexdigest()
# TODO: Could hash the entire key eventually.
# But before that project, we better first
# change the scope to use input scope
# instead of response scope,
# so that a search() can probably have O(1) hit.
if req_ds_cnf else "", # CDT
]).lower(),
self.CredentialType.ID_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
Expand Down Expand Up @@ -267,10 +276,13 @@ def __add(self, event, now=None):
"expires_on": str(now + expires_in), # Same here
"extended_expires_on": str(now + ext_expires_in) # Same here
}
if response.get("xms_ds_nonce"): # Available for CDT
at["xms_ds_nonce"] = response["xms_ds_nonce"]
at.update({k: data[k] for k in data if k in {
# Also store extra data which we explicitly allow
# So that we won't accidentally store a user's password etc.
"key_id", # It happens in SSH-cert or POP scenario
"req_ds_cnf", # Used in CDT
}})
if "refresh_in" in response:
refresh_in = response["refresh_in"] # It is an integer
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ install_requires =

# MSAL does not use jwt.decode(),
# therefore is insusceptible to CVE-2022-29217 so no need to bump to PyJWT 2.4+
PyJWT[crypto]>=1.0.0,<3
PyJWT[crypto]>=2.7.0,<3

# load_key_and_certificates() is available since 2.5
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/serialization/#cryptography.hazmat.primitives.serialization.pkcs12.load_key_and_certificates
Expand Down
83 changes: 83 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ClientApplication, PublicClientApplication, ConfidentialClientApplication,
_str2bytes, _merge_claims_challenge_and_capabilities,
)
from msal.oauth2cli.oidc import decode_part
from tests import unittest
from tests.test_token_cache import build_id_token, build_response
from tests.http_client import MinimalHttpClient, MinimalResponse
Expand Down Expand Up @@ -856,3 +857,85 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
)
self.assertEqual(result.get("error"), "broker_error")


class CdtTestCase(unittest.TestCase):

def createConstraint(self, typ: str, action: str, targets: list[str]) -> dict:
return {"ver": "1.0", "typ": typ, "a": action, "target": [
{"val": t} for t in targets
]}

def test_constraint_format(self):
self.assertEqual([
self.createConstraint("ns:usr", "create", ["guid1", "guid2"]),
self.createConstraint("ns:app", "update", ["guid3", "guid4"]),
self.createConstraint("ns:subscription", "read", ["guid5", "guid6"]),
], [ # Format defined in https://microsoft-my.sharepoint-df.com/:w:/p/rohitshende/EZgP9niwOvhKn-CUbj1NgG4BTZ6FSD9_16vXvsaXTiUzkg?e=j5DcQu&nav=eyJoIjoiODU5NDAyNjI4In0
{"ver": "1.0", "typ": "ns:usr", "a": "create", "target": [
{"val": "guid1"}, {"val": "guid2"},
],
},
{"ver": "1.0", "typ": "ns:app", "a": "update", "target": [
{"val": "guid3"}, {"val": "guid4"},
],
},
{"ver": "1.0", "typ": "ns:subscription", "a": "read", "target": [
{"val": "guid5"}, {"val": "guid6"},
],
},
], "Constraint format is correct") # MSAL actually accepts arbitrary JSON blob

def assertCdt(self, result: dict, constraints: list[dict]) -> None:
self.assertIsNotNone(
result.get("access_token"), "Encountered {}: {}".format(
result.get("error"), result.get("error_description")))
_expectancy = "The return value should look like a Bearer response"
self.assertEqual(result["token_type"], "Bearer", _expectancy)
self.assertNotIn("xms_ds_nonce", result, _expectancy)
headers = json.loads(decode_part(result["access_token"].split(".")[0]))
self.assertEqual(headers.get("typ"), "cdt+jwt", "typ should be cdt+jwt")
payload = json.loads(decode_part(result["access_token"].split(".")[1]))
self.assertIsNotNone(payload.get("t") and payload.get("c"))
cdt_envelope = json.loads(decode_part(payload["c"].split(".")[1]))
self.assertIn("xms_ds_nonce", cdt_envelope)
self.assertEqual(cdt_envelope["constraints"], constraints)

def assertAppObtainsCdt(self, client_app, scopes) -> None:
constraints1 = [self.createConstraint("ns:usr", "create", ["guid1"])]
result = client_app.acquire_token_for_client(
scopes, delegation_constraints=constraints1,
)
self.assertCdt(result, constraints1)

constraints2 = [self.createConstraint("ns:app", "update", ["guid2"])]
result = client_app.acquire_token_for_client(
scopes, delegation_constraints=constraints2,
)
self.assertEqual(result["token_source"], "cache", "App token Should hit cache")
self.assertCdt(result, constraints2)

result = client_app.acquire_token_for_client(
scopes, delegation_constraints=constraints2,
delegation_confirmation_key=client_app._get_rsa_key("new"),
)
self.assertEqual(
result["token_source"], "identity_provider",
"Different key should result in a new app token")
self.assertCdt(result, constraints2)

@patch("msal.authority.tenant_discovery", new=Mock(return_value={
"authorization_endpoint": "https://contoso.com/placeholder",
"token_endpoint": "https://contoso.com/placeholder",
}))
def test_acquire_token_for_client_should_return_a_cdt(self):
app = msal.ConfidentialClientApplication("id", client_credential="secret")
with patch.object(app.http_client, "post", return_value=MinimalResponse(
status_code=200, text=json.dumps({
"token_type": "Bearer",
"access_token": "app token",
"expires_in": 3600,
"xms_ds_nonce": "nonce",
}))) as mocked_post:
self.assertAppObtainsCdt(app, ["scope1", "scope2"])
mocked_post.assert_called_once()

12 changes: 12 additions & 0 deletions tests/test_crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from unittest import TestCase

from msal.crypto import _generate_rsa_key, _convert_rsa_keys


class CryptoTestCase(TestCase):
def test_key_generation(self):
key = _generate_rsa_key()
_, jwk = _convert_rsa_keys(key)
self.assertEqual(jwk.get("kty"), "RSA")
self.assertIsNotNone(jwk.get("n") and jwk.get("e"))

29 changes: 24 additions & 5 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

import msal
from tests.http_client import MinimalHttpClient, MinimalResponse
from tests.test_application import CdtTestCase
from msal.oauth2cli import AuthCodeReceiver
from msal.oauth2cli.oidc import decode_part
from msal.application import _build_req_cnf

try:
import pymsalruntime
Expand Down Expand Up @@ -533,7 +535,7 @@ def tearDownClass(cls):
cls.session.close()

@classmethod
def get_lab_app_object(cls, client_id=None, **query): # https://msidlab.com/swagger/index.html
def get_lab_app_object(cls, client_id=None, **query) -> dict: # https://msidlab.com/swagger/index.html
url = "https://msidlab.com/api/app/{}".format(client_id or "")
resp = cls.session.get(url, params=query)
result = resp.json()[0]
Expand Down Expand Up @@ -791,12 +793,12 @@ def test_user_account(self):
self._test_user_account()


def _data_for_pop(key):
raw_req_cnf = json.dumps({"kid": key, "xms_ksl": "sw"})
def _data_for_pop(key_id):
return { # Sampled from Azure CLI's plugin connectedk8s
'token_type': 'pop',
'key_id': key,
"req_cnf": base64.urlsafe_b64encode(raw_req_cnf.encode('utf-8')).decode('utf-8').rstrip('='),
'key_id': key_id,
"req_cnf": _build_req_cnf(
{"kid": key_id, "xms_ksl": "sw"}, remove_padding=True),
# Note: Sending raw_req_cnf without base64 encoding would result in an http 500 error
} # See also https://github.com/Azure/azure-cli-extensions/blob/main/src/connectedk8s/azext_connectedk8s/_clientproxyutils.py#L86-L92

Expand All @@ -817,6 +819,23 @@ def test_user_account(self):
self._test_user_account()


class CdtTestCase(LabBasedTestCase, CdtTestCase):
def test_acquire_token_for_client_should_return_a_cdt(self):
resource = self.get_lab_app_object( # This resource has opted in to CDT
publicClient="no", signinAudience="AzureAdMyOrg")
client_app = msal.ConfidentialClientApplication(
# Any CCA can use a CDT, as long as the resource opted in for a CDT
# Here we use the OBO app which is in same tenant as the resource.
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
client_credential=os.getenv("LAB_OBO_CLIENT_SECRET"),
authority="{}{}.onmicrosoft.com".format(
resource["authority"],
resource["labName"].lower().rstrip(".com"),
),
)
self.assertAppObtainsCdt(client_app, [f"{resource['appId']}/.default"])


class WorldWideTestCase(LabBasedTestCase):

def test_aad_managed_user(self): # Pure cloud
Expand Down

0 comments on commit b69dda8

Please sign in to comment.