Skip to content

Commit

Permalink
[BUG] Allow nulls in partition column (#2344)
Browse files Browse the repository at this point in the history
Closes #2292

Currently Daft panics if there are nulls in a partition column, the
detailed error message can be found in the linked issue.

A simple reproduction:
```
from deltalake import write_deltalake
import pandas as pd
import daft

df = pd.DataFrame(
    {
        "group": [1, 2, 3, None],
        "num": list(range(4)),
    }
)
write_deltalake("z", df, partition_by="group", mode="overwrite")

df = daft.read_deltalake("z")
df.show()
```

This PR modifies the partition spec equality logic and partition pruning
semantics to allow reading nulls in partition columns.
  • Loading branch information
colin-ho authored Jun 6, 2024
1 parent 408f977 commit 87f6706
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/daft-scan/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ pub mod pylib {
assert_eq!(boolean.len(), 1);
let value = boolean.get(0);
match value {
Some(false) => return Ok(None),
None | Some(true) => {}
None | Some(false) => return Ok(None),
Some(true) => {}
}
}
// TODO(Clark): Filter out scan tasks with pushed down filters + table stats?
Expand Down
17 changes: 13 additions & 4 deletions src/daft-stats/src/partition_spec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use daft_core::array::ops::DaftCompare;
use daft_core::array::ops::{DaftCompare, DaftLogical};
use daft_dsl::{ExprRef, Literal};
use daft_table::Table;

Expand Down Expand Up @@ -42,9 +42,18 @@ impl PartialEq for PartitionSpec {
for field_name in self.keys.schema.as_ref().fields.keys() {
let self_column = self.keys.get_column(field_name).unwrap();
let other_column = other.keys.get_column(field_name).unwrap();
let value_eq = self_column.equal(other_column).unwrap().get(0).unwrap();
if !value_eq {
return false;
if let Some(value_eq) = self_column.equal(other_column).unwrap().get(0) {
if !value_eq {
return false;
}
} else {
// For partition spec, we treat null as equal to null, in order to allow for
// partitioning on columns that may have nulls.
let self_null = self_column.is_null().unwrap();
let other_null = other_column.is_null().unwrap();
if self_null.xor(&other_null).unwrap().get(0).unwrap() {
return false;
}
}
}

Expand Down
5 changes: 3 additions & 2 deletions tests/io/delta_lake/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def num_partitions(request) -> int:
pytest.param((lambda i: datetime.datetime(2024, 2, i + 1), "f"), id="timestamp_partitioned"),
pytest.param((lambda i: datetime.date(2024, 2, i + 1), "g"), id="date_partitioned"),
pytest.param((lambda i: decimal.Decimal(str(1000 + i) + ".567"), "h"), id="decimal_partitioned"),
pytest.param((lambda i: i if i % 2 == 0 else None, "a"), id="partitioned_with_nulls"),
]
)
def partition_generator(request) -> tuple[callable, str]:
Expand Down Expand Up @@ -424,14 +425,14 @@ def cloud_paths(request) -> tuple[str, daft.io.IOConfig | None, DataCatalogTable
def deltalake_table(
cloud_paths, base_table: pa.Table, num_partitions: int, partition_generator: callable
) -> tuple[str, daft.io.IOConfig | None, dict[str, str], list[pa.Table]]:
partition_generator, _ = partition_generator
partition_generator, col = partition_generator
path, io_config, catalog_table = cloud_paths
storage_options = io_config_to_storage_options(io_config, path) if io_config is not None else None
parts = []
for i in range(num_partitions):
# Generate partition value and add partition column.
part_value = partition_generator(i)
part = base_table.append_column("part_idx", pa.array([part_value if part_value is not None else i] * 3))
part = base_table.append_column("part_idx", pa.array([part_value] * 3, type=base_table.column(col).type))
parts.append(part)
table = pa.concat_tables(parts)
deltalake = pytest.importorskip("deltalake")
Expand Down
2 changes: 1 addition & 1 deletion tests/io/delta_lake/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_deltalake_read_full(deltalake_table):
delta_schema = deltalake.DeltaTable(path, storage_options=io_config_to_storage_options(io_config, path)).schema()
expected_schema = Schema.from_pyarrow_schema(delta_schema.to_pyarrow())
assert df.schema() == expected_schema
assert_pyarrow_tables_equal(df.to_arrow().sort_by("part_idx"), pa.concat_tables(parts))
assert_pyarrow_tables_equal(df.to_arrow().sort_by("part_idx"), pa.concat_tables(parts).sort_by("part_idx"))


def test_deltalake_read_show(deltalake_table):
Expand Down
3 changes: 2 additions & 1 deletion tests/io/delta_lake/test_table_read_pushdowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def test_read_predicate_pushdown_on_data(deltalake_table):
expected_schema = Schema.from_pyarrow_schema(delta_schema.to_pyarrow())
assert df.schema() == expected_schema
assert_pyarrow_tables_equal(
df.to_arrow().sort_by("part_idx"), pa.concat_tables([table.filter(pc.field("a") == 2) for table in tables])
df.to_arrow().sort_by("part_idx"),
pa.concat_tables([table.filter(pc.field("a") == 2) for table in tables]).sort_by("part_idx"),
)


Expand Down

0 comments on commit 87f6706

Please sign in to comment.