Skip to content

Commit

Permalink
Run all refresh tests on local filesystem destination
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed May 27, 2024
1 parent c79bddc commit 403ae8b
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 130 deletions.
205 changes: 80 additions & 125 deletions tests/load/pipeline/test_refresh_modes.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from unittest import mock
from typing import Sequence, Any, List
from typing import Any, List

import pytest
import dlt
from dlt.common.pipeline import resource_state
from dlt.destinations.exceptions import DatabaseUndefinedRelation
from dlt.destinations.sql_client import DBApiCursor
from dlt.pipeline.state_sync import load_pipeline_state_from_destination
from dlt.common.typing import DictStrAny
from dlt.common.pipeline import pipeline_state as current_pipeline_state

from tests.utils import clean_test_storage, preserve_environ
from tests.pipeline.utils import assert_load_info
from tests.pipeline.utils import (
assert_load_info,
load_tables_to_dicts,
assert_only_table_columns,
table_exists,
)
from tests.load.utils import destinations_configs, DestinationTestConfiguration


Expand All @@ -22,13 +25,6 @@ def assert_source_state_is_wiped(state: DictStrAny) -> None:
assert not value


def assert_only_table_columns(cursor: DBApiCursor, expected_columns: Sequence[str]) -> None:
"""Table has all and only the expected columns (excluding _dlt columns)"""
# Ignore _dlt columns
columns = [c[0] for c in cursor.native_cursor.description if not c[0].startswith("_")]
assert set(columns) == set(expected_columns)


def column_values(cursor: DBApiCursor, column_name: str) -> List[Any]:
"""Return all values in a column from a cursor"""
idx = [c[0] for c in cursor.native_cursor.description].index(column_name)
Expand Down Expand Up @@ -96,7 +92,9 @@ def some_data_4():

@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
destinations_configs(
default_sql_configs=True, subset=["duckdb", "filesystem"], local_filesystem_configs=True
),
ids=lambda x: x.name,
)
def test_refresh_drop_sources(destination_config: DestinationTestConfiguration):
Expand All @@ -120,20 +118,15 @@ def test_refresh_drop_sources(destination_config: DestinationTestConfiguration):
"some_data_4",
}

# Confirm resource tables not selected on second run got dropped
with pytest.raises(DatabaseUndefinedRelation):
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT * FROM some_data_3")

with pipeline.sql_client() as client:
with client.execute_query("SELECT * FROM some_data_1 ORDER BY id") as cursor:
# No "name" column should exist as table was dropped and re-created without it
assert_only_table_columns(cursor, ["id"])
result = column_values(cursor, "id")

# Only rows from second run should exist
assert result == [3, 4]
# No "name" column should exist as table was dropped and re-created without it
assert_only_table_columns(pipeline, "some_data_1", ["id"])
data = load_tables_to_dicts(pipeline, "some_data_1")["some_data_1"]
result = sorted([row["id"] for row in data])
# Only rows from second run should exist
assert result == [3, 4]

# Confirm resource tables not selected on second run got dropped
assert not table_exists(pipeline, "some_data_3")
# Loaded state is wiped
with pipeline.destination_client() as dest_client:
destination_state = load_pipeline_state_from_destination(
Expand All @@ -144,7 +137,9 @@ def test_refresh_drop_sources(destination_config: DestinationTestConfiguration):

@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
destinations_configs(
default_sql_configs=True, local_filesystem_configs=True, subset=["duckdb", "filesystem"]
),
ids=lambda x: x.name,
)
def test_existing_schema_hash(destination_config: DestinationTestConfiguration):
Expand Down Expand Up @@ -174,9 +169,9 @@ def test_existing_schema_hash(destination_config: DestinationTestConfiguration):
# The new schema in this case should match the schema of the first run exactly
info = pipeline.run(refresh_source(first_run=True, drop_sources=True))
# Check table 3 was re-created
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id, name FROM some_data_3 ORDER BY id")
assert result == [(9, "Jack"), (10, "Jill")]
data = load_tables_to_dicts(pipeline, "some_data_3")["some_data_3"]
result = sorted([(row["id"], row["name"]) for row in data])
assert result == [(9, "Jack"), (10, "Jill")]

# Schema is identical to first schema
new_schema_hash = pipeline.default_schema.version_hash
Expand All @@ -185,10 +180,12 @@ def test_existing_schema_hash(destination_config: DestinationTestConfiguration):

@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
destinations_configs(
default_sql_configs=True, local_filesystem_configs=True, subset=["duckdb", "filesystem"]
),
ids=lambda x: x.name,
)
def test_refresh_drop_tables(destination_config: DestinationTestConfiguration):
def test_refresh_drop_resources(destination_config: DestinationTestConfiguration):
# First run pipeline with load to destination so tables are created
pipeline = destination_config.setup_pipeline("refresh_full_test", refresh="drop_tables")

Expand All @@ -201,21 +198,16 @@ def test_refresh_drop_tables(destination_config: DestinationTestConfiguration):
)

# Confirm resource tables not selected on second run are untouched
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id FROM some_data_3 ORDER BY id")
assert result == [(9,), (10,)]

with pipeline.sql_client() as client:
# Check the columns to ensure the name column was dropped
with client.execute_query("SELECT * FROM some_data_1 ORDER BY id") as cursor:
columns = [c[0] for c in cursor.native_cursor.description]
assert "id" in columns
# Second run data contains no "name" column. Table was dropped and re-created so it should not exist
assert "name" not in columns
id_idx = columns.index("id")
result = [row[id_idx] for row in cursor.fetchall()]

assert result == [3, 4]
data = load_tables_to_dicts(pipeline, "some_data_3")["some_data_3"]
result = sorted([(row["id"], row["name"]) for row in data])
assert result == [(9, "Jack"), (10, "Jill")]

# Check the columns to ensure the name column was dropped
assert_only_table_columns(pipeline, "some_data_1", ["id"])
data = load_tables_to_dicts(pipeline, "some_data_1")["some_data_1"]
# Only second run data
result = sorted([row["id"] for row in data])
assert result == [3, 4]

# Loaded state contains only keys created in second run
with pipeline.destination_client() as dest_client:
Expand All @@ -236,7 +228,9 @@ def test_refresh_drop_tables(destination_config: DestinationTestConfiguration):

@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
destinations_configs(
default_sql_configs=True, local_filesystem_configs=True, subset=["duckdb", "filesystem"]
),
ids=lambda x: x.name,
)
def test_refresh_drop_data_only(destination_config: DestinationTestConfiguration):
Expand All @@ -250,41 +244,35 @@ def test_refresh_drop_data_only(destination_config: DestinationTestConfiguration
first_schema_hash = pipeline.default_schema.version_hash

# Second run of pipeline with only selected resources
# Mock wrap sql client to capture all queries executed
from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient

with mock.patch.object(
DuckDbSqlClient, "execute_query", side_effect=DuckDbSqlClient.execute_query, autospec=True
) as mock_execute_query:
info = pipeline.run(
refresh_source(first_run=False).with_resources("some_data_1", "some_data_2"),
write_disposition="append",
)

info = pipeline.run(
refresh_source(first_run=False).with_resources("some_data_1", "some_data_2"),
write_disposition="append",
)
assert_load_info(info)

# Schema should not be mutated
assert pipeline.default_schema.version_hash == first_schema_hash

all_queries = [k[0][1] for k in mock_execute_query.call_args_list]
assert all_queries
for q in all_queries:
assert "drop table" not in q.lower() # Tables are only truncated, never dropped

# Tables selected in second run are truncated and should only have data from second run
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id, name FROM some_data_2 ORDER BY id")
# name column still remains when table was truncated instead of dropped
assert result == [(7, None), (8, None)]

with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id, name FROM some_data_1 ORDER BY id")
data = load_tables_to_dicts(pipeline, "some_data_1", "some_data_2", "some_data_3")
# name column still remains when table was truncated instead of dropped
# (except on filesystem where truncate and drop are the same)
if destination_config.destination == "filesystem":
result = sorted([row["id"] for row in data["some_data_1"]])
assert result == [3, 4]

result = sorted([row["id"] for row in data["some_data_2"]])
assert result == [7, 8]
else:
result = sorted([(row["id"], row["name"]) for row in data["some_data_1"]])
assert result == [(3, None), (4, None)]

result = sorted([(row["id"], row["name"]) for row in data["some_data_2"]])
assert result == [(7, None), (8, None)]

# Other tables still have data from first run
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id, name FROM some_data_3 ORDER BY id")
assert result == [(9, "Jack"), (10, "Jill")]
result = sorted([(row["id"], row["name"]) for row in data["some_data_3"]])
assert result == [(9, "Jack"), (10, "Jill")]

# State of selected resources is wiped, source level state is kept
with pipeline.destination_client() as dest_client:
Expand Down Expand Up @@ -374,26 +362,27 @@ def source_2_data_2():
assert table_names == {"source_2_data_1"}

# Destination still has tables from source 1
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id, name FROM some_data_1 ORDER BY id")
assert result == [(1, "John"), (2, "Jane")]

# First table from source1 exists, with only first column
with pipeline.sql_client() as client:
with client.execute_query("SELECT * FROM source_2_data_1 ORDER BY product") as cursor:
assert_only_table_columns(cursor, ["product"])
result = column_values(cursor, "product")
assert result == ["orange", "pear"]
data = load_tables_to_dicts(pipeline, "some_data_1")
result = sorted([(row["id"], row["name"]) for row in data["some_data_1"]])
assert result == [(1, "John"), (2, "Jane")]

# # First table from source2 exists, with only first column
data = load_tables_to_dicts(pipeline, "source_2_data_1", schema_name="refresh_source_2")
assert_only_table_columns(
pipeline, "source_2_data_1", ["product"], schema_name="refresh_source_2"
)
result = sorted([row["product"] for row in data["source_2_data_1"]])
assert result == ["orange", "pear"]

# Second table from source 2 is gone
with pytest.raises(DatabaseUndefinedRelation):
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT * FROM source_2_data_2")
# # Second table from source 2 is gone
assert not table_exists(pipeline, "source_2_data_2", schema_name="refresh_source_2")


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
destinations_configs(
default_sql_configs=True, local_filesystem_configs=True, subset=["duckdb", "filesystem"]
),
ids=lambda x: x.name,
)
def test_refresh_argument_to_run(destination_config: DestinationTestConfiguration):
Expand Down Expand Up @@ -423,7 +412,9 @@ def test_refresh_argument_to_run(destination_config: DestinationTestConfiguratio

@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
destinations_configs(
default_sql_configs=True, local_filesystem_configs=True, subset=["duckdb", "filesystem"]
),
ids=lambda x: x.name,
)
def test_refresh_argument_to_extract(destination_config: DestinationTestConfiguration):
Expand All @@ -446,39 +437,3 @@ def test_refresh_argument_to_extract(destination_config: DestinationTestConfigur

tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True))
assert tables == {"some_data_2", "some_data_3", "some_data_4"}


@pytest.mark.parametrize(
"destination_config", destinations_configs(local_filesystem_configs=True), ids=lambda x: x.name
)
def test_refresh_drop_sources_local_filesystem(destination_config: DestinationTestConfiguration):
pipeline = destination_config.setup_pipeline("refresh_full_test", refresh="drop_data")

info = pipeline.run(refresh_source(first_run=True, drop_sources=False))
assert_load_info(info)
load_1_id = info.loads_ids[0]

info = pipeline.run(
refresh_source(first_run=False, drop_sources=False).with_resources(
"some_data_1", "some_data_2"
)
)
assert_load_info(info)
load_2_id = info.loads_ids[0]

client = pipeline._fs_client()

# Only contains files from load 2
file_names = client.list_table_files("some_data_1")
assert len(file_names) == 1
assert load_2_id in file_names[0]

# Only contains files from load 2
file_names = client.list_table_files("some_data_2")
assert len(file_names) == 1
assert load_2_id in file_names[0]

# Nothing dropped, only file from load 1
file_names = client.list_table_files("some_data_3")
assert len(file_names) == 1
assert load_1_id in file_names[0]
42 changes: 37 additions & 5 deletions tests/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dlt.destinations.fs_client import FSClientBase
from dlt.pipeline.exceptions import SqlClientNotAvailable
from dlt.common.storages import FileStorage
from dlt.destinations.exceptions import DatabaseUndefinedRelation

from tests.utils import TEST_STORAGE_ROOT

Expand Down Expand Up @@ -172,12 +173,13 @@ def _load_tables_to_dicts_fs(p: dlt.Pipeline, *table_names: str) -> Dict[str, Li


def _load_tables_to_dicts_sql(
p: dlt.Pipeline, *table_names: str
p: dlt.Pipeline, *table_names: str, schema_name: str = None
) -> Dict[str, List[Dict[str, Any]]]:
result = {}
schema = p.default_schema if not schema_name else p.schemas[schema_name]
for table_name in table_names:
table_rows = []
columns = p.default_schema.get_table_columns(table_name).keys()
columns = schema.get_table_columns(table_name).keys()
query_columns = ",".join(map(p.sql_client().capabilities.escape_identifier, columns))

with p.sql_client() as c:
Expand All @@ -191,9 +193,23 @@ def _load_tables_to_dicts_sql(
return result


def load_tables_to_dicts(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, Any]]]:
func = _load_tables_to_dicts_fs if _is_filesystem(p) else _load_tables_to_dicts_sql
return func(p, *table_names)
def load_tables_to_dicts(
p: dlt.Pipeline, *table_names: str, schema_name: str = None
) -> Dict[str, List[Dict[str, Any]]]:
if _is_filesystem(p):
return _load_tables_to_dicts_fs(p, *table_names)
return _load_tables_to_dicts_sql(p, *table_names, schema_name=schema_name)


def assert_only_table_columns(
p: dlt.Pipeline, table_name: str, expected_columns: Sequence[str], schema_name: str = None
) -> None:
"""Table has all and only the expected columns (excluding _dlt columns)"""
rows = load_tables_to_dicts(p, table_name, schema_name=schema_name)[table_name]
assert rows, f"Table {table_name} is empty"
# Ignore _dlt columns
columns = set(col for col in rows[0].keys() if not col.startswith("_dlt"))
assert columns == set(expected_columns)


#
Expand Down Expand Up @@ -244,6 +260,22 @@ def assert_data_table_counts(p: dlt.Pipeline, expected_counts: DictStrAny) -> No
#


def table_exists(p: dlt.Pipeline, table_name: str, schema_name: str = None) -> bool:
"""Returns True if table exists in the destination database/filesystem"""
if _is_filesystem(p):
client = p._fs_client(schema_name=schema_name)
files = client.list_table_files(table_name)
return not not files

with p.sql_client(schema_name=schema_name) as c:
try:
qual_table_name = c.make_qualified_table_name(table_name)
c.execute_sql(f"SELECT 1 FROM {qual_table_name} LIMIT 1")
return True
except DatabaseUndefinedRelation:
return False


def _assert_table_sql(
p: dlt.Pipeline,
table_name: str,
Expand Down

0 comments on commit 403ae8b

Please sign in to comment.