Skip to content

Commit

Permalink
adds type check at runtime to prevent wrong usage and clarify intenti…
Browse files Browse the repository at this point in the history
…on. Removes unnecessary dataclass and returns to simple dictionary.
  • Loading branch information
willi-mueller committed Jul 18, 2024
1 parent 3dd41cf commit 096946f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 38 deletions.
56 changes: 21 additions & 35 deletions sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
Callable,
cast,
NamedTuple,
KeysView,
)
import graphlib # type: ignore[import,unused-ignore]
import string
from dataclasses import dataclass, field

import dlt
from dlt.common import logger
Expand Down Expand Up @@ -61,6 +59,16 @@
from .utils import exclude_keys


PAGINATOR_MAP: Dict[str, Type[BasePaginator]] = {
"json_response": JSONResponsePaginator,
"header_link": HeaderLinkPaginator,
"auto": None,
"single_page": SinglePagePaginator,
"cursor": JSONResponseCursorPaginator,
"offset": OffsetPaginator,
"page_number": PageNumberPaginator,
}

AUTH_MAP: Dict[AuthType, Type[AuthConfigBase]] = {
"bearer": BearerTokenAuth,
"api_key": APIKeyAuth,
Expand All @@ -73,45 +81,23 @@ class IncrementalParam(NamedTuple):
end: Optional[str]


@dataclass
class PaginatorMap:
_map: Dict[str, Type[BasePaginator]] = field(default_factory=dict)

def __post_init__(self) -> None:
self._map = {
"json_response": JSONResponsePaginator,
"header_link": HeaderLinkPaginator,
"auto": None,
"single_page": SinglePagePaginator,
"cursor": JSONResponseCursorPaginator,
"offset": OffsetPaginator,
"page_number": PageNumberPaginator,
}

def __getitem__(self, key: str) -> Type[BasePaginator]:
return self._map[key]

def __setitem__(self, key: str, value: Type[BasePaginator]) -> None:
self._map[key] = value

def keys(self) -> KeysView[str]:
return self._map.keys()


paginator_map = PaginatorMap()


def register_paginator(
paginator_name: str, paginator_class: Type[BasePaginator]
paginator_name: str,
paginator_class: Type[BasePaginator],
) -> None:
paginator_map[paginator_name] = paginator_class
if not issubclass(paginator_class, BasePaginator):
raise ValueError(
f"Invalid paginator: {paginator_class.__name__}. "
"Your custom paginator has to be a subclass of BasePaginator"
)
PAGINATOR_MAP[paginator_name] = paginator_class


def get_paginator_class(paginator_type: PaginatorType) -> Type[BasePaginator]:
try:
return paginator_map[paginator_type]
return PAGINATOR_MAP[paginator_type]
except KeyError:
available_options = ", ".join(paginator_map.keys())
available_options = ", ".join(PAGINATOR_MAP.keys())
raise ValueError(
f"Invalid paginator: {paginator_type}. "
f"Available options: {available_options}"
Expand All @@ -127,7 +113,7 @@ def create_paginator(
if isinstance(paginator_config, str):
paginator_class = get_paginator_class(paginator_config)
try:
# `auto` has no associated class in `paginator_map`
# `auto` has no associated class in `PAGINATOR_MAP`
return paginator_class() if paginator_class else None
except TypeError:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions tests/rest_api/test_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sources.rest_api.config_setup import (
AUTH_MAP,
paginator_map,
PAGINATOR_MAP,
IncrementalParam,
_bind_path_params,
_setup_single_entity_endpoint,
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_paginator_type_configs(paginator_type_config: PaginatorTypeConfig) -> N
assert paginator is None
else:
# assert types and default params
assert isinstance(paginator, paginator_map[paginator_type_config["type"]])
assert isinstance(paginator, PAGINATOR_MAP[paginator_type_config["type"]])
# check if params are bound
if isinstance(paginator, HeaderLinkPaginator):
assert paginator.links_next_key == "next_page"
Expand Down
10 changes: 9 additions & 1 deletion tests/rest_api/test_paginators.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_not_registering_throws_error(self, custom_paginator_config) -> None:

assert e.match("Invalid paginator: custom_paginator.")

def test_registering_adds_to_paginator_map(self, custom_paginator_config) -> None:
def test_registering_adds_to_PAGINATOR_MAP(self, custom_paginator_config) -> None:
rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator)
cls = rest_api.config_setup.get_paginator_class("custom_paginator")
assert cls is CustomPaginator
Expand All @@ -50,3 +50,11 @@ def test_registering_allows_usage(self, custom_paginator_config) -> None:
paginator = rest_api.config_setup.create_paginator(custom_paginator_config)
assert paginator.has_next_page is True
assert str(paginator.next_url_path) == "response.next_page_link"

def test_registering_not_base_paginator_throws_error(self) -> None:
class NotAPaginator:
pass

with pytest.raises(ValueError) as e:
rest_api.config_setup.register_paginator("not_a_paginator", NotAPaginator)
assert e.match("Invalid paginator: NotAPaginator.")

0 comments on commit 096946f

Please sign in to comment.