Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not merge] dataset factory test #1952

Closed
wants to merge 13 commits into from
2 changes: 2 additions & 0 deletions dlt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from dlt.pipeline import progress
from dlt import destinations
from dlt.destinations.dataset import dataset as _dataset

pipeline = _pipeline
current = _current
Expand Down Expand Up @@ -79,6 +80,7 @@
"TCredentials",
"sources",
"destinations",
"_dataset",
]

# verify that no injection context was created
Expand Down
9 changes: 7 additions & 2 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration")
TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase")
TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration")
TDatasetType = Literal["dbapi", "ibis"]


DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}"

Expand Down Expand Up @@ -657,8 +659,11 @@ def __exit__(

class WithStateSync(ABC):
@abstractmethod
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""
Retrieves newest schema with given name from destination storage
If no name is provided, the newest schema found is retrieved.
"""
pass

@abstractmethod
Expand Down
95 changes: 88 additions & 7 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from typing import Any, Generator, AnyStr, Optional
from typing import Any, Generator, Optional, Union
from dlt.common.json import json

from contextlib import contextmanager
from dlt.common.destination.reference import (
SupportsReadableRelation,
SupportsReadableDataset,
TDatasetType,
TDestinationReferenceArg,
Destination,
JobClientBase,
WithStateSync,
DestinationClientDwhConfiguration,
)

from dlt.common.schema.typing import TTableSchemaColumns
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.common.schema import Schema


Expand Down Expand Up @@ -71,22 +78,85 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
class ReadableDBAPIDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset via dbapi"""

def __init__(self, client: SqlClientBase[Any], schema: Optional[Schema]) -> None:
self.client = client
self.schema = schema
def __init__(
self,
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
) -> None:
self._destination = Destination.from_reference(destination)
self._provided_schema = schema
self._dataset_name = dataset_name
self._sql_client: SqlClientBase[Any] = None
self._schema: Schema = None

@property
def schema(self) -> Schema:
self._ensure_client_and_schema()
return self._schema

@property
def sql_client(self) -> SqlClientBase[Any]:
self._ensure_client_and_schema()
return self._sql_client

def _destination_client(self, schema: Schema) -> JobClientBase:
client_spec = self._destination.spec()
if isinstance(client_spec, DestinationClientDwhConfiguration):
client_spec._bind_dataset_name(
dataset_name=self._dataset_name, default_schema_name=schema.name
)
return self._destination.client(schema, client_spec)

def _ensure_client_and_schema(self) -> None:
"""Lazy load schema and client"""
# full schema given, nothing to do
if not self._schema and isinstance(self._provided_schema, Schema):
self._schema = self._provided_schema

# schema name given, resolve it from destination by name
elif not self._schema and isinstance(self._provided_schema, str):
with self._destination_client(Schema(self._provided_schema)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema(self._provided_schema)
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# no schema name given, load newest schema from destination
elif not self._schema:
with self._destination_client(Schema(self._dataset_name)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema()
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# default to empty schema with dataset name if nothing found
if not self._schema:
self._schema = Schema(self._dataset_name)

# here we create the client bound to the resolved schema
if not self._sql_client:
destination_client = self._destination_client(self._schema)
if isinstance(destination_client, WithSqlClient):
self._sql_client = destination_client.sql_client
else:
raise Exception(
f"Destination {destination_client.config.destination_type} does not support"
" SqlClient."
)

def __call__(
self, query: Any, schema_columns: TTableSchemaColumns = None
) -> ReadableDBAPIRelation:
schema_columns = schema_columns or {}
return ReadableDBAPIRelation(client=self.client, query=query, schema_columns=schema_columns) # type: ignore[abstract]
return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract]

def table(self, table_name: str) -> SupportsReadableRelation:
# prepare query for table relation
schema_columns = (
self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {}
)
table_name = self.client.make_qualified_table_name(table_name)
table_name = self.sql_client.make_qualified_table_name(table_name)
query = f"SELECT * FROM {table_name}"
return self(query, schema_columns)

Expand All @@ -97,3 +167,14 @@ def __getitem__(self, table_name: str) -> SupportsReadableRelation:
def __getattr__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via property notation"""
return self.table(table_name)


def dataset(
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
dataset_type: TDatasetType = "dbapi",
) -> SupportsReadableDataset:
if dataset_type == "dbapi":
return ReadableDBAPIDataset(destination, dataset_name, schema)
raise NotImplementedError(f"Dataset of type {dataset_type} not implemented")
11 changes: 6 additions & 5 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,12 @@ def __init__(
# verify if staging layout is valid for Athena
# this will raise if the table prefix is not properly defined
# we actually that {table_name} is first, no {schema_name} is allowed
self.table_prefix_layout = path_utils.get_table_prefix_layout(
config.staging_config.layout,
supported_prefix_placeholders=[],
table_needs_own_folder=True,
)
if config.staging_config:
self.table_prefix_layout = path_utils.get_table_prefix_layout(
config.staging_config.layout,
supported_prefix_placeholders=[],
table_needs_own_folder=True,
)

sql_client = AthenaSQLClient(
config.normalize_dataset_name(schema),
Expand Down
44 changes: 24 additions & 20 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,29 +650,33 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]:
yield filepath, fileparts

def _get_stored_schema_by_hash_or_newest(
self, version_hash: str = None
self, version_hash: str = None, schema_name: str = None
) -> Optional[StorageSchemaInfo]:
"""Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name"""
version_hash = self._to_path_safe_string(version_hash)
# find newest schema for pipeline or by version hash
selected_path = None
newest_load_id = "0"
for filepath, fileparts in self._iter_stored_schema_files():
if (
not version_hash
and fileparts[0] == self.schema.name
and fileparts[1] > newest_load_id
):
newest_load_id = fileparts[1]
selected_path = filepath
elif fileparts[2] == version_hash:
selected_path = filepath
break
try:
selected_path = None
newest_load_id = "0"
for filepath, fileparts in self._iter_stored_schema_files():
if (
not version_hash
and (fileparts[0] == schema_name or (not schema_name))
and fileparts[1] > newest_load_id
):
newest_load_id = fileparts[1]
selected_path = filepath
elif fileparts[2] == version_hash:
selected_path = filepath
break

if selected_path:
return StorageSchemaInfo(
**json.loads(self.fs_client.read_text(selected_path, encoding="utf-8"))
)
if selected_path:
return StorageSchemaInfo(
**json.loads(self.fs_client.read_text(selected_path, encoding="utf-8"))
)
except DestinationUndefinedEntity:
# ignore missing table
pass

return None

Expand All @@ -699,9 +703,9 @@ def _store_current_schema(self) -> None:
# we always keep tabs on what the current schema is
self._write_to_json_file(filepath, version_info)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
return self._get_stored_schema_by_hash_or_newest()
return self._get_stored_schema_by_hash_or_newest(schema_name=schema_name)

def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema_by_hash_or_newest(version_hash)
Expand Down
11 changes: 5 additions & 6 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI
return None

@lancedb_error
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

Expand All @@ -553,11 +553,10 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
p_schema = self.schema.naming.normalize_identifier("schema")

try:
schemas = (
version_table.search().where(
f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True
)
).to_list()
query = version_table.search()
if schema_name:
query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True)
schemas = query.to_list()

# LanceDB's ORDER BY clause doesn't seem to work.
# See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341
Expand Down
21 changes: 14 additions & 7 deletions dlt/destinations/impl/qdrant/qdrant_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,23 +377,30 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
raise DestinationUndefinedEntity(str(e)) from e
raise

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
try:
scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name)
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")
response = self.db_client.scroll(
scroll_table_name,
with_payload=True,
scroll_filter=models.Filter(

name_filter = (
models.Filter(
must=[
models.FieldCondition(
key=p_schema_name,
match=models.MatchValue(value=self.schema.name),
match=models.MatchValue(value=schema_name),
)
]
),
)
if schema_name
else None
)

response = self.db_client.scroll(
scroll_table_name,
with_payload=True,
scroll_filter=name_filter,
limit=1,
order_by=models.OrderBy(
key=p_inserted_at,
Expand Down
8 changes: 5 additions & 3 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def _update_schema_in_storage(self, schema: Schema) -> None:
self.sql_client.execute_sql(table_obj.insert().values(schema_mapping))

def _get_stored_schema(
self, version_hash: Optional[str] = None, schema_name: Optional[str] = None
self,
version_hash: Optional[str] = None,
schema_name: Optional[str] = None,
) -> Optional[StorageSchemaInfo]:
version_table = self.schema.tables[self.schema.version_table_name]
table_obj = self._to_table_object(version_table) # type: ignore[arg-type]
Expand All @@ -267,9 +269,9 @@ def _get_stored_schema(
def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema(version_hash)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Get the latest stored schema"""
return self._get_stored_schema(schema_name=self.schema.name)
return self._get_stored_schema(schema_name=schema_name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.schema.tables.get(
Expand Down
19 changes: 13 additions & 6 deletions dlt/destinations/impl/weaviate/weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,19 +516,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
if len(load_records):
return StateInfo(**state)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")

name_filter = (
{
"path": [p_schema_name],
"operator": "Equal",
"valueString": schema_name,
}
if schema_name
else None
)

try:
record = self.get_records(
self.schema.version_table_name,
sort={"path": [p_inserted_at], "order": "desc"},
where={
"path": [p_schema_name],
"operator": "Equal",
"valueString": self.schema.name,
},
where=name_filter,
limit=1,
)[0]
return StorageSchemaInfo(**record)
Expand Down
19 changes: 13 additions & 6 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,21 @@ def _from_db_type(
) -> TColumnType:
pass

def get_stored_schema(self) -> StorageSchemaInfo:
def get_stored_schema(self, schema_name: str = None) -> StorageSchemaInfo:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at")
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, self.schema.name)
if not schema_name:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name}"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query)
else:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, schema_name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
Expand Down
Loading
Loading