Skip to content

Commit

Permalink
Merge branch 'master' into fix/deprecate_full_refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix authored Oct 1, 2024
2 parents b824c68 + 5312382 commit e8f7dc7
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 52 deletions.
3 changes: 1 addition & 2 deletions sources/filesystem/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
dlt>=0.5.1
openpyxl>=3.0.0
dlt>=0.5.1, <1
16 changes: 14 additions & 2 deletions sources/mongodb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Source that loads collections form any a mongo database, supports incremental loads."""

from typing import Any, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional

import dlt
from dlt.common.data_writers import TDataItemFormat
Expand All @@ -23,6 +23,7 @@ def mongodb(
write_disposition: Optional[str] = dlt.config.value,
parallel: Optional[bool] = dlt.config.value,
limit: Optional[int] = None,
filter_: Optional[Dict[str, Any]] = None,
) -> Iterable[DltResource]:
"""
A DLT source which loads data from a mongo database using PyMongo.
Expand All @@ -39,6 +40,7 @@ def mongodb(
limit (Optional[int]):
The maximum number of documents to load. The limit is
applied to each requested collection separately.
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand All @@ -64,7 +66,14 @@ def mongodb(
primary_key="_id",
write_disposition=write_disposition,
spec=MongoDbCollectionConfiguration,
)(client, collection, incremental=incremental, parallel=parallel, limit=limit)
)(
client,
collection,
incremental=incremental,
parallel=parallel,
limit=limit,
filter_=filter_ or {},
)


@dlt.common.configuration.with_config(
Expand All @@ -80,6 +89,7 @@ def mongodb_collection(
limit: Optional[int] = None,
chunk_size: Optional[int] = 10000,
data_item_format: Optional[TDataItemFormat] = "object",
filter_: Optional[Dict[str, Any]] = None,
) -> Any:
"""
A DLT source which loads a collection from a mongo database using PyMongo.
Expand All @@ -98,6 +108,7 @@ def mongodb_collection(
Supported formats:
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand All @@ -124,4 +135,5 @@ def mongodb_collection(
limit=limit,
chunk_size=chunk_size,
data_item_format=data_item_format,
filter_=filter_ or {},
)
116 changes: 97 additions & 19 deletions sources/mongodb/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
TCollection = Any
TCursor = Any

try:
import pymongoarrow # type: ignore

PYMONGOARROW_AVAILABLE = True
except ImportError:
PYMONGOARROW_AVAILABLE = False


class CollectionLoader:
def __init__(
Expand Down Expand Up @@ -100,7 +107,7 @@ def _filter_op(self) -> Dict[str, Any]:

return filt

def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> Cursor: # type: ignore
def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> TCursor: # type: ignore
"""Apply a limit to the cursor, if needed.
Args:
Expand All @@ -120,16 +127,23 @@ def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> Cursor: # type

return cursor

def load_documents(self, limit: Optional[int] = None) -> Iterator[TDataItem]:
def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[TDataItem]:
"""Construct the query and load the documents from the collection.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
cursor = self.collection.find(self._filter_op)
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find(filter=filter_op)
if self._sort_op:
cursor = cursor.sort(self._sort_op)

Expand Down Expand Up @@ -157,8 +171,20 @@ def _create_batches(self, limit: Optional[int] = None) -> List[Dict[str, int]]:

return batches

def _get_cursor(self) -> TCursor:
cursor = self.collection.find(filter=self._filter_op)
def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
"""Get a reading cursor for the collection.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
Returns:
Cursor: The cursor for the collection.
"""
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find(filter=filter_op)
if self._sort_op:
cursor = cursor.sort(self._sort_op)

Expand All @@ -174,31 +200,37 @@ def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:

return data

def _get_all_batches(self, limit: Optional[int] = None) -> Iterator[TDataItem]:
def _get_all_batches(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[TDataItem]:
"""Load all documents from the collection in parallel batches.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The maximum number of documents to load.
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
batches = self._create_batches(limit)
cursor = self._get_cursor()
batches = self._create_batches(limit=limit)
cursor = self._get_cursor(filter_=filter_)

for batch in batches:
yield self._run_batch(cursor=cursor, batch=batch)

def load_documents(self, limit: Optional[int] = None) -> Iterator[TDataItem]:
def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[TDataItem]:
"""Load documents from the collection in parallel.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
for document in self._get_all_batches(limit):
for document in self._get_all_batches(limit=limit, filter_=filter_):
yield document


Expand All @@ -208,11 +240,14 @@ class CollectionArrowLoader(CollectionLoader):
Apache Arrow for data processing.
"""

def load_documents(self, limit: Optional[int] = None) -> Iterator[Any]:
def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
) -> Iterator[Any]:
"""
Load documents from the collection in Apache Arrow format.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
Yields:
Expand All @@ -225,9 +260,11 @@ def load_documents(self, limit: Optional[int] = None) -> Iterator[Any]:
None, codec_options=self.collection.codec_options
)

cursor = self.collection.find_raw_batches(
self._filter_op, batch_size=self.chunk_size
)
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find_raw_batches(filter_, batch_size=self.chunk_size)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore

Expand All @@ -246,9 +283,21 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel):
Apache Arrow for data processing.
"""

def _get_cursor(self) -> TCursor:
def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
"""Get a reading cursor for the collection.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
Returns:
Cursor: The cursor for the collection.
"""
filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find_raw_batches(
filter=self._filter_op, batch_size=self.chunk_size
filter=filter_op, batch_size=self.chunk_size
)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore
Expand Down Expand Up @@ -276,6 +325,7 @@ def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:
def collection_documents(
client: TMongoClient,
collection: TCollection,
filter_: Dict[str, Any],
incremental: Optional[dlt.sources.incremental[Any]] = None,
parallel: bool = False,
limit: Optional[int] = None,
Expand All @@ -289,6 +339,7 @@ def collection_documents(
Args:
client (MongoClient): The PyMongo client `pymongo.MongoClient` instance.
collection (Collection): The collection `pymongo.collection.Collection` to load.
filter_ (Dict[str, Any]): The filter to apply to the collection.
incremental (Optional[dlt.sources.incremental[Any]]): The incremental configuration.
parallel (bool): Option to enable parallel loading for the collection. Default is False.
limit (Optional[int]): The maximum number of documents to load.
Expand All @@ -301,21 +352,27 @@ def collection_documents(
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
"""
if data_item_format == "arrow" and not PYMONGOARROW_AVAILABLE:
dlt.common.logger.warn(
"'pymongoarrow' is not installed; falling back to standard MongoDB CollectionLoader."
)
data_item_format = "object"

if parallel:
if data_item_format == "arrow":
LoaderClass = CollectionArrowLoaderParallel
elif data_item_format == "object":
else:
LoaderClass = CollectionLoaderParallel # type: ignore
else:
if data_item_format == "arrow":
LoaderClass = CollectionArrowLoader # type: ignore
elif data_item_format == "object":
else:
LoaderClass = CollectionLoader # type: ignore

loader = LoaderClass(
client, collection, incremental=incremental, chunk_size=chunk_size
)
for data in loader.load_documents(limit=limit):
for data in loader.load_documents(limit=limit, filter_=filter_):
yield data


Expand Down Expand Up @@ -377,6 +434,27 @@ def client_from_credentials(connection_url: str) -> TMongoClient:
return client


def _raise_if_intersection(filter1: Dict[str, Any], filter2: Dict[str, Any]) -> None:
"""
Raise an exception, if the given filters'
fields are intersecting.
Args:
filter1 (Dict[str, Any]): The first filter.
filter2 (Dict[str, Any]): The second filter.
"""
field_inter = filter1.keys() & filter2.keys()
for field in field_inter:
if filter1[field].keys() & filter2[field].keys():
str_repr = str({field: filter1[field]})
raise ValueError(
(
f"Filtering operator {str_repr} is already used by the "
"incremental and can't be used in the filter."
)
)


@configspec
class MongoDbCollectionConfiguration(BaseConfiguration):
incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg]
Expand Down
3 changes: 1 addition & 2 deletions sources/mongodb/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
pymongo>=4.3.3
pymongoarrow>=1.3.0
pymongo>=3
dlt>=0.5.1
2 changes: 1 addition & 1 deletion sources/rest_api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Possible paginators are:
Usage example of the `JSONLinkPaginator`, for a response with the URL of the next page located at `paging.next`:
```python
"paginator": JSONLinkPaginator(
next_url_path="paging.next"]
next_url_path="paging.next"
)
```

Expand Down
2 changes: 1 addition & 1 deletion sources/rest_api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
dlt>=0.5.2
dlt>=0.5.2, <1
2 changes: 1 addition & 1 deletion sources/sql_database/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
sqlalchemy>=1.4
dlt>=0.5.1
dlt>=0.5.1, <1
6 changes: 3 additions & 3 deletions sources/sql_database_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ def specify_columns_to_load() -> None:
)

# Columns can be specified per table in env var (json array) or in `.dlt/config.toml`
os.environ["SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"] = (
'["rfam_acc", "description"]'
)
os.environ[
"SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"
] = '["rfam_acc", "description"]'

sql_alchemy_source = sql_database(
"mysql+pymysql://[email protected]:4497/Rfam?&binary_prefix=true",
Expand Down
Loading

0 comments on commit e8f7dc7

Please sign in to comment.