From 096946fddeb0604ae63e30332849e038716ba8d0 Mon Sep 17 00:00:00 2001 From: Willi Date: Thu, 18 Jul 2024 16:26:09 +0530 Subject: [PATCH] adds type check at runtime to prevent wrong usage and clarify intention. Removes unnecessary dataclass and returns to simple dictionary. --- sources/rest_api/config_setup.py | 56 ++++++++++----------------- tests/rest_api/test_configurations.py | 4 +- tests/rest_api/test_paginators.py | 10 ++++- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/sources/rest_api/config_setup.py b/sources/rest_api/config_setup.py index 53cbab222..4e1396bc2 100644 --- a/sources/rest_api/config_setup.py +++ b/sources/rest_api/config_setup.py @@ -10,11 +10,9 @@ Callable, cast, NamedTuple, - KeysView, ) import graphlib # type: ignore[import,unused-ignore] import string -from dataclasses import dataclass, field import dlt from dlt.common import logger @@ -61,6 +59,16 @@ from .utils import exclude_keys +PAGINATOR_MAP: Dict[str, Type[BasePaginator]] = { + "json_response": JSONResponsePaginator, + "header_link": HeaderLinkPaginator, + "auto": None, + "single_page": SinglePagePaginator, + "cursor": JSONResponseCursorPaginator, + "offset": OffsetPaginator, + "page_number": PageNumberPaginator, +} + AUTH_MAP: Dict[AuthType, Type[AuthConfigBase]] = { "bearer": BearerTokenAuth, "api_key": APIKeyAuth, @@ -73,45 +81,23 @@ class IncrementalParam(NamedTuple): end: Optional[str] -@dataclass -class PaginatorMap: - _map: Dict[str, Type[BasePaginator]] = field(default_factory=dict) - - def __post_init__(self) -> None: - self._map = { - "json_response": JSONResponsePaginator, - "header_link": HeaderLinkPaginator, - "auto": None, - "single_page": SinglePagePaginator, - "cursor": JSONResponseCursorPaginator, - "offset": OffsetPaginator, - "page_number": PageNumberPaginator, - } - - def __getitem__(self, key: str) -> Type[BasePaginator]: - return self._map[key] - - def __setitem__(self, key: str, value: Type[BasePaginator]) -> None: - self._map[key] = value - - def keys(self) -> KeysView[str]: - return self._map.keys() - - -paginator_map = PaginatorMap() - - def register_paginator( - paginator_name: str, paginator_class: Type[BasePaginator] + paginator_name: str, + paginator_class: Type[BasePaginator], ) -> None: - paginator_map[paginator_name] = paginator_class + if not issubclass(paginator_class, BasePaginator): + raise ValueError( + f"Invalid paginator: {paginator_class.__name__}. " + "Your custom paginator has to be a subclass of BasePaginator" + ) + PAGINATOR_MAP[paginator_name] = paginator_class def get_paginator_class(paginator_type: PaginatorType) -> Type[BasePaginator]: try: - return paginator_map[paginator_type] + return PAGINATOR_MAP[paginator_type] except KeyError: - available_options = ", ".join(paginator_map.keys()) + available_options = ", ".join(PAGINATOR_MAP.keys()) raise ValueError( f"Invalid paginator: {paginator_type}. " f"Available options: {available_options}" @@ -127,7 +113,7 @@ def create_paginator( if isinstance(paginator_config, str): paginator_class = get_paginator_class(paginator_config) try: - # `auto` has no associated class in `paginator_map` + # `auto` has no associated class in `PAGINATOR_MAP` return paginator_class() if paginator_class else None except TypeError: raise ValueError( diff --git a/tests/rest_api/test_configurations.py b/tests/rest_api/test_configurations.py index 5a9383d83..7e6e1c0fa 100644 --- a/tests/rest_api/test_configurations.py +++ b/tests/rest_api/test_configurations.py @@ -26,7 +26,7 @@ from sources.rest_api.config_setup import ( AUTH_MAP, - paginator_map, + PAGINATOR_MAP, IncrementalParam, _bind_path_params, _setup_single_entity_endpoint, @@ -108,7 +108,7 @@ def test_paginator_type_configs(paginator_type_config: PaginatorTypeConfig) -> N assert paginator is None else: # assert types and default params - assert isinstance(paginator, paginator_map[paginator_type_config["type"]]) + assert isinstance(paginator, PAGINATOR_MAP[paginator_type_config["type"]]) # check if params are bound if isinstance(paginator, HeaderLinkPaginator): assert paginator.links_next_key == "next_page" diff --git a/tests/rest_api/test_paginators.py b/tests/rest_api/test_paginators.py index 064283804..0e2cbfb7c 100644 --- a/tests/rest_api/test_paginators.py +++ b/tests/rest_api/test_paginators.py @@ -40,7 +40,7 @@ def test_not_registering_throws_error(self, custom_paginator_config) -> None: assert e.match("Invalid paginator: custom_paginator.") - def test_registering_adds_to_paginator_map(self, custom_paginator_config) -> None: + def test_registering_adds_to_PAGINATOR_MAP(self, custom_paginator_config) -> None: rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator) cls = rest_api.config_setup.get_paginator_class("custom_paginator") assert cls is CustomPaginator @@ -50,3 +50,11 @@ def test_registering_allows_usage(self, custom_paginator_config) -> None: paginator = rest_api.config_setup.create_paginator(custom_paginator_config) assert paginator.has_next_page is True assert str(paginator.next_url_path) == "response.next_page_link" + + def test_registering_not_base_paginator_throws_error(self) -> None: + class NotAPaginator: + pass + + with pytest.raises(ValueError) as e: + rest_api.config_setup.register_paginator("not_a_paginator", NotAPaginator) + assert e.match("Invalid paginator: NotAPaginator.")