diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 37c0de3db1..b42471e102 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,28 +1,28 @@ -from base64 import b64encode import math +from base64 import b64encode from typing import ( - List, + TYPE_CHECKING, + Any, Dict, Final, + Iterable, + List, Literal, Optional, Union, - Any, cast, - Iterable, - TYPE_CHECKING, ) -from requests.auth import AuthBase + from requests import PreparedRequest # noqa: I251 +from requests.auth import AuthBase from dlt.common import logger -from dlt.common.exceptions import MissingDependencyException -from dlt.common.configuration.specs.base_configuration import configspec from dlt.common.configuration.specs import CredentialsConfiguration +from dlt.common.configuration.specs.base_configuration import configspec from dlt.common.configuration.specs.exceptions import NativeValueError +from dlt.common.exceptions import MissingDependencyException from dlt.common.pendulum import pendulum from dlt.common.typing import TSecretStrValue - from dlt.sources.helpers import requests if TYPE_CHECKING: @@ -138,6 +138,59 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: return request +@configspec +class OAuth2ImplicitFlow(OAuth2AuthBase): + """ + This class implements OAuth2 implicit flow which does not require the end user to + give permission to the app to access their information. This is often used for + server-to-server authentication. + The client obtains a temporary access token from the authorization service. + With the access token, the client can access resource services + """ + + def __init__( + self, + access_token_url: str, + access_token_request_data, + client_id: TSecretStrValue, + client_secret: TSecretStrValue, + access_token: TSecretStrValue = None, + default_token_expiration: int = 3600, + ) -> None: + super().__init__(access_token) + self.access_token_request_data = access_token_request_data + self.access_token_url = access_token_url + self.client_id = client_id + self.client_secret = client_secret + self.default_token_expiration = default_token_expiration + self.token_expiry: pendulum.DateTime = pendulum.now() + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + if self.access_token is None or self.is_token_expired(): + self.obtain_token() + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + def is_token_expired(self) -> bool: + return pendulum.now() >= self.token_expiry + + def obtain_token(self) -> None: + authentication: str = b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode() + + response = requests.post( + url=self.access_token_url, + headers={ + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + data=self.access_token_request_data, + ) + response.raise_for_status() + self.access_token = response.json()["access_token"] + expires_in = response.json().get("expires_in", self.default_token_expiration) + self.token_expiry = pendulum.now().add(seconds=expires_in) + + @configspec class OAuthJWTAuth(BearerTokenAuth): """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""