Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1356 implements OAuth2 Client Credentials Flow #1357

Merged
merged 31 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
111149f
implements OAuth2 implicit flow for server-to-server authentication w…
willi-mueller May 14, 2024
cfd9c05
generalizes OAuth2 implementation. Provides example for Zoom Video co…
willi-mueller May 21, 2024
02acc95
make obtain_token a template method for greater compatibility. Remove…
willi-mueller May 24, 2024
4204764
authentication -> authorization
willi-mueller May 24, 2024
6bdbc26
implements response parsing that works for many cases
willi-mueller May 24, 2024
823f4fa
fixes types, uses parse_native_representation, removes access_token f…
willi-mueller May 25, 2024
755a680
removes access_token_url from constructor to build_access_token_request
willi-mueller May 25, 2024
eb636e4
documentation draft
willi-mueller May 25, 2024
f509877
Merge branch 'devel' into 1356-oauth2-implicit
burnash Jun 12, 2024
5bde898
Add a test case for a successful auth for client credentials flow
burnash Jun 13, 2024
f3e0ea1
Fix mypy errors
burnash Jun 13, 2024
7684d11
Oauth tests: checks for valid client_id & client_secret, tests invali…
willi-mueller Jun 13, 2024
f425d44
formats code
willi-mueller Jun 13, 2024
38f3d19
code formatting
willi-mueller Jun 14, 2024
82c9416
OAuth: tests token refresh
willi-mueller Jun 14, 2024
235c7e5
removes redundant import
willi-mueller Jun 14, 2024
ac80b4b
formatting and DRYing test code
willi-mueller Jun 14, 2024
73a0422
formatting and improving clarity
willi-mueller Jun 14, 2024
fda0b4d
Corrects OAuth2 name from "Implicit Flow" to "Client Credentials flow"
willi-mueller Jun 14, 2024
df1dd7b
fixes mypy error
willi-mueller Jun 17, 2024
ec4f632
removes non-existing alias from the docs
willi-mueller Jun 17, 2024
a410b35
OAuth2ClientCredentialsFlow works without subclassing it.
willi-mueller Jun 18, 2024
879bccb
renames OAuth classes, formats code
willi-mueller Jun 18, 2024
a14fb59
fixes linter (B006)
willi-mueller Jun 18, 2024
760db76
updates docs
willi-mueller Jun 18, 2024
352b5d7
adds the possibility for custom access_token_request_data without sub…
willi-mueller Jun 18, 2024
edfff5b
adds session to OAuth constructor so that we can mock the session to …
willi-mueller Jun 20, 2024
41fff55
Use `self.session` for making a request in obtain_token
burnash Jun 20, 2024
8fdd7f4
Update dlt/sources/helpers/rest_client/auth.py
burnash Jun 21, 2024
8ab73cc
Update dlt/sources/helpers/rest_client/auth.py
burnash Jun 21, 2024
8da3672
spies on obtaining OAuth token to ensure it got called
willi-mueller Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from base64 import b64encode
import dataclasses
import math
import dataclasses
from abc import abstractmethod
from base64 import b64encode
from typing import (
List,
TYPE_CHECKING,
Any,
Dict,
Final,
Iterable,
List,
Literal,
Optional,
Union,
Any,
cast,
Iterable,
TYPE_CHECKING,
)
from typing_extensions import Annotated
from requests.auth import AuthBase
Expand All @@ -21,10 +22,11 @@
from dlt.common.exceptions import MissingDependencyException
from dlt.common.configuration.specs.base_configuration import configspec, NotResolved
from dlt.common.configuration.specs import CredentialsConfiguration
from dlt.common.configuration.specs.base_configuration import configspec
from dlt.common.configuration.specs.exceptions import NativeValueError
burnash marked this conversation as resolved.
Show resolved Hide resolved
from dlt.common.exceptions import MissingDependencyException
from dlt.common.pendulum import pendulum
burnash marked this conversation as resolved.
Show resolved Hide resolved
from dlt.common.typing import TSecretStrValue

from dlt.sources.helpers import requests

if TYPE_CHECKING:
Expand Down Expand Up @@ -144,6 +146,58 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
return request


@configspec
class OAuth2ImplicitFlow(OAuth2AuthBase):
"""
This class implements OAuth2 implicit flow which does not require the end user to
give permission to the app to access their information. This is often used for
server-to-server authorization.
The client obtains a temporary access token from the authorization service.
With the access token, the client can access resource services
"""

def __init__(
self,
access_token_request_data: Dict[str, Any],
client_id: TSecretStrValue,
client_secret: TSecretStrValue,
default_token_expiration: int = 3600,
) -> None:
super().__init__()
self.access_token_request_data = access_token_request_data
self.client_id = client_id
self.client_secret = client_secret
self.default_token_expiration = default_token_expiration
self.token_expiry: pendulum.DateTime = pendulum.now()

def __call__(self, request: PreparedRequest) -> PreparedRequest:
if self.access_token is None or self.is_token_expired():
self.obtain_token()
request.headers["Authorization"] = f"Bearer {self.access_token}"
return request

def is_token_expired(self) -> bool:
return pendulum.now() >= self.token_expiry

def obtain_token(self) -> None:
response = requests.post(**self.build_access_token_request())
response.raise_for_status()
response_json = response.json()
self.parse_native_representation(self.parse_access_token(response_json))
expires_in_seconds = self.parse_expiration_in_seconds(response_json)
self.token_expiry = pendulum.now().add(seconds=expires_in_seconds)

@abstractmethod
def build_access_token_request(self) -> Dict[str, Any]:
pass

def parse_expiration_in_seconds(self, response_json: Any) -> int:
return int(response_json.get("expires_in", self.default_token_expiration))

def parse_access_token(self, response_json: Any) -> str:
return str(response_json.get("access_token"))


@configspec
class OAuthJWTAuth(BearerTokenAuth):
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ Available authentication types:
| [BearTokenAuth](../../general-usage/http/rest-client.md#bearer-token-authentication) | `bearer` | Bearer token authentication. |
| [HTTPBasicAuth](../../general-usage/http/rest-client.md#http-basic-authentication) | `http_basic` | Basic HTTP authentication. |
| [APIKeyAuth](../../general-usage/http/rest-client.md#api-key-authentication) | `api_key` | API key authentication with key defined in the query parameters or in the headers. |
| [OAuth2ImplicitFlow](../../general-usage/http/rest-client.md#oauth20-authorization) | `oauth2` | OAuth 2.0 authorization with a temporary access token obtained from the authorization server. |

To specify the authentication configuration, use the `auth` field in the [client](#client) configuration:

Expand Down
51 changes: 51 additions & 0 deletions docs/website/docs/general-usage/http/rest-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,11 @@ The available authentication methods are defined in the `dlt.sources.helpers.res
- [BearerTokenAuth](#bearer-token-authentication)
- [APIKeyAuth](#api-key-authentication)
- [HttpBasicAuth](#http-basic-authentication)
- [OAuth2ImplicitFlow](#oauth20-authorization)

For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library.
For specific flavors of OAuth 2.0 you can [implement custom OAuth 2.0](#oauth2-authorization)
by subclassing `OAuth2ImplicitFlow`.

### Bearer token authentication

Expand Down Expand Up @@ -477,6 +480,54 @@ client = RESTClient(base_url="https://api.example.com", auth=auth)
response = client.get("/protected/resource")
```

### OAuth 2.0 authorization

OAuth 2.0 is a common protocol for authorization. We have implemented two-legged authorization employed for server-to-server authorization because the end user (resource owner) does not need to grant approval.
The REST client acts as the OAuth client which obtains a temporary access token from the authorization server. This access token is then sent to the resource server to access protected content.

Unfortunately, most OAuth 2.0 implementations vary and thus you need to subclass `OAuth2ImplicitFlow` and implement `obtain_token()` to suite the requirements of the specific authorization server you want to interact with.

**Parameters:**

- `access_token_request_data`: A dictionary with data required by the autorization server. Includes typically a key `grant_type`.
- `client_id`: Client credential to obtain authorization. Usually issued via a developer portal.
- `client_secret`: Client credential to obtain authorization. Usually issued via a developer portal.
- `default_token_expiration`: The time in seconds after which the temporary access token expires. Defaults to 3600.

**Example:**

```py
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.auth import OAuth2ImplicitFlow

class OAuth2Zoom(OAuth2ImplicitFlow):
def build_access_token_request(self) -> Dict[str, Any]:
authentication: str = b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode()
return {
"url": "https://zoom.us/oauth/token",
"headers": {
"Authorization": f"Basic {authentication}",
"Content-Type": "application/x-www-form-urlencoded",
},
"data": self.access_token_request_data,
}


auth = OAuth2Zoom(
access_token_request_data={
"grant_type": "account_credentials",
"account_id": dlt.secrets["sources.zoom.account_id"],
},
client_id=dlt.secrets["sources.zoom.client_id"],
client_secret=dlt.secrets["sources.zoom.client_secret"],
)
client = RESTClient(base_url="https://api.zoom.us/v2", auth=auth)

response = client.get("/users")
```



### Implementing custom authentication

You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method:
Expand Down
28 changes: 25 additions & 3 deletions tests/sources/helpers/rest_client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import re
from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING, Dict, List, Any
from typing import NamedTuple, Callable, Pattern, Union, TYPE_CHECKING, Dict, List, Any
import base64

from urllib.parse import urlsplit, urlunsplit
from urllib.parse import parse_qs, urlsplit, urlunsplit

import pytest
import requests_mock
Expand Down Expand Up @@ -207,7 +207,17 @@ def protected_api_key(request, context):

@router.post("/oauth/token")
def oauth_token(request, context):
return {"access_token": "test-token", "expires_in": 3600}
if oauth_authorize(request):
return {"access_token": "test-token", "expires_in": 3600}
context.status_code = 401
return {"error": "Unauthorized"}

@router.post("/oauth/token-expires-now")
def oauth_token_expires_now(request, context):
if oauth_authorize(request):
return {"access_token": "test-token", "expires_in": 0}
context.status_code = 401
return {"error": "Unauthorized"}

@router.post("/auth/refresh")
def refresh_token(request, context):
Expand All @@ -222,6 +232,18 @@ def refresh_token(request, context):
yield m


def oauth_authorize(request):
qs = parse_qs(request.text)
grant_type = qs.get("grant_type")[0]
if "jwt-bearer" in grant_type:
return True
if "client_credentials" in grant_type:
return (
qs["client_secret"][0] == "test-client-secret"
and qs["client_id"][0] == "test-client-id"
)


def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10):
assert len(pages) == total_pages
for i, page in enumerate(pages):
Expand Down
141 changes: 132 additions & 9 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import os
import pytest
from typing import Any, cast
from dlt.common import logger

import pytest
from requests import PreparedRequest, Request, Response
from requests.auth import AuthBase
from requests.exceptions import HTTPError

from dlt.common import logger
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Client
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator

from dlt.sources.helpers.rest_client.auth import AuthConfigBase
from dlt.sources.helpers.rest_client.auth import (
BearerTokenAuth,
APIKeyAuth,
AuthConfigBase,
BearerTokenAuth,
HttpBasicAuth,
OAuth2ImplicitFlow,
OAuthJWTAuth,
)
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator

from .conftest import assert_pagination

Expand All @@ -31,14 +34,73 @@ def load_private_key(name="private_key.pem"):
TEST_PRIVATE_KEY = load_private_key()


@pytest.fixture
def rest_client() -> RESTClient:
def build_rest_client(auth=None) -> RESTClient:
return RESTClient(
base_url="https://api.example.com",
headers={"Accept": "application/json"},
session=Client().session,
auth=auth,
)


@pytest.fixture
def rest_client() -> RESTClient:
return build_rest_client()


@pytest.fixture
def rest_client_oauth() -> RESTClient:
auth = OAuth2ClientCredentialsExample(
access_token_request_data={
"grant_type": "client_credentials",
},
client_id=cast(TSecretStrValue, "test-client-id"),
client_secret=cast(TSecretStrValue, "test-client-secret"),
)
return build_rest_client(auth=auth)


class OAuth2ClientCredentialsExample(OAuth2ImplicitFlow):
def build_access_token_request(self):
return {
"url": "https://api.example.com/oauth/token",
"headers": {
"Content-Type": "application/x-www-form-urlencoded",
},
"data": {
**self.access_token_request_data,
"client_id": self.client_id,
"client_secret": self.client_secret,
},
}


@pytest.fixture
def rest_client_immediate_oauth_expiry(auth=None) -> RESTClient:
class OAuth2ClientCredentialsExpiringNow(OAuth2ImplicitFlow):
def build_access_token_request(self):
return {
"url": "https://api.example.com/oauth/token-expires-now",
"headers": {
"Content-Type": "application/x-www-form-urlencoded",
},
"data": {
**self.access_token_request_data,
"client_id": self.client_id,
"client_secret": self.client_secret,
},
}

auth = OAuth2ClientCredentialsExpiringNow(
access_token_request_data={
"grant_type": "client_credentials",
},
client_id=cast(TSecretStrValue, "test-client-id"),
client_secret=cast(TSecretStrValue, "test-client-secret"),
)

return build_rest_client(auth=auth)


@pytest.mark.usefixtures("mock_api_server")
class TestRESTClient:
Expand Down Expand Up @@ -163,6 +225,67 @@ def test_api_key_auth_success(self, rest_client: RESTClient):
assert response.status_code == 200
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}

def test_oauth2_client_credentials_flow_auth_success(self, rest_client_oauth: RESTClient):
response = rest_client_oauth.get("/protected/posts/bearer-token")

assert response.status_code == 200
assert "test-token" in response.request.headers["Authorization"]

pages_iter = rest_client_oauth.paginate("/protected/posts/bearer-token")

assert_pagination(list(pages_iter))

def test_oauth2_client_credentials_flow_wrong_client_id(self, rest_client: RESTClient):
auth = OAuth2ClientCredentialsExample(
access_token_request_data={
"grant_type": "client_credentials",
},
client_id=cast(TSecretStrValue, "invalid-client-id"),
client_secret=cast(TSecretStrValue, "test-client-secret"),
)

with pytest.raises(HTTPError) as e:
rest_client.get(
"/protected/posts/bearer-token",
auth=auth,
)
assert e.type == HTTPError
assert e.match("401 Client Error")

def test_oauth2_client_credentials_flow_wrong_client_secret(self, rest_client: RESTClient):
auth = OAuth2ClientCredentialsExample(
access_token_request_data={
"grant_type": "client_credentials",
},
client_id=cast(TSecretStrValue, "test-client-id"),
client_secret=cast(TSecretStrValue, "invalid-client-secret"),
)

with pytest.raises(HTTPError) as e:
rest_client.get(
"/protected/posts/bearer-token",
auth=auth,
)
assert e.type == HTTPError
assert e.match("401 Client Error")

def test_oauth_token_expired_refresh(self, rest_client_immediate_oauth_expiry: RESTClient):
rest_client = rest_client_immediate_oauth_expiry
assert rest_client.auth.access_token is None
response = rest_client.get("/protected/posts/bearer-token")
assert response.status_code == 200
assert rest_client.auth.access_token is not None
expiry_0 = rest_client.auth.token_expiry
rest_client.auth.token_expiry = rest_client.auth.token_expiry.subtract(seconds=1)
willi-mueller marked this conversation as resolved.
Show resolved Hide resolved
expiry_1 = rest_client.auth.token_expiry
assert expiry_1 < expiry_0
assert rest_client.auth.is_token_expired()

response = rest_client.get("/protected/posts/bearer-token")
assert response.status_code == 200
expiry_2 = rest_client.auth.token_expiry
assert expiry_2 > expiry_1

def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
auth = OAuthJWTAuth(
client_id="test-client-id",
Expand Down
Loading