From 9e0147b2870a54d818866fac25f99233d92ab194 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 10 Oct 2024 19:06:50 +0200 Subject: [PATCH 01/13] create first version of dataset factory --- dlt/__init__.py | 2 + dlt/common/destination/reference.py | 6 +- dlt/destinations/dataset.py | 88 +++++++++++++++++++++++++++-- dlt/destinations/job_client_impl.py | 19 +++++-- dlt/pipeline/pipeline.py | 16 +++--- tests/load/test_read_interfaces.py | 41 ++++++++++++++ 6 files changed, 151 insertions(+), 21 deletions(-) diff --git a/dlt/__init__.py b/dlt/__init__.py index 328817efd2..1dfd17e769 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -42,6 +42,7 @@ ) from dlt.pipeline import progress from dlt import destinations +from dlt.destinations.dataset import dataset pipeline = _pipeline current = _current @@ -79,6 +80,7 @@ "TCredentials", "sources", "destinations", + "dataset", ] # verify that no injection context was created diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 0c572379de..059b487a13 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -66,6 +66,8 @@ TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration") +TDatasetType = Literal["dbapi", "ibis"] + DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}" @@ -657,8 +659,8 @@ def __exit__( class WithStateSync(ABC): @abstractmethod - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: - """Retrieves newest schema from destination storage""" + def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + """Retrieves newest schema from destination storage, setting any_schema_name to true will return the newest schema regardless of the schema name""" pass @abstractmethod diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index a5584851e9..fcb90e3e66 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -1,9 +1,14 @@ -from typing import Any, Generator, AnyStr, Optional +from typing import Any, Generator, Optional, Union +from dlt.common.json import json from contextlib import contextmanager from dlt.common.destination.reference import ( SupportsReadableRelation, SupportsReadableDataset, + TDatasetType, + TDestinationReferenceArg, + Destination, + JobClientBase, ) from dlt.common.schema.typing import TTableSchemaColumns @@ -71,25 +76,85 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: class ReadableDBAPIDataset(SupportsReadableDataset): """Access to dataframes and arrowtables in the destination dataset via dbapi""" - def __init__(self, client: SqlClientBase[Any], schema: Optional[Schema]) -> None: - self.client = client - self.schema = schema + def __init__( + self, + destination: TDestinationReferenceArg, + dataset_name: str, + schema: Union[Schema, str, None] = None, + ) -> None: + self._destination = Destination.from_reference(destination) + self._schema = schema + self._resolved_schema: Schema = None + self._dataset_name = dataset_name + self._sql_client: SqlClientBase[Any] = None + + def _destination_client(self, schema: Schema) -> JobClientBase: + client_spec = self._destination.spec() + client_spec._bind_dataset_name( + dataset_name=self._dataset_name, default_schema_name=schema.name + ) + return self._destination.client(schema, client_spec) + + def _ensure_client_and_schema(self) -> None: + """Lazy load schema and client""" + # full schema given, nothing to do + if not self._resolved_schema and isinstance(self._schema, Schema): + self._resolved_schema = self._schema + + # schema name given, resolve it from destination by name + elif not self._resolved_schema and isinstance(self._schema, str): + with self._destination_client(Schema(self._schema)) as client: + stored_schema = client.get_stored_schema() + if stored_schema: + self._resolved_schema = Schema.from_stored_schema( + json.loads(stored_schema.schema) + ) + + # no schema name given, load newest schema from destination + elif not self._resolved_schema: + with self._destination_client(Schema(self._dataset_name)) as client: + stored_schema = client.get_stored_schema(any_schema_name=True) + if stored_schema: + self._resolved_schema = Schema.from_stored_schema( + json.loads(stored_schema.schema) + ) + + # default to empty schema with dataset name if nothing found + if not self._resolved_schema: + self._resolved_schema = Schema(self._dataset_name) + + # here we create the client bound to the resolved schema + # TODO: ensure that this destination supports the sql_client. otherwise error + if not self._sql_client: + self._sql_client = self._destination_client(self._resolved_schema).sql_client def __call__( self, query: Any, schema_columns: TTableSchemaColumns = None ) -> ReadableDBAPIRelation: schema_columns = schema_columns or {} - return ReadableDBAPIRelation(client=self.client, query=query, schema_columns=schema_columns) # type: ignore[abstract] + return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract] def table(self, table_name: str) -> SupportsReadableRelation: # prepare query for table relation schema_columns = ( self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {} ) - table_name = self.client.make_qualified_table_name(table_name) + table_name = self.sql_client.make_qualified_table_name(table_name) query = f"SELECT * FROM {table_name}" return self(query, schema_columns) + @property + def schema(self) -> Schema: + """Lazy load schema from destination""" + self._ensure_client_and_schema() + return self._resolved_schema + + @property + def sql_client(self) -> SqlClientBase[Any]: + """Lazy instantiate client""" + self._ensure_client_and_schema() + return self._sql_client + def __getitem__(self, table_name: str) -> SupportsReadableRelation: """access of table via dict notation""" return self.table(table_name) @@ -97,3 +162,14 @@ def __getitem__(self, table_name: str) -> SupportsReadableRelation: def __getattr__(self, table_name: str) -> SupportsReadableRelation: """access of table via property notation""" return self.table(table_name) + + +def dataset( + destination: TDestinationReferenceArg, + dataset_name: str, + schema: Union[Schema, str, None] = None, + dataset_type: TDatasetType = "dbapi", +) -> SupportsReadableDataset: + if dataset_type == "dbapi": + return ReadableDBAPIDataset(destination, dataset_name, schema) + raise NotImplementedError(f"Dataset of type {dataset_type} not implemented") diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 0fca64d7ba..90c00530dc 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -397,14 +397,21 @@ def _from_db_type( ) -> TColumnType: pass - def get_stored_schema(self) -> StorageSchemaInfo: + def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at") - query = ( - f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" - f" ORDER BY {c_inserted_at} DESC;" - ) - return self._row_to_schema_info(query, self.schema.name) + if any_schema_name: + query = ( + f"SELECT {self.version_table_schema_columns} FROM {name}" + f" ORDER BY {c_inserted_at} DESC;" + ) + return self._row_to_schema_info(query) + else: + query = ( + f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" + f" ORDER BY {c_inserted_at} DESC;" + ) + return self._row_to_schema_info(query, self.schema.name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 348f445967..c9a3950722 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -84,6 +84,7 @@ DestinationClientStagingConfiguration, DestinationClientDwhWithStagingConfiguration, SupportsReadableDataset, + TDatasetType, ) from dlt.common.normalizers.naming import NamingConvention from dlt.common.pipeline import ( @@ -113,7 +114,7 @@ from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.destinations.fs_client import FSClientBase from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.destinations.dataset import ReadableDBAPIDataset +from dlt.destinations.dataset import dataset from dlt.load.configuration import LoaderConfiguration from dlt.load import Load @@ -1717,10 +1718,11 @@ def __getstate__(self) -> Any: # pickle only the SupportsPipeline protocol fields return {"pipeline_name": self.pipeline_name} - def _dataset(self, dataset_type: Literal["dbapi", "ibis"] = "dbapi") -> SupportsReadableDataset: + def _dataset(self, dataset_type: TDatasetType = "dbapi") -> SupportsReadableDataset: """Access helper to dataset""" - if dataset_type == "dbapi": - return ReadableDBAPIDataset( - self.sql_client(), schema=self.default_schema if self.default_schema_name else None - ) - raise NotImplementedError(f"Dataset of type {dataset_type} not implemented") + return dataset( + self.destination, + self.dataset_name, + schema=(self.default_schema if self.default_schema_name else None), + dataset_type=dataset_type, + ) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index e093e4d670..3aad23b1ae 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -212,6 +212,47 @@ def double_items(): loads_table = pipeline._dataset()[pipeline.default_schema.loads_table_name] loads_table.fetchall() + # check dataset factory + dataset = dlt.dataset( + destination=destination_config.destination_type, dataset_name=pipeline.dataset_name + ) + table_relationship = dataset.items + table = table_relationship.fetchall() + assert len(table) == total_records + + # check that schema is loaded by name + dataset = dlt.dataset( + destination=destination_config.destination_type, + dataset_name=pipeline.dataset_name, + schema=pipeline.default_schema_name, + ) + assert dataset.schema.tables["items"]["write_disposition"] == "replace" + + # check that schema is not loaded when wrong name given + dataset = dlt.dataset( + destination=destination_config.destination_type, + dataset_name=pipeline.dataset_name, + schema="wrong_schema_name", + ) + assert "items" not in dataset.schema.tables + assert dataset.schema.name == pipeline.dataset_name + + # check that schema is loaded if no schema name given + dataset = dlt.dataset( + destination=destination_config.destination_type, + dataset_name=pipeline.dataset_name, + ) + assert dataset.schema.name == pipeline.default_schema_name + assert dataset.schema.tables["items"]["write_disposition"] == "replace" + + # check that there is no error when creating dataset without schema table + dataset = dlt.dataset( + destination=destination_config.destination_type, + dataset_name="unknown_dataset", + ) + assert dataset.schema.name == "unknown_dataset" + assert "items" not in dataset.schema.tables + @pytest.mark.essential @pytest.mark.parametrize( From c9ad58f426ee32fd43d13d1a15accbbb9686c672 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 11 Oct 2024 13:52:37 +0200 Subject: [PATCH 02/13] update all destination implementations for getting the newest schema, fixed linter errors, made dataset aware of config types --- dlt/destinations/dataset.py | 43 ++++++++++++------- .../impl/filesystem/filesystem.py | 8 ++-- .../impl/lancedb/lancedb_client.py | 11 +++-- .../impl/qdrant/qdrant_job_client.py | 21 +++++---- .../impl/sqlalchemy/sqlalchemy_job_client.py | 13 ++++-- .../impl/weaviate/weaviate_client.py | 15 ++++--- tests/load/test_read_interfaces.py | 14 +++--- 7 files changed, 73 insertions(+), 52 deletions(-) diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index fcb90e3e66..666a13cbfc 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -9,10 +9,12 @@ TDestinationReferenceArg, Destination, JobClientBase, + WithStateSync, + DestinationClientDwhConfiguration, ) from dlt.common.schema.typing import TTableSchemaColumns -from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.common.schema import Schema @@ -90,9 +92,10 @@ def __init__( def _destination_client(self, schema: Schema) -> JobClientBase: client_spec = self._destination.spec() - client_spec._bind_dataset_name( - dataset_name=self._dataset_name, default_schema_name=schema.name - ) + if isinstance(client_spec, DestinationClientDwhConfiguration): + client_spec._bind_dataset_name( + dataset_name=self._dataset_name, default_schema_name=schema.name + ) return self._destination.client(schema, client_spec) def _ensure_client_and_schema(self) -> None: @@ -104,29 +107,37 @@ def _ensure_client_and_schema(self) -> None: # schema name given, resolve it from destination by name elif not self._resolved_schema and isinstance(self._schema, str): with self._destination_client(Schema(self._schema)) as client: - stored_schema = client.get_stored_schema() - if stored_schema: - self._resolved_schema = Schema.from_stored_schema( - json.loads(stored_schema.schema) - ) + if isinstance(client, WithStateSync): + stored_schema = client.get_stored_schema() + if stored_schema: + self._resolved_schema = Schema.from_stored_schema( + json.loads(stored_schema.schema) + ) # no schema name given, load newest schema from destination elif not self._resolved_schema: with self._destination_client(Schema(self._dataset_name)) as client: - stored_schema = client.get_stored_schema(any_schema_name=True) - if stored_schema: - self._resolved_schema = Schema.from_stored_schema( - json.loads(stored_schema.schema) - ) + if isinstance(client, WithStateSync): + stored_schema = client.get_stored_schema(any_schema_name=True) + if stored_schema: + self._resolved_schema = Schema.from_stored_schema( + json.loads(stored_schema.schema) + ) # default to empty schema with dataset name if nothing found if not self._resolved_schema: self._resolved_schema = Schema(self._dataset_name) # here we create the client bound to the resolved schema - # TODO: ensure that this destination supports the sql_client. otherwise error if not self._sql_client: - self._sql_client = self._destination_client(self._resolved_schema).sql_client + destination_client = self._destination_client(self._resolved_schema) + if isinstance(destination_client, WithSqlClient): + self._sql_client = destination_client.sql_client + else: + raise Exception( + f"Destination {destination_client.config.destination_type} does not support" + " SqlClient." + ) def __call__( self, query: Any, schema_columns: TTableSchemaColumns = None diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index d6d9865a06..1c2965f84a 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -650,7 +650,7 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]: yield filepath, fileparts def _get_stored_schema_by_hash_or_newest( - self, version_hash: str = None + self, version_hash: str = None, any_schema_name: bool = False ) -> Optional[StorageSchemaInfo]: """Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name""" version_hash = self._to_path_safe_string(version_hash) @@ -660,7 +660,7 @@ def _get_stored_schema_by_hash_or_newest( for filepath, fileparts in self._iter_stored_schema_files(): if ( not version_hash - and fileparts[0] == self.schema.name + and (fileparts[0] == self.schema.name or any_schema_name) and fileparts[1] > newest_load_id ): newest_load_id = fileparts[1] @@ -699,9 +699,9 @@ def _store_current_schema(self) -> None: # we always keep tabs on what the current schema is self._write_to_json_file(filepath, version_info) - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" - return self._get_stored_schema_by_hash_or_newest() + return self._get_stored_schema_by_hash_or_newest(any_schema_name=any_schema_name) def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema_by_hash_or_newest(version_hash) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ffa556797e..d0a840f292 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI return None @lancedb_error - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) @@ -553,11 +553,10 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = ( - version_table.search().where( - f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True - ) - ).to_list() + query = version_table.search() + if not any_schema_name: + query = query.where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True) + schemas = query.to_list() # LanceDB's ORDER BY clause doesn't seem to work. # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 2536bd369d..90b5cc29c0 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -377,23 +377,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: raise DestinationUndefinedEntity(str(e)) from e raise - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + + name_filter = models.Filter( + must=[ + models.FieldCondition( + key=p_schema_name, + match=models.MatchValue(value=self.schema.name), + ) + ] + ) + response = self.db_client.scroll( scroll_table_name, with_payload=True, - scroll_filter=models.Filter( - must=[ - models.FieldCondition( - key=p_schema_name, - match=models.MatchValue(value=self.schema.name), - ) - ] - ), + scroll_filter=None if any_schema_name else name_filter, limit=1, order_by=models.OrderBy( key=p_inserted_at, diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index c5a6442d8a..1c39dde239 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -240,7 +240,10 @@ def _update_schema_in_storage(self, schema: Schema) -> None: self.sql_client.execute_sql(table_obj.insert().values(schema_mapping)) def _get_stored_schema( - self, version_hash: Optional[str] = None, schema_name: Optional[str] = None + self, + version_hash: Optional[str] = None, + schema_name: Optional[str] = None, + any_schema_name: bool = False, ) -> Optional[StorageSchemaInfo]: version_table = self.schema.tables[self.schema.version_table_name] table_obj = self._to_table_object(version_table) # type: ignore[arg-type] @@ -249,7 +252,7 @@ def _get_stored_schema( if version_hash is not None: version_hash_col = self.schema.naming.normalize_identifier("version_hash") q = q.where(table_obj.c[version_hash_col] == version_hash) - if schema_name is not None: + if schema_name is not None and not any_schema_name: schema_name_col = self.schema.naming.normalize_identifier("schema_name") q = q.where(table_obj.c[schema_name_col] == schema_name) inserted_at_col = self.schema.naming.normalize_identifier("inserted_at") @@ -267,9 +270,11 @@ def _get_stored_schema( def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema(version_hash) - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: """Get the latest stored schema""" - return self._get_stored_schema(schema_name=self.schema.name) + return self._get_stored_schema( + schema_name=self.schema.name, any_schema_name=any_schema_name + ) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.schema.tables.get( diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index 76e5fd8b1e..fc6eb8a94b 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -516,19 +516,22 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: if len(load_records): return StateInfo(**state) - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + + name_filter = { + "path": [p_schema_name], + "operator": "Equal", + "valueString": self.schema.name, + } + try: record = self.get_records( self.schema.version_table_name, sort={"path": [p_inserted_at], "order": "desc"}, - where={ - "path": [p_schema_name], - "operator": "Equal", - "valueString": self.schema.name, - }, + where=None if any_schema_name else name_filter, limit=1, )[0] return StorageSchemaInfo(**record) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 3aad23b1ae..8850dc409d 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -226,7 +226,7 @@ def double_items(): dataset_name=pipeline.dataset_name, schema=pipeline.default_schema_name, ) - assert dataset.schema.tables["items"]["write_disposition"] == "replace" + assert dataset.schema.tables["items"]["write_disposition"] == "replace" # type: ignore # check that schema is not loaded when wrong name given dataset = dlt.dataset( @@ -234,24 +234,24 @@ def double_items(): dataset_name=pipeline.dataset_name, schema="wrong_schema_name", ) - assert "items" not in dataset.schema.tables - assert dataset.schema.name == pipeline.dataset_name + assert "items" not in dataset.schema.tables # type: ignore + assert dataset.schema.name == pipeline.dataset_name # type: ignore # check that schema is loaded if no schema name given dataset = dlt.dataset( destination=destination_config.destination_type, dataset_name=pipeline.dataset_name, ) - assert dataset.schema.name == pipeline.default_schema_name - assert dataset.schema.tables["items"]["write_disposition"] == "replace" + assert dataset.schema.name == pipeline.default_schema_name # type: ignore + assert dataset.schema.tables["items"]["write_disposition"] == "replace" # type: ignore # check that there is no error when creating dataset without schema table dataset = dlt.dataset( destination=destination_config.destination_type, dataset_name="unknown_dataset", ) - assert dataset.schema.name == "unknown_dataset" - assert "items" not in dataset.schema.tables + assert dataset.schema.name == "unknown_dataset" # type: ignore + assert "items" not in dataset.schema.tables # type: ignore @pytest.mark.essential From a715f1aad2cd1fd37bf28e32b58b22b4922d47bf Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 11 Oct 2024 15:45:04 +0200 Subject: [PATCH 03/13] test retrieval of schema for all destinations (except custom destination) --- tests/load/test_job_client.py | 68 +++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 84d08a5a89..0bb88c4dd3 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -31,6 +31,7 @@ StateInfo, WithStagingDataset, DestinationClientConfiguration, + WithStateSync, ) from dlt.common.time import ensure_pendulum_datetime @@ -951,6 +952,73 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: ) +# NOTE: this could be folded into the above tests, but these only run on sql_client destinations for now +# but we want to test filesystem and vector db here too +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, default_vector_configs=True, all_buckets_filesystem_configs=True + ), + ids=lambda x: x.name, +) +def test_schema_retrieval(destination_config: DestinationTestConfiguration) -> None: + p = destination_config.setup_pipeline("schema_test", dev_mode=True) + from dlt.common.schema import utils + + # we create 2 versions of 2 schemas + s1_v1 = Schema("schema_1") + s1_v2 = s1_v1.clone() + s1_v2.tables["items"] = utils.new_table("items") + s2_v1 = Schema("schema_2") + s2_v2 = s2_v1.clone() + s2_v2.tables["other_items"] = utils.new_table("other_items") + + # sanity check + assert s1_v1.version_hash != s1_v2.version_hash + assert s2_v1.version_hash != s2_v2.version_hash + + client: WithStateSync + + def add_schema_to_pipeline(s: Schema) -> None: + p._inject_schema(s) + p.default_schema_name = s.name + with p.destination_client() as client: + client.initialize_storage() + client.update_stored_schema() + + # check what happens if there is only one + add_schema_to_pipeline(s1_v1) + p.default_schema_name = s1_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema().version_hash == s1_v1.version_hash + assert client.get_stored_schema(any_schema_name=True).version_hash == s1_v1.version_hash + + # now we add a different schema + # but keep default schema name at v1 + add_schema_to_pipeline(s2_v1) + p.default_schema_name = s1_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema().version_hash == s1_v1.version_hash + # here v2 will be selected as it is newer + assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v1.version_hash + + # add two more version, + add_schema_to_pipeline(s1_v2) + add_schema_to_pipeline(s2_v2) + p.default_schema_name = s1_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema().version_hash == s1_v2.version_hash + # here v2 will be selected as it is newer + assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash + + # check same setup with other default schema name + p.default_schema_name = s2_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema().version_hash == s2_v2.version_hash + # here v2 will be selected as it is newer + assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash + + def prepare_schema(client: SqlJobClientBase, case: str) -> Tuple[List[Dict[str, Any]], str]: client.update_stored_schema() rows = load_json_case(case) From 213b89ceeaf5f0feabc715f58ae16924172cb018 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 11 Oct 2024 16:12:53 +0200 Subject: [PATCH 04/13] add simple tests for schema selection in dataset tests --- tests/load/test_read_interfaces.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 8850dc409d..fc6e5fdebe 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -216,6 +216,9 @@ def double_items(): dataset = dlt.dataset( destination=destination_config.destination_type, dataset_name=pipeline.dataset_name ) + # verfiy that sql client and schema are lazy loaded + assert not dataset._schema + assert not dataset._sql_client table_relationship = dataset.items table = table_relationship.fetchall() assert len(table) == total_records @@ -253,6 +256,25 @@ def double_items(): assert dataset.schema.name == "unknown_dataset" # type: ignore assert "items" not in dataset.schema.tables # type: ignore + # create a newer schema with different name and see wether this is loaded + from dlt.common.schema import Schema + from dlt.common.schema import utils + + other_schema = Schema("some_other_schema") + other_schema.tables["other_table"] = utils.new_table("other_table") + + pipeline._inject_schema(other_schema) + pipeline.default_schema_name = other_schema.name + with pipeline.destination_client() as client: + client.update_stored_schema() + + dataset = dlt.dataset( + destination=destination_config.destination_type, + dataset_name=pipeline.dataset_name, + ) + assert dataset.schema.name == "some_other_schema" # type: ignore + assert "other_table" in dataset.schema.tables # type: ignore + @pytest.mark.essential @pytest.mark.parametrize( From 50a5d479fd58c51f314a5a09ea1862a156d3a754 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 11 Oct 2024 16:58:32 +0200 Subject: [PATCH 05/13] unify filesystem schema behavior with other destinations --- .../impl/filesystem/filesystem.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 1c2965f84a..bccf8ec686 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -655,24 +655,28 @@ def _get_stored_schema_by_hash_or_newest( """Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name""" version_hash = self._to_path_safe_string(version_hash) # find newest schema for pipeline or by version hash - selected_path = None - newest_load_id = "0" - for filepath, fileparts in self._iter_stored_schema_files(): - if ( - not version_hash - and (fileparts[0] == self.schema.name or any_schema_name) - and fileparts[1] > newest_load_id - ): - newest_load_id = fileparts[1] - selected_path = filepath - elif fileparts[2] == version_hash: - selected_path = filepath - break + try: + selected_path = None + newest_load_id = "0" + for filepath, fileparts in self._iter_stored_schema_files(): + if ( + not version_hash + and (fileparts[0] == self.schema.name or any_schema_name) + and fileparts[1] > newest_load_id + ): + newest_load_id = fileparts[1] + selected_path = filepath + elif fileparts[2] == version_hash: + selected_path = filepath + break - if selected_path: - return StorageSchemaInfo( - **json.loads(self.fs_client.read_text(selected_path, encoding="utf-8")) - ) + if selected_path: + return StorageSchemaInfo( + **json.loads(self.fs_client.read_text(selected_path, encoding="utf-8")) + ) + except DestinationUndefinedEntity: + # ignore missing table + pass return None From d9ab96c7adc495fc6f39cbd44d4d631ad455c7eb Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 11 Oct 2024 17:07:05 +0200 Subject: [PATCH 06/13] fix gcs delta tests --- tests/load/test_read_interfaces.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index fc6e5fdebe..54c820337c 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -20,6 +20,7 @@ ) from dlt.destinations import filesystem from tests.utils import TEST_STORAGE_ROOT +from dlt.common.destination.reference import TDestinationReferenceArg def _run_dataset_checks( @@ -212,10 +213,14 @@ def double_items(): loads_table = pipeline._dataset()[pipeline.default_schema.loads_table_name] loads_table.fetchall() - # check dataset factory - dataset = dlt.dataset( - destination=destination_config.destination_type, dataset_name=pipeline.dataset_name + destination_for_dataset: TDestinationReferenceArg = ( + alternate_access_pipeline.destination + if alternate_access_pipeline + else destination_config.destination_type ) + + # check dataset factory + dataset = dlt.dataset(destination=destination_for_dataset, dataset_name=pipeline.dataset_name) # verfiy that sql client and schema are lazy loaded assert not dataset._schema assert not dataset._sql_client @@ -225,7 +230,7 @@ def double_items(): # check that schema is loaded by name dataset = dlt.dataset( - destination=destination_config.destination_type, + destination=destination_for_dataset, dataset_name=pipeline.dataset_name, schema=pipeline.default_schema_name, ) @@ -233,7 +238,7 @@ def double_items(): # check that schema is not loaded when wrong name given dataset = dlt.dataset( - destination=destination_config.destination_type, + destination=destination_for_dataset, dataset_name=pipeline.dataset_name, schema="wrong_schema_name", ) @@ -242,7 +247,7 @@ def double_items(): # check that schema is loaded if no schema name given dataset = dlt.dataset( - destination=destination_config.destination_type, + destination=destination_for_dataset, dataset_name=pipeline.dataset_name, ) assert dataset.schema.name == pipeline.default_schema_name # type: ignore @@ -250,7 +255,7 @@ def double_items(): # check that there is no error when creating dataset without schema table dataset = dlt.dataset( - destination=destination_config.destination_type, + destination=destination_for_dataset, dataset_name="unknown_dataset", ) assert dataset.schema.name == "unknown_dataset" # type: ignore @@ -269,7 +274,7 @@ def double_items(): client.update_stored_schema() dataset = dlt.dataset( - destination=destination_config.destination_type, + destination=destination_for_dataset, dataset_name=pipeline.dataset_name, ) assert dataset.schema.name == "some_other_schema" # type: ignore From c6f178bed90b89730857c95d2026f4cf5b5d0599 Mon Sep 17 00:00:00 2001 From: Dave Date: Sun, 13 Oct 2024 16:59:22 +0200 Subject: [PATCH 07/13] try to fix ci errors --- dlt/destinations/dataset.py | 48 +++++++----------- tests/load/test_read_interfaces.py | 79 +++++++++++++++++++----------- 2 files changed, 67 insertions(+), 60 deletions(-) diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index 666a13cbfc..736dc8e2ed 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -85,10 +85,10 @@ def __init__( schema: Union[Schema, str, None] = None, ) -> None: self._destination = Destination.from_reference(destination) - self._schema = schema - self._resolved_schema: Schema = None + self._provided_schema = schema self._dataset_name = dataset_name - self._sql_client: SqlClientBase[Any] = None + self.sql_client: SqlClientBase[Any] = None + self.schema: Schema = None def _destination_client(self, schema: Schema) -> JobClientBase: client_spec = self._destination.spec() @@ -101,38 +101,34 @@ def _destination_client(self, schema: Schema) -> JobClientBase: def _ensure_client_and_schema(self) -> None: """Lazy load schema and client""" # full schema given, nothing to do - if not self._resolved_schema and isinstance(self._schema, Schema): - self._resolved_schema = self._schema + if not self.schema and isinstance(self._provided_schema, Schema): + self.schema = self._provided_schema # schema name given, resolve it from destination by name - elif not self._resolved_schema and isinstance(self._schema, str): - with self._destination_client(Schema(self._schema)) as client: + elif not self.schema and isinstance(self._provided_schema, str): + with self._destination_client(Schema(self._provided_schema)) as client: if isinstance(client, WithStateSync): stored_schema = client.get_stored_schema() if stored_schema: - self._resolved_schema = Schema.from_stored_schema( - json.loads(stored_schema.schema) - ) + self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) # no schema name given, load newest schema from destination - elif not self._resolved_schema: + elif not self.schema: with self._destination_client(Schema(self._dataset_name)) as client: if isinstance(client, WithStateSync): stored_schema = client.get_stored_schema(any_schema_name=True) if stored_schema: - self._resolved_schema = Schema.from_stored_schema( - json.loads(stored_schema.schema) - ) + self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) # default to empty schema with dataset name if nothing found - if not self._resolved_schema: - self._resolved_schema = Schema(self._dataset_name) + if not self.schema: + self.schema = Schema(self._dataset_name) # here we create the client bound to the resolved schema - if not self._sql_client: - destination_client = self._destination_client(self._resolved_schema) + if not self.sql_client: + destination_client = self._destination_client(self.schema) if isinstance(destination_client, WithSqlClient): - self._sql_client = destination_client.sql_client + self.sql_client = destination_client.sql_client else: raise Exception( f"Destination {destination_client.config.destination_type} does not support" @@ -142,11 +138,13 @@ def _ensure_client_and_schema(self) -> None: def __call__( self, query: Any, schema_columns: TTableSchemaColumns = None ) -> ReadableDBAPIRelation: + self._ensure_client_and_schema() schema_columns = schema_columns or {} return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract] def table(self, table_name: str) -> SupportsReadableRelation: # prepare query for table relation + self._ensure_client_and_schema() schema_columns = ( self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {} ) @@ -154,18 +152,6 @@ def table(self, table_name: str) -> SupportsReadableRelation: query = f"SELECT * FROM {table_name}" return self(query, schema_columns) - @property - def schema(self) -> Schema: - """Lazy load schema from destination""" - self._ensure_client_and_schema() - return self._resolved_schema - - @property - def sql_client(self) -> SqlClientBase[Any]: - """Lazy instantiate client""" - self._ensure_client_and_schema() - return self._sql_client - def __getitem__(self, table_name: str) -> SupportsReadableRelation: """access of table via dict notation""" return self.table(table_name) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 54c820337c..cc35a47540 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast import pytest import dlt @@ -21,6 +21,7 @@ from dlt.destinations import filesystem from tests.utils import TEST_STORAGE_ROOT from dlt.common.destination.reference import TDestinationReferenceArg +from dlt.destinations.dataset import ReadableDBAPIDataset def _run_dataset_checks( @@ -222,44 +223,60 @@ def double_items(): # check dataset factory dataset = dlt.dataset(destination=destination_for_dataset, dataset_name=pipeline.dataset_name) # verfiy that sql client and schema are lazy loaded - assert not dataset._schema - assert not dataset._sql_client + assert not dataset.schema + assert not dataset.sql_client table_relationship = dataset.items table = table_relationship.fetchall() assert len(table) == total_records # check that schema is loaded by name - dataset = dlt.dataset( - destination=destination_for_dataset, - dataset_name=pipeline.dataset_name, - schema=pipeline.default_schema_name, + dataset = cast( + ReadableDBAPIDataset, + dlt.dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + schema=pipeline.default_schema_name, + ), ) - assert dataset.schema.tables["items"]["write_disposition"] == "replace" # type: ignore + dataset._ensure_client_and_schema() + assert dataset.schema.tables["items"]["write_disposition"] == "replace" # check that schema is not loaded when wrong name given - dataset = dlt.dataset( - destination=destination_for_dataset, - dataset_name=pipeline.dataset_name, - schema="wrong_schema_name", + dataset = cast( + ReadableDBAPIDataset, + dlt.dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + schema="wrong_schema_name", + ), ) - assert "items" not in dataset.schema.tables # type: ignore - assert dataset.schema.name == pipeline.dataset_name # type: ignore + dataset._ensure_client_and_schema() + assert "items" not in dataset.schema.tables + assert dataset.schema.name == pipeline.dataset_name # check that schema is loaded if no schema name given - dataset = dlt.dataset( - destination=destination_for_dataset, - dataset_name=pipeline.dataset_name, + dataset = cast( + ReadableDBAPIDataset, + dlt.dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + ), ) - assert dataset.schema.name == pipeline.default_schema_name # type: ignore - assert dataset.schema.tables["items"]["write_disposition"] == "replace" # type: ignore + dataset._ensure_client_and_schema() + assert dataset.schema.name == pipeline.default_schema_name + assert dataset.schema.tables["items"]["write_disposition"] == "replace" # check that there is no error when creating dataset without schema table - dataset = dlt.dataset( - destination=destination_for_dataset, - dataset_name="unknown_dataset", + dataset = cast( + ReadableDBAPIDataset, + dlt.dataset( + destination=destination_for_dataset, + dataset_name="unknown_dataset", + ), ) - assert dataset.schema.name == "unknown_dataset" # type: ignore - assert "items" not in dataset.schema.tables # type: ignore + dataset._ensure_client_and_schema() + assert dataset.schema.name == "unknown_dataset" + assert "items" not in dataset.schema.tables # create a newer schema with different name and see wether this is loaded from dlt.common.schema import Schema @@ -273,12 +290,16 @@ def double_items(): with pipeline.destination_client() as client: client.update_stored_schema() - dataset = dlt.dataset( - destination=destination_for_dataset, - dataset_name=pipeline.dataset_name, + dataset = cast( + ReadableDBAPIDataset, + dlt.dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + ), ) - assert dataset.schema.name == "some_other_schema" # type: ignore - assert "other_table" in dataset.schema.tables # type: ignore + dataset._ensure_client_and_schema() + assert dataset.schema.name == "some_other_schema" + assert "other_table" in dataset.schema.tables @pytest.mark.essential From 1e78212538fe6e2fd5c696c999ffd437957aeeab Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 14 Oct 2024 09:59:29 +0200 Subject: [PATCH 08/13] allow athena in a kind of "read only" mode --- dlt/destinations/impl/athena/athena.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index a2e2566a76..c7e30aaf55 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -318,11 +318,12 @@ def __init__( # verify if staging layout is valid for Athena # this will raise if the table prefix is not properly defined # we actually that {table_name} is first, no {schema_name} is allowed - self.table_prefix_layout = path_utils.get_table_prefix_layout( - config.staging_config.layout, - supported_prefix_placeholders=[], - table_needs_own_folder=True, - ) + if config.staging_config: + self.table_prefix_layout = path_utils.get_table_prefix_layout( + config.staging_config.layout, + supported_prefix_placeholders=[], + table_needs_own_folder=True, + ) sql_client = AthenaSQLClient( config.normalize_dataset_name(schema), From f49c3ca037f971504d48ee132b2c7b47778734dd Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 14 Oct 2024 12:35:57 +0200 Subject: [PATCH 09/13] fix delta table tests? --- tests/load/test_read_interfaces.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index cc35a47540..ab84ef0698 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -367,8 +367,7 @@ def test_delta_tables(destination_config: DestinationTestConfiguration) -> None: os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "700" pipeline = destination_config.setup_pipeline( - "read_pipeline", - dataset_name="read_test", + "read_pipeline", dataset_name="read_test", dev_mode=True ) # in case of gcs we use the s3 compat layer for reading From 9354e8694be2342d295bad7461de7457ce521032 Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 14 Oct 2024 12:38:39 +0200 Subject: [PATCH 10/13] mark dataset factory as private --- dlt/__init__.py | 4 ++-- tests/load/test_read_interfaces.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dlt/__init__.py b/dlt/__init__.py index 1dfd17e769..e8a1b7bf92 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -42,7 +42,7 @@ ) from dlt.pipeline import progress from dlt import destinations -from dlt.destinations.dataset import dataset +from dlt.destinations.dataset import dataset as _dataset pipeline = _pipeline current = _current @@ -80,7 +80,7 @@ "TCredentials", "sources", "destinations", - "dataset", + "_dataset", ] # verify that no injection context was created diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index ab84ef0698..7b467598ff 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -221,7 +221,7 @@ def double_items(): ) # check dataset factory - dataset = dlt.dataset(destination=destination_for_dataset, dataset_name=pipeline.dataset_name) + dataset = dlt._dataset(destination=destination_for_dataset, dataset_name=pipeline.dataset_name) # verfiy that sql client and schema are lazy loaded assert not dataset.schema assert not dataset.sql_client @@ -232,7 +232,7 @@ def double_items(): # check that schema is loaded by name dataset = cast( ReadableDBAPIDataset, - dlt.dataset( + dlt._dataset( destination=destination_for_dataset, dataset_name=pipeline.dataset_name, schema=pipeline.default_schema_name, @@ -244,7 +244,7 @@ def double_items(): # check that schema is not loaded when wrong name given dataset = cast( ReadableDBAPIDataset, - dlt.dataset( + dlt._dataset( destination=destination_for_dataset, dataset_name=pipeline.dataset_name, schema="wrong_schema_name", @@ -257,7 +257,7 @@ def double_items(): # check that schema is loaded if no schema name given dataset = cast( ReadableDBAPIDataset, - dlt.dataset( + dlt._dataset( destination=destination_for_dataset, dataset_name=pipeline.dataset_name, ), @@ -269,7 +269,7 @@ def double_items(): # check that there is no error when creating dataset without schema table dataset = cast( ReadableDBAPIDataset, - dlt.dataset( + dlt._dataset( destination=destination_for_dataset, dataset_name="unknown_dataset", ), @@ -292,7 +292,7 @@ def double_items(): dataset = cast( ReadableDBAPIDataset, - dlt.dataset( + dlt._dataset( destination=destination_for_dataset, dataset_name=pipeline.dataset_name, ), From 5f3dbdf2ab5799b93659c417d277dfd7eeecfe61 Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 14 Oct 2024 13:00:04 +0200 Subject: [PATCH 11/13] change signature and behavior of get_stored_schema --- dlt/common/destination/reference.py | 7 ++-- dlt/destinations/dataset.py | 4 +-- .../impl/filesystem/filesystem.py | 8 ++--- .../impl/lancedb/lancedb_client.py | 6 ++-- .../impl/qdrant/qdrant_job_client.py | 22 +++++++------ .../impl/sqlalchemy/sqlalchemy_job_client.py | 9 ++---- .../impl/weaviate/weaviate_client.py | 18 +++++++---- dlt/destinations/job_client_impl.py | 6 ++-- dlt/pipeline/pipeline.py | 2 +- tests/load/lancedb/test_pipeline.py | 2 +- .../load/pipeline/test_filesystem_pipeline.py | 8 ++--- tests/load/pipeline/test_restore_state.py | 4 +-- tests/load/qdrant/test_pipeline.py | 2 +- tests/load/redshift/test_redshift_client.py | 2 +- tests/load/test_job_client.py | 32 +++++++++++-------- tests/load/test_sql_client.py | 6 ++-- tests/load/weaviate/test_pipeline.py | 2 +- 17 files changed, 76 insertions(+), 64 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 059b487a13..8b3819e32b 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -659,8 +659,11 @@ def __exit__( class WithStateSync(ABC): @abstractmethod - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: - """Retrieves newest schema from destination storage, setting any_schema_name to true will return the newest schema regardless of the schema name""" + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: + """ + Retrieves newest schema with given name from destination storage + If no name is provided, the newest schema found is retrieved. + """ pass @abstractmethod diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index 736dc8e2ed..33d2f4aac5 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -108,7 +108,7 @@ def _ensure_client_and_schema(self) -> None: elif not self.schema and isinstance(self._provided_schema, str): with self._destination_client(Schema(self._provided_schema)) as client: if isinstance(client, WithStateSync): - stored_schema = client.get_stored_schema() + stored_schema = client.get_stored_schema(self._provided_schema) if stored_schema: self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) @@ -116,7 +116,7 @@ def _ensure_client_and_schema(self) -> None: elif not self.schema: with self._destination_client(Schema(self._dataset_name)) as client: if isinstance(client, WithStateSync): - stored_schema = client.get_stored_schema(any_schema_name=True) + stored_schema = client.get_stored_schema() if stored_schema: self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index bccf8ec686..0cf63b3ac9 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -650,7 +650,7 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]: yield filepath, fileparts def _get_stored_schema_by_hash_or_newest( - self, version_hash: str = None, any_schema_name: bool = False + self, version_hash: str = None, schema_name: str = None ) -> Optional[StorageSchemaInfo]: """Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name""" version_hash = self._to_path_safe_string(version_hash) @@ -661,7 +661,7 @@ def _get_stored_schema_by_hash_or_newest( for filepath, fileparts in self._iter_stored_schema_files(): if ( not version_hash - and (fileparts[0] == self.schema.name or any_schema_name) + and (fileparts[0] == schema_name or (not schema_name)) and fileparts[1] > newest_load_id ): newest_load_id = fileparts[1] @@ -703,9 +703,9 @@ def _store_current_schema(self) -> None: # we always keep tabs on what the current schema is self._write_to_json_file(filepath, version_info) - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" - return self._get_stored_schema_by_hash_or_newest(any_schema_name=any_schema_name) + return self._get_stored_schema_by_hash_or_newest(schema_name=schema_name) def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema_by_hash_or_newest(version_hash) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d0a840f292..8a347989a0 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI return None @lancedb_error - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) @@ -554,8 +554,8 @@ def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSc try: query = version_table.search() - if not any_schema_name: - query = query.where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True) + if schema_name: + query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True) schemas = query.to_list() # LanceDB's ORDER BY clause doesn't seem to work. diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 90b5cc29c0..6c8de52f98 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -377,26 +377,30 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: raise DestinationUndefinedEntity(str(e)) from e raise - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") - name_filter = models.Filter( - must=[ - models.FieldCondition( - key=p_schema_name, - match=models.MatchValue(value=self.schema.name), - ) - ] + name_filter = ( + models.Filter( + must=[ + models.FieldCondition( + key=p_schema_name, + match=models.MatchValue(value=schema_name), + ) + ] + ) + if schema_name + else None ) response = self.db_client.scroll( scroll_table_name, with_payload=True, - scroll_filter=None if any_schema_name else name_filter, + scroll_filter=name_filter, limit=1, order_by=models.OrderBy( key=p_inserted_at, diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 1c39dde239..ab73ecf502 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -243,7 +243,6 @@ def _get_stored_schema( self, version_hash: Optional[str] = None, schema_name: Optional[str] = None, - any_schema_name: bool = False, ) -> Optional[StorageSchemaInfo]: version_table = self.schema.tables[self.schema.version_table_name] table_obj = self._to_table_object(version_table) # type: ignore[arg-type] @@ -252,7 +251,7 @@ def _get_stored_schema( if version_hash is not None: version_hash_col = self.schema.naming.normalize_identifier("version_hash") q = q.where(table_obj.c[version_hash_col] == version_hash) - if schema_name is not None and not any_schema_name: + if schema_name is not None: schema_name_col = self.schema.naming.normalize_identifier("schema_name") q = q.where(table_obj.c[schema_name_col] == schema_name) inserted_at_col = self.schema.naming.normalize_identifier("inserted_at") @@ -270,11 +269,9 @@ def _get_stored_schema( def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema(version_hash) - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Get the latest stored schema""" - return self._get_stored_schema( - schema_name=self.schema.name, any_schema_name=any_schema_name - ) + return self._get_stored_schema(schema_name=schema_name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.schema.tables.get( diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index fc6eb8a94b..c4b32db20c 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -516,22 +516,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: if len(load_records): return StateInfo(**state) - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") - name_filter = { - "path": [p_schema_name], - "operator": "Equal", - "valueString": self.schema.name, - } + name_filter = ( + { + "path": [p_schema_name], + "operator": "Equal", + "valueString": self.schema.name, + } + if schema_name + else None + ) try: record = self.get_records( self.schema.version_table_name, sort={"path": [p_inserted_at], "order": "desc"}, - where=None if any_schema_name else name_filter, + where=name_filter, limit=1, )[0] return StorageSchemaInfo(**record) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 90c00530dc..fab4d96112 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -397,10 +397,10 @@ def _from_db_type( ) -> TColumnType: pass - def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo: + def get_stored_schema(self, schema_name: str = None) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at") - if any_schema_name: + if not schema_name: query = ( f"SELECT {self.version_table_schema_columns} FROM {name}" f" ORDER BY {c_inserted_at} DESC;" @@ -411,7 +411,7 @@ def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo: f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" f" ORDER BY {c_inserted_at} DESC;" ) - return self._row_to_schema_info(query, self.schema.name) + return self._row_to_schema_info(query, schema_name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index c9a3950722..5373bfb0cb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1547,7 +1547,7 @@ def _get_schemas_from_destination( f" {self.destination.destination_name}" ) return restored_schemas - schema_info = job_client.get_stored_schema() + schema_info = job_client.get_stored_schema(schema_name) if schema_info is None: logger.info( f"The schema {schema.name} was not found in the destination" diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 3dc2a999d4..6cd0abd587 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -75,7 +75,7 @@ def some_data() -> Generator[DictStrStr, Any, None]: client: LanceDBClient with pipeline.destination_client() as client: # type: ignore # Check if we can get a stored schema and state. - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) print("Print dataset name", client.dataset_name) assert schema state = client.get_stored_state("test_pipeline_append") diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 11e0c88451..b8cf66608c 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -1134,8 +1134,8 @@ def _collect_table_counts(p, *items: str) -> Dict[str, int]: "_dlt_pipeline_state": 2, "_dlt_version": 2, } - sc1_old = c1.get_stored_schema() - sc2_old = c2.get_stored_schema() + sc1_old = c1.get_stored_schema(c1.schema.name) + sc2_old = c2.get_stored_schema(c2.schema.name) s1_old = c1.get_stored_state("p1") s2_old = c1.get_stored_state("p2") @@ -1172,8 +1172,8 @@ def some_data(): assert s2_old.version == s2.version # test accessors for schema - sc1 = c1.get_stored_schema() - sc2 = c2.get_stored_schema() + sc1 = c1.get_stored_schema(c1.schema.name) + sc2 = c2.get_stored_schema(c2.schema.name) assert sc1.version_hash != sc1_old.version_hash assert sc2.version_hash == sc2_old.version_hash assert sc1.version_hash != sc2.version_hash diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 51cb392b29..b78306210f 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -70,7 +70,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.sync_schema() # check if schema exists with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] - stored_schema = job_client.get_stored_schema() + stored_schema = job_client.get_stored_schema(job_client.schema.name) assert stored_schema is not None # dataset exists, still no table with pytest.raises(DestinationUndefinedEntity): @@ -97,7 +97,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - # schema.bump_version() p.sync_schema() with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] - stored_schema = job_client.get_stored_schema() + stored_schema = job_client.get_stored_schema(job_client.schema.name) assert stored_schema is not None # table is there but no state assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index 73f53221ed..48a180ac83 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -68,7 +68,7 @@ def some_data(): client: QdrantClient with pipeline.destination_client() as client: # type: ignore[assignment] # check if we can get a stored schema and state - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) print("Print dataset name", client.dataset_name) assert schema state = client.get_stored_state("test_pipeline_append") diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index 41287fcd2d..b60c6a8956 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -123,7 +123,7 @@ def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None: schema_str = json.dumps(schema.to_dict()) assert len(schema_str.encode("utf-8")) > client.capabilities.max_text_data_type_length client._update_schema_in_storage(schema) - schema_info = client.get_stored_schema() + schema_info = client.get_stored_schema(client.schema.name) assert schema_info.schema == schema_str # take base64 from db with client.sql_client.execute_query( diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 0bb88c4dd3..9f64722a1e 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -100,7 +100,7 @@ def test_get_schema_on_empty_storage(naming: str, client: SqlJobClientBase) -> N table_name, table_columns = list(client.get_storage_tables([version_table_name]))[0] assert table_name == version_table_name assert len(table_columns) == 0 - schema_info = client.get_stored_schema() + schema_info = client.get_stored_schema(client.schema.name) assert schema_info is None schema_info = client.get_stored_schema_by_hash("8a0298298823928939") assert schema_info is None @@ -128,7 +128,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: assert [len(table[1]) > 0 for table in storage_tables] == [True, True] # verify if schemas stored this_schema = client.get_stored_schema_by_hash(schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) # should point to the same schema assert this_schema == newest_schema # check fields @@ -151,7 +151,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: client._update_schema_in_storage(schema) sleep(1) this_schema = client.get_stored_schema_by_hash(schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema assert this_schema.version == schema.version == 3 assert this_schema.version_hash == schema.stored_version_hash @@ -166,7 +166,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: sleep(1) client._update_schema_in_storage(first_schema) this_schema = client.get_stored_schema_by_hash(first_schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema # error assert this_schema.version == first_schema.version == 3 assert this_schema.version_hash == first_schema.stored_version_hash @@ -176,17 +176,17 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: # mock other schema in client and get the newest schema. it should not exist... client.schema = Schema("ethereum") - assert client.get_stored_schema() is None + assert client.get_stored_schema(client.schema.name) is None client.schema._bump_version() schema_update = client.update_stored_schema() # no schema updates because schema has no tables assert schema_update == {} - that_info = client.get_stored_schema() + that_info = client.get_stored_schema(client.schema.name) assert that_info.schema_name == "ethereum" # get event schema again client.schema = Schema("event") - this_schema = client.get_stored_schema() + this_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema @@ -990,33 +990,37 @@ def add_schema_to_pipeline(s: Schema) -> None: add_schema_to_pipeline(s1_v1) p.default_schema_name = s1_v1.name with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash assert client.get_stored_schema().version_hash == s1_v1.version_hash - assert client.get_stored_schema(any_schema_name=True).version_hash == s1_v1.version_hash + assert not client.get_stored_schema("other_schema") # now we add a different schema # but keep default schema name at v1 add_schema_to_pipeline(s2_v1) p.default_schema_name = s1_v1.name with p.destination_client() as client: # type: ignore[assignment] - assert client.get_stored_schema().version_hash == s1_v1.version_hash + assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash # here v2 will be selected as it is newer - assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v1.version_hash + assert client.get_stored_schema(None).version_hash == s2_v1.version_hash + assert not client.get_stored_schema("other_schema") # add two more version, add_schema_to_pipeline(s1_v2) add_schema_to_pipeline(s2_v2) p.default_schema_name = s1_v1.name with p.destination_client() as client: # type: ignore[assignment] - assert client.get_stored_schema().version_hash == s1_v2.version_hash + assert client.get_stored_schema("schema_1").version_hash == s1_v2.version_hash # here v2 will be selected as it is newer - assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash + assert client.get_stored_schema(None).version_hash == s2_v2.version_hash + assert not client.get_stored_schema("other_schema") # check same setup with other default schema name p.default_schema_name = s2_v1.name with p.destination_client() as client: # type: ignore[assignment] - assert client.get_stored_schema().version_hash == s2_v2.version_hash + assert client.get_stored_schema("schema_2").version_hash == s2_v2.version_hash # here v2 will be selected as it is newer - assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash + assert client.get_stored_schema(None).version_hash == s2_v2.version_hash + assert not client.get_stored_schema("other_schema") def prepare_schema(client: SqlJobClientBase, case: str) -> Tuple[List[Dict[str, Any]], str]: diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 3636b3e53a..0aaa18eac1 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -661,7 +661,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: client.sql_client.execute_sql(sql) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") # still can execute dml and selects - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("ABC") assert_load_id(client.sql_client, "ABC") @@ -670,7 +670,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: with pytest.raises(DatabaseTransientException): client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("EFG") assert_load_id(client.sql_client, "EFG") @@ -685,7 +685,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "IntegrityError") # assert isinstance(term_ex.value.dbapi_exception, (psycopg2.InternalError, psycopg2.)) - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("HJK") assert_load_id(client.sql_client, "HJK") diff --git a/tests/load/weaviate/test_pipeline.py b/tests/load/weaviate/test_pipeline.py index fc46d00d05..6fcb9b7e4f 100644 --- a/tests/load/weaviate/test_pipeline.py +++ b/tests/load/weaviate/test_pipeline.py @@ -72,7 +72,7 @@ def some_data(): client: WeaviateClient with pipeline.destination_client() as client: # type: ignore[assignment] # check if we can get a stored schema and state - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) assert schema state = client.get_stored_state("test_pipeline_append") assert state From b52a4a8f9b618d47ab7cd766e648c34af0ca8681 Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 14 Oct 2024 14:53:48 +0200 Subject: [PATCH 12/13] fix weaviate schema retrieval --- dlt/destinations/impl/weaviate/weaviate_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index c4b32db20c..e9d6a76a17 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -525,7 +525,7 @@ def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaIn { "path": [p_schema_name], "operator": "Equal", - "valueString": self.schema.name, + "valueString": schema_name, } if schema_name else None From d367c6884a428d97591529fe8f4b98f5931cde15 Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 14 Oct 2024 15:39:23 +0200 Subject: [PATCH 13/13] switch back to properties --- dlt/destinations/dataset.py | 38 ++++++++++++++++++------------ tests/load/test_read_interfaces.py | 9 ++----- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index 33d2f4aac5..40583c6a9c 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -87,8 +87,18 @@ def __init__( self._destination = Destination.from_reference(destination) self._provided_schema = schema self._dataset_name = dataset_name - self.sql_client: SqlClientBase[Any] = None - self.schema: Schema = None + self._sql_client: SqlClientBase[Any] = None + self._schema: Schema = None + + @property + def schema(self) -> Schema: + self._ensure_client_and_schema() + return self._schema + + @property + def sql_client(self) -> SqlClientBase[Any]: + self._ensure_client_and_schema() + return self._sql_client def _destination_client(self, schema: Schema) -> JobClientBase: client_spec = self._destination.spec() @@ -101,34 +111,34 @@ def _destination_client(self, schema: Schema) -> JobClientBase: def _ensure_client_and_schema(self) -> None: """Lazy load schema and client""" # full schema given, nothing to do - if not self.schema and isinstance(self._provided_schema, Schema): - self.schema = self._provided_schema + if not self._schema and isinstance(self._provided_schema, Schema): + self._schema = self._provided_schema # schema name given, resolve it from destination by name - elif not self.schema and isinstance(self._provided_schema, str): + elif not self._schema and isinstance(self._provided_schema, str): with self._destination_client(Schema(self._provided_schema)) as client: if isinstance(client, WithStateSync): stored_schema = client.get_stored_schema(self._provided_schema) if stored_schema: - self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) + self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) # no schema name given, load newest schema from destination - elif not self.schema: + elif not self._schema: with self._destination_client(Schema(self._dataset_name)) as client: if isinstance(client, WithStateSync): stored_schema = client.get_stored_schema() if stored_schema: - self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) + self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) # default to empty schema with dataset name if nothing found - if not self.schema: - self.schema = Schema(self._dataset_name) + if not self._schema: + self._schema = Schema(self._dataset_name) # here we create the client bound to the resolved schema - if not self.sql_client: - destination_client = self._destination_client(self.schema) + if not self._sql_client: + destination_client = self._destination_client(self._schema) if isinstance(destination_client, WithSqlClient): - self.sql_client = destination_client.sql_client + self._sql_client = destination_client.sql_client else: raise Exception( f"Destination {destination_client.config.destination_type} does not support" @@ -138,13 +148,11 @@ def _ensure_client_and_schema(self) -> None: def __call__( self, query: Any, schema_columns: TTableSchemaColumns = None ) -> ReadableDBAPIRelation: - self._ensure_client_and_schema() schema_columns = schema_columns or {} return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract] def table(self, table_name: str) -> SupportsReadableRelation: # prepare query for table relation - self._ensure_client_and_schema() schema_columns = ( self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {} ) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 7b467598ff..ef73cbd509 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -223,8 +223,8 @@ def double_items(): # check dataset factory dataset = dlt._dataset(destination=destination_for_dataset, dataset_name=pipeline.dataset_name) # verfiy that sql client and schema are lazy loaded - assert not dataset.schema - assert not dataset.sql_client + assert not dataset._schema + assert not dataset._sql_client table_relationship = dataset.items table = table_relationship.fetchall() assert len(table) == total_records @@ -238,7 +238,6 @@ def double_items(): schema=pipeline.default_schema_name, ), ) - dataset._ensure_client_and_schema() assert dataset.schema.tables["items"]["write_disposition"] == "replace" # check that schema is not loaded when wrong name given @@ -250,7 +249,6 @@ def double_items(): schema="wrong_schema_name", ), ) - dataset._ensure_client_and_schema() assert "items" not in dataset.schema.tables assert dataset.schema.name == pipeline.dataset_name @@ -262,7 +260,6 @@ def double_items(): dataset_name=pipeline.dataset_name, ), ) - dataset._ensure_client_and_schema() assert dataset.schema.name == pipeline.default_schema_name assert dataset.schema.tables["items"]["write_disposition"] == "replace" @@ -274,7 +271,6 @@ def double_items(): dataset_name="unknown_dataset", ), ) - dataset._ensure_client_and_schema() assert dataset.schema.name == "unknown_dataset" assert "items" not in dataset.schema.tables @@ -297,7 +293,6 @@ def double_items(): dataset_name=pipeline.dataset_name, ), ) - dataset._ensure_client_and_schema() assert dataset.schema.name == "some_other_schema" assert "other_table" in dataset.schema.tables