diff --git a/dlt/__init__.py b/dlt/__init__.py index 328817efd2..e8a1b7bf92 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -42,6 +42,7 @@ ) from dlt.pipeline import progress from dlt import destinations +from dlt.destinations.dataset import dataset as _dataset pipeline = _pipeline current = _current @@ -79,6 +80,7 @@ "TCredentials", "sources", "destinations", + "_dataset", ] # verify that no injection context was created diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 0c572379de..8b3819e32b 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -66,6 +66,8 @@ TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration") +TDatasetType = Literal["dbapi", "ibis"] + DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}" @@ -657,8 +659,11 @@ def __exit__( class WithStateSync(ABC): @abstractmethod - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: - """Retrieves newest schema from destination storage""" + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: + """ + Retrieves newest schema with given name from destination storage + If no name is provided, the newest schema found is retrieved. + """ pass @abstractmethod diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index a5584851e9..40583c6a9c 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -1,13 +1,20 @@ -from typing import Any, Generator, AnyStr, Optional +from typing import Any, Generator, Optional, Union +from dlt.common.json import json from contextlib import contextmanager from dlt.common.destination.reference import ( SupportsReadableRelation, SupportsReadableDataset, + TDatasetType, + TDestinationReferenceArg, + Destination, + JobClientBase, + WithStateSync, + DestinationClientDwhConfiguration, ) from dlt.common.schema.typing import TTableSchemaColumns -from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.common.schema import Schema @@ -71,22 +78,85 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: class ReadableDBAPIDataset(SupportsReadableDataset): """Access to dataframes and arrowtables in the destination dataset via dbapi""" - def __init__(self, client: SqlClientBase[Any], schema: Optional[Schema]) -> None: - self.client = client - self.schema = schema + def __init__( + self, + destination: TDestinationReferenceArg, + dataset_name: str, + schema: Union[Schema, str, None] = None, + ) -> None: + self._destination = Destination.from_reference(destination) + self._provided_schema = schema + self._dataset_name = dataset_name + self._sql_client: SqlClientBase[Any] = None + self._schema: Schema = None + + @property + def schema(self) -> Schema: + self._ensure_client_and_schema() + return self._schema + + @property + def sql_client(self) -> SqlClientBase[Any]: + self._ensure_client_and_schema() + return self._sql_client + + def _destination_client(self, schema: Schema) -> JobClientBase: + client_spec = self._destination.spec() + if isinstance(client_spec, DestinationClientDwhConfiguration): + client_spec._bind_dataset_name( + dataset_name=self._dataset_name, default_schema_name=schema.name + ) + return self._destination.client(schema, client_spec) + + def _ensure_client_and_schema(self) -> None: + """Lazy load schema and client""" + # full schema given, nothing to do + if not self._schema and isinstance(self._provided_schema, Schema): + self._schema = self._provided_schema + + # schema name given, resolve it from destination by name + elif not self._schema and isinstance(self._provided_schema, str): + with self._destination_client(Schema(self._provided_schema)) as client: + if isinstance(client, WithStateSync): + stored_schema = client.get_stored_schema(self._provided_schema) + if stored_schema: + self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) + + # no schema name given, load newest schema from destination + elif not self._schema: + with self._destination_client(Schema(self._dataset_name)) as client: + if isinstance(client, WithStateSync): + stored_schema = client.get_stored_schema() + if stored_schema: + self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) + + # default to empty schema with dataset name if nothing found + if not self._schema: + self._schema = Schema(self._dataset_name) + + # here we create the client bound to the resolved schema + if not self._sql_client: + destination_client = self._destination_client(self._schema) + if isinstance(destination_client, WithSqlClient): + self._sql_client = destination_client.sql_client + else: + raise Exception( + f"Destination {destination_client.config.destination_type} does not support" + " SqlClient." + ) def __call__( self, query: Any, schema_columns: TTableSchemaColumns = None ) -> ReadableDBAPIRelation: schema_columns = schema_columns or {} - return ReadableDBAPIRelation(client=self.client, query=query, schema_columns=schema_columns) # type: ignore[abstract] + return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract] def table(self, table_name: str) -> SupportsReadableRelation: # prepare query for table relation schema_columns = ( self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {} ) - table_name = self.client.make_qualified_table_name(table_name) + table_name = self.sql_client.make_qualified_table_name(table_name) query = f"SELECT * FROM {table_name}" return self(query, schema_columns) @@ -97,3 +167,14 @@ def __getitem__(self, table_name: str) -> SupportsReadableRelation: def __getattr__(self, table_name: str) -> SupportsReadableRelation: """access of table via property notation""" return self.table(table_name) + + +def dataset( + destination: TDestinationReferenceArg, + dataset_name: str, + schema: Union[Schema, str, None] = None, + dataset_type: TDatasetType = "dbapi", +) -> SupportsReadableDataset: + if dataset_type == "dbapi": + return ReadableDBAPIDataset(destination, dataset_name, schema) + raise NotImplementedError(f"Dataset of type {dataset_type} not implemented") diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index a2e2566a76..c7e30aaf55 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -318,11 +318,12 @@ def __init__( # verify if staging layout is valid for Athena # this will raise if the table prefix is not properly defined # we actually that {table_name} is first, no {schema_name} is allowed - self.table_prefix_layout = path_utils.get_table_prefix_layout( - config.staging_config.layout, - supported_prefix_placeholders=[], - table_needs_own_folder=True, - ) + if config.staging_config: + self.table_prefix_layout = path_utils.get_table_prefix_layout( + config.staging_config.layout, + supported_prefix_placeholders=[], + table_needs_own_folder=True, + ) sql_client = AthenaSQLClient( config.normalize_dataset_name(schema), diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index d6d9865a06..0cf63b3ac9 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -650,29 +650,33 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]: yield filepath, fileparts def _get_stored_schema_by_hash_or_newest( - self, version_hash: str = None + self, version_hash: str = None, schema_name: str = None ) -> Optional[StorageSchemaInfo]: """Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name""" version_hash = self._to_path_safe_string(version_hash) # find newest schema for pipeline or by version hash - selected_path = None - newest_load_id = "0" - for filepath, fileparts in self._iter_stored_schema_files(): - if ( - not version_hash - and fileparts[0] == self.schema.name - and fileparts[1] > newest_load_id - ): - newest_load_id = fileparts[1] - selected_path = filepath - elif fileparts[2] == version_hash: - selected_path = filepath - break + try: + selected_path = None + newest_load_id = "0" + for filepath, fileparts in self._iter_stored_schema_files(): + if ( + not version_hash + and (fileparts[0] == schema_name or (not schema_name)) + and fileparts[1] > newest_load_id + ): + newest_load_id = fileparts[1] + selected_path = filepath + elif fileparts[2] == version_hash: + selected_path = filepath + break - if selected_path: - return StorageSchemaInfo( - **json.loads(self.fs_client.read_text(selected_path, encoding="utf-8")) - ) + if selected_path: + return StorageSchemaInfo( + **json.loads(self.fs_client.read_text(selected_path, encoding="utf-8")) + ) + except DestinationUndefinedEntity: + # ignore missing table + pass return None @@ -699,9 +703,9 @@ def _store_current_schema(self) -> None: # we always keep tabs on what the current schema is self._write_to_json_file(filepath, version_info) - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" - return self._get_stored_schema_by_hash_or_newest() + return self._get_stored_schema_by_hash_or_newest(schema_name=schema_name) def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema_by_hash_or_newest(version_hash) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ffa556797e..8a347989a0 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI return None @lancedb_error - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) @@ -553,11 +553,10 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = ( - version_table.search().where( - f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True - ) - ).to_list() + query = version_table.search() + if schema_name: + query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True) + schemas = query.to_list() # LanceDB's ORDER BY clause doesn't seem to work. # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 2536bd369d..6c8de52f98 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -377,23 +377,30 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: raise DestinationUndefinedEntity(str(e)) from e raise - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") - response = self.db_client.scroll( - scroll_table_name, - with_payload=True, - scroll_filter=models.Filter( + + name_filter = ( + models.Filter( must=[ models.FieldCondition( key=p_schema_name, - match=models.MatchValue(value=self.schema.name), + match=models.MatchValue(value=schema_name), ) ] - ), + ) + if schema_name + else None + ) + + response = self.db_client.scroll( + scroll_table_name, + with_payload=True, + scroll_filter=name_filter, limit=1, order_by=models.OrderBy( key=p_inserted_at, diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index c5a6442d8a..ab73ecf502 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -240,7 +240,9 @@ def _update_schema_in_storage(self, schema: Schema) -> None: self.sql_client.execute_sql(table_obj.insert().values(schema_mapping)) def _get_stored_schema( - self, version_hash: Optional[str] = None, schema_name: Optional[str] = None + self, + version_hash: Optional[str] = None, + schema_name: Optional[str] = None, ) -> Optional[StorageSchemaInfo]: version_table = self.schema.tables[self.schema.version_table_name] table_obj = self._to_table_object(version_table) # type: ignore[arg-type] @@ -267,9 +269,9 @@ def _get_stored_schema( def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema(version_hash) - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Get the latest stored schema""" - return self._get_stored_schema(schema_name=self.schema.name) + return self._get_stored_schema(schema_name=schema_name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.schema.tables.get( diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index 76e5fd8b1e..e9d6a76a17 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -516,19 +516,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: if len(load_records): return StateInfo(**state) - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + + name_filter = ( + { + "path": [p_schema_name], + "operator": "Equal", + "valueString": schema_name, + } + if schema_name + else None + ) + try: record = self.get_records( self.schema.version_table_name, sort={"path": [p_inserted_at], "order": "desc"}, - where={ - "path": [p_schema_name], - "operator": "Equal", - "valueString": self.schema.name, - }, + where=name_filter, limit=1, )[0] return StorageSchemaInfo(**record) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 0fca64d7ba..fab4d96112 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -397,14 +397,21 @@ def _from_db_type( ) -> TColumnType: pass - def get_stored_schema(self) -> StorageSchemaInfo: + def get_stored_schema(self, schema_name: str = None) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at") - query = ( - f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" - f" ORDER BY {c_inserted_at} DESC;" - ) - return self._row_to_schema_info(query, self.schema.name) + if not schema_name: + query = ( + f"SELECT {self.version_table_schema_columns} FROM {name}" + f" ORDER BY {c_inserted_at} DESC;" + ) + return self._row_to_schema_info(query) + else: + query = ( + f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" + f" ORDER BY {c_inserted_at} DESC;" + ) + return self._row_to_schema_info(query, schema_name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 348f445967..5373bfb0cb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -84,6 +84,7 @@ DestinationClientStagingConfiguration, DestinationClientDwhWithStagingConfiguration, SupportsReadableDataset, + TDatasetType, ) from dlt.common.normalizers.naming import NamingConvention from dlt.common.pipeline import ( @@ -113,7 +114,7 @@ from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.destinations.fs_client import FSClientBase from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.destinations.dataset import ReadableDBAPIDataset +from dlt.destinations.dataset import dataset from dlt.load.configuration import LoaderConfiguration from dlt.load import Load @@ -1546,7 +1547,7 @@ def _get_schemas_from_destination( f" {self.destination.destination_name}" ) return restored_schemas - schema_info = job_client.get_stored_schema() + schema_info = job_client.get_stored_schema(schema_name) if schema_info is None: logger.info( f"The schema {schema.name} was not found in the destination" @@ -1717,10 +1718,11 @@ def __getstate__(self) -> Any: # pickle only the SupportsPipeline protocol fields return {"pipeline_name": self.pipeline_name} - def _dataset(self, dataset_type: Literal["dbapi", "ibis"] = "dbapi") -> SupportsReadableDataset: + def _dataset(self, dataset_type: TDatasetType = "dbapi") -> SupportsReadableDataset: """Access helper to dataset""" - if dataset_type == "dbapi": - return ReadableDBAPIDataset( - self.sql_client(), schema=self.default_schema if self.default_schema_name else None - ) - raise NotImplementedError(f"Dataset of type {dataset_type} not implemented") + return dataset( + self.destination, + self.dataset_name, + schema=(self.default_schema if self.default_schema_name else None), + dataset_type=dataset_type, + ) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 3dc2a999d4..6cd0abd587 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -75,7 +75,7 @@ def some_data() -> Generator[DictStrStr, Any, None]: client: LanceDBClient with pipeline.destination_client() as client: # type: ignore # Check if we can get a stored schema and state. - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) print("Print dataset name", client.dataset_name) assert schema state = client.get_stored_state("test_pipeline_append") diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 11e0c88451..b8cf66608c 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -1134,8 +1134,8 @@ def _collect_table_counts(p, *items: str) -> Dict[str, int]: "_dlt_pipeline_state": 2, "_dlt_version": 2, } - sc1_old = c1.get_stored_schema() - sc2_old = c2.get_stored_schema() + sc1_old = c1.get_stored_schema(c1.schema.name) + sc2_old = c2.get_stored_schema(c2.schema.name) s1_old = c1.get_stored_state("p1") s2_old = c1.get_stored_state("p2") @@ -1172,8 +1172,8 @@ def some_data(): assert s2_old.version == s2.version # test accessors for schema - sc1 = c1.get_stored_schema() - sc2 = c2.get_stored_schema() + sc1 = c1.get_stored_schema(c1.schema.name) + sc2 = c2.get_stored_schema(c2.schema.name) assert sc1.version_hash != sc1_old.version_hash assert sc2.version_hash == sc2_old.version_hash assert sc1.version_hash != sc2.version_hash diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 51cb392b29..b78306210f 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -70,7 +70,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.sync_schema() # check if schema exists with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] - stored_schema = job_client.get_stored_schema() + stored_schema = job_client.get_stored_schema(job_client.schema.name) assert stored_schema is not None # dataset exists, still no table with pytest.raises(DestinationUndefinedEntity): @@ -97,7 +97,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - # schema.bump_version() p.sync_schema() with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] - stored_schema = job_client.get_stored_schema() + stored_schema = job_client.get_stored_schema(job_client.schema.name) assert stored_schema is not None # table is there but no state assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index 73f53221ed..48a180ac83 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -68,7 +68,7 @@ def some_data(): client: QdrantClient with pipeline.destination_client() as client: # type: ignore[assignment] # check if we can get a stored schema and state - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) print("Print dataset name", client.dataset_name) assert schema state = client.get_stored_state("test_pipeline_append") diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index 41287fcd2d..b60c6a8956 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -123,7 +123,7 @@ def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None: schema_str = json.dumps(schema.to_dict()) assert len(schema_str.encode("utf-8")) > client.capabilities.max_text_data_type_length client._update_schema_in_storage(schema) - schema_info = client.get_stored_schema() + schema_info = client.get_stored_schema(client.schema.name) assert schema_info.schema == schema_str # take base64 from db with client.sql_client.execute_query( diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 84d08a5a89..9f64722a1e 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -31,6 +31,7 @@ StateInfo, WithStagingDataset, DestinationClientConfiguration, + WithStateSync, ) from dlt.common.time import ensure_pendulum_datetime @@ -99,7 +100,7 @@ def test_get_schema_on_empty_storage(naming: str, client: SqlJobClientBase) -> N table_name, table_columns = list(client.get_storage_tables([version_table_name]))[0] assert table_name == version_table_name assert len(table_columns) == 0 - schema_info = client.get_stored_schema() + schema_info = client.get_stored_schema(client.schema.name) assert schema_info is None schema_info = client.get_stored_schema_by_hash("8a0298298823928939") assert schema_info is None @@ -127,7 +128,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: assert [len(table[1]) > 0 for table in storage_tables] == [True, True] # verify if schemas stored this_schema = client.get_stored_schema_by_hash(schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) # should point to the same schema assert this_schema == newest_schema # check fields @@ -150,7 +151,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: client._update_schema_in_storage(schema) sleep(1) this_schema = client.get_stored_schema_by_hash(schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema assert this_schema.version == schema.version == 3 assert this_schema.version_hash == schema.stored_version_hash @@ -165,7 +166,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: sleep(1) client._update_schema_in_storage(first_schema) this_schema = client.get_stored_schema_by_hash(first_schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema # error assert this_schema.version == first_schema.version == 3 assert this_schema.version_hash == first_schema.stored_version_hash @@ -175,17 +176,17 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: # mock other schema in client and get the newest schema. it should not exist... client.schema = Schema("ethereum") - assert client.get_stored_schema() is None + assert client.get_stored_schema(client.schema.name) is None client.schema._bump_version() schema_update = client.update_stored_schema() # no schema updates because schema has no tables assert schema_update == {} - that_info = client.get_stored_schema() + that_info = client.get_stored_schema(client.schema.name) assert that_info.schema_name == "ethereum" # get event schema again client.schema = Schema("event") - this_schema = client.get_stored_schema() + this_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema @@ -951,6 +952,77 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: ) +# NOTE: this could be folded into the above tests, but these only run on sql_client destinations for now +# but we want to test filesystem and vector db here too +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, default_vector_configs=True, all_buckets_filesystem_configs=True + ), + ids=lambda x: x.name, +) +def test_schema_retrieval(destination_config: DestinationTestConfiguration) -> None: + p = destination_config.setup_pipeline("schema_test", dev_mode=True) + from dlt.common.schema import utils + + # we create 2 versions of 2 schemas + s1_v1 = Schema("schema_1") + s1_v2 = s1_v1.clone() + s1_v2.tables["items"] = utils.new_table("items") + s2_v1 = Schema("schema_2") + s2_v2 = s2_v1.clone() + s2_v2.tables["other_items"] = utils.new_table("other_items") + + # sanity check + assert s1_v1.version_hash != s1_v2.version_hash + assert s2_v1.version_hash != s2_v2.version_hash + + client: WithStateSync + + def add_schema_to_pipeline(s: Schema) -> None: + p._inject_schema(s) + p.default_schema_name = s.name + with p.destination_client() as client: + client.initialize_storage() + client.update_stored_schema() + + # check what happens if there is only one + add_schema_to_pipeline(s1_v1) + p.default_schema_name = s1_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash + assert client.get_stored_schema().version_hash == s1_v1.version_hash + assert not client.get_stored_schema("other_schema") + + # now we add a different schema + # but keep default schema name at v1 + add_schema_to_pipeline(s2_v1) + p.default_schema_name = s1_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash + # here v2 will be selected as it is newer + assert client.get_stored_schema(None).version_hash == s2_v1.version_hash + assert not client.get_stored_schema("other_schema") + + # add two more version, + add_schema_to_pipeline(s1_v2) + add_schema_to_pipeline(s2_v2) + p.default_schema_name = s1_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema("schema_1").version_hash == s1_v2.version_hash + # here v2 will be selected as it is newer + assert client.get_stored_schema(None).version_hash == s2_v2.version_hash + assert not client.get_stored_schema("other_schema") + + # check same setup with other default schema name + p.default_schema_name = s2_v1.name + with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema("schema_2").version_hash == s2_v2.version_hash + # here v2 will be selected as it is newer + assert client.get_stored_schema(None).version_hash == s2_v2.version_hash + assert not client.get_stored_schema("other_schema") + + def prepare_schema(client: SqlJobClientBase, case: str) -> Tuple[List[Dict[str, Any]], str]: client.update_stored_schema() rows = load_json_case(case) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index e093e4d670..ef73cbd509 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast import pytest import dlt @@ -20,6 +20,8 @@ ) from dlt.destinations import filesystem from tests.utils import TEST_STORAGE_ROOT +from dlt.common.destination.reference import TDestinationReferenceArg +from dlt.destinations.dataset import ReadableDBAPIDataset def _run_dataset_checks( @@ -212,6 +214,88 @@ def double_items(): loads_table = pipeline._dataset()[pipeline.default_schema.loads_table_name] loads_table.fetchall() + destination_for_dataset: TDestinationReferenceArg = ( + alternate_access_pipeline.destination + if alternate_access_pipeline + else destination_config.destination_type + ) + + # check dataset factory + dataset = dlt._dataset(destination=destination_for_dataset, dataset_name=pipeline.dataset_name) + # verfiy that sql client and schema are lazy loaded + assert not dataset._schema + assert not dataset._sql_client + table_relationship = dataset.items + table = table_relationship.fetchall() + assert len(table) == total_records + + # check that schema is loaded by name + dataset = cast( + ReadableDBAPIDataset, + dlt._dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + schema=pipeline.default_schema_name, + ), + ) + assert dataset.schema.tables["items"]["write_disposition"] == "replace" + + # check that schema is not loaded when wrong name given + dataset = cast( + ReadableDBAPIDataset, + dlt._dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + schema="wrong_schema_name", + ), + ) + assert "items" not in dataset.schema.tables + assert dataset.schema.name == pipeline.dataset_name + + # check that schema is loaded if no schema name given + dataset = cast( + ReadableDBAPIDataset, + dlt._dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + ), + ) + assert dataset.schema.name == pipeline.default_schema_name + assert dataset.schema.tables["items"]["write_disposition"] == "replace" + + # check that there is no error when creating dataset without schema table + dataset = cast( + ReadableDBAPIDataset, + dlt._dataset( + destination=destination_for_dataset, + dataset_name="unknown_dataset", + ), + ) + assert dataset.schema.name == "unknown_dataset" + assert "items" not in dataset.schema.tables + + # create a newer schema with different name and see wether this is loaded + from dlt.common.schema import Schema + from dlt.common.schema import utils + + other_schema = Schema("some_other_schema") + other_schema.tables["other_table"] = utils.new_table("other_table") + + pipeline._inject_schema(other_schema) + pipeline.default_schema_name = other_schema.name + with pipeline.destination_client() as client: + client.update_stored_schema() + + dataset = cast( + ReadableDBAPIDataset, + dlt._dataset( + destination=destination_for_dataset, + dataset_name=pipeline.dataset_name, + ), + ) + assert dataset.schema.name == "some_other_schema" + assert "other_table" in dataset.schema.tables + @pytest.mark.essential @pytest.mark.parametrize( @@ -278,8 +362,7 @@ def test_delta_tables(destination_config: DestinationTestConfiguration) -> None: os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "700" pipeline = destination_config.setup_pipeline( - "read_pipeline", - dataset_name="read_test", + "read_pipeline", dataset_name="read_test", dev_mode=True ) # in case of gcs we use the s3 compat layer for reading diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 3636b3e53a..0aaa18eac1 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -661,7 +661,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: client.sql_client.execute_sql(sql) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") # still can execute dml and selects - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("ABC") assert_load_id(client.sql_client, "ABC") @@ -670,7 +670,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: with pytest.raises(DatabaseTransientException): client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("EFG") assert_load_id(client.sql_client, "EFG") @@ -685,7 +685,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "IntegrityError") # assert isinstance(term_ex.value.dbapi_exception, (psycopg2.InternalError, psycopg2.)) - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("HJK") assert_load_id(client.sql_client, "HJK") diff --git a/tests/load/weaviate/test_pipeline.py b/tests/load/weaviate/test_pipeline.py index fc46d00d05..6fcb9b7e4f 100644 --- a/tests/load/weaviate/test_pipeline.py +++ b/tests/load/weaviate/test_pipeline.py @@ -72,7 +72,7 @@ def some_data(): client: WeaviateClient with pipeline.destination_client() as client: # type: ignore[assignment] # check if we can get a stored schema and state - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) assert schema state = client.get_stored_state("test_pipeline_append") assert state