Skip to content

Commit

Permalink
Proof-of-Concept: MI via CCA
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Jun 13, 2023
1 parent 3dd50a2 commit 36f5bb8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
9 changes: 9 additions & 0 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
from .cloudshell import _is_running_in_cloud_shell
from .imds import ManagedIdentityClient, ManagedIdentity, _scope_to_resource


# The __init__.py will import this. Not the other way around.
Expand Down Expand Up @@ -2021,6 +2022,14 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
- an error response would contain "error" and usually "error_description".
"""
# TBD: force_refresh behavior
if ManagedIdentity.is_managed_identity(self.client_id):
if len(scopes) != 1:
raise ValueError("Managed Identity supports only one scope/resource")
if claims_challenge:
raise ValueError("Managed Identity does not support claims_challenge")
return ManagedIdentityClient(
self.http_client, self.client_id, self.token_cache
).acquire_token(_scope_to_resource(scopes[0]))
if self.authority.tenant.lower() in ["common", "organizations"]:
warnings.warn(
"Using /common or /organizations authority "
Expand Down
29 changes: 13 additions & 16 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

from tests.http_client import MinimalResponse
from msal import (
ConfidentialClientApplication,
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
ManagedIdentityClient)
)


class ManagedIdentityTestCase(unittest.TestCase):
Expand All @@ -39,26 +40,22 @@ class ClientTestCase(unittest.TestCase):
maxDiff = None

def setUp(self):
self.app = ManagedIdentityClient(
{ # Here we test it with the raw dict form, to test that
# the client has no hard dependency on ManagedIdentity object
"ManagedIdentityIdType": "SystemAssigned", "Id": None,
},
requests.Session(),
)
system_assigned = {"ManagedIdentityIdType": "SystemAssigned", "Id": None}
self.app = ConfidentialClientApplication(client_id=system_assigned)

def _test_token_cache(self, app):
cache = app._token_cache._cache
cache = app.token_cache._cache
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
at = list(cache["AccessToken"].values())[0]
self.assertEqual(
app._managed_identity.get("Id", "SYSTEM_ASSIGNED_MANAGED_IDENTITY"),
app.client_id.get("Id", "SYSTEM_ASSIGNED_MANAGED_IDENTITY"),
at["client_id"],
"Should have expected client_id")
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")

def _test_happy_path(self, app, mocked_http):
result = app.acquire_token_for_client(resource="R")
#result = app.acquire_token_for_client(resource="R")
result = app.acquire_token_for_client(["R"])
mocked_http.assert_called()
self.assertEqual({
"access_token": "AT",
Expand All @@ -68,29 +65,29 @@ def _test_happy_path(self, app, mocked_http):
}, result, "Should obtain a token response")
self.assertEqual(
result["access_token"],
app.acquire_token_for_client(resource="R").get("access_token"),
app.acquire_token_for_client(["R"]).get("access_token"),
"Should hit the same token from cache")
self._test_token_cache(app)


class VmTestCase(ClientTestCase):

def test_happy_path(self):
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
with patch.object(self.app.http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)

def test_vm_error_should_be_returned_as_is(self):
raw_error = '{"raw": "error format is undefined"}'
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
with patch.object(self.app.http_client, "get", return_value=MinimalResponse(
status_code=400,
text=raw_error,
)) as mocked_method:
self.assertEqual(
json.loads(raw_error), self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)
json.loads(raw_error), self.app.acquire_token_for_client(["R"]))
self.assertEqual({}, self.app.token_cache._cache)


@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})
Expand Down

0 comments on commit 36f5bb8

Please sign in to comment.