diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 059b487a13..8b3819e32b 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -659,8 +659,11 @@ def __exit__( class WithStateSync(ABC): @abstractmethod - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: - """Retrieves newest schema from destination storage, setting any_schema_name to true will return the newest schema regardless of the schema name""" + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: + """ + Retrieves newest schema with given name from destination storage + If no name is provided, the newest schema found is retrieved. + """ pass @abstractmethod diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index 736dc8e2ed..33d2f4aac5 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -108,7 +108,7 @@ def _ensure_client_and_schema(self) -> None: elif not self.schema and isinstance(self._provided_schema, str): with self._destination_client(Schema(self._provided_schema)) as client: if isinstance(client, WithStateSync): - stored_schema = client.get_stored_schema() + stored_schema = client.get_stored_schema(self._provided_schema) if stored_schema: self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) @@ -116,7 +116,7 @@ def _ensure_client_and_schema(self) -> None: elif not self.schema: with self._destination_client(Schema(self._dataset_name)) as client: if isinstance(client, WithStateSync): - stored_schema = client.get_stored_schema(any_schema_name=True) + stored_schema = client.get_stored_schema() if stored_schema: self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index bccf8ec686..0cf63b3ac9 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -650,7 +650,7 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]: yield filepath, fileparts def _get_stored_schema_by_hash_or_newest( - self, version_hash: str = None, any_schema_name: bool = False + self, version_hash: str = None, schema_name: str = None ) -> Optional[StorageSchemaInfo]: """Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name""" version_hash = self._to_path_safe_string(version_hash) @@ -661,7 +661,7 @@ def _get_stored_schema_by_hash_or_newest( for filepath, fileparts in self._iter_stored_schema_files(): if ( not version_hash - and (fileparts[0] == self.schema.name or any_schema_name) + and (fileparts[0] == schema_name or (not schema_name)) and fileparts[1] > newest_load_id ): newest_load_id = fileparts[1] @@ -703,9 +703,9 @@ def _store_current_schema(self) -> None: # we always keep tabs on what the current schema is self._write_to_json_file(filepath, version_info) - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" - return self._get_stored_schema_by_hash_or_newest(any_schema_name=any_schema_name) + return self._get_stored_schema_by_hash_or_newest(schema_name=schema_name) def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema_by_hash_or_newest(version_hash) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d0a840f292..8a347989a0 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI return None @lancedb_error - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) @@ -554,8 +554,8 @@ def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSc try: query = version_table.search() - if not any_schema_name: - query = query.where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True) + if schema_name: + query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True) schemas = query.to_list() # LanceDB's ORDER BY clause doesn't seem to work. diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 90b5cc29c0..6c8de52f98 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -377,26 +377,30 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: raise DestinationUndefinedEntity(str(e)) from e raise - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") - name_filter = models.Filter( - must=[ - models.FieldCondition( - key=p_schema_name, - match=models.MatchValue(value=self.schema.name), - ) - ] + name_filter = ( + models.Filter( + must=[ + models.FieldCondition( + key=p_schema_name, + match=models.MatchValue(value=schema_name), + ) + ] + ) + if schema_name + else None ) response = self.db_client.scroll( scroll_table_name, with_payload=True, - scroll_filter=None if any_schema_name else name_filter, + scroll_filter=name_filter, limit=1, order_by=models.OrderBy( key=p_inserted_at, diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 1c39dde239..ab73ecf502 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -243,7 +243,6 @@ def _get_stored_schema( self, version_hash: Optional[str] = None, schema_name: Optional[str] = None, - any_schema_name: bool = False, ) -> Optional[StorageSchemaInfo]: version_table = self.schema.tables[self.schema.version_table_name] table_obj = self._to_table_object(version_table) # type: ignore[arg-type] @@ -252,7 +251,7 @@ def _get_stored_schema( if version_hash is not None: version_hash_col = self.schema.naming.normalize_identifier("version_hash") q = q.where(table_obj.c[version_hash_col] == version_hash) - if schema_name is not None and not any_schema_name: + if schema_name is not None: schema_name_col = self.schema.naming.normalize_identifier("schema_name") q = q.where(table_obj.c[schema_name_col] == schema_name) inserted_at_col = self.schema.naming.normalize_identifier("inserted_at") @@ -270,11 +269,9 @@ def _get_stored_schema( def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema(version_hash) - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Get the latest stored schema""" - return self._get_stored_schema( - schema_name=self.schema.name, any_schema_name=any_schema_name - ) + return self._get_stored_schema(schema_name=schema_name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.schema.tables.get( diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index fc6eb8a94b..c4b32db20c 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -516,22 +516,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: if len(load_records): return StateInfo(**state) - def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]: + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" p_schema_name = self.schema.naming.normalize_identifier("schema_name") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") - name_filter = { - "path": [p_schema_name], - "operator": "Equal", - "valueString": self.schema.name, - } + name_filter = ( + { + "path": [p_schema_name], + "operator": "Equal", + "valueString": self.schema.name, + } + if schema_name + else None + ) try: record = self.get_records( self.schema.version_table_name, sort={"path": [p_inserted_at], "order": "desc"}, - where=None if any_schema_name else name_filter, + where=name_filter, limit=1, )[0] return StorageSchemaInfo(**record) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 90c00530dc..fab4d96112 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -397,10 +397,10 @@ def _from_db_type( ) -> TColumnType: pass - def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo: + def get_stored_schema(self, schema_name: str = None) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at") - if any_schema_name: + if not schema_name: query = ( f"SELECT {self.version_table_schema_columns} FROM {name}" f" ORDER BY {c_inserted_at} DESC;" @@ -411,7 +411,7 @@ def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo: f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" f" ORDER BY {c_inserted_at} DESC;" ) - return self._row_to_schema_info(query, self.schema.name) + return self._row_to_schema_info(query, schema_name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index c9a3950722..5373bfb0cb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1547,7 +1547,7 @@ def _get_schemas_from_destination( f" {self.destination.destination_name}" ) return restored_schemas - schema_info = job_client.get_stored_schema() + schema_info = job_client.get_stored_schema(schema_name) if schema_info is None: logger.info( f"The schema {schema.name} was not found in the destination" diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 3dc2a999d4..6cd0abd587 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -75,7 +75,7 @@ def some_data() -> Generator[DictStrStr, Any, None]: client: LanceDBClient with pipeline.destination_client() as client: # type: ignore # Check if we can get a stored schema and state. - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) print("Print dataset name", client.dataset_name) assert schema state = client.get_stored_state("test_pipeline_append") diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 11e0c88451..b8cf66608c 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -1134,8 +1134,8 @@ def _collect_table_counts(p, *items: str) -> Dict[str, int]: "_dlt_pipeline_state": 2, "_dlt_version": 2, } - sc1_old = c1.get_stored_schema() - sc2_old = c2.get_stored_schema() + sc1_old = c1.get_stored_schema(c1.schema.name) + sc2_old = c2.get_stored_schema(c2.schema.name) s1_old = c1.get_stored_state("p1") s2_old = c1.get_stored_state("p2") @@ -1172,8 +1172,8 @@ def some_data(): assert s2_old.version == s2.version # test accessors for schema - sc1 = c1.get_stored_schema() - sc2 = c2.get_stored_schema() + sc1 = c1.get_stored_schema(c1.schema.name) + sc2 = c2.get_stored_schema(c2.schema.name) assert sc1.version_hash != sc1_old.version_hash assert sc2.version_hash == sc2_old.version_hash assert sc1.version_hash != sc2.version_hash diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 51cb392b29..b78306210f 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -70,7 +70,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.sync_schema() # check if schema exists with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] - stored_schema = job_client.get_stored_schema() + stored_schema = job_client.get_stored_schema(job_client.schema.name) assert stored_schema is not None # dataset exists, still no table with pytest.raises(DestinationUndefinedEntity): @@ -97,7 +97,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - # schema.bump_version() p.sync_schema() with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] - stored_schema = job_client.get_stored_schema() + stored_schema = job_client.get_stored_schema(job_client.schema.name) assert stored_schema is not None # table is there but no state assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index 73f53221ed..48a180ac83 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -68,7 +68,7 @@ def some_data(): client: QdrantClient with pipeline.destination_client() as client: # type: ignore[assignment] # check if we can get a stored schema and state - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) print("Print dataset name", client.dataset_name) assert schema state = client.get_stored_state("test_pipeline_append") diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index 41287fcd2d..b60c6a8956 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -123,7 +123,7 @@ def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None: schema_str = json.dumps(schema.to_dict()) assert len(schema_str.encode("utf-8")) > client.capabilities.max_text_data_type_length client._update_schema_in_storage(schema) - schema_info = client.get_stored_schema() + schema_info = client.get_stored_schema(client.schema.name) assert schema_info.schema == schema_str # take base64 from db with client.sql_client.execute_query( diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 0bb88c4dd3..9f64722a1e 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -100,7 +100,7 @@ def test_get_schema_on_empty_storage(naming: str, client: SqlJobClientBase) -> N table_name, table_columns = list(client.get_storage_tables([version_table_name]))[0] assert table_name == version_table_name assert len(table_columns) == 0 - schema_info = client.get_stored_schema() + schema_info = client.get_stored_schema(client.schema.name) assert schema_info is None schema_info = client.get_stored_schema_by_hash("8a0298298823928939") assert schema_info is None @@ -128,7 +128,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: assert [len(table[1]) > 0 for table in storage_tables] == [True, True] # verify if schemas stored this_schema = client.get_stored_schema_by_hash(schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) # should point to the same schema assert this_schema == newest_schema # check fields @@ -151,7 +151,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: client._update_schema_in_storage(schema) sleep(1) this_schema = client.get_stored_schema_by_hash(schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema assert this_schema.version == schema.version == 3 assert this_schema.version_hash == schema.stored_version_hash @@ -166,7 +166,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: sleep(1) client._update_schema_in_storage(first_schema) this_schema = client.get_stored_schema_by_hash(first_schema.version_hash) - newest_schema = client.get_stored_schema() + newest_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema # error assert this_schema.version == first_schema.version == 3 assert this_schema.version_hash == first_schema.stored_version_hash @@ -176,17 +176,17 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: # mock other schema in client and get the newest schema. it should not exist... client.schema = Schema("ethereum") - assert client.get_stored_schema() is None + assert client.get_stored_schema(client.schema.name) is None client.schema._bump_version() schema_update = client.update_stored_schema() # no schema updates because schema has no tables assert schema_update == {} - that_info = client.get_stored_schema() + that_info = client.get_stored_schema(client.schema.name) assert that_info.schema_name == "ethereum" # get event schema again client.schema = Schema("event") - this_schema = client.get_stored_schema() + this_schema = client.get_stored_schema(client.schema.name) assert this_schema == newest_schema @@ -990,33 +990,37 @@ def add_schema_to_pipeline(s: Schema) -> None: add_schema_to_pipeline(s1_v1) p.default_schema_name = s1_v1.name with p.destination_client() as client: # type: ignore[assignment] + assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash assert client.get_stored_schema().version_hash == s1_v1.version_hash - assert client.get_stored_schema(any_schema_name=True).version_hash == s1_v1.version_hash + assert not client.get_stored_schema("other_schema") # now we add a different schema # but keep default schema name at v1 add_schema_to_pipeline(s2_v1) p.default_schema_name = s1_v1.name with p.destination_client() as client: # type: ignore[assignment] - assert client.get_stored_schema().version_hash == s1_v1.version_hash + assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash # here v2 will be selected as it is newer - assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v1.version_hash + assert client.get_stored_schema(None).version_hash == s2_v1.version_hash + assert not client.get_stored_schema("other_schema") # add two more version, add_schema_to_pipeline(s1_v2) add_schema_to_pipeline(s2_v2) p.default_schema_name = s1_v1.name with p.destination_client() as client: # type: ignore[assignment] - assert client.get_stored_schema().version_hash == s1_v2.version_hash + assert client.get_stored_schema("schema_1").version_hash == s1_v2.version_hash # here v2 will be selected as it is newer - assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash + assert client.get_stored_schema(None).version_hash == s2_v2.version_hash + assert not client.get_stored_schema("other_schema") # check same setup with other default schema name p.default_schema_name = s2_v1.name with p.destination_client() as client: # type: ignore[assignment] - assert client.get_stored_schema().version_hash == s2_v2.version_hash + assert client.get_stored_schema("schema_2").version_hash == s2_v2.version_hash # here v2 will be selected as it is newer - assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash + assert client.get_stored_schema(None).version_hash == s2_v2.version_hash + assert not client.get_stored_schema("other_schema") def prepare_schema(client: SqlJobClientBase, case: str) -> Tuple[List[Dict[str, Any]], str]: diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 3636b3e53a..0aaa18eac1 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -661,7 +661,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: client.sql_client.execute_sql(sql) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") # still can execute dml and selects - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("ABC") assert_load_id(client.sql_client, "ABC") @@ -670,7 +670,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: with pytest.raises(DatabaseTransientException): client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("EFG") assert_load_id(client.sql_client, "EFG") @@ -685,7 +685,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "IntegrityError") # assert isinstance(term_ex.value.dbapi_exception, (psycopg2.InternalError, psycopg2.)) - assert client.get_stored_schema() is not None + assert client.get_stored_schema(client.schema.name) is not None client.complete_load("HJK") assert_load_id(client.sql_client, "HJK") diff --git a/tests/load/weaviate/test_pipeline.py b/tests/load/weaviate/test_pipeline.py index fc46d00d05..6fcb9b7e4f 100644 --- a/tests/load/weaviate/test_pipeline.py +++ b/tests/load/weaviate/test_pipeline.py @@ -72,7 +72,7 @@ def some_data(): client: WeaviateClient with pipeline.destination_client() as client: # type: ignore[assignment] # check if we can get a stored schema and state - schema = client.get_stored_schema() + schema = client.get_stored_schema(client.schema.name) assert schema state = client.get_stored_state("test_pipeline_append") assert state