Skip to content

Commit

Permalink
Merge pull request #577 from AzureAD/silent-adjustment
Browse files Browse the repository at this point in the history
Remove acquire_token_silent(..., account=None) usage in a backward-compatible way
  • Loading branch information
rayluo authored Jul 22, 2023
2 parents e1e3d1c + 2288b77 commit 1b316e3
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 62 deletions.
90 changes: 63 additions & 27 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,32 +1209,24 @@ def acquire_token_silent(
**kwargs):
"""Acquire an access token for given account, without user interaction.
It is done either by finding a valid access token from cache,
or by finding a valid refresh token from cache and then automatically
use it to redeem a new access token.
It has same parameters as the :func:`~acquire_token_silent_with_error`.
The difference is the behavior of the return value.
This method will combine the cache empty and refresh error
into one return value, `None`.
If your app does not care about the exact token refresh error during
token cache look-up, then this method is easier and recommended.
Internally, this method calls :func:`~acquire_token_silent_with_error`.
:param claims_challenge:
The claims_challenge parameter requests specific claims requested by the resource provider
in the form of a claims_challenge directive in the www-authenticate header to be
returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token.
It is a string of a JSON object which contains lists of claims being requested from these locations.
:return:
- A dict containing no "error" key,
and typically contains an "access_token" key,
if cache lookup succeeded.
- None when cache lookup does not yield a token.
"""
result = self.acquire_token_silent_with_error(
if not account:
return None # A backward-compatible NO-OP to drop the account=None usage
result = _clean_up(self._acquire_token_silent_with_error(
scopes, account, authority=authority, force_refresh=force_refresh,
claims_challenge=claims_challenge, **kwargs)
claims_challenge=claims_challenge, **kwargs))
return result if result and "error" not in result else None

def acquire_token_silent_with_error(
Expand All @@ -1258,9 +1250,10 @@ def acquire_token_silent_with_error(
:param list[str] scopes: (Required)
Scopes requested to access a protected API (a resource).
:param account:
one of the account object returned by :func:`~get_accounts`,
or use None when you want to find an access token for this client.
:param account: (Required)
One of the account object returned by :func:`~get_accounts`.
Starting from MSAL Python 1.23,
a ``None`` input will become a NO-OP and always return ``None``.
:param force_refresh:
If True, it will skip Access Token look-up,
and try to find a Refresh Token to obtain a new Access Token.
Expand All @@ -1276,6 +1269,20 @@ def acquire_token_silent_with_error(
- None when there is simply no token in the cache.
- A dict containing an "error" key, when token refresh failed.
"""
if not account:
return None # A backward-compatible NO-OP to drop the account=None usage
return _clean_up(self._acquire_token_silent_with_error(
scopes, account, authority=authority, force_refresh=force_refresh,
claims_challenge=claims_challenge, **kwargs))

def _acquire_token_silent_with_error(
self,
scopes, # type: List[str]
account, # type: Optional[Account]
authority=None, # See get_authorization_request_url()
force_refresh=False, # type: Optional[boolean]
claims_challenge=None,
**kwargs):
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
correlation_id = msal.telemetry._get_new_correlation_id()
Expand Down Expand Up @@ -1335,7 +1342,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
force_refresh=False, # type: Optional[boolean]
claims_challenge=None,
correlation_id=None,
http_exceptions=None,
**kwargs):
# This internal method has two calling patterns:
# it accepts a non-empty account to find token for a user,
# and accepts account=None to find a token for the current app.
access_token_from_cache = None
if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims
query={
Expand Down Expand Up @@ -1372,6 +1383,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
else:
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
assert refresh_reason, "It should have been established at this point"
if not http_exceptions: # It can be a tuple of exceptions
# The exact HTTP exceptions are transportation-layer dependent
from requests.exceptions import RequestException # Lazy load
http_exceptions = (RequestException,)
try:
data = kwargs.get("data", {})
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
Expand All @@ -1391,14 +1406,19 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
if response: # The broker provided a decisive outcome, so we use it
return self._process_broker_response(response, scopes, data)

result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, self._decorate_scope(scopes), account,
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
correlation_id=correlation_id,
**kwargs))
if account:
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, self._decorate_scope(scopes), account,
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
correlation_id=correlation_id,
**kwargs)
else: # The caller is acquire_token_for_client()
result = self._acquire_token_for_client(
scopes, refresh_reason, claims_challenge=claims_challenge,
**kwargs)
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
except http_exceptions:
# Typically network error. Potential AAD outage?
if not access_token_from_cache: # It means there is no fall back option
raise # We choose to bubble up the exception
Expand Down Expand Up @@ -2007,6 +2027,9 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
"""Acquires token for the current confidential client, not for an end user.
Since MSAL Python 1.23, it will automatically look for token from cache,
and only send request to Identity Provider when cache misses.
:param list[str] scopes: (Required)
Scopes requested to access a protected API (a resource).
:param claims_challenge:
Expand All @@ -2020,24 +2043,37 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
# TBD: force_refresh behavior
if kwargs.get("force_refresh"):
raise ValueError( # We choose to disallow force_refresh
"Historically, this method does not support force_refresh behavior. "
)
return _clean_up(self._acquire_token_silent_with_error(
scopes, None, claims_challenge=claims_challenge, **kwargs))

def _acquire_token_for_client(
self,
scopes,
refresh_reason,
claims_challenge=None,
**kwargs
):
if self.authority.tenant.lower() in ["common", "organizations"]:
warnings.warn(
"Using /common or /organizations authority "
"in acquire_token_for_client() is unreliable. "
"Please use a specific tenant instead.", DeprecationWarning)
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)
self.ACQUIRE_TOKEN_FOR_CLIENT_ID, refresh_reason=refresh_reason)
client = self._regional_client or self.client
response = _clean_up(client.obtain_token_for_client(
response = client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
**kwargs))
**kwargs)
telemetry_context.update_telemetry(response)
return response

Expand Down
14 changes: 3 additions & 11 deletions sample/confidential_client_certificate_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,9 @@
# https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache
)

# The pattern to acquire a token looks like this.
result = None

# Firstly, looks up a token from cache
# Since we are looking for token for the current app, NOT for an end user,
# notice we give account parameter as None.
result = app.acquire_token_silent(config["scope"], account=None)

if not result:
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
result = app.acquire_token_for_client(scopes=config["scope"])
# Since MSAL 1.23, acquire_token_for_client(...) will automatically look up
# a token from cache, and fall back to acquire a fresh token when needed.
result = app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
# Calling graph using the access token
Expand Down
14 changes: 3 additions & 11 deletions sample/confidential_client_secret_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,9 @@
# https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache
)

# The pattern to acquire a token looks like this.
result = None

# Firstly, looks up a token from cache
# Since we are looking for token for the current app, NOT for an end user,
# notice we give account parameter as None.
result = app.acquire_token_silent(config["scope"], account=None)

if not result:
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
result = app.acquire_token_for_client(scopes=config["scope"])
# Since MSAL 1.23, acquire_token_for_client(...) will automatically look up
# a token from cache, and fall back to acquire a fresh token when needed.
result = app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
# Calling graph using the access token
Expand Down
31 changes: 25 additions & 6 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def test_aging_token_and_unavailable_aad_should_return_old_token(self):
old_at = "old AT"
self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1)
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,2|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": error}))
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(old_at, result.get("access_token"))

Expand Down Expand Up @@ -549,12 +549,31 @@ def setUpClass(cls): # Initialization at runtime, not interpret-time
authority="https://login.microsoftonline.com/common")

def test_acquire_token_for_client(self):
at = "this is an access token"
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|730,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
self.assertEqual("4|730,2|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": "AT 1",
"expires_in": 0,
}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual(at, result.get("access_token"))
self.assertEqual("AT 1", result.get("access_token"), "Shall get a new token")

def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|730,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": "AT 2",
"expires_in": 3600,
"refresh_in": -100, # A hack to make sure it will attempt refresh
}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual("AT 2", result.get("access_token"), "Shall get a new token")

def mock_post(url, headers=None, *args, **kwargs):
# 1/0 # TODO: Make sure this was called
self.assertEqual("4|730,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual("AT 2", result.get("access_token"), "Shall get aging token")

def test_acquire_token_on_behalf_of(self):
at = "this is an access token"
Expand Down
12 changes: 5 additions & 7 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,15 @@ def assertCacheWorksForApp(self, result_from_wire, scope):
json.dumps(self.app.token_cache._cache, indent=4),
json.dumps(result_from_wire.get("id_token_claims"), indent=4),
)
# Going to test acquire_token_silent(...) to locate an AT from cache
result_from_cache = self.app.acquire_token_silent(scope, account=None)
self.assertIsNone(
self.app.acquire_token_silent(scope, account=None),
"acquire_token_silent(..., account=None) shall always return None")
# Going to test acquire_token_for_client(...) to locate an AT from cache
result_from_cache = self.app.acquire_token_for_client(scope)
self.assertIsNotNone(result_from_cache)
self.assertEqual(
result_from_wire['access_token'], result_from_cache['access_token'],
"We should get a cached AT")
self.app.acquire_token_silent(
# Result will typically be None, because client credential grant returns no RT.
# But we care more on this call should succeed without exception.
scope, account=None,
force_refresh=True) # Mimic the AT already expires

@classmethod
def _build_app(cls,
Expand Down

0 comments on commit 1b316e3

Please sign in to comment.