diff --git a/daft/daft.pyi b/daft/daft.pyi index 6f9362c2a4..dfaea18054 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1,6 +1,6 @@ import builtins from enum import Enum -from typing import Any, Callable +from typing import Any, Callable, TYPE_CHECKING from daft.runners.partitioning import PartitionCacheEntry from daft.execution import physical_plan @@ -8,6 +8,10 @@ from daft.plan_scheduler.physical_plan_scheduler import PartitionT import pyarrow from daft.io.scan import ScanOperator +if TYPE_CHECKING: + from pyiceberg.schema import Schema as IcebergSchema + from pyiceberg.table import TableProperties as IcebergTableProperties + class ImageMode(Enum): """ Supported image modes for Daft's image type. @@ -1239,6 +1243,16 @@ class LogicalPlanBuilder: compression: str | None = None, io_config: IOConfig | None = None, ) -> LogicalPlanBuilder: ... + def iceberg_write( + self, + table_name: str, + table_location: str, + spec_id: int, + iceberg_schema: IcebergSchema, + iceberg_properties: IcebergTableProperties, + catalog_columns: list[str], + io_config: IOConfig | None = None, + ) -> LogicalPlanBuilder: ... def schema(self) -> PySchema: ... def optimize(self) -> LogicalPlanBuilder: ... def to_physical_plan_scheduler(self, cfg: PyDaftExecutionConfig) -> PhysicalPlanScheduler: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index c3a1a6289f..e9a8af32d8 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -4,6 +4,7 @@ # in order to support runtime typechecking across different Python versions. # For technical details, see https://github.com/Eventual-Inc/Daft/pull/630 +import os import pathlib import warnings from dataclasses import dataclass @@ -45,6 +46,7 @@ import pandas as pd import pyarrow as pa import dask + from pyiceberg.table import Table as IcebergTable from daft.logical.schema import Schema @@ -341,8 +343,6 @@ def write_parquet( Files will be written to ``/*`` with randomly generated UUIDs as the file names. - Currently generates a parquet file per partition unless `partition_cols` are used, then the number of files can equal the number of partitions times the number of values of partition col. - .. NOTE:: This call is **blocking** and will execute the DataFrame when called @@ -393,8 +393,6 @@ def write_csv( Files will be written to ``/*`` with randomly generated UUIDs as the file names. - Currently generates a csv file per partition unless `partition_cols` are used, then the number of files can equal the number of partitions times the number of values of partition col. - .. NOTE:: This call is **blocking** and will execute the DataFrame when called @@ -429,6 +427,94 @@ def write_csv( result_df._preview = write_df._preview return result_df + @DataframePublicAPI + def write_iceberg(self, table: "IcebergTable", mode: str = "append") -> "DataFrame": + """Writes the DataFrame to an Iceberg Table, returning a new DataFrame with the operations that occurred. + Can be run in either `append` or `overwrite` mode which will either appends the rows in the DataFrame or will delete the existing rows and then append the DataFrame rows respectively. + + .. NOTE:: + This call is **blocking** and will execute the DataFrame when called + + Args: + table (IcebergTable): Destination Iceberg Table to write dataframe to. + mode (str, optional): Operation mode of the write. `append` or `overwrite` Iceberg Table. Defaults to "append". + + Returns: + DataFrame: The operations that occurred with this write. + """ + + if len(table.spec().fields) > 0: + raise ValueError("Cannot write to partitioned Iceberg tables") + + import pyiceberg + from packaging.version import parse + + if parse(pyiceberg.__version__) < parse("0.6.0"): + raise ValueError(f"Write Iceberg is only supported on pyiceberg>=0.6.0, found {pyiceberg.__version__}") + + from pyiceberg.table import _MergingSnapshotProducer + from pyiceberg.table.snapshots import Operation + + operations = [] + path = [] + rows = [] + size = [] + + if mode == "append": + operation = Operation.APPEND + elif mode == "overwrite": + operation = Operation.OVERWRITE + else: + raise ValueError(f"Only support `append` or `overwrite` mode. {mode} is unsupported") + + # We perform the merge here since IcebergTable is not pickle-able + # We should be able to move to a transaction API for iceberg 0.7.0 + merge = _MergingSnapshotProducer(operation=operation, table=table) + + builder = self._builder.write_iceberg(table) + write_df = DataFrame(builder) + write_df.collect() + + write_result = write_df.to_pydict() + assert "data_file" in write_result + data_files = write_result["data_file"] + + if operation == Operation.OVERWRITE: + deleted_files = table.scan().plan_files() + else: + deleted_files = [] + + for data_file in data_files: + merge.append_data_file(data_file) + operations.append("ADD") + path.append(data_file.file_path) + rows.append(data_file.record_count) + size.append(data_file.file_size_in_bytes) + + for pf in deleted_files: + data_file = pf.file + operations.append("DELETE") + path.append(data_file.file_path) + rows.append(data_file.record_count) + size.append(data_file.file_size_in_bytes) + + merge.commit() + import pyarrow as pa + + from daft import from_pydict + + with_operations = from_pydict( + { + "operation": pa.array(operations, type=pa.string()), + "rows": pa.array(rows, type=pa.int64()), + "file_size": pa.array(size, type=pa.int64()), + "file_name": pa.array([os.path.basename(fp) for fp in path], type=pa.string()), + } + ) + # NOTE: We are losing the history of the plan here. + # This is due to the fact that the logical plan of the write_iceberg returns datafiles but we want to return the above data + return with_operations + ### # DataFrame operations ### diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index a70ab0a1da..5f0c9b021a 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -4,7 +4,7 @@ import pathlib import sys from dataclasses import dataclass, field -from typing import Generic +from typing import TYPE_CHECKING, Generic if sys.version_info < (3, 8): from typing_extensions import Protocol @@ -24,6 +24,11 @@ ) from daft.table import MicroPartition, table_io +if TYPE_CHECKING: + from pyiceberg.schema import Schema as IcebergSchema + from pyiceberg.table import TableProperties as IcebergTableProperties + + ID_GEN = itertools.count() @@ -354,7 +359,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) assert len(input_metadatas) == 1 return [ PartialPartitionMetadata( - num_rows=1, # We currently write one file per partition. + num_rows=None, # we can write more than 1 file per partition size_bytes=None, ) ] @@ -371,6 +376,44 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition: ) +@dataclass(frozen=True) +class WriteIceberg(SingleOutputInstruction): + base_path: str + iceberg_schema: IcebergSchema + iceberg_properties: IcebergTableProperties + spec_id: int + io_config: IOConfig | None + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return self._write_iceberg(inputs) + + def _write_iceberg(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + [input] = inputs + partition = self._handle_file_write( + input=input, + ) + return [partition] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + assert len(input_metadatas) == 1 + return [ + PartialPartitionMetadata( + num_rows=None, # we can write more than 1 file per partition + size_bytes=None, + ) + ] + + def _handle_file_write(self, input: MicroPartition) -> MicroPartition: + return table_io.write_iceberg( + input, + base_path=self.base_path, + schema=self.iceberg_schema, + properties=self.iceberg_properties, + spec_id=self.spec_id, + io_config=self.io_config, + ) + + @dataclass(frozen=True) class Filter(SingleOutputInstruction): predicate: ExpressionsProjection diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 6989915e77..12f0e823f8 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -19,7 +19,7 @@ import math import pathlib from collections import deque -from typing import Generator, Generic, Iterable, Iterator, TypeVar, Union +from typing import TYPE_CHECKING, Generator, Generic, Iterable, Iterator, TypeVar, Union from daft.context import get_context from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest @@ -45,6 +45,10 @@ T = TypeVar("T") +if TYPE_CHECKING: + from pyiceberg.schema import Schema as IcebergSchema + from pyiceberg.table import TableProperties as IcebergTableProperties + # A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks. InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]] @@ -105,6 +109,32 @@ def file_write( ) +def iceberg_write( + child_plan: InProgressPhysicalPlan[PartitionT], + base_path: str, + iceberg_schema: IcebergSchema, + iceberg_properties: IcebergTableProperties, + spec_id: int, + io_config: IOConfig | None, +) -> InProgressPhysicalPlan[PartitionT]: + """Write the results of `child_plan` into pyiceberg data files described by `write_info`.""" + + yield from ( + step.add_instruction( + execution_step.WriteIceberg( + base_path=base_path, + iceberg_schema=iceberg_schema, + iceberg_properties=iceberg_properties, + spec_id=spec_id, + io_config=io_config, + ), + ) + if isinstance(step, PartitionTaskBuilder) + else step + for step in child_plan + ) + + def pipeline_instruction( child_plan: InProgressPhysicalPlan[PartitionT], pipeable_instruction: Instruction, diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index a0b1b8de8f..ce3de78b4b 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from daft.daft import ( FileFormat, IOConfig, @@ -16,6 +18,10 @@ from daft.runners.partitioning import PartitionT from daft.table import MicroPartition +if TYPE_CHECKING: + from pyiceberg.schema import Schema as IcebergSchema + from pyiceberg.table import TableProperties as IcebergTableProperties + def scan_with_tasks( scan_tasks: list[ScanTask], @@ -253,3 +259,21 @@ def write_file( expr_projection, io_config, ) + + +def write_iceberg( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + base_path: str, + iceberg_schema: IcebergSchema, + iceberg_properties: IcebergTableProperties, + spec_id: int, + io_config: IOConfig | None, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + return physical_plan.iceberg_write( + input, + base_path=base_path, + iceberg_schema=iceberg_schema, + iceberg_properties=iceberg_properties, + spec_id=spec_id, + io_config=io_config, + ) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 18f147a389..a2c317f531 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -11,6 +11,8 @@ from daft.runners.partitioning import PartitionCacheEntry if TYPE_CHECKING: + from pyiceberg.table import Table as IcebergTable + from daft.plan_scheduler.physical_plan_scheduler import PhysicalPlanScheduler @@ -193,3 +195,16 @@ def write_tabular( part_cols_pyexprs = [expr._expr for expr in partition_cols] if partition_cols is not None else None builder = self._builder.table_write(str(root_dir), file_format, part_cols_pyexprs, compression, io_config) return LogicalPlanBuilder(builder) + + def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder: + from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config + + name = ".".join(table.name()) + location = f"{table.location()}/data" + spec_id = table.spec().spec_id + schema = table.schema() + props = table.properties + columns = [col.name for col in schema.columns] + io_config = _convert_iceberg_file_io_properties_to_io_config(table.io.properties) + builder = self._builder.iceberg_write(name, location, spec_id, schema, props, columns, io_config) + return LogicalPlanBuilder(builder) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 5eb0fcb332..9fc882c165 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -4,7 +4,7 @@ import math import pathlib from collections.abc import Callable, Generator -from typing import IO, Any, Union +from typing import IO, TYPE_CHECKING, Any, Union from uuid import uuid4 import pyarrow as pa @@ -47,6 +47,10 @@ FileInput = Union[pathlib.Path, str, IO[bytes]] +if TYPE_CHECKING: + from pyiceberg.schema import Schema as IcebergSchema + from pyiceberg.table import TableProperties as IcebergTableProperties + @contextlib.contextmanager def _open_stream( @@ -410,7 +414,7 @@ def write_tabular( partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__", ) -> MicroPartition: - from daft.utils import ARROW_VERSION + pass [resolved_path], fs = _resolve_paths_and_filesystem(path, io_config=io_config) if isinstance(path, pathlib.Path): @@ -492,28 +496,17 @@ def file_visitor(written_file, i=i): visited_paths.append(written_file.path) partition_idx.append(i) - kwargs = dict() - - if ARROW_VERSION >= (7, 0, 0): - kwargs["max_rows_per_file"] = rows_per_file - kwargs["min_rows_per_group"] = rows_per_row_group - kwargs["max_rows_per_group"] = rows_per_row_group - - if ARROW_VERSION >= (8, 0, 0) and not is_local_fs: - kwargs["create_dir"] = False - - pads.write_dataset( - arrow_table, - base_dir=full_path, - basename_template=str(uuid4()) + "-{i}." + format.default_extname, + _write_tabular_arrow_table( + arrow_table=arrow_table, + schema=arrow_table.schema, + full_path=full_path, format=format, - partitioning=None, - file_options=opts, + opts=opts, + fs=fs, + rows_per_file=rows_per_file, + rows_per_row_group=rows_per_row_group, + create_dir=is_local_fs, file_visitor=file_visitor, - use_threads=True, - existing_data_behavior="overwrite_or_ignore", - filesystem=fs, - **kwargs, ) data_dict: dict[str, Any] = { @@ -527,3 +520,184 @@ def file_visitor(written_file, i=i): for c_name in partition_values.column_names(): data_dict[c_name] = partition_values.get_column(c_name).take(partition_idx_series) return MicroPartition.from_pydict(data_dict) + + +def coerce_pyarrow_table_to_schema(pa_table: pa.Table, input_schema: pa.Schema) -> pa.Table: + """Coerces a PyArrow table to the supplied schema + + 1. For each field in `pa_table`, cast it to the field in `input_schema` if one with a matching name + is available + 2. Reorder the fields in the casted table to the supplied schema, dropping any fields in `pa_table` + that do not exist in the supplied schema + 3. If any fields in the supplied schema are not present, add a null array of the correct type + + Args: + pa_table (pa.Table): Table to coerce + input_schema (pa.Schema): Schema to coerce to + + Returns: + pa.Table: Table with schema == `input_schema` + """ + input_schema_names = set(input_schema.names) + + # Perform casting of types to provided schema's types + cast_to_schema = [ + input_schema.field(inferred_field.name) if inferred_field.name in input_schema_names else inferred_field + for inferred_field in pa_table.schema + ] + casted_table = pa_table.cast(pa.schema(cast_to_schema)) + + # Reorder and pad columns with a null column where necessary + pa_table_column_names = set(casted_table.column_names) + columns = [] + for name in input_schema.names: + if name in pa_table_column_names: + columns.append(casted_table[name]) + else: + columns.append(pa.nulls(len(casted_table), type=input_schema.field(name).type)) + return pa.table(columns, schema=input_schema) + + +def write_iceberg( + mp: MicroPartition, + base_path: str, + schema: IcebergSchema, + properties: IcebergTableProperties, + spec_id: int | None, + io_config: IOConfig | None = None, +): + + from pyiceberg.io.pyarrow import ( + compute_statistics_plan, + fill_parquet_file_metadata, + parquet_path_to_id_mapping, + schema_to_pyarrow, + ) + from pyiceberg.manifest import DataFile, DataFileContent + from pyiceberg.manifest import FileFormat as IcebergFileFormat + from pyiceberg.typedef import Record + + [resolved_path], fs = _resolve_paths_and_filesystem(base_path, io_config=io_config) + if isinstance(base_path, pathlib.Path): + path_str = str(base_path) + else: + path_str = base_path + + protocol = get_protocol_from_path(path_str) + canonicalized_protocol = canonicalize_protocol(protocol) + + data_files = [] + + def file_visitor(written_file, protocol=protocol): + + file_path = f"{protocol}://{written_file.path}" + size = written_file.size + metadata = written_file.metadata + # TODO Version guard pyarrow version + data_file = DataFile( + content=DataFileContent.DATA, + file_path=file_path, + file_format=IcebergFileFormat.PARQUET, + partition=Record(), + file_size_in_bytes=size, + # After this has been fixed: + # https://github.com/apache/iceberg-python/issues/271 + # sort_order_id=task.sort_order_id, + sort_order_id=None, + # Just copy these from the table for now + spec_id=spec_id, + equality_ids=None, + key_metadata=None, + ) + fill_parquet_file_metadata( + data_file=data_file, + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), + ) + data_files.append(data_file) + + is_local_fs = canonicalized_protocol == "file" + + execution_config = get_context().daft_execution_config + inflation_factor = execution_config.parquet_inflation_factor + + # TODO: these should be populate by `properties` but pyiceberg doesn't support them yet + target_file_size = 512 * 1024 * 1024 + TARGET_ROW_GROUP_SIZE = 128 * 1024 * 1024 + + arrow_table = mp.to_arrow() + + file_schema = schema_to_pyarrow(schema) + + # This ensures that we populate field_id for iceberg as well as fill in null values where needed + # This might break for nested fields with large_strings + # we should test that behavior + arrow_table = coerce_pyarrow_table_to_schema(arrow_table, file_schema) + + size_bytes = arrow_table.nbytes + + target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) + num_rows = len(arrow_table) + + rows_per_file = max(math.ceil(num_rows / target_num_files), 1) + + target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1) + rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) + + format = pads.ParquetFileFormat() + + _write_tabular_arrow_table( + arrow_table=arrow_table, + schema=file_schema, + full_path=resolved_path, + format=format, + opts=format.make_write_options(compression="zstd"), + fs=fs, + rows_per_file=rows_per_file, + rows_per_row_group=rows_per_row_group, + create_dir=is_local_fs, + file_visitor=file_visitor, + ) + + return MicroPartition.from_pydict({"data_file": Series.from_pylist(data_files, name="data_file", pyobj="force")}) + + +def _write_tabular_arrow_table( + arrow_table: pa.Table, + schema: pa.Schema | None, + full_path: str, + format: pads.FileFormat, + opts: pads.FileWriteOptions | None, + fs: Any, + rows_per_file: int, + rows_per_row_group: int, + create_dir: bool, + file_visitor: Callable | None, +): + kwargs = dict() + + from daft.utils import ARROW_VERSION + + if ARROW_VERSION >= (7, 0, 0): + kwargs["max_rows_per_file"] = rows_per_file + kwargs["min_rows_per_group"] = rows_per_row_group + kwargs["max_rows_per_group"] = rows_per_row_group + + if ARROW_VERSION >= (8, 0, 0) and not create_dir: + kwargs["create_dir"] = False + + pads.write_dataset( + arrow_table, + schema=schema, + base_dir=full_path, + basename_template=str(uuid4()) + "-{i}." + format.default_extname, + format=format, + partitioning=None, + file_options=opts, + file_visitor=file_visitor, + use_threads=True, + existing_data_behavior="overwrite_or_ignore", + filesystem=fs, + **kwargs, + ) diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 8c9c242001..e10eb057da 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -21,6 +21,7 @@ use daft_scan::{file_format::FileFormat, Pushdowns, ScanExternalInfo, ScanOperat #[cfg(feature = "python")] use { + crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, crate::{physical_plan::PhysicalPlanRef, source_info::InMemoryInfo}, common_daft_config::PyDaftExecutionConfig, daft_core::python::schema::PySchema, @@ -281,6 +282,35 @@ impl LogicalPlanBuilder { Ok(logical_plan.into()) } + #[cfg(feature = "python")] + #[allow(clippy::too_many_arguments)] + pub fn iceberg_write( + &self, + table_name: String, + table_location: String, + spec_id: i64, + iceberg_schema: PyObject, + iceberg_properties: PyObject, + io_config: Option, + catalog_columns: Vec, + ) -> DaftResult { + let sink_info = SinkInfo::CatalogInfo(CatalogInfo { + catalog: crate::sink_info::CatalogType::Iceberg(IcebergCatalogInfo { + table_name, + table_location, + spec_id, + iceberg_schema, + iceberg_properties, + io_config, + }), + catalog_columns, + }); + + let logical_plan: LogicalPlan = + logical_ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); + Ok(logical_plan.into()) + } + pub fn build(&self) -> Arc { self.plan.clone() } @@ -486,6 +516,31 @@ impl PyLogicalPlanBuilder { .into()) } + #[allow(clippy::too_many_arguments)] + pub fn iceberg_write( + &self, + table_name: String, + table_location: String, + spec_id: i64, + iceberg_schema: PyObject, + iceberg_properties: PyObject, + catalog_columns: Vec, + io_config: Option, + ) -> PyResult { + Ok(self + .builder + .iceberg_write( + table_name, + table_location, + spec_id, + iceberg_schema, + iceberg_properties, + io_config.map(|cfg| cfg.config), + catalog_columns, + )? + .into()) + } + pub fn schema(&self) -> PyResult { Ok(self.builder.schema().into()) } diff --git a/src/daft-plan/src/logical_ops/sink.rs b/src/daft-plan/src/logical_ops/sink.rs index 287689d1da..af38c690d4 100644 --- a/src/daft-plan/src/logical_ops/sink.rs +++ b/src/daft-plan/src/logical_ops/sink.rs @@ -19,18 +19,26 @@ pub struct Sink { impl Sink { pub(crate) fn try_new(input: Arc, sink_info: Arc) -> DaftResult { - let mut fields = vec![Field::new("path", daft_core::DataType::Utf8)]; let schema = input.schema(); - match sink_info.as_ref() { + let fields = match sink_info.as_ref() { SinkInfo::OutputFileInfo(output_file_info) => { + let mut fields = vec![Field::new("path", daft_core::DataType::Utf8)]; if let Some(ref pcols) = output_file_info.partition_cols { for pc in pcols { fields.push(pc.to_field(&schema)?); } } + fields } - } + #[cfg(feature = "python")] + SinkInfo::CatalogInfo(..) => { + vec![ + // We have to return datafile since PyIceberg Table is not picklable yet + Field::new("data_file", daft_core::DataType::Python), + ] + } + }; let schema = Schema::new(fields)?.into(); Ok(Self { input, @@ -47,6 +55,13 @@ impl Sink { res.push(format!("Sink: {:?}", output_file_info.file_format)); res.extend(output_file_info.multiline_display()); } + #[cfg(feature = "python")] + SinkInfo::CatalogInfo(catalog_info) => match &catalog_info.catalog { + crate::sink_info::CatalogType::Iceberg(iceberg_info) => { + res.push(format!("Sink: Iceberg({})", iceberg_info.table_name)); + res.extend(iceberg_info.multiline_display()); + } + }, } res.push(format!("Output schema = {}", self.schema.short_string())); res diff --git a/src/daft-plan/src/physical_ops/iceberg_write.rs b/src/daft-plan/src/physical_ops/iceberg_write.rs new file mode 100644 index 0000000000..f425fd3c9f --- /dev/null +++ b/src/daft-plan/src/physical_ops/iceberg_write.rs @@ -0,0 +1,35 @@ +use daft_core::schema::SchemaRef; + +use crate::{physical_plan::PhysicalPlanRef, sink_info::IcebergCatalogInfo}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] + +pub struct IcebergWrite { + pub schema: SchemaRef, + pub iceberg_info: IcebergCatalogInfo, + // Upstream node. + pub input: PhysicalPlanRef, +} + +impl IcebergWrite { + pub(crate) fn new( + schema: SchemaRef, + iceberg_info: IcebergCatalogInfo, + input: PhysicalPlanRef, + ) -> Self { + Self { + schema, + iceberg_info, + input, + } + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push("IcebergWrite:".to_string()); + res.push(format!("Schema = {}", self.schema.short_string())); + res.extend(self.iceberg_info.multiline_display()); + res + } +} diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index 346462f796..c6ed657a6f 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -10,6 +10,8 @@ mod filter; mod flatten; mod hash_join; #[cfg(feature = "python")] +mod iceberg_write; +#[cfg(feature = "python")] mod in_memory; mod json; mod limit; @@ -35,6 +37,8 @@ pub use filter::Filter; pub use flatten::Flatten; pub use hash_join::HashJoin; #[cfg(feature = "python")] +pub use iceberg_write::IcebergWrite; +#[cfg(feature = "python")] pub use in_memory::InMemoryScan; pub use json::TabularWriteJson; pub use limit::Limit; diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index fd2c39a553..4e7dc166a8 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -27,6 +27,9 @@ use crate::{ physical_ops::*, }; +#[cfg(feature = "python")] +use crate::sink_info::IcebergCatalogInfo; + pub(crate) type PhysicalPlanRef = Arc; /// Physical plan for a Daft query. @@ -59,6 +62,8 @@ pub enum PhysicalPlan { TabularWriteParquet(TabularWriteParquet), TabularWriteJson(TabularWriteJson), TabularWriteCsv(TabularWriteCsv), + #[cfg(feature = "python")] + IcebergWrite(IcebergWrite), } impl PhysicalPlan { @@ -197,6 +202,10 @@ impl PhysicalPlan { Self::TabularWriteParquet(TabularWriteParquet { input, .. }) => input.clustering_spec(), Self::TabularWriteCsv(TabularWriteCsv { input, .. }) => input.clustering_spec(), Self::TabularWriteJson(TabularWriteJson { input, .. }) => input.clustering_spec(), + #[cfg(feature = "python")] + Self::IcebergWrite(..) => { + ClusteringSpec::Unknown(UnknownClusteringConfig::new(1)).into() + } } } @@ -265,6 +274,8 @@ impl PhysicalPlan { Self::TabularWriteParquet(_) | Self::TabularWriteCsv(_) | Self::TabularWriteJson(_) => { None } + #[cfg(feature = "python")] + Self::IcebergWrite(_) => None, } } @@ -290,6 +301,8 @@ impl PhysicalPlan { Self::TabularWriteParquet(TabularWriteParquet { input, .. }) => vec![input], Self::TabularWriteCsv(TabularWriteCsv { input, .. }) => vec![input], Self::TabularWriteJson(TabularWriteJson { input, .. }) => vec![input], + #[cfg(feature = "python")] + Self::IcebergWrite(IcebergWrite { input, .. }) => vec![input], Self::HashJoin(HashJoin { left, right, .. }) => vec![left, right], Self::BroadcastJoin(BroadcastJoin { broadcaster, @@ -328,6 +341,8 @@ impl PhysicalPlan { Self::TabularWriteParquet(TabularWriteParquet { schema, file_info, .. }) => Self::TabularWriteParquet(TabularWriteParquet::new(schema.clone(), file_info.clone(), input.clone())), Self::TabularWriteCsv(TabularWriteCsv { schema, file_info, .. }) => Self::TabularWriteCsv(TabularWriteCsv::new(schema.clone(), file_info.clone(), input.clone())), Self::TabularWriteJson(TabularWriteJson { schema, file_info, .. }) => Self::TabularWriteJson(TabularWriteJson::new(schema.clone(), file_info.clone(), input.clone())), + #[cfg(feature = "python")] + Self::IcebergWrite(IcebergWrite { schema, iceberg_info, .. }) => Self::IcebergWrite(IcebergWrite::new(schema.clone(), iceberg_info.clone(), input.clone())), _ => panic!("Physical op {:?} has two inputs, but got one", self), }, [input1, input2] => match self { @@ -379,6 +394,8 @@ impl PhysicalPlan { Self::TabularWriteCsv(..) => "TabularWriteCsv", Self::TabularWriteJson(..) => "TabularWriteJson", Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId", + #[cfg(feature = "python")] + Self::IcebergWrite(..) => "IcebergWrite", }; name.to_string() } @@ -415,6 +432,8 @@ impl PhysicalPlan { Self::MonotonicallyIncreasingId(monotonically_increasing_id) => { monotonically_increasing_id.multiline_display() } + #[cfg(feature = "python")] + Self::IcebergWrite(iceberg_info) => iceberg_info.multiline_display(), } } @@ -518,6 +537,32 @@ fn tabular_write( Ok(py_iter.into()) } +#[allow(clippy::too_many_arguments)] +#[cfg(feature = "python")] +fn iceberg_write( + py: Python<'_>, + upstream_iter: PyObject, + iceberg_info: &IcebergCatalogInfo, +) -> PyResult { + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "write_iceberg"))? + .call1(( + upstream_iter, + &iceberg_info.table_location, + &iceberg_info.iceberg_schema, + &iceberg_info.iceberg_properties, + iceberg_info.spec_id, + iceberg_info + .io_config + .as_ref() + .map(|cfg| common_io_config::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(py_iter.into()) +} + #[cfg(feature = "python")] impl PhysicalPlan { pub fn to_partition_tasks( @@ -953,6 +998,12 @@ impl PhysicalPlan { partition_cols, io_config, ), + #[cfg(feature = "python")] + PhysicalPlan::IcebergWrite(IcebergWrite { + schema: _, + iceberg_info, + input, + }) => iceberg_write(py, input.to_partition_tasks(py, psets)?, iceberg_info), } } } diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 9c7b667bd9..c691e78e7c 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -720,6 +720,16 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe )), } } + #[cfg(feature = "python")] + SinkInfo::CatalogInfo(catalog_info) => match &catalog_info.catalog { + crate::sink_info::CatalogType::Iceberg(iceberg_info) => { + Ok(PhysicalPlan::IcebergWrite(IcebergWrite::new( + schema.clone(), + iceberg_info.clone(), + input_physical.into(), + ))) + } + }, } } LogicalPlan::MonotonicallyIncreasingId(LogicalMonotonicallyIncreasingId { diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index 3009f6edb7..8647628793 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -1,13 +1,24 @@ +use std::hash::Hash; + use common_io_config::IOConfig; use daft_dsl::Expr; use itertools::Itertools; +#[cfg(feature = "python")] +use pyo3::PyObject; + use crate::FileFormat; use serde::{Deserialize, Serialize}; +#[cfg(feature = "python")] +use daft_scan::py_object_serde::{deserialize_py_object, serialize_py_object}; + +#[allow(clippy::large_enum_variant)] #[derive(Debug, PartialEq, Eq, Hash)] pub enum SinkInfo { OutputFileInfo(OutputFileInfo), + #[cfg(feature = "python")] + CatalogInfo(CatalogInfo), } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -19,6 +30,76 @@ pub struct OutputFileInfo { pub io_config: Option, } +#[cfg(feature = "python")] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct CatalogInfo { + pub catalog: CatalogType, + pub catalog_columns: Vec, +} + +#[cfg(feature = "python")] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum CatalogType { + Iceberg(IcebergCatalogInfo), +} + +#[cfg(feature = "python")] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IcebergCatalogInfo { + pub table_name: String, + pub table_location: String, + pub spec_id: i64, + #[serde( + serialize_with = "serialize_py_object", + deserialize_with = "deserialize_py_object" + )] + pub iceberg_schema: PyObject, + + #[serde( + serialize_with = "serialize_py_object", + deserialize_with = "deserialize_py_object" + )] + pub iceberg_properties: PyObject, + pub io_config: Option, +} + +#[cfg(feature = "python")] +impl PartialEq for IcebergCatalogInfo { + fn eq(&self, other: &Self) -> bool { + self.table_name == other.table_name + && self.table_location == other.table_location + && self.spec_id == other.spec_id + && self.io_config == other.io_config + } +} +#[cfg(feature = "python")] +impl Eq for IcebergCatalogInfo {} + +#[cfg(feature = "python")] +impl Hash for IcebergCatalogInfo { + fn hash(&self, state: &mut H) { + self.table_name.hash(state); + self.table_location.hash(state); + self.spec_id.hash(state); + self.io_config.hash(state); + } +} + +#[cfg(feature = "python")] +impl IcebergCatalogInfo { + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push(format!("Table Name = {}", self.table_name)); + res.push(format!("Table Location = {}", self.table_location)); + res.push(format!("Spec ID = {}", self.spec_id)); + match &self.io_config { + None => res.push("IOConfig = None".to_string()), + Some(io_config) => res.push(format!("IOConfig = {}", io_config)), + }; + res + } +} + impl OutputFileInfo { pub fn new( root_dir: String, diff --git a/tests/integration/iceberg/test_pyiceberg_written_table_load.py b/tests/integration/iceberg/test_pyiceberg_written_table_load.py index c2af50f165..e7d1f982c3 100644 --- a/tests/integration/iceberg/test_pyiceberg_written_table_load.py +++ b/tests/integration/iceberg/test_pyiceberg_written_table_load.py @@ -28,10 +28,38 @@ def table_written_by_pyiceberg(local_iceberg_catalog): local_iceberg_catalog.drop_table("pyiceberg.map_table") +@contextlib.contextmanager +def table_written_by_daft(local_iceberg_catalog): + schema = pa.schema([("col", pa.int64()), ("mapCol", pa.map_(pa.int32(), pa.string()))]) + + data = {"col": [1, 2, 3], "mapCol": [[(1, "foo"), (2, "bar")], [(3, "baz")], [(4, "foobar")]]} + arrow_table = pa.Table.from_pydict(data, schema=schema) + try: + table = local_iceberg_catalog.create_table("pyiceberg.map_table", schema=schema) + df = daft.from_arrow(arrow_table) + df.write_iceberg(table, mode="overwrite") + table.refresh() + yield table + except Exception as e: + raise e + finally: + local_iceberg_catalog.drop_table("pyiceberg.map_table") + + @pytest.mark.integration() -def test_localdb_catalog(local_iceberg_catalog): +def test_pyiceberg_written_catalog(local_iceberg_catalog): with table_written_by_pyiceberg(local_iceberg_catalog) as catalog_table: df = daft.read_iceberg(catalog_table) daft_pandas = df.to_pandas() iceberg_pandas = catalog_table.scan().to_arrow().to_pandas() assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +@pytest.mark.integration() +@pytest.mark.skip +def test_daft_written_catalog(local_iceberg_catalog): + with table_written_by_daft(local_iceberg_catalog) as catalog_table: + df = daft.read_iceberg(catalog_table) + daft_pandas = df.to_pandas() + iceberg_pandas = catalog_table.scan().to_arrow().to_pandas() + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) diff --git a/tests/io/iceberg/__init__.py b/tests/io/iceberg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/io/iceberg/test_iceberg_writes.py b/tests/io/iceberg/test_iceberg_writes.py new file mode 100644 index 0000000000..a731be4119 --- /dev/null +++ b/tests/io/iceberg/test_iceberg_writes.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import pyarrow as pa +import pytest + +pyiceberg = pytest.importorskip("pyiceberg") + +from pyiceberg.catalog.sql import SqlCatalog + +import daft + + +@pytest.fixture(scope="function") +def local_catalog(tmpdir): + catalog = SqlCatalog( + "default", + **{ + "uri": f"sqlite:///{tmpdir}/pyiceberg_catalog.db", + "warehouse": f"file://{tmpdir}", + }, + ) + catalog.create_namespace("default") + return catalog + + +def test_read_after_write_append(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD"] + assert as_dict["rows"] == [5] + read_back = daft.read_iceberg(table) + assert as_arrow == read_back.to_arrow() + + +def test_read_after_write_overwrite(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD"] + assert as_dict["rows"] == [5] + + # write again (in append) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD"] + assert as_dict["rows"] == [5] + + read_back = daft.read_iceberg(table) + assert pa.concat_tables([as_arrow, as_arrow]) == read_back.to_arrow() + + # write again (in overwrite) + result = df.write_iceberg(table, mode="overwrite") + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD", "DELETE", "DELETE"] + assert as_dict["rows"] == [5, 5, 5] + + read_back = daft.read_iceberg(table) + assert as_arrow == read_back.to_arrow() + + +def test_read_and_overwrite(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD"] + assert as_dict["rows"] == [5] + + df = daft.read_iceberg(table).with_column("x", daft.col("x") + 1) + result = df.write_iceberg(table, mode="overwrite") + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD", "DELETE"] + assert as_dict["rows"] == [5, 5] + + read_back = daft.read_iceberg(table) + assert daft.from_pydict({"x": [2, 3, 4, 5, 6]}).to_arrow() == read_back.to_arrow() + + +def test_missing_columns_write(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + + df = daft.from_pydict({"y": [1, 2, 3, 4, 5]}) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD"] + assert as_dict["rows"] == [5] + read_back = daft.read_iceberg(table) + assert read_back.to_pydict() == {"x": [None] * 5} + + +def test_too_many_columns_write(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + + df = daft.from_pydict({"x": [1, 2, 3, 4, 5], "y": [6, 7, 8, 9, 10]}) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD"] + assert as_dict["rows"] == [5] + read_back = daft.read_iceberg(table) + assert as_arrow == read_back.to_arrow() + + +@pytest.mark.skip +def test_read_after_write_nested_fields(local_catalog): + # We need to cast Large Types such as LargeList and LargeString to the i32 variants + df = daft.from_pydict({"x": [["a", "b"], ["c", "d", "e"]]}) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD"] + assert as_dict["rows"] == [2] + read_back = daft.read_iceberg(table) + assert as_arrow == read_back.to_arrow()