Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add time travel to read_deltalake #3022

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime

Check warning on line 25 in daft/delta_lake/delta_lake_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/delta_lake/delta_lake_scan.py#L25

Added line #L25 was not covered by tests

logger = logging.getLogger(__name__)


class DeltaLakeScanOperator(ScanOperator):
def __init__(self, table_uri: str, storage_config: StorageConfig) -> None:
def __init__(
self, table_uri: str, storage_config: StorageConfig, version: int | str | datetime | None = None
) -> None:
super().__init__()

# Unfortunately delta-rs doesn't do very good inference of credentials for S3. Thus the current Daft behavior of passing
Expand Down Expand Up @@ -67,6 +70,9 @@
table_uri, storage_options=io_config_to_storage_options(deltalake_sdk_io_config, table_uri)
)

if version is not None:
self._table.load_as_version(version)

self._storage_config = storage_config
self._schema = Schema.from_pyarrow_schema(self._table.schema().to_pyarrow())
partition_columns = set(self._table.metadata().partition_columns)
Expand Down
12 changes: 9 additions & 3 deletions daft/io/_deltalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from daft.logical.builder import LogicalPlanBuilder

if TYPE_CHECKING:
from datetime import datetime

Check warning on line 14 in daft/io/_deltalake.py

View check run for this annotation

Codecov / codecov/patch

daft/io/_deltalake.py#L14

Added line #L14 was not covered by tests

from daft.unity_catalog import UnityCatalogTable


@PublicAPI
def read_deltalake(
table: Union[str, DataCatalogTable, "UnityCatalogTable"],
version: Optional[Union[int, str, "datetime"]] = None,
io_config: Optional["IOConfig"] = None,
_multithreaded_io: Optional[bool] = None,
) -> DataFrame:
Expand All @@ -37,8 +40,11 @@
Args:
table: Either a URI for the Delta Lake table or a :class:`~daft.io.catalog.DataCatalogTable` instance
referencing a table in a data catalog, such as AWS Glue Data Catalog or Databricks Unity Catalog.
io_config: A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None.
_multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing
version (optional): If int is passed, read the table with specified version number. Otherwise if string or datetime,
read the timestamp version of the table. Strings must be RFC 3339 and ISO 8601 date and time format.
Datetimes are assumed to be UTC timezone unless specified. By default, read the latest version of the table.
io_config (optional): A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None.
_multithreaded_io (optional): Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing
the amount of system resources (number of connections and thread contention) when running in the Ray runner.
Defaults to None, which will let Daft decide based on the runner it is currently using.

Expand Down Expand Up @@ -69,7 +75,7 @@
raise ValueError(
f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}"
)
delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config)
delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config, version=version)

handle = ScanOperatorHandle.from_python_scan_operator(delta_lake_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
Expand Down
22 changes: 22 additions & 0 deletions tests/io/delta_lake/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ def test_deltalake_read_row_group_splits_with_limit(tmp_path, base_table):
df = df.limit(2)
df.collect()
assert len(df) == 2, "Length of non-materialized data when read through deltalake should be correct"


def test_deltalake_read_versioned(tmp_path, base_table):
deltalake = pytest.importorskip("deltalake")
path = tmp_path / "some_table"
deltalake.write_deltalake(path, base_table)

updated_columns = base_table.columns + [pa.array(["x", "y", "z"])]
updated_column_names = base_table.column_names + ["new_column"]
updated_table = pa.Table.from_arrays(updated_columns, names=updated_column_names)
deltalake.write_deltalake(path, updated_table, mode="overwrite", schema_mode="overwrite")

for version in [None, 1]:
df = daft.read_deltalake(str(path), version=version)
expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path).schema().to_pyarrow())
assert df.schema() == expected_schema
assert_pyarrow_tables_equal(df.to_arrow(), updated_table)

df = daft.read_deltalake(str(path), version=0)
expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path, version=0).schema().to_pyarrow())
assert df.schema() == expected_schema
assert_pyarrow_tables_equal(df.to_arrow(), base_table)
Loading