Skip to content

Commit

Permalink
Merge pull request #9 from bmsuisse/dev
Browse files Browse the repository at this point in the history
Better decimal support for DuckDB
  • Loading branch information
aersam authored Jun 18, 2024
2 parents b75f17d + b6115c1 commit 0ad4879
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 106 deletions.
18 changes: 14 additions & 4 deletions deltalake2db/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import duckdb


def _cast(s: ex.Expression, t: Optional[ex.DataType.Type]):
def _cast(s: ex.Expression, t: Optional[ex.DATA_TYPE]):
if t is None:
return s
return ex.cast(s, t)
Expand All @@ -25,7 +25,10 @@ def _dummy_expr(
from deltalake.schema import PrimitiveType

if isinstance(field_type, PrimitiveType):
cast_as = type_map.get(field_type.type)
if str(field_type).startswith("decimal("):
cast_as = ex.DataType.build(str(field_type))
else:
cast_as = type_map.get(field_type.type)
return _cast(ex.Null(), cast_as)
elif isinstance(field_type, StructType):
return squ.struct(
Expand Down Expand Up @@ -227,6 +230,7 @@ def apply_storage_options(
"short": ex.DataType.Type.SMALLINT,
"binary": ex.DataType.Type.BINARY,
"timestampNtz": ex.DataType.Type.TIMESTAMP,
"timestamp_ntz": ex.DataType.Type.TIMESTAMP,
"decimal": ex.DataType.Type.DECIMAL,
}

Expand Down Expand Up @@ -316,8 +320,12 @@ def get_sql_for_delta_expr(
)

cast_as = None

if isinstance(field.type, PrimitiveType):
cast_as = type_map.get(field.type.type)
if str(field.type).startswith("decimal("):
cast_as = ex.DataType.build(str(field.type))
else:
cast_as = type_map.get(field.type.type)
if "partition_values" in ac and phys_name in ac["partition_values"]:
cols_sql.append(
_cast(
Expand Down Expand Up @@ -349,7 +357,9 @@ def get_sql_for_delta_expr(
else:
cols_sql.append(ex.Null().as_(field_name))

select_pq = ex.select(*cols_sql).from_(
select_pq = ex.select(
*cols_sql
).from_(
read_parquet(ex.convert(fullpath))
) # "SELECT " + ", ".join(cols_sql) + " FROM read_parquet('" + fullpath + "')"
file_selects.append(select_pq)
Expand Down
5 changes: 3 additions & 2 deletions deltalake2db/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def _get_type(dtype: "DataType") -> "pl.PolarsDataType":
elif dtype_str == "binary":
return pl.Binary
elif dtype_str.startswith("decimal"):
return pl.Decimal
precision, scale = dtype_str.split("(")[1].split(")")[0].split(",")
return pl.Decimal(int(precision), int(scale))
elif dtype_str == "short":
return pl.Int16
elif dtype == "byte":
Expand Down Expand Up @@ -172,7 +173,7 @@ def scan_delta_union(
base_ds = pl.scan_parquet(
fullpath, storage_options=delta_table._storage_options
)
parquet_schema = base_ds.limit(0).schema
parquet_schema = base_ds.limit(0).collect_schema()
selects = []
for field in all_fields:
pl_dtype = _get_type(field.type)
Expand Down
191 changes: 100 additions & 91 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "deltalake2db"
version = "0.3.6"
version = "0.3.9"
description = ""
authors = ["Adrian Ehrsam <[email protected]>"]
license = "MIT"
Expand All @@ -14,14 +14,14 @@ azure-identity = { version = "^1.16.0", optional = true }

[tool.poetry.group.dev.dependencies]
pyright = "^1.1.352"
polars = "^0.20.16"
polars = { version = ">=1.0.0-beta.1", allow-prereleases = true }
duckdb = "^1.0.0"
ruff = "^0.4.3"

[tool.poetry.group.test.dependencies]
pytest-cov = "^4.1.0"
pytest = "^8.1.0"
polars = "^0.20.13"
polars = { version = ">=1.0.0-beta.1", allow-prereleases = true }
duckdb = "^1.0.0"
docker = "^7.0.0"
azure-storage-blob = "^12.19.1"
Expand Down
11 changes: 11 additions & 0 deletions tests/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,20 @@ def test_filter_number():
duckdb_create_view_for_delta(con, dt, "delta_table", conditions={"Age": 23.0})
con.execute("select FirstName from delta_table")
col_names = [c[0] for c in con.description]
assert col_names == ["FirstName"]
names = con.fetchall()
assert len(names) == 1
assert names[0][0] == "Peter"
with duckdb.connect() as con:
duckdb_create_view_for_delta(con, dt, "delta_table_1", conditions={"Age": 23.0})
con.execute("select * from delta_table_1")
col_types_1 = [c[1] for c in con.description]
duckdb_create_view_for_delta(
con, dt, "delta_table_2", conditions={"Age": 500.0}
)
con.execute("select * from delta_table_2")
col_types_2 = [c[1] for c in con.description]
assert col_types_1 == col_types_2


def test_filter_name():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_duckdb_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_col_mapping(storage_options):
def test_empty_struct(storage_options):
# >>> duckdb.execute("""Select { 'lat': 1 } as tester union all select Null""").fetchall()
import pyarrow as pa
import pyarrow.compute as pc

dt = DeltaTable("az://testlakedb/td/delta/fake", storage_options=storage_options)

Expand All @@ -96,7 +97,7 @@ def test_empty_struct(storage_options):
df = con.execute("select * from delta_table").fetch_arrow_table()
print(df)
mc = (
df.filter(pa.compute.field("new_name") == "Hans Heiri")
df.filter(pc.field("new_name") == "Hans Heiri")
.select(["main_coord"])
.to_pylist()
)
Expand Down
15 changes: 10 additions & 5 deletions tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def test_filter_number():
assert len(res) == 1
assert res[0]["FirstName"] == "Peter"

df2 = polars_scan_delta(dt, conditions={"Age": 500})

assert df.schema == df2.schema, "Schema does not match"


def test_filter_name():
dt = DeltaTable("tests/data/user")
Expand All @@ -111,14 +115,15 @@ def test_filter_name():


def test_schema():
dt = DeltaTable("tests/data/user")

from deltalake2db import polars_scan_delta, get_polars_schema

df = polars_scan_delta(dt)
schema = get_polars_schema(dt)
for tbl in ["user", "faker2", "user_empty"]:
dt = DeltaTable("tests/data/" + tbl)

df = polars_scan_delta(dt)
schema = get_polars_schema(dt)

assert df.schema == schema
assert df.schema == schema, f"Schema for {tbl} does not match"


@pytest.mark.skip(reason="Polars reads null structs as structs, so no luck")
Expand Down

0 comments on commit 0ad4879

Please sign in to comment.