From 45e386b4862871cf251f1d538e08e01dbf68ccb1 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 19:32:59 -0400 Subject: [PATCH] Update tests destination config --- tests/load/pipeline/test_arrow_loading.py | 4 +++- tests/load/pipeline/test_merge_disposition.py | 5 ++++- tests/load/pipeline/test_pipelines.py | 2 +- tests/load/pipeline/test_stage_loading.py | 9 ++++++--- tests/load/sources/filesystem/test_filesystem_source.py | 4 ++-- .../test_sql_database_source_all_destinations.py | 8 ++++---- tests/load/test_job_client.py | 4 ++-- 7 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 369359d61a..5bebf6f7ed 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -88,7 +88,9 @@ def some_data(): # use csv for postgres to get native arrow processing destination_config.file_format = ( - destination_config.file_format if destination_config.destination_type != "postgres" else "csv" + destination_config.file_format + if destination_config.destination_type != "postgres" + else "csv" ) load_info = pipeline.run(some_data(), **destination_config.run_kwargs) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index d1082263dd..b1244de336 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -496,7 +496,10 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) assert_load_info(info) # make sure it was parquet or sql inserts files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs["completed_jobs"] - if destination_config.destination == "athena" and destination_config.table_format == "iceberg": + if ( + destination_config.destination_type == "athena" + and destination_config.table_format == "iceberg" + ): # iceberg uses sql to copy tables expected_formats.append("sql") assert all(f.job_file_info.file_format in expected_formats for f in files) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 2a29a0a24d..659bca6cb9 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -561,7 +561,7 @@ def some_source(): if destination_config.supports_merge: expected_completed_jobs += 1 # add iceberg copy jobs - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": expected_completed_jobs += 2 # if destination_config.supports_merge else 4 assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index de30615a6a..cc8175b677 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -231,7 +231,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati with staging_client: # except Athena + Iceberg which does not store tables in staging dataset if ( - destination_config.destination == "athena" + destination_config.destination_type == "athena" and destination_config.table_format == "iceberg" ): table_count = 0 @@ -257,7 +257,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) with staging_client: # except for Athena which does not delete staging destination tables - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": if destination_config.table_format == "iceberg": table_count = 0 else: @@ -302,7 +302,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ): # Redshift can't load fixed width binary columns from parquet exclude_columns.append("col7_precision") - if destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl": + if ( + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" + ): exclude_types.extend(["decimal", "binary", "wei", "json", "date"]) exclude_columns.append("col1_precision") diff --git a/tests/load/sources/filesystem/test_filesystem_source.py b/tests/load/sources/filesystem/test_filesystem_source.py index 947e7e9e1c..15a1079cca 100644 --- a/tests/load/sources/filesystem/test_filesystem_source.py +++ b/tests/load/sources/filesystem/test_filesystem_source.py @@ -126,7 +126,7 @@ def test_csv_transformers( # print(pipeline.last_trace.last_normalize_info) # must contain 24 rows of A881 - if not destination_config.destination == "filesystem": + if not destination_config.destination_type == "filesystem": # TODO: comment out when filesystem destination supports queries (data pond PR) assert_query_data(pipeline, "SELECT code FROM met_csv", ["A881"] * 24) @@ -138,7 +138,7 @@ def test_csv_transformers( assert_load_info(load_info) # print(pipeline.last_trace.last_normalize_info) # must contain 48 rows of A803 - if not destination_config.destination == "filesystem": + if not destination_config.destination_type == "filesystem": # TODO: comment out when filesystem destination supports queries (data pond PR) assert_query_data(pipeline, "SELECT code FROM met_csv", ["A803"] * 48) # and 48 rows in total -> A881 got replaced diff --git a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py index 7012602b4a..4f4e876fb6 100644 --- a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py +++ b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py @@ -51,10 +51,10 @@ def test_load_sql_schema_loads_all_tables( schema=sql_source_db.schema, backend=backend, reflection_level="minimal", - type_adapter_callback=default_test_callback(destination_config.destination, backend), + type_adapter_callback=default_test_callback(destination_config.destination_type, backend), ) - if destination_config.destination == "bigquery" and backend == "connectorx": + if destination_config.destination_type == "bigquery" and backend == "connectorx": # connectorx generates nanoseconds time which bigquery cannot load source.has_precision.add_map(convert_time_to_us) source.has_precision_nullable.add_map(convert_time_to_us) @@ -91,10 +91,10 @@ def test_load_sql_schema_loads_all_tables_parallel( schema=sql_source_db.schema, backend=backend, reflection_level="minimal", - type_adapter_callback=default_test_callback(destination_config.destination, backend), + type_adapter_callback=default_test_callback(destination_config.destination_type, backend), ).parallelize() - if destination_config.destination == "bigquery" and backend == "connectorx": + if destination_config.destination_type == "bigquery" and backend == "connectorx": # connectorx generates nanoseconds time which bigquery cannot load source.has_precision.add_map(convert_time_to_us) source.has_precision_nullable.add_map(convert_time_to_us) diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index ba3baec17c..84d08a5a89 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -592,9 +592,9 @@ def test_load_with_all_types( client.update_stored_schema() if isinstance(client, WithStagingDataset): - should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) # type: ignore[attr-defined] + should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) if should_load_to_staging: - with client.with_staging_dataset(): # type: ignore[attr-defined] + with client.with_staging_dataset(): # create staging for merge dataset client.initialize_storage() client.update_stored_schema()