Skip to content

Commit

Permalink
change signature and behavior of get_stored_schema
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 14, 2024
1 parent 9354e86 commit 5f3dbdf
Show file tree
Hide file tree
Showing 17 changed files with 76 additions and 64 deletions.
7 changes: 5 additions & 2 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,11 @@ def __exit__(

class WithStateSync(ABC):
@abstractmethod
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage, setting any_schema_name to true will return the newest schema regardless of the schema name"""
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""
Retrieves newest schema with given name from destination storage
If no name is provided, the newest schema found is retrieved.
"""
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ def _ensure_client_and_schema(self) -> None:
elif not self.schema and isinstance(self._provided_schema, str):
with self._destination_client(Schema(self._provided_schema)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema()
stored_schema = client.get_stored_schema(self._provided_schema)
if stored_schema:
self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# no schema name given, load newest schema from destination
elif not self.schema:
with self._destination_client(Schema(self._dataset_name)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema(any_schema_name=True)
stored_schema = client.get_stored_schema()
if stored_schema:
self.schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

Expand Down
8 changes: 4 additions & 4 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]:
yield filepath, fileparts

def _get_stored_schema_by_hash_or_newest(
self, version_hash: str = None, any_schema_name: bool = False
self, version_hash: str = None, schema_name: str = None
) -> Optional[StorageSchemaInfo]:
"""Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name"""
version_hash = self._to_path_safe_string(version_hash)
Expand All @@ -661,7 +661,7 @@ def _get_stored_schema_by_hash_or_newest(
for filepath, fileparts in self._iter_stored_schema_files():
if (
not version_hash
and (fileparts[0] == self.schema.name or any_schema_name)
and (fileparts[0] == schema_name or (not schema_name))
and fileparts[1] > newest_load_id
):
newest_load_id = fileparts[1]
Expand Down Expand Up @@ -703,9 +703,9 @@ def _store_current_schema(self) -> None:
# we always keep tabs on what the current schema is
self._write_to_json_file(filepath, version_info)

def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
return self._get_stored_schema_by_hash_or_newest(any_schema_name=any_schema_name)
return self._get_stored_schema_by_hash_or_newest(schema_name=schema_name)

def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema_by_hash_or_newest(version_hash)
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI
return None

@lancedb_error
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

Expand All @@ -554,8 +554,8 @@ def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSc

try:
query = version_table.search()
if not any_schema_name:
query = query.where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True)
if schema_name:
query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True)
schemas = query.to_list()

# LanceDB's ORDER BY clause doesn't seem to work.
Expand Down
22 changes: 13 additions & 9 deletions dlt/destinations/impl/qdrant/qdrant_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,26 +377,30 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
raise DestinationUndefinedEntity(str(e)) from e
raise

def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
try:
scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name)
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")

name_filter = models.Filter(
must=[
models.FieldCondition(
key=p_schema_name,
match=models.MatchValue(value=self.schema.name),
)
]
name_filter = (
models.Filter(
must=[
models.FieldCondition(
key=p_schema_name,
match=models.MatchValue(value=schema_name),
)
]
)
if schema_name
else None
)

response = self.db_client.scroll(
scroll_table_name,
with_payload=True,
scroll_filter=None if any_schema_name else name_filter,
scroll_filter=name_filter,
limit=1,
order_by=models.OrderBy(
key=p_inserted_at,
Expand Down
9 changes: 3 additions & 6 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def _get_stored_schema(
self,
version_hash: Optional[str] = None,
schema_name: Optional[str] = None,
any_schema_name: bool = False,
) -> Optional[StorageSchemaInfo]:
version_table = self.schema.tables[self.schema.version_table_name]
table_obj = self._to_table_object(version_table) # type: ignore[arg-type]
Expand All @@ -252,7 +251,7 @@ def _get_stored_schema(
if version_hash is not None:
version_hash_col = self.schema.naming.normalize_identifier("version_hash")
q = q.where(table_obj.c[version_hash_col] == version_hash)
if schema_name is not None and not any_schema_name:
if schema_name is not None:
schema_name_col = self.schema.naming.normalize_identifier("schema_name")
q = q.where(table_obj.c[schema_name_col] == schema_name)
inserted_at_col = self.schema.naming.normalize_identifier("inserted_at")
Expand All @@ -270,11 +269,9 @@ def _get_stored_schema(
def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema(version_hash)

def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Get the latest stored schema"""
return self._get_stored_schema(
schema_name=self.schema.name, any_schema_name=any_schema_name
)
return self._get_stored_schema(schema_name=schema_name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.schema.tables.get(
Expand Down
18 changes: 11 additions & 7 deletions dlt/destinations/impl/weaviate/weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,22 +516,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
if len(load_records):
return StateInfo(**state)

def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")

name_filter = {
"path": [p_schema_name],
"operator": "Equal",
"valueString": self.schema.name,
}
name_filter = (
{
"path": [p_schema_name],
"operator": "Equal",
"valueString": self.schema.name,
}
if schema_name
else None
)

try:
record = self.get_records(
self.schema.version_table_name,
sort={"path": [p_inserted_at], "order": "desc"},
where=None if any_schema_name else name_filter,
where=name_filter,
limit=1,
)[0]
return StorageSchemaInfo(**record)
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,10 @@ def _from_db_type(
) -> TColumnType:
pass

def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo:
def get_stored_schema(self, schema_name: str = None) -> StorageSchemaInfo:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at")
if any_schema_name:
if not schema_name:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name}"
f" ORDER BY {c_inserted_at} DESC;"
Expand All @@ -411,7 +411,7 @@ def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo:
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, self.schema.name)
return self._row_to_schema_info(query, schema_name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
Expand Down
2 changes: 1 addition & 1 deletion dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ def _get_schemas_from_destination(
f" {self.destination.destination_name}"
)
return restored_schemas
schema_info = job_client.get_stored_schema()
schema_info = job_client.get_stored_schema(schema_name)
if schema_info is None:
logger.info(
f"The schema {schema.name} was not found in the destination"
Expand Down
2 changes: 1 addition & 1 deletion tests/load/lancedb/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def some_data() -> Generator[DictStrStr, Any, None]:
client: LanceDBClient
with pipeline.destination_client() as client: # type: ignore
# Check if we can get a stored schema and state.
schema = client.get_stored_schema()
schema = client.get_stored_schema(client.schema.name)
print("Print dataset name", client.dataset_name)
assert schema
state = client.get_stored_state("test_pipeline_append")
Expand Down
8 changes: 4 additions & 4 deletions tests/load/pipeline/test_filesystem_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,8 @@ def _collect_table_counts(p, *items: str) -> Dict[str, int]:
"_dlt_pipeline_state": 2,
"_dlt_version": 2,
}
sc1_old = c1.get_stored_schema()
sc2_old = c2.get_stored_schema()
sc1_old = c1.get_stored_schema(c1.schema.name)
sc2_old = c2.get_stored_schema(c2.schema.name)
s1_old = c1.get_stored_state("p1")
s2_old = c1.get_stored_state("p2")

Expand Down Expand Up @@ -1172,8 +1172,8 @@ def some_data():
assert s2_old.version == s2.version

# test accessors for schema
sc1 = c1.get_stored_schema()
sc2 = c2.get_stored_schema()
sc1 = c1.get_stored_schema(c1.schema.name)
sc2 = c2.get_stored_schema(c2.schema.name)
assert sc1.version_hash != sc1_old.version_hash
assert sc2.version_hash == sc2_old.version_hash
assert sc1.version_hash != sc2.version_hash
Expand Down
4 changes: 2 additions & 2 deletions tests/load/pipeline/test_restore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) -
p.sync_schema()
# check if schema exists
with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment]
stored_schema = job_client.get_stored_schema()
stored_schema = job_client.get_stored_schema(job_client.schema.name)
assert stored_schema is not None
# dataset exists, still no table
with pytest.raises(DestinationUndefinedEntity):
Expand All @@ -97,7 +97,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) -
# schema.bump_version()
p.sync_schema()
with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment]
stored_schema = job_client.get_stored_schema()
stored_schema = job_client.get_stored_schema(job_client.schema.name)
assert stored_schema is not None
# table is there but no state
assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None
Expand Down
2 changes: 1 addition & 1 deletion tests/load/qdrant/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def some_data():
client: QdrantClient
with pipeline.destination_client() as client: # type: ignore[assignment]
# check if we can get a stored schema and state
schema = client.get_stored_schema()
schema = client.get_stored_schema(client.schema.name)
print("Print dataset name", client.dataset_name)
assert schema
state = client.get_stored_state("test_pipeline_append")
Expand Down
2 changes: 1 addition & 1 deletion tests/load/redshift/test_redshift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None:
schema_str = json.dumps(schema.to_dict())
assert len(schema_str.encode("utf-8")) > client.capabilities.max_text_data_type_length
client._update_schema_in_storage(schema)
schema_info = client.get_stored_schema()
schema_info = client.get_stored_schema(client.schema.name)
assert schema_info.schema == schema_str
# take base64 from db
with client.sql_client.execute_query(
Expand Down
32 changes: 18 additions & 14 deletions tests/load/test_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_get_schema_on_empty_storage(naming: str, client: SqlJobClientBase) -> N
table_name, table_columns = list(client.get_storage_tables([version_table_name]))[0]
assert table_name == version_table_name
assert len(table_columns) == 0
schema_info = client.get_stored_schema()
schema_info = client.get_stored_schema(client.schema.name)
assert schema_info is None
schema_info = client.get_stored_schema_by_hash("8a0298298823928939")
assert schema_info is None
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None:
assert [len(table[1]) > 0 for table in storage_tables] == [True, True]
# verify if schemas stored
this_schema = client.get_stored_schema_by_hash(schema.version_hash)
newest_schema = client.get_stored_schema()
newest_schema = client.get_stored_schema(client.schema.name)
# should point to the same schema
assert this_schema == newest_schema
# check fields
Expand All @@ -151,7 +151,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None:
client._update_schema_in_storage(schema)
sleep(1)
this_schema = client.get_stored_schema_by_hash(schema.version_hash)
newest_schema = client.get_stored_schema()
newest_schema = client.get_stored_schema(client.schema.name)
assert this_schema == newest_schema
assert this_schema.version == schema.version == 3
assert this_schema.version_hash == schema.stored_version_hash
Expand All @@ -166,7 +166,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None:
sleep(1)
client._update_schema_in_storage(first_schema)
this_schema = client.get_stored_schema_by_hash(first_schema.version_hash)
newest_schema = client.get_stored_schema()
newest_schema = client.get_stored_schema(client.schema.name)
assert this_schema == newest_schema # error
assert this_schema.version == first_schema.version == 3
assert this_schema.version_hash == first_schema.stored_version_hash
Expand All @@ -176,17 +176,17 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None:

# mock other schema in client and get the newest schema. it should not exist...
client.schema = Schema("ethereum")
assert client.get_stored_schema() is None
assert client.get_stored_schema(client.schema.name) is None
client.schema._bump_version()
schema_update = client.update_stored_schema()
# no schema updates because schema has no tables
assert schema_update == {}
that_info = client.get_stored_schema()
that_info = client.get_stored_schema(client.schema.name)
assert that_info.schema_name == "ethereum"

# get event schema again
client.schema = Schema("event")
this_schema = client.get_stored_schema()
this_schema = client.get_stored_schema(client.schema.name)
assert this_schema == newest_schema


Expand Down Expand Up @@ -990,33 +990,37 @@ def add_schema_to_pipeline(s: Schema) -> None:
add_schema_to_pipeline(s1_v1)
p.default_schema_name = s1_v1.name
with p.destination_client() as client: # type: ignore[assignment]
assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash
assert client.get_stored_schema().version_hash == s1_v1.version_hash
assert client.get_stored_schema(any_schema_name=True).version_hash == s1_v1.version_hash
assert not client.get_stored_schema("other_schema")

# now we add a different schema
# but keep default schema name at v1
add_schema_to_pipeline(s2_v1)
p.default_schema_name = s1_v1.name
with p.destination_client() as client: # type: ignore[assignment]
assert client.get_stored_schema().version_hash == s1_v1.version_hash
assert client.get_stored_schema("schema_1").version_hash == s1_v1.version_hash
# here v2 will be selected as it is newer
assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v1.version_hash
assert client.get_stored_schema(None).version_hash == s2_v1.version_hash
assert not client.get_stored_schema("other_schema")

# add two more version,
add_schema_to_pipeline(s1_v2)
add_schema_to_pipeline(s2_v2)
p.default_schema_name = s1_v1.name
with p.destination_client() as client: # type: ignore[assignment]
assert client.get_stored_schema().version_hash == s1_v2.version_hash
assert client.get_stored_schema("schema_1").version_hash == s1_v2.version_hash
# here v2 will be selected as it is newer
assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash
assert client.get_stored_schema(None).version_hash == s2_v2.version_hash
assert not client.get_stored_schema("other_schema")

# check same setup with other default schema name
p.default_schema_name = s2_v1.name
with p.destination_client() as client: # type: ignore[assignment]
assert client.get_stored_schema().version_hash == s2_v2.version_hash
assert client.get_stored_schema("schema_2").version_hash == s2_v2.version_hash
# here v2 will be selected as it is newer
assert client.get_stored_schema(any_schema_name=True).version_hash == s2_v2.version_hash
assert client.get_stored_schema(None).version_hash == s2_v2.version_hash
assert not client.get_stored_schema("other_schema")


def prepare_schema(client: SqlJobClientBase, case: str) -> Tuple[List[Dict[str, Any]], str]:
Expand Down
Loading

0 comments on commit 5f3dbdf

Please sign in to comment.