Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1356 implements OAuth2 Client Credentials Flow #1357

Merged
merged 31 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
111149f
implements OAuth2 implicit flow for server-to-server authentication w…
willi-mueller May 14, 2024
cfd9c05
generalizes OAuth2 implementation. Provides example for Zoom Video co…
willi-mueller May 21, 2024
02acc95
make obtain_token a template method for greater compatibility. Remove…
willi-mueller May 24, 2024
4204764
authentication -> authorization
willi-mueller May 24, 2024
6bdbc26
implements response parsing that works for many cases
willi-mueller May 24, 2024
823f4fa
fixes types, uses parse_native_representation, removes access_token f…
willi-mueller May 25, 2024
755a680
removes access_token_url from constructor to build_access_token_request
willi-mueller May 25, 2024
eb636e4
documentation draft
willi-mueller May 25, 2024
f509877
Merge branch 'devel' into 1356-oauth2-implicit
burnash Jun 12, 2024
5bde898
Add a test case for a successful auth for client credentials flow
burnash Jun 13, 2024
f3e0ea1
Fix mypy errors
burnash Jun 13, 2024
7684d11
Oauth tests: checks for valid client_id & client_secret, tests invali…
willi-mueller Jun 13, 2024
f425d44
formats code
willi-mueller Jun 13, 2024
38f3d19
code formatting
willi-mueller Jun 14, 2024
82c9416
OAuth: tests token refresh
willi-mueller Jun 14, 2024
235c7e5
removes redundant import
willi-mueller Jun 14, 2024
ac80b4b
formatting and DRYing test code
willi-mueller Jun 14, 2024
73a0422
formatting and improving clarity
willi-mueller Jun 14, 2024
fda0b4d
Corrects OAuth2 name from "Implicit Flow" to "Client Credentials flow"
willi-mueller Jun 14, 2024
df1dd7b
fixes mypy error
willi-mueller Jun 17, 2024
ec4f632
removes non-existing alias from the docs
willi-mueller Jun 17, 2024
a410b35
OAuth2ClientCredentialsFlow works without subclassing it.
willi-mueller Jun 18, 2024
879bccb
renames OAuth classes, formats code
willi-mueller Jun 18, 2024
a14fb59
fixes linter (B006)
willi-mueller Jun 18, 2024
760db76
updates docs
willi-mueller Jun 18, 2024
352b5d7
adds the possibility for custom access_token_request_data without sub…
willi-mueller Jun 18, 2024
edfff5b
adds session to OAuth constructor so that we can mock the session to …
willi-mueller Jun 20, 2024
41fff55
Use `self.session` for making a request in obtain_token
burnash Jun 20, 2024
8fdd7f4
Update dlt/sources/helpers/rest_client/auth.py
burnash Jun 21, 2024
8ab73cc
Update dlt/sources/helpers/rest_client/auth.py
burnash Jun 21, 2024
8da3672
spies on obtaining OAuth token to ensure it got called
willi-mueller Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from base64 import b64encode
import dataclasses
import math
import dataclasses
from abc import abstractmethod
from base64 import b64encode
from typing import (
List,
TYPE_CHECKING,
Any,
Dict,
Final,
Iterable,
List,
Literal,
Optional,
Union,
Any,
cast,
Iterable,
TYPE_CHECKING,
)
from typing_extensions import Annotated
from requests.auth import AuthBase
Expand All @@ -24,7 +25,6 @@
from dlt.common.configuration.specs.exceptions import NativeValueError
from dlt.common.pendulum import pendulum
from dlt.common.typing import TSecretStrValue

from dlt.sources.helpers import requests

if TYPE_CHECKING:
Expand Down Expand Up @@ -144,6 +144,76 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
return request


@configspec
class OAuth2ClientCredentials(OAuth2AuthBase):
"""
This class implements OAuth2 Client Credentials flow where the autorization service
gives permission without the end user approving.
This is often used for machine-to-machine authorization.
The client sends its client ID and client secret to the authorization service which replies
with a temporary access token.
With the access token, the client can access resource services.
"""

def __init__(
self,
access_token_url: TSecretStrValue,
client_id: TSecretStrValue,
client_secret: TSecretStrValue,
access_token_request_data: Dict[str, Any] = None,
default_token_expiration: int = 3600,
session: Annotated[BaseSession, NotResolved()] = None,
) -> None:
super().__init__()
self.access_token_url = access_token_url
self.client_id = client_id
self.client_secret = client_secret
if access_token_request_data is None:
self.access_token_request_data = {}
else:
self.access_token_request_data = access_token_request_data
self.default_token_expiration = default_token_expiration
self.token_expiry: pendulum.DateTime = pendulum.now()

self.session = session if session is not None else requests.client.session

def __call__(self, request: PreparedRequest) -> PreparedRequest:
if self.access_token is None or self.is_token_expired():
self.obtain_token()
request.headers["Authorization"] = f"Bearer {self.access_token}"
return request

def is_token_expired(self) -> bool:
return pendulum.now() >= self.token_expiry

def obtain_token(self) -> None:
response = self.session.post(self.access_token_url, **self.build_access_token_request())
response.raise_for_status()
response_json = response.json()
self.parse_native_representation(self.parse_access_token(response_json))
expires_in_seconds = self.parse_expiration_in_seconds(response_json)
self.token_expiry = pendulum.now().add(seconds=expires_in_seconds)

def build_access_token_request(self) -> Dict[str, Any]:
return {
"headers": {
"Content-Type": "application/x-www-form-urlencoded",
},
"data": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "client_credentials",
**self.access_token_request_data,
},
}

def parse_expiration_in_seconds(self, response_json: Any) -> int:
return int(response_json.get("expires_in", self.default_token_expiration))

def parse_access_token(self, response_json: Any) -> str:
return str(response_json.get("access_token"))


@configspec
class OAuthJWTAuth(BearerTokenAuth):
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""
Expand All @@ -164,7 +234,7 @@ def __post_init__(self) -> None:
self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None
# use default system session is not specified
# use default system session unless specified otherwise
if self.session is None:
self.session = requests.client.session

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ Available authentication types:
| [BearTokenAuth](../../general-usage/http/rest-client.md#bearer-token-authentication) | `bearer` | Bearer token authentication. |
| [HTTPBasicAuth](../../general-usage/http/rest-client.md#http-basic-authentication) | `http_basic` | Basic HTTP authentication. |
| [APIKeyAuth](../../general-usage/http/rest-client.md#api-key-authentication) | `api_key` | API key authentication with key defined in the query parameters or in the headers. |
| [OAuth2ClientCredentials](../../general-usage/http/rest-client.md#oauth20-authorization) | N/A | OAuth 2.0 authorization with a temporary access token obtained from the authorization server. |

To specify the authentication configuration, use the `auth` field in the [client](#client) configuration:

Expand Down
54 changes: 54 additions & 0 deletions docs/website/docs/general-usage/http/rest-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,11 @@ The available authentication methods are defined in the `dlt.sources.helpers.res
- [BearerTokenAuth](#bearer-token-authentication)
- [APIKeyAuth](#api-key-authentication)
- [HttpBasicAuth](#http-basic-authentication)
- [OAuth2ClientCredentials](#oauth20-authorization)

For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library.
For specific flavors of OAuth 2.0 you can [implement custom OAuth 2.0](#oauth2-authorization)
by subclassing `OAuth2ClientCredentials`.

### Bearer token authentication

Expand Down Expand Up @@ -477,6 +480,57 @@ client = RESTClient(base_url="https://api.example.com", auth=auth)
response = client.get("/protected/resource")
```

### OAuth 2.0 authorization

OAuth 2.0 is a common protocol for authorization. We have implemented two-legged authorization employed for server-to-server authorization because the end user (resource owner) does not need to grant approval.
The REST client acts as the OAuth client which obtains a temporary access token from the authorization server. This access token is then sent to the resource server to access protected content. If the access token is expired, the OAuth client automatically refreshes it.

Unfortunately, most OAuth 2.0 implementations vary and thus you might need to subclass `OAuth2ClientCredentials` and implement `build_access_token_request()` to suite the requirements of the specific authorization server you want to interact with.

**Parameters:**
- `access_token_url`: The url to obtain the temporary access token.
- `client_id`: Client credential to obtain authorization. Usually issued via a developer portal.
- `client_secret`: Client credential to obtain authorization. Usually issued via a developer portal.
- `access_token_request_data`: A dictionary with data required by the autorization server apart from the `client_id`, `client_secret`, and `"grant_type": "client_credentials"`. Defaults to `None`.
- `default_token_expiration`: The time in seconds after which the temporary access token expires. Defaults to 3600.

**Example:**

```py
from base64 import b64encode
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.auth import OAuth2ClientCredentials

class OAuth2ClientCredentialsHTTPBasic(OAuth2ClientCredentials):
"""Used e.g. by Zoom Zoom Video Communications, Inc."""
def build_access_token_request(self) -> Dict[str, Any]:
authentication: str = b64encode(
f"{self.client_id}:{self.client_secret}".encode()
).decode()
return {
"headers": {
"Authorization": f"Basic {authentication}",
"Content-Type": "application/x-www-form-urlencoded",
},
"data": self.access_token_request_data,
}

auth = OAuth2ClientCredentialsHTTPBasic(
access_token_url=dlt.secrets["sources.zoom.access_token_url"], # "https://zoom.us/oauth/token"
client_id=dlt.secrets["sources.zoom.client_id"],
client_secret=dlt.secrets["sources.zoom.client_secret"],
access_token_request_data={
"grant_type": "account_credentials",
"account_id": dlt.secrets["sources.zoom.account_id"],
},
)
client = RESTClient(base_url="https://api.zoom.us/v2", auth=auth)

response = client.get("/users")
```



### Implementing custom authentication

You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method:
Expand Down
41 changes: 38 additions & 3 deletions tests/sources/helpers/rest_client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import re
from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING, Dict, List, Any
from typing import NamedTuple, Callable, Pattern, Union, TYPE_CHECKING, Dict, List, Any
import base64

from urllib.parse import urlsplit, urlunsplit
from urllib.parse import parse_qs, urlsplit, urlunsplit

import pytest
import requests_mock
Expand Down Expand Up @@ -207,7 +207,17 @@ def protected_api_key(request, context):

@router.post("/oauth/token")
def oauth_token(request, context):
return {"access_token": "test-token", "expires_in": 3600}
if oauth_authorize(request):
return {"access_token": "test-token", "expires_in": 3600}
context.status_code = 401
return {"error": "Unauthorized"}

@router.post("/oauth/token-expires-now")
def oauth_token_expires_now(request, context):
if oauth_authorize(request):
return {"access_token": "test-token", "expires_in": 0}
context.status_code = 401
return {"error": "Unauthorized"}

@router.post("/auth/refresh")
def refresh_token(request, context):
Expand All @@ -217,11 +227,36 @@ def refresh_token(request, context):
context.status_code = 401
return {"error": "Invalid refresh token"}

@router.post("/custom-oauth/token")
def custom_oauth_token(request, context):
qs = parse_qs(request.text)
if (
qs.get("grant_type")[0] == "account_credentials"
and qs.get("account_id")[0] == "test-account-id"
and request.headers["Authorization"]
== "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA=="
):
return {"access_token": "test-token", "expires_in": 3600}
context.status_code = 401
return {"error": "Unauthorized"}

router.register_routes(m)

yield m


def oauth_authorize(request):
qs = parse_qs(request.text)
grant_type = qs.get("grant_type")[0]
if "jwt-bearer" in grant_type:
return True
if "client_credentials" in grant_type:
return (
qs["client_secret"][0] == "test-client-secret"
and qs["client_id"][0] == "test-client-id"
)


def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10):
assert len(pages) == total_pages
for i, page in enumerate(pages):
Expand Down
Loading
Loading