Skip to content

Commit

Permalink
[FEAT] iceberg writes unpartitioned (#2016)
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 authored Mar 20, 2024
1 parent e5231e1 commit c2db062
Show file tree
Hide file tree
Showing 17 changed files with 823 additions and 34 deletions.
16 changes: 15 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import builtins
from enum import Enum
from typing import Any, Callable
from typing import Any, Callable, TYPE_CHECKING

from daft.runners.partitioning import PartitionCacheEntry
from daft.execution import physical_plan
from daft.plan_scheduler.physical_plan_scheduler import PartitionT
import pyarrow
from daft.io.scan import ScanOperator

if TYPE_CHECKING:
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

class ImageMode(Enum):
"""
Supported image modes for Daft's image type.
Expand Down Expand Up @@ -1239,6 +1243,16 @@ class LogicalPlanBuilder:
compression: str | None = None,
io_config: IOConfig | None = None,
) -> LogicalPlanBuilder: ...
def iceberg_write(
self,
table_name: str,
table_location: str,
spec_id: int,
iceberg_schema: IcebergSchema,
iceberg_properties: IcebergTableProperties,
catalog_columns: list[str],
io_config: IOConfig | None = None,
) -> LogicalPlanBuilder: ...
def schema(self) -> PySchema: ...
def optimize(self) -> LogicalPlanBuilder: ...
def to_physical_plan_scheduler(self, cfg: PyDaftExecutionConfig) -> PhysicalPlanScheduler: ...
Expand Down
94 changes: 90 additions & 4 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# in order to support runtime typechecking across different Python versions.
# For technical details, see https://github.com/Eventual-Inc/Daft/pull/630

import os
import pathlib
import warnings
from dataclasses import dataclass
Expand Down Expand Up @@ -45,6 +46,7 @@
import pandas as pd
import pyarrow as pa
import dask
from pyiceberg.table import Table as IcebergTable

from daft.logical.schema import Schema

Expand Down Expand Up @@ -341,8 +343,6 @@ def write_parquet(
Files will be written to ``<root_dir>/*`` with randomly generated UUIDs as the file names.
Currently generates a parquet file per partition unless `partition_cols` are used, then the number of files can equal the number of partitions times the number of values of partition col.
.. NOTE::
This call is **blocking** and will execute the DataFrame when called
Expand Down Expand Up @@ -393,8 +393,6 @@ def write_csv(
Files will be written to ``<root_dir>/*`` with randomly generated UUIDs as the file names.
Currently generates a csv file per partition unless `partition_cols` are used, then the number of files can equal the number of partitions times the number of values of partition col.
.. NOTE::
This call is **blocking** and will execute the DataFrame when called
Expand Down Expand Up @@ -429,6 +427,94 @@ def write_csv(
result_df._preview = write_df._preview
return result_df

@DataframePublicAPI
def write_iceberg(self, table: "IcebergTable", mode: str = "append") -> "DataFrame":
"""Writes the DataFrame to an Iceberg Table, returning a new DataFrame with the operations that occurred.
Can be run in either `append` or `overwrite` mode which will either appends the rows in the DataFrame or will delete the existing rows and then append the DataFrame rows respectively.
.. NOTE::
This call is **blocking** and will execute the DataFrame when called
Args:
table (IcebergTable): Destination Iceberg Table to write dataframe to.
mode (str, optional): Operation mode of the write. `append` or `overwrite` Iceberg Table. Defaults to "append".
Returns:
DataFrame: The operations that occurred with this write.
"""

if len(table.spec().fields) > 0:
raise ValueError("Cannot write to partitioned Iceberg tables")

import pyiceberg
from packaging.version import parse

if parse(pyiceberg.__version__) < parse("0.6.0"):
raise ValueError(f"Write Iceberg is only supported on pyiceberg>=0.6.0, found {pyiceberg.__version__}")

from pyiceberg.table import _MergingSnapshotProducer
from pyiceberg.table.snapshots import Operation

operations = []
path = []
rows = []
size = []

if mode == "append":
operation = Operation.APPEND
elif mode == "overwrite":
operation = Operation.OVERWRITE
else:
raise ValueError(f"Only support `append` or `overwrite` mode. {mode} is unsupported")

# We perform the merge here since IcebergTable is not pickle-able
# We should be able to move to a transaction API for iceberg 0.7.0
merge = _MergingSnapshotProducer(operation=operation, table=table)

builder = self._builder.write_iceberg(table)
write_df = DataFrame(builder)
write_df.collect()

write_result = write_df.to_pydict()
assert "data_file" in write_result
data_files = write_result["data_file"]

if operation == Operation.OVERWRITE:
deleted_files = table.scan().plan_files()
else:
deleted_files = []

for data_file in data_files:
merge.append_data_file(data_file)
operations.append("ADD")
path.append(data_file.file_path)
rows.append(data_file.record_count)
size.append(data_file.file_size_in_bytes)

for pf in deleted_files:
data_file = pf.file
operations.append("DELETE")
path.append(data_file.file_path)
rows.append(data_file.record_count)
size.append(data_file.file_size_in_bytes)

merge.commit()
import pyarrow as pa

from daft import from_pydict

with_operations = from_pydict(
{
"operation": pa.array(operations, type=pa.string()),
"rows": pa.array(rows, type=pa.int64()),
"file_size": pa.array(size, type=pa.int64()),
"file_name": pa.array([os.path.basename(fp) for fp in path], type=pa.string()),
}
)
# NOTE: We are losing the history of the plan here.
# This is due to the fact that the logical plan of the write_iceberg returns datafiles but we want to return the above data
return with_operations

###
# DataFrame operations
###
Expand Down
47 changes: 45 additions & 2 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pathlib
import sys
from dataclasses import dataclass, field
from typing import Generic
from typing import TYPE_CHECKING, Generic

if sys.version_info < (3, 8):
from typing_extensions import Protocol
Expand All @@ -24,6 +24,11 @@
)
from daft.table import MicroPartition, table_io

if TYPE_CHECKING:
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties


ID_GEN = itertools.count()


Expand Down Expand Up @@ -354,7 +359,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
assert len(input_metadatas) == 1
return [
PartialPartitionMetadata(
num_rows=1, # We currently write one file per partition.
num_rows=None, # we can write more than 1 file per partition
size_bytes=None,
)
]
Expand All @@ -371,6 +376,44 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition:
)


@dataclass(frozen=True)
class WriteIceberg(SingleOutputInstruction):
base_path: str
iceberg_schema: IcebergSchema
iceberg_properties: IcebergTableProperties
spec_id: int
io_config: IOConfig | None

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._write_iceberg(inputs)

def _write_iceberg(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
[input] = inputs
partition = self._handle_file_write(
input=input,
)
return [partition]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 1
return [
PartialPartitionMetadata(
num_rows=None, # we can write more than 1 file per partition
size_bytes=None,
)
]

def _handle_file_write(self, input: MicroPartition) -> MicroPartition:
return table_io.write_iceberg(
input,
base_path=self.base_path,
schema=self.iceberg_schema,
properties=self.iceberg_properties,
spec_id=self.spec_id,
io_config=self.io_config,
)


@dataclass(frozen=True)
class Filter(SingleOutputInstruction):
predicate: ExpressionsProjection
Expand Down
32 changes: 31 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import math
import pathlib
from collections import deque
from typing import Generator, Generic, Iterable, Iterator, TypeVar, Union
from typing import TYPE_CHECKING, Generator, Generic, Iterable, Iterator, TypeVar, Union

from daft.context import get_context
from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest
Expand All @@ -45,6 +45,10 @@

T = TypeVar("T")

if TYPE_CHECKING:
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties


# A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks.
InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]]
Expand Down Expand Up @@ -105,6 +109,32 @@ def file_write(
)


def iceberg_write(
child_plan: InProgressPhysicalPlan[PartitionT],
base_path: str,
iceberg_schema: IcebergSchema,
iceberg_properties: IcebergTableProperties,
spec_id: int,
io_config: IOConfig | None,
) -> InProgressPhysicalPlan[PartitionT]:
"""Write the results of `child_plan` into pyiceberg data files described by `write_info`."""

yield from (
step.add_instruction(
execution_step.WriteIceberg(
base_path=base_path,
iceberg_schema=iceberg_schema,
iceberg_properties=iceberg_properties,
spec_id=spec_id,
io_config=io_config,
),
)
if isinstance(step, PartitionTaskBuilder)
else step
for step in child_plan
)


def pipeline_instruction(
child_plan: InProgressPhysicalPlan[PartitionT],
pipeable_instruction: Instruction,
Expand Down
24 changes: 24 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from daft.daft import (
FileFormat,
IOConfig,
Expand All @@ -16,6 +18,10 @@
from daft.runners.partitioning import PartitionT
from daft.table import MicroPartition

if TYPE_CHECKING:
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties


def scan_with_tasks(
scan_tasks: list[ScanTask],
Expand Down Expand Up @@ -253,3 +259,21 @@ def write_file(
expr_projection,
io_config,
)


def write_iceberg(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
base_path: str,
iceberg_schema: IcebergSchema,
iceberg_properties: IcebergTableProperties,
spec_id: int,
io_config: IOConfig | None,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
return physical_plan.iceberg_write(
input,
base_path=base_path,
iceberg_schema=iceberg_schema,
iceberg_properties=iceberg_properties,
spec_id=spec_id,
io_config=io_config,
)
15 changes: 15 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from daft.runners.partitioning import PartitionCacheEntry

if TYPE_CHECKING:
from pyiceberg.table import Table as IcebergTable

from daft.plan_scheduler.physical_plan_scheduler import PhysicalPlanScheduler


Expand Down Expand Up @@ -193,3 +195,16 @@ def write_tabular(
part_cols_pyexprs = [expr._expr for expr in partition_cols] if partition_cols is not None else None
builder = self._builder.table_write(str(root_dir), file_format, part_cols_pyexprs, compression, io_config)
return LogicalPlanBuilder(builder)

def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder:
from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config

name = ".".join(table.name())
location = f"{table.location()}/data"
spec_id = table.spec().spec_id
schema = table.schema()
props = table.properties
columns = [col.name for col in schema.columns]
io_config = _convert_iceberg_file_io_properties_to_io_config(table.io.properties)
builder = self._builder.iceberg_write(name, location, spec_id, schema, props, columns, io_config)
return LogicalPlanBuilder(builder)
Loading

0 comments on commit c2db062

Please sign in to comment.