diff --git a/sources/rest_api/__init__.py b/sources/rest_api/__init__.py index 0d841c933..9eb6c2fbd 100644 --- a/sources/rest_api/__init__.py +++ b/sources/rest_api/__init__.py @@ -1,4 +1,5 @@ """Generic API Source""" + from copy import deepcopy from typing import Type, Any, Dict, List, Optional, Generator, Callable, cast, Union import graphlib # type: ignore[import,unused-ignore] @@ -33,6 +34,7 @@ IncrementalParamConfig, RESTAPIConfig, ParamBindType, + ProcessingSteps, ) from .config_setup import ( IncrementalParam, @@ -222,6 +224,7 @@ def create_resources( request_params = endpoint_config.get("params", {}) request_json = endpoint_config.get("json", None) paginator = create_paginator(endpoint_config.get("paginator")) + processing_steps = endpoint_resource.pop("processing_steps", []) resolved_param: ResolvedParam = resolved_param_map[resource_name] @@ -253,6 +256,17 @@ def create_resources( endpoint_resource, {"endpoint", "include_from_parent"} ) + def process( + resource: DltResource, + processing_steps: List[ProcessingSteps], + ) -> Any: + for step in processing_steps: + if "filter" in step: + resource.add_filter(step["filter"]) + if "map" in step: + resource.add_map(step["map"]) + return resource + if resolved_param is None: def paginate_resource( @@ -300,6 +314,9 @@ def paginate_resource( data_selector=endpoint_config.get("data_selector"), hooks=hooks, ) + resources[resource_name] = process( + resources[resource_name], processing_steps + ) else: predecessor = resources[resolved_param.resolve_config["resource"]] @@ -361,6 +378,9 @@ def paginate_dependent_resource( data_selector=endpoint_config.get("data_selector"), hooks=hooks, ) + resources[resource_name] = process( + resources[resource_name], processing_steps + ) return resources @@ -384,7 +404,14 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig: has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS) if ( - isinstance(auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth)) + isinstance( + auth_config, + ( + APIKeyAuth, + BearerTokenAuth, + HttpBasicAuth, + ), + ) or has_sensitive_key ): return _mask_secrets_dict(auth_config) diff --git a/sources/rest_api/typing.py b/sources/rest_api/typing.py index 8926adaaa..3d83570f4 100644 --- a/sources/rest_api/typing.py +++ b/sources/rest_api/typing.py @@ -240,6 +240,11 @@ class Endpoint(TypedDict, total=False): incremental: Optional[IncrementalConfig] +class ProcessingSteps(TypedDict): + filter: Optional[Callable[[Any], bool]] # noqa: A003 + map: Optional[Callable[[Any], Any]] # noqa: A003 + + class ResourceBase(TypedDict, total=False): """Defines hints that may be passed to `dlt.resource` decorator""" @@ -254,6 +259,7 @@ class ResourceBase(TypedDict, total=False): table_format: Optional[TTableHintTemplate[TTableFormat]] selected: Optional[bool] parallelized: Optional[bool] + processing_steps: Optional[List[ProcessingSteps]] class EndpointResourceBase(ResourceBase, total=False): diff --git a/tests/rest_api/test_rest_api_source_processed.py b/tests/rest_api/test_rest_api_source_processed.py new file mode 100644 index 000000000..cc04c27a6 --- /dev/null +++ b/tests/rest_api/test_rest_api_source_processed.py @@ -0,0 +1,254 @@ +from typing import Callable, List + +import dlt +import pytest +from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator +from dlt.extract.source import DltResource +from sources.rest_api import RESTAPIConfig, rest_api_source +from tests.utils import ALL_DESTINATIONS, assert_load_info, load_table_counts + + +def _make_pipeline(destination_name: str): + return dlt.pipeline( + pipeline_name="rest_api", + destination=destination_name, + dataset_name="rest_api_data", + full_refresh=True, + ) + + +def test_rest_api_source_filtered(mock_api_server) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] == 1}, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + assert len(data) == 1 + assert data[0]["title"] == "Post 1" + + +def test_rest_api_source_exclude_columns(mock_api_server) -> None: + + def exclude_columns(columns: List[str]) -> Callable: + def pop_columns(resource: DltResource) -> DltResource: + for col in columns: + resource.pop(col) + return resource + + return pop_columns + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + { + "map": exclude_columns(["title"]), + }, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + + assert all("title" not in record for record in data) + + +def test_rest_api_source_anonymize_columns(mock_api_server) -> None: + + def anonymize_columns(columns: List[str]) -> Callable: + def empty_columns(resource: DltResource) -> DltResource: + for col in columns: + resource[col] = "dummy" + return resource + + return empty_columns + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + { + "map": anonymize_columns(["title"]), + }, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + + assert all(record["title"] == "dummy" for record in data) + + +def test_rest_api_source_map(mock_api_server) -> None: + + def lower_title(row): + row["title"] = row["title"].lower() + return row + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"map": lower_title}, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + + assert all(record["title"].startswith("post ") for record in data) + + +def test_rest_api_source_filter_and_map(mock_api_server) -> None: + + def id_by_10(row): + row["id"] = row["id"] * 10 + return row + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"map": id_by_10}, + {"filter": lambda x: x["id"] == 10}, + ], + }, + { + "name": "posts_2", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] == 10}, + {"map": id_by_10}, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("posts")) + assert len(data) == 1 + assert data[0]["title"] == "Post 1" + + data = list(mock_source.with_resources("posts_2")) + assert len(data) == 1 + assert data[0]["id"] == 100 + assert data[0]["title"] == "Post 10" + + +def test_rest_api_source_filtered_child(mock_api_server) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] in (1, 2)}, + ], + }, + { + "name": "comments", + "endpoint": { + "path": "/posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + }, + "processing_steps": [ + {"filter": lambda x: x["id"] == 1}, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("comments")) + assert len(data) == 2 + + +def test_rest_api_source_filtered_and_map_child(mock_api_server) -> None: + + def extend_body(row): + row["body"] = f"{row['_posts_title']} - {row['body']}" + return row + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] in (1, 2)}, + ], + }, + { + "name": "comments", + "endpoint": { + "path": "/posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + }, + "include_from_parent": ["title"], + "processing_steps": [ + {"map": extend_body}, + {"filter": lambda x: x["body"].startswith("Post 2")}, + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("comments")) + assert data[0]["body"] == "Post 2 - Comment 0 for post 2"