diff --git a/sources/filesystem/requirements.txt b/sources/filesystem/requirements.txt index ef815263f..fd2471c94 100644 --- a/sources/filesystem/requirements.txt +++ b/sources/filesystem/requirements.txt @@ -1,2 +1 @@ -dlt>=0.5.1 -openpyxl>=3.0.0 \ No newline at end of file +dlt>=0.5.1, <1 \ No newline at end of file diff --git a/sources/mongodb/__init__.py b/sources/mongodb/__init__.py index 4588f7001..db6b9d054 100644 --- a/sources/mongodb/__init__.py +++ b/sources/mongodb/__init__.py @@ -1,6 +1,6 @@ """Source that loads collections form any a mongo database, supports incremental loads.""" -from typing import Any, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional import dlt from dlt.common.data_writers import TDataItemFormat @@ -23,6 +23,7 @@ def mongodb( write_disposition: Optional[str] = dlt.config.value, parallel: Optional[bool] = dlt.config.value, limit: Optional[int] = None, + filter_: Optional[Dict[str, Any]] = None, ) -> Iterable[DltResource]: """ A DLT source which loads data from a mongo database using PyMongo. @@ -39,6 +40,7 @@ def mongodb( limit (Optional[int]): The maximum number of documents to load. The limit is applied to each requested collection separately. + filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection. Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. @@ -64,7 +66,14 @@ def mongodb( primary_key="_id", write_disposition=write_disposition, spec=MongoDbCollectionConfiguration, - )(client, collection, incremental=incremental, parallel=parallel, limit=limit) + )( + client, + collection, + incremental=incremental, + parallel=parallel, + limit=limit, + filter_=filter_ or {}, + ) @dlt.common.configuration.with_config( @@ -80,6 +89,7 @@ def mongodb_collection( limit: Optional[int] = None, chunk_size: Optional[int] = 10000, data_item_format: Optional[TDataItemFormat] = "object", + filter_: Optional[Dict[str, Any]] = None, ) -> Any: """ A DLT source which loads a collection from a mongo database using PyMongo. @@ -98,6 +108,7 @@ def mongodb_collection( Supported formats: object - Python objects (dicts, lists). arrow - Apache Arrow tables. + filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection. Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. @@ -124,4 +135,5 @@ def mongodb_collection( limit=limit, chunk_size=chunk_size, data_item_format=data_item_format, + filter_=filter_ or {}, ) diff --git a/sources/mongodb/helpers.py b/sources/mongodb/helpers.py index 9769e4fbd..fe5dcc69c 100644 --- a/sources/mongodb/helpers.py +++ b/sources/mongodb/helpers.py @@ -29,6 +29,13 @@ TCollection = Any TCursor = Any +try: + import pymongoarrow # type: ignore + + PYMONGOARROW_AVAILABLE = True +except ImportError: + PYMONGOARROW_AVAILABLE = False + class CollectionLoader: def __init__( @@ -100,7 +107,7 @@ def _filter_op(self) -> Dict[str, Any]: return filt - def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> Cursor: # type: ignore + def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> TCursor: # type: ignore """Apply a limit to the cursor, if needed. Args: @@ -120,16 +127,23 @@ def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> Cursor: # type return cursor - def load_documents(self, limit: Optional[int] = None) -> Iterator[TDataItem]: + def load_documents( + self, filter_: Dict[str, Any], limit: Optional[int] = None + ) -> Iterator[TDataItem]: """Construct the query and load the documents from the collection. Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. limit (Optional[int]): The number of documents to load. Yields: Iterator[TDataItem]: An iterator of the loaded documents. """ - cursor = self.collection.find(self._filter_op) + filter_op = self._filter_op + _raise_if_intersection(filter_op, filter_) + filter_op.update(filter_) + + cursor = self.collection.find(filter=filter_op) if self._sort_op: cursor = cursor.sort(self._sort_op) @@ -157,8 +171,20 @@ def _create_batches(self, limit: Optional[int] = None) -> List[Dict[str, int]]: return batches - def _get_cursor(self) -> TCursor: - cursor = self.collection.find(filter=self._filter_op) + def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor: + """Get a reading cursor for the collection. + + Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. + + Returns: + Cursor: The cursor for the collection. + """ + filter_op = self._filter_op + _raise_if_intersection(filter_op, filter_) + filter_op.update(filter_) + + cursor = self.collection.find(filter=filter_op) if self._sort_op: cursor = cursor.sort(self._sort_op) @@ -174,31 +200,37 @@ def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem: return data - def _get_all_batches(self, limit: Optional[int] = None) -> Iterator[TDataItem]: + def _get_all_batches( + self, filter_: Dict[str, Any], limit: Optional[int] = None + ) -> Iterator[TDataItem]: """Load all documents from the collection in parallel batches. Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. limit (Optional[int]): The maximum number of documents to load. Yields: Iterator[TDataItem]: An iterator of the loaded documents. """ - batches = self._create_batches(limit) - cursor = self._get_cursor() + batches = self._create_batches(limit=limit) + cursor = self._get_cursor(filter_=filter_) for batch in batches: yield self._run_batch(cursor=cursor, batch=batch) - def load_documents(self, limit: Optional[int] = None) -> Iterator[TDataItem]: + def load_documents( + self, filter_: Dict[str, Any], limit: Optional[int] = None + ) -> Iterator[TDataItem]: """Load documents from the collection in parallel. Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. limit (Optional[int]): The number of documents to load. Yields: Iterator[TDataItem]: An iterator of the loaded documents. """ - for document in self._get_all_batches(limit): + for document in self._get_all_batches(limit=limit, filter_=filter_): yield document @@ -208,11 +240,14 @@ class CollectionArrowLoader(CollectionLoader): Apache Arrow for data processing. """ - def load_documents(self, limit: Optional[int] = None) -> Iterator[Any]: + def load_documents( + self, filter_: Dict[str, Any], limit: Optional[int] = None + ) -> Iterator[Any]: """ Load documents from the collection in Apache Arrow format. Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. limit (Optional[int]): The number of documents to load. Yields: @@ -225,9 +260,11 @@ def load_documents(self, limit: Optional[int] = None) -> Iterator[Any]: None, codec_options=self.collection.codec_options ) - cursor = self.collection.find_raw_batches( - self._filter_op, batch_size=self.chunk_size - ) + filter_op = self._filter_op + _raise_if_intersection(filter_op, filter_) + filter_op.update(filter_) + + cursor = self.collection.find_raw_batches(filter_, batch_size=self.chunk_size) if self._sort_op: cursor = cursor.sort(self._sort_op) # type: ignore @@ -246,9 +283,21 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel): Apache Arrow for data processing. """ - def _get_cursor(self) -> TCursor: + def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor: + """Get a reading cursor for the collection. + + Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. + + Returns: + Cursor: The cursor for the collection. + """ + filter_op = self._filter_op + _raise_if_intersection(filter_op, filter_) + filter_op.update(filter_) + cursor = self.collection.find_raw_batches( - filter=self._filter_op, batch_size=self.chunk_size + filter=filter_op, batch_size=self.chunk_size ) if self._sort_op: cursor = cursor.sort(self._sort_op) # type: ignore @@ -276,6 +325,7 @@ def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem: def collection_documents( client: TMongoClient, collection: TCollection, + filter_: Dict[str, Any], incremental: Optional[dlt.sources.incremental[Any]] = None, parallel: bool = False, limit: Optional[int] = None, @@ -289,6 +339,7 @@ def collection_documents( Args: client (MongoClient): The PyMongo client `pymongo.MongoClient` instance. collection (Collection): The collection `pymongo.collection.Collection` to load. + filter_ (Dict[str, Any]): The filter to apply to the collection. incremental (Optional[dlt.sources.incremental[Any]]): The incremental configuration. parallel (bool): Option to enable parallel loading for the collection. Default is False. limit (Optional[int]): The maximum number of documents to load. @@ -301,21 +352,27 @@ def collection_documents( Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. """ + if data_item_format == "arrow" and not PYMONGOARROW_AVAILABLE: + dlt.common.logger.warn( + "'pymongoarrow' is not installed; falling back to standard MongoDB CollectionLoader." + ) + data_item_format = "object" + if parallel: if data_item_format == "arrow": LoaderClass = CollectionArrowLoaderParallel - elif data_item_format == "object": + else: LoaderClass = CollectionLoaderParallel # type: ignore else: if data_item_format == "arrow": LoaderClass = CollectionArrowLoader # type: ignore - elif data_item_format == "object": + else: LoaderClass = CollectionLoader # type: ignore loader = LoaderClass( client, collection, incremental=incremental, chunk_size=chunk_size ) - for data in loader.load_documents(limit=limit): + for data in loader.load_documents(limit=limit, filter_=filter_): yield data @@ -377,6 +434,27 @@ def client_from_credentials(connection_url: str) -> TMongoClient: return client +def _raise_if_intersection(filter1: Dict[str, Any], filter2: Dict[str, Any]) -> None: + """ + Raise an exception, if the given filters' + fields are intersecting. + + Args: + filter1 (Dict[str, Any]): The first filter. + filter2 (Dict[str, Any]): The second filter. + """ + field_inter = filter1.keys() & filter2.keys() + for field in field_inter: + if filter1[field].keys() & filter2[field].keys(): + str_repr = str({field: filter1[field]}) + raise ValueError( + ( + f"Filtering operator {str_repr} is already used by the " + "incremental and can't be used in the filter." + ) + ) + + @configspec class MongoDbCollectionConfiguration(BaseConfiguration): incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] diff --git a/sources/mongodb/requirements.txt b/sources/mongodb/requirements.txt index 45ac0bc3d..5240a44e2 100644 --- a/sources/mongodb/requirements.txt +++ b/sources/mongodb/requirements.txt @@ -1,3 +1,2 @@ -pymongo>=4.3.3 -pymongoarrow>=1.3.0 +pymongo>=3 dlt>=0.5.1 diff --git a/sources/rest_api/README.md b/sources/rest_api/README.md index d2878d062..df077cb34 100644 --- a/sources/rest_api/README.md +++ b/sources/rest_api/README.md @@ -137,7 +137,7 @@ Possible paginators are: Usage example of the `JSONLinkPaginator`, for a response with the URL of the next page located at `paging.next`: ```python "paginator": JSONLinkPaginator( - next_url_path="paging.next"] + next_url_path="paging.next" ) ``` diff --git a/sources/rest_api/requirements.txt b/sources/rest_api/requirements.txt index d6077ae57..369e301a7 100644 --- a/sources/rest_api/requirements.txt +++ b/sources/rest_api/requirements.txt @@ -1 +1 @@ -dlt>=0.5.2 +dlt>=0.5.2, <1 diff --git a/sources/sql_database/requirements.txt b/sources/sql_database/requirements.txt index 3bf6b9829..69e47dad5 100644 --- a/sources/sql_database/requirements.txt +++ b/sources/sql_database/requirements.txt @@ -1,2 +1,2 @@ sqlalchemy>=1.4 -dlt>=0.5.1 \ No newline at end of file +dlt>=0.5.1, <1 \ No newline at end of file diff --git a/sources/sql_database_pipeline.py b/sources/sql_database_pipeline.py index 100323a5f..55c25c76b 100644 --- a/sources/sql_database_pipeline.py +++ b/sources/sql_database_pipeline.py @@ -338,9 +338,9 @@ def specify_columns_to_load() -> None: ) # Columns can be specified per table in env var (json array) or in `.dlt/config.toml` - os.environ["SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"] = ( - '["rfam_acc", "description"]' - ) + os.environ[ + "SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS" + ] = '["rfam_acc", "description"]' sql_alchemy_source = sql_database( "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", diff --git a/tests/mongodb/test_mongodb_source.py b/tests/mongodb/test_mongodb_source.py index 10071b916..b39ce6997 100644 --- a/tests/mongodb/test_mongodb_source.py +++ b/tests/mongodb/test_mongodb_source.py @@ -1,11 +1,15 @@ -import bson import json +from unittest import mock + +import bson +import dlt import pyarrow import pytest from pendulum import DateTime, timezone from unittest import mock import dlt +from dlt.pipeline.exceptions import PipelineStepFailed from sources.mongodb import mongodb, mongodb_collection from sources.mongodb_pipeline import ( @@ -356,3 +360,73 @@ def test_arrow_types(destination_name): info = pipeline.run(res, table_name="types_test") assert info.loads_ids != [] + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_filter(destination_name): + """ + The field `runtime` is not set in some movies, + thus incremental will not work. However, adding + an explicit filter_, which says to consider + only documents with `runtime`, makes it work. + """ + pipeline = dlt.pipeline( + pipeline_name="mongodb_test", + destination=destination_name, + dataset_name="mongodb_test_data", + full_refresh=True, + ) + movies = mongodb_collection( + collection="movies", + incremental=dlt.sources.incremental("runtime", initial_value=500), + filter_={"runtime": {"$exists": True}}, + ) + pipeline.run(movies) + + table_counts = load_table_counts(pipeline, "movies") + assert table_counts["movies"] == 23 + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_filter_intersect(destination_name): + """ + Check that using in the filter_ fields that + are used by incremental is not allowed. + """ + pipeline = dlt.pipeline( + pipeline_name="mongodb_test", + destination=destination_name, + dataset_name="mongodb_test_data", + full_refresh=True, + ) + movies = mongodb_collection( + collection="movies", + incremental=dlt.sources.incremental("runtime", initial_value=20), + filter_={"runtime": {"$gte": 20}}, + ) + + with pytest.raises(PipelineStepFailed): + pipeline.run(movies) + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("data_item_format", ["object", "arrow"]) +def test_mongodb_without_pymongoarrow( + destination_name: str, data_item_format: str +) -> None: + with mock.patch.dict("sys.modules", {"pymongoarrow": None}): + pipeline = dlt.pipeline( + pipeline_name="test_mongodb_without_pymongoarrow", + destination=destination_name, + dataset_name="test_mongodb_without_pymongoarrow_data", + full_refresh=True, + ) + + comments = mongodb_collection( + collection="comments", limit=10, data_item_format=data_item_format + ) + load_info = pipeline.run(comments) + + assert load_info.loads_ids != [] + table_counts = load_table_counts(pipeline, "comments") + assert table_counts["comments"] == 10 diff --git a/tests/rest_api/test_rest_api_source_processed.py b/tests/rest_api/test_rest_api_source_processed.py index c19e16e3c..b0b61131b 100644 --- a/tests/rest_api/test_rest_api_source_processed.py +++ b/tests/rest_api/test_rest_api_source_processed.py @@ -40,7 +40,6 @@ def test_rest_api_source_filtered(mock_api_server) -> None: def test_rest_api_source_exclude_columns(mock_api_server) -> None: - def exclude_columns(columns: List[str]) -> Callable: def pop_columns(resource: DltResource) -> DltResource: for col in columns: @@ -73,7 +72,6 @@ def pop_columns(resource: DltResource) -> DltResource: def test_rest_api_source_anonymize_columns(mock_api_server) -> None: - def anonymize_columns(columns: List[str]) -> Callable: def empty_columns(resource: DltResource) -> DltResource: for col in columns: @@ -106,7 +104,6 @@ def empty_columns(resource: DltResource) -> DltResource: def test_rest_api_source_map(mock_api_server) -> None: - def lower_title(row): row["title"] = row["title"].lower() return row @@ -133,7 +130,6 @@ def lower_title(row): def test_rest_api_source_filter_and_map(mock_api_server) -> None: - def id_by_10(row): row["id"] = row["id"] * 10 return row @@ -211,7 +207,6 @@ def test_rest_api_source_filtered_child(mock_api_server) -> None: def test_rest_api_source_filtered_and_map_child(mock_api_server) -> None: - def extend_body(row): row["body"] = f"{row['_posts_title']} - {row['body']}" return row diff --git a/tests/sql_database/test_sql_database_source.py b/tests/sql_database/test_sql_database_source.py index e7ea70ac6..589e9616b 100644 --- a/tests/sql_database/test_sql_database_source.py +++ b/tests/sql_database/test_sql_database_source.py @@ -97,9 +97,9 @@ def test_pass_engine_credentials(sql_source_db: SQLAlchemySourceDB) -> None: def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: # set the credentials per table name - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"] = ( - sql_source_db.engine.url.render_as_string(False) - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS" + ] = sql_source_db.engine.url.render_as_string(False) table = sql_table(table="chat_message", schema=sql_source_db.schema) assert table.name == "chat_message" assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] @@ -119,9 +119,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: assert len(list(table)) == 10 # make it fail on cursor - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at_x" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at_x" table = sql_table(table="chat_message", schema=sql_source_db.schema) with pytest.raises(ResourceExtractionError) as ext_ex: len(list(table)) @@ -130,9 +130,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: # set the credentials per table name - os.environ["SOURCES__SQL_DATABASE__CREDENTIALS"] = ( - sql_source_db.engine.url.render_as_string(False) - ) + os.environ[ + "SOURCES__SQL_DATABASE__CREDENTIALS" + ] = sql_source_db.engine.url.render_as_string(False) # applies to both sql table and sql database table = sql_table(table="chat_message", schema=sql_source_db.schema) assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] @@ -155,9 +155,9 @@ def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: assert len(list(database)) == 10 # make it fail on cursor - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at_x" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at_x" table = sql_table(table="chat_message", schema=sql_source_db.schema) with pytest.raises(ResourceExtractionError) as ext_ex: len(list(table)) @@ -275,9 +275,9 @@ def test_load_sql_table_incremental( """Run pipeline twice. Insert more rows after first run and ensure only those rows are stored after the second run. """ - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at" pipeline = make_pipeline(destination_name) tables = ["chat_message"]