diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 29e6d8c77a..d2ca1c1ca6 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,17 +1,18 @@ -from base64 import b64encode -import dataclasses import math +import dataclasses +from abc import abstractmethod +from base64 import b64encode from typing import ( - List, + TYPE_CHECKING, + Any, Dict, Final, + Iterable, + List, Literal, Optional, Union, - Any, cast, - Iterable, - TYPE_CHECKING, ) from typing_extensions import Annotated from requests.auth import AuthBase @@ -24,7 +25,6 @@ from dlt.common.configuration.specs.exceptions import NativeValueError from dlt.common.pendulum import pendulum from dlt.common.typing import TSecretStrValue - from dlt.sources.helpers import requests if TYPE_CHECKING: @@ -144,6 +144,76 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: return request +@configspec +class OAuth2ClientCredentials(OAuth2AuthBase): + """ + This class implements OAuth2 Client Credentials flow where the autorization service + gives permission without the end user approving. + This is often used for machine-to-machine authorization. + The client sends its client ID and client secret to the authorization service which replies + with a temporary access token. + With the access token, the client can access resource services. + """ + + def __init__( + self, + access_token_url: TSecretStrValue, + client_id: TSecretStrValue, + client_secret: TSecretStrValue, + access_token_request_data: Dict[str, Any] = None, + default_token_expiration: int = 3600, + session: Annotated[BaseSession, NotResolved()] = None, + ) -> None: + super().__init__() + self.access_token_url = access_token_url + self.client_id = client_id + self.client_secret = client_secret + if access_token_request_data is None: + self.access_token_request_data = {} + else: + self.access_token_request_data = access_token_request_data + self.default_token_expiration = default_token_expiration + self.token_expiry: pendulum.DateTime = pendulum.now() + + self.session = session if session is not None else requests.client.session + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + if self.access_token is None or self.is_token_expired(): + self.obtain_token() + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + def is_token_expired(self) -> bool: + return pendulum.now() >= self.token_expiry + + def obtain_token(self) -> None: + response = self.session.post(self.access_token_url, **self.build_access_token_request()) + response.raise_for_status() + response_json = response.json() + self.parse_native_representation(self.parse_access_token(response_json)) + expires_in_seconds = self.parse_expiration_in_seconds(response_json) + self.token_expiry = pendulum.now().add(seconds=expires_in_seconds) + + def build_access_token_request(self) -> Dict[str, Any]: + return { + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "client_credentials", + **self.access_token_request_data, + }, + } + + def parse_expiration_in_seconds(self, response_json: Any) -> int: + return int(response_json.get("expires_in", self.default_token_expiration)) + + def parse_access_token(self, response_json: Any) -> str: + return str(response_json.get("access_token")) + + @configspec class OAuthJWTAuth(BearerTokenAuth): """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI""" @@ -164,7 +234,7 @@ def __post_init__(self) -> None: self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) self.token = None self.token_expiry: Optional[pendulum.DateTime] = None - # use default system session is not specified + # use default system session unless specified otherwise if self.session is None: self.session = requests.client.session diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md index 98725627b9..11d09c89f7 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md @@ -416,6 +416,7 @@ Available authentication types: | [BearTokenAuth](../../general-usage/http/rest-client.md#bearer-token-authentication) | `bearer` | Bearer token authentication. | | [HTTPBasicAuth](../../general-usage/http/rest-client.md#http-basic-authentication) | `http_basic` | Basic HTTP authentication. | | [APIKeyAuth](../../general-usage/http/rest-client.md#api-key-authentication) | `api_key` | API key authentication with key defined in the query parameters or in the headers. | +| [OAuth2ClientCredentials](../../general-usage/http/rest-client.md#oauth20-authorization) | N/A | OAuth 2.0 authorization with a temporary access token obtained from the authorization server. | To specify the authentication configuration, use the `auth` field in the [client](#client) configuration: diff --git a/docs/website/docs/general-usage/http/rest-client.md b/docs/website/docs/general-usage/http/rest-client.md index 1093428b0f..3a7276a534 100644 --- a/docs/website/docs/general-usage/http/rest-client.md +++ b/docs/website/docs/general-usage/http/rest-client.md @@ -406,8 +406,11 @@ The available authentication methods are defined in the `dlt.sources.helpers.res - [BearerTokenAuth](#bearer-token-authentication) - [APIKeyAuth](#api-key-authentication) - [HttpBasicAuth](#http-basic-authentication) +- [OAuth2ClientCredentials](#oauth20-authorization) For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library. +For specific flavors of OAuth 2.0 you can [implement custom OAuth 2.0](#oauth2-authorization) +by subclassing `OAuth2ClientCredentials`. ### Bearer token authentication @@ -477,6 +480,57 @@ client = RESTClient(base_url="https://api.example.com", auth=auth) response = client.get("/protected/resource") ``` +### OAuth 2.0 authorization + +OAuth 2.0 is a common protocol for authorization. We have implemented two-legged authorization employed for server-to-server authorization because the end user (resource owner) does not need to grant approval. +The REST client acts as the OAuth client which obtains a temporary access token from the authorization server. This access token is then sent to the resource server to access protected content. If the access token is expired, the OAuth client automatically refreshes it. + +Unfortunately, most OAuth 2.0 implementations vary and thus you might need to subclass `OAuth2ClientCredentials` and implement `build_access_token_request()` to suite the requirements of the specific authorization server you want to interact with. + +**Parameters:** +- `access_token_url`: The url to obtain the temporary access token. +- `client_id`: Client credential to obtain authorization. Usually issued via a developer portal. +- `client_secret`: Client credential to obtain authorization. Usually issued via a developer portal. +- `access_token_request_data`: A dictionary with data required by the autorization server apart from the `client_id`, `client_secret`, and `"grant_type": "client_credentials"`. Defaults to `None`. +- `default_token_expiration`: The time in seconds after which the temporary access token expires. Defaults to 3600. + +**Example:** + +```py +from base64 import b64encode +from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client.auth import OAuth2ClientCredentials + +class OAuth2ClientCredentialsHTTPBasic(OAuth2ClientCredentials): + """Used e.g. by Zoom Zoom Video Communications, Inc.""" + def build_access_token_request(self) -> Dict[str, Any]: + authentication: str = b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + return { + "headers": { + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": self.access_token_request_data, + } + +auth = OAuth2ClientCredentialsHTTPBasic( + access_token_url=dlt.secrets["sources.zoom.access_token_url"], # "https://zoom.us/oauth/token" + client_id=dlt.secrets["sources.zoom.client_id"], + client_secret=dlt.secrets["sources.zoom.client_secret"], + access_token_request_data={ + "grant_type": "account_credentials", + "account_id": dlt.secrets["sources.zoom.account_id"], + }, +) +client = RESTClient(base_url="https://api.zoom.us/v2", auth=auth) + +response = client.get("/users") +``` + + + ### Implementing custom authentication You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method: diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index 7453c63d14..08233bc3a8 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -1,8 +1,8 @@ import re -from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING, Dict, List, Any +from typing import NamedTuple, Callable, Pattern, Union, TYPE_CHECKING, Dict, List, Any import base64 -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import parse_qs, urlsplit, urlunsplit import pytest import requests_mock @@ -207,7 +207,17 @@ def protected_api_key(request, context): @router.post("/oauth/token") def oauth_token(request, context): - return {"access_token": "test-token", "expires_in": 3600} + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.post("/oauth/token-expires-now") + def oauth_token_expires_now(request, context): + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 0} + context.status_code = 401 + return {"error": "Unauthorized"} @router.post("/auth/refresh") def refresh_token(request, context): @@ -217,11 +227,36 @@ def refresh_token(request, context): context.status_code = 401 return {"error": "Invalid refresh token"} + @router.post("/custom-oauth/token") + def custom_oauth_token(request, context): + qs = parse_qs(request.text) + if ( + qs.get("grant_type")[0] == "account_credentials" + and qs.get("account_id")[0] == "test-account-id" + and request.headers["Authorization"] + == "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==" + ): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + router.register_routes(m) yield m +def oauth_authorize(request): + qs = parse_qs(request.text) + grant_type = qs.get("grant_type")[0] + if "jwt-bearer" in grant_type: + return True + if "client_credentials" in grant_type: + return ( + qs["client_secret"][0] == "test-client-secret" + and qs["client_id"][0] == "test-client-id" + ) + + def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10): assert len(pages) == total_pages for i, page in enumerate(pages): diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index bd65affe62..7196ef3436 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -1,23 +1,28 @@ import os +from base64 import b64encode +from typing import Any, Dict, cast +from unittest.mock import patch + import pytest -from typing import Any, cast -from dlt.common import logger from requests import PreparedRequest, Request, Response from requests.auth import AuthBase +from requests.exceptions import HTTPError + +from dlt.common import logger from dlt.common.typing import TSecretStrValue from dlt.sources.helpers.requests import Client from dlt.sources.helpers.rest_client import RESTClient -from dlt.sources.helpers.rest_client.client import Hooks -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator - -from dlt.sources.helpers.rest_client.auth import AuthConfigBase from dlt.sources.helpers.rest_client.auth import ( - BearerTokenAuth, APIKeyAuth, + AuthConfigBase, + BearerTokenAuth, HttpBasicAuth, + OAuth2ClientCredentials, OAuthJWTAuth, ) +from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator from .conftest import assert_pagination @@ -31,13 +36,40 @@ def load_private_key(name="private_key.pem"): TEST_PRIVATE_KEY = load_private_key() -@pytest.fixture -def rest_client() -> RESTClient: +def build_rest_client(auth=None) -> RESTClient: return RESTClient( base_url="https://api.example.com", headers={"Accept": "application/json"}, session=Client().session, + auth=auth, + ) + + +@pytest.fixture +def rest_client() -> RESTClient: + return build_rest_client() + + +@pytest.fixture +def rest_client_oauth() -> RESTClient: + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, ) + return build_rest_client(auth=auth) + + +@pytest.fixture +def rest_client_immediate_oauth_expiry(auth=None) -> RESTClient: + credentials_expiring_now = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token-expires-now"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, + ) + return build_rest_client(auth=credentials_expiring_now) @pytest.mark.usefixtures("mock_api_server") @@ -163,6 +195,114 @@ def test_api_key_auth_success(self, rest_client: RESTClient): assert response.status_code == 200 assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + def test_oauth2_client_credentials_flow_auth_success(self, rest_client_oauth: RESTClient): + response = rest_client_oauth.get("/protected/posts/bearer-token") + + assert response.status_code == 200 + assert "test-token" in response.request.headers["Authorization"] + + pages_iter = rest_client_oauth.paginate("/protected/posts/bearer-token") + + assert_pagination(list(pages_iter)) + + def test_oauth2_client_credentials_flow_wrong_client_id(self, rest_client: RESTClient): + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "invalid-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, + ) + + with pytest.raises(HTTPError) as e: + rest_client.get("/protected/posts/bearer-token", auth=auth) + assert e.type == HTTPError + assert e.match("401 Client Error") + + def test_oauth2_client_credentials_flow_wrong_client_secret(self, rest_client: RESTClient): + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "invalid-client-secret"), + session=Client().session, + ) + + with pytest.raises(HTTPError) as e: + rest_client.get( + "/protected/posts/bearer-token", + auth=auth, + ) + assert e.type == HTTPError + assert e.match("401 Client Error") + + + def test_oauth_token_expired_refresh(self, rest_client_immediate_oauth_expiry: RESTClient): + rest_client = rest_client_immediate_oauth_expiry + auth = cast(OAuth2ClientCredentials, rest_client.auth) + + with patch.object(auth, "obtain_token", wraps=auth.obtain_token) as mock_obtain_token: + assert auth.access_token is None + response = rest_client.get("/protected/posts/bearer-token") + mock_obtain_token.assert_called_once() + assert response.status_code == 200 + assert auth.access_token is not None + expiry_0 = auth.token_expiry + auth.token_expiry = auth.token_expiry.subtract(seconds=1) + expiry_1 = auth.token_expiry + assert expiry_0 > expiry_1 + assert auth.is_token_expired() + + response = rest_client.get("/protected/posts/bearer-token") + assert mock_obtain_token.call_count == 2 + assert response.status_code == 200 + expiry_2 = auth.token_expiry + assert expiry_2 > expiry_1 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + def test_oauth_customized_token_request(self, rest_client: RESTClient): + class OAuth2ClientCredentialsHTTPBasic(OAuth2ClientCredentials): + """OAuth 2.0 as required by e.g. Zoom Video Communications, Inc.""" + + def build_access_token_request(self) -> Dict[str, Any]: + authentication: str = b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + return { + "headers": { + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "grant_type": "account_credentials", + **self.access_token_request_data, + }, + } + + auth = OAuth2ClientCredentialsHTTPBasic( + access_token_url=cast(TSecretStrValue, "https://api.example.com/custom-oauth/token"), + client_id=cast(TSecretStrValue, "test-account-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + access_token_request_data={ + "account_id": cast(TSecretStrValue, "test-account-id"), + }, + session=Client().session, + ) + + assert auth.build_access_token_request() == { + "headers": { + "Authorization": "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "grant_type": "account_credentials", + "account_id": "test-account-id", + }, + } + + rest_client.auth = auth + pages_iter = rest_client.paginate("/protected/posts/bearer-token") + + assert_pagination(list(pages_iter)) + def test_oauth_jwt_auth_success(self, rest_client: RESTClient): auth = OAuthJWTAuth( client_id="test-client-id",