Skip to content

Commit

Permalink
SQL database: included_columns option
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Aug 8, 2024
1 parent c2a49a6 commit d97b5bb
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 34 deletions.
9 changes: 7 additions & 2 deletions sources/sql_database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def sql_database(
name=table.name,
primary_key=get_primary_key(table),
spec=SqlDatabaseTableConfiguration,
columns=table_to_columns(table, reflection_level, type_adapter_callback),
# columns hint will be set at runtime, after included_columns setting has been resolved
)(
engine,
table,
Expand Down Expand Up @@ -142,6 +142,7 @@ def sql_table(
table_adapter_callback: Callable[[Table], None] = None,
backend_kwargs: Dict[str, Any] = None,
type_adapter_callback: Optional[TTypeAdapter] = None,
included_columns: Optional[List[str]] = None,
) -> DltResource:
"""
A dlt resource which loads data from an SQL database table using SQLAlchemy.
Expand Down Expand Up @@ -170,6 +171,7 @@ def sql_table(
backend_kwargs (**kwargs): kwargs passed to table backend ie. "conn" is used to pass specialized connection string to connectorx.
type_adapter_callback(Optional[Callable]): Callable to override type inference when reflecting columns.
Argument is a single sqlalchemy data type (`TypeEngine` instance) and it should return another sqlalchemy data type, or `None` (type will be inferred from data)
included_columns (Optional[List[str]): List of column names to select from the table. If not provided, all columns are loaded.
Returns:
DltResource: The dlt resource for loading data from the SQL database table.
Expand Down Expand Up @@ -197,7 +199,9 @@ def sql_table(
table_rows,
name=table_obj.name,
primary_key=get_primary_key(table_obj),
columns=table_to_columns(table_obj, reflection_level, type_adapter_callback),
columns=table_to_columns(
table_obj, reflection_level, type_adapter_callback, included_columns
),
)(
engine,
table_obj,
Expand All @@ -209,4 +213,5 @@ def sql_table(
table_adapter_callback=table_adapter_callback,
backend_kwargs=backend_kwargs,
type_adapter_callback=type_adapter_callback,
included_columns=included_columns,
)
33 changes: 21 additions & 12 deletions sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
TTypeAdapter,
)

from sqlalchemy import Table, create_engine
from sqlalchemy import Table, create_engine, select
from sqlalchemy.engine import Engine
from sqlalchemy.exc import CompileError

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(

def make_query(self) -> SelectAny:
table = self.table
query = table.select()
query = select(*[c for c in table.c if c.name in self.columns])
if not self.incremental:
return query
last_value_func = self.incremental.last_value_func
Expand Down Expand Up @@ -189,6 +189,7 @@ def table_rows(
reflection_level: ReflectionLevel = "minimal",
backend_kwargs: Dict[str, Any] = None,
type_adapter_callback: Optional[TTypeAdapter] = None,
included_columns: Optional[List[str]] = None,
) -> Iterator[TDataItem]:
columns: TTableSchemaColumns = None
if defer_table_reflect:
Expand All @@ -198,24 +199,30 @@ def table_rows(
default_table_adapter(table)
if table_adapter_callback:
table_adapter_callback(table)
columns = table_to_columns(table, reflection_level, type_adapter_callback)
columns = table_to_columns(
table, reflection_level, type_adapter_callback, included_columns
)

# set the primary_key in the incremental
if incremental and incremental.primary_key is None:
primary_key = get_primary_key(table)
if primary_key is not None:
incremental.primary_key = primary_key
# yield empty record to set hints
yield dlt.mark.with_hints(
[],
dlt.mark.make_hints(
primary_key=get_primary_key(table),
columns=columns,
),
)

else:
# table was already reflected
columns = table_to_columns(table, reflection_level, type_adapter_callback)
columns = table_to_columns(
table, reflection_level, type_adapter_callback, included_columns
)

# yield empty record to set hints
yield dlt.mark.with_hints(
[],
dlt.mark.make_hints(
primary_key=get_primary_key(table),
columns=columns,
),
)

loader = TableLoader(
engine, backend, table, columns, incremental=incremental, chunk_size=chunk_size
Expand Down Expand Up @@ -282,6 +289,7 @@ def _detect_precision_hints_deprecated(value: Optional[bool]) -> None:
@configspec
class SqlDatabaseTableConfiguration(BaseConfiguration):
incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg]
included_columns: Optional[List[str]] = None


@configspec
Expand All @@ -295,3 +303,4 @@ class SqlTableResourceConfiguration(BaseConfiguration):
detect_precision_hints: Optional[bool] = None
defer_table_reflect: Optional[bool] = False
reflection_level: Optional[ReflectionLevel] = "full"
included_columns: Optional[List[str]] = None
2 changes: 2 additions & 0 deletions sources/sql_database/schema_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ def table_to_columns(
table: Table,
reflection_level: ReflectionLevel = "full",
type_conversion_fallback: Optional[TTypeAdapter] = None,
included_columns: Optional[List[str]] = None,
) -> TTableSchemaColumns:
"""Convert an sqlalchemy table to a dlt table schema."""
return {
col["name"]: col
for col in (
sqla_col_to_column_schema(c, reflection_level, type_conversion_fallback)
for c in table.columns
if included_columns is None or c.name in included_columns
)
if col is not None
}
24 changes: 24 additions & 0 deletions sources/sql_database_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sqlalchemy as sa
import humanize
from typing import Any
import os

import dlt
from dlt.common import pendulum
Expand Down Expand Up @@ -328,6 +329,29 @@ def type_adapter(sql_type: TypeEngine[Any]) -> TypeEngine[Any]:
print(info)


def specify_columns_to_load() -> None:
"""Run the SQL database source with a subset of table columns loaded"""
pipeline = dlt.pipeline(
pipeline_name="dummy",
destination="postgres",
dataset_name="dummy",
)

# Columns can be specified per table in env var (json array) or in `.dlt/config.toml`
os.environ["SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"] = (
'["rfam_acc", "description"]'
)

sql_alchemy_source = sql_database(
"mysql+pymysql://[email protected]:4497/Rfam?&binary_prefix=true",
backend="pyarrow",
reflection_level="full_with_precision",
).with_resources("family", "genome")

info = pipeline.run(sql_alchemy_source)
print(info)


if __name__ == "__main__":
# Load selected tables with different settings
# load_select_tables_from_database()
Expand Down
9 changes: 5 additions & 4 deletions tests/sql_database/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dlt
from dlt.common.typing import TDataItem
import sqlalchemy as sa

from sources.sql_database.helpers import TableLoader, TableBackend
from sources.sql_database.schema_types import table_to_columns
Expand Down Expand Up @@ -49,7 +50,7 @@ class MockIncremental:

query = loader.make_query()
expected = (
table.select()
sa.select(*table.c)
.order_by(table.c.created_at.asc())
.where(table.c.created_at >= MockIncremental.last_value)
)
Expand Down Expand Up @@ -79,7 +80,7 @@ class MockIncremental:

query = loader.make_query()
expected = (
table.select()
sa.select(*table.c)
.order_by(table.c.created_at.asc()) # `min` func swaps order
.where(table.c.created_at <= MockIncremental.last_value)
)
Expand Down Expand Up @@ -111,7 +112,7 @@ class MockIncremental:

query = loader.make_query()
expected = (
table.select()
sa.select(*table.c)
.where(table.c.created_at <= MockIncremental.last_value)
.where(table.c.created_at > MockIncremental.end_value)
)
Expand Down Expand Up @@ -140,7 +141,7 @@ class MockIncremental:
)

query = loader.make_query()
expected = table.select()
expected = sa.select(*table.c)

assert query.compare(expected)

Expand Down
95 changes: 79 additions & 16 deletions tests/sql_database/test_sql_database_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
load_table_counts,
load_tables_to_dicts,
assert_schema_on_data,
preserve_environ,
)
from tests.sql_database.sql_source import SQLAlchemySourceDB

Expand Down Expand Up @@ -96,9 +97,9 @@ def test_pass_engine_credentials(sql_source_db: SQLAlchemySourceDB) -> None:

def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None:
# set the credentials per table name
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"
] = sql_source_db.engine.url.render_as_string(False)
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"] = (
sql_source_db.engine.url.render_as_string(False)
)
table = sql_table(table="chat_message", schema=sql_source_db.schema)
assert table.name == "chat_message"
assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"]
Expand All @@ -118,9 +119,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None:
assert len(list(table)) == 10

# make it fail on cursor
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"
] = "updated_at_x"
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = (
"updated_at_x"
)
table = sql_table(table="chat_message", schema=sql_source_db.schema)
with pytest.raises(ResourceExtractionError) as ext_ex:
len(list(table))
Expand All @@ -129,9 +130,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None:

def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None:
# set the credentials per table name
os.environ[
"SOURCES__SQL_DATABASE__CREDENTIALS"
] = sql_source_db.engine.url.render_as_string(False)
os.environ["SOURCES__SQL_DATABASE__CREDENTIALS"] = (
sql_source_db.engine.url.render_as_string(False)
)
# applies to both sql table and sql database
table = sql_table(table="chat_message", schema=sql_source_db.schema)
assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"]
Expand All @@ -154,9 +155,9 @@ def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None:
assert len(list(database)) == 10

# make it fail on cursor
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"
] = "updated_at_x"
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = (
"updated_at_x"
)
table = sql_table(table="chat_message", schema=sql_source_db.schema)
with pytest.raises(ResourceExtractionError) as ext_ex:
len(list(table))
Expand Down Expand Up @@ -274,9 +275,9 @@ def test_load_sql_table_incremental(
"""Run pipeline twice. Insert more rows after first run
and ensure only those rows are stored after the second run.
"""
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"
] = "updated_at"
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = (
"updated_at"
)

pipeline = make_pipeline(destination_name)
tables = ["chat_message"]
Expand Down Expand Up @@ -821,7 +822,7 @@ def _assert_incremental(item):
# assert _r.incremental._incremental is updated_at
if len(item) == 0:
# not yet propagated
assert _r.incremental.primary_key == ["id"]
assert _r.incremental.primary_key is None
else:
assert _r.incremental.primary_key == ["id"]
assert _r.incremental._incremental.primary_key == ["id"]
Expand Down Expand Up @@ -1161,6 +1162,68 @@ def dummy_source():
assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
@pytest.mark.parametrize("defer_table_reflect", (False, True))
def test_sql_database_included_columns(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend, defer_table_reflect: bool
) -> None:
# include only some columns from the table
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCLUDED_COLUMNS"] = json.dumps(
["id", "created_at"]
)

source = sql_database(
credentials=sql_source_db.credentials,
schema=sql_source_db.schema,
table_names=["chat_message"],
reflection_level="full",
defer_table_reflect=defer_table_reflect,
backend=backend,
)

pipeline = make_pipeline("duckdb")
pipeline.run(source)

schema = pipeline.default_schema
schema_cols = set(
col
for col in schema.get_table_columns("chat_message", include_incomplete=True)
if not col.startswith("_dlt_")
)
assert schema_cols == {"id", "created_at"}

assert_row_counts(pipeline, sql_source_db, ["chat_message"])


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
@pytest.mark.parametrize("defer_table_reflect", (False, True))
def test_sql_table_included_columns(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend, defer_table_reflect: bool
) -> None:
source = sql_table(
credentials=sql_source_db.credentials,
schema=sql_source_db.schema,
table="chat_message",
reflection_level="full",
defer_table_reflect=defer_table_reflect,
backend=backend,
included_columns=["id", "created_at"],
)

pipeline = make_pipeline("duckdb")
pipeline.run(source)

schema = pipeline.default_schema
schema_cols = set(
col
for col in schema.get_table_columns("chat_message", include_incomplete=True)
if not col.startswith("_dlt_")
)
assert schema_cols == {"id", "created_at"}

assert_row_counts(pipeline, sql_source_db, ["chat_message"])


def assert_row_counts(
pipeline: dlt.Pipeline,
sql_source_db: SQLAlchemySourceDB,
Expand Down

0 comments on commit d97b5bb

Please sign in to comment.