Skip to content

Commit

Permalink
[CHORE] Add unit tests for int96 timestamps (#1229)
Browse files Browse the repository at this point in the history
Closes: #1215

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Aug 3, 2023
1 parent c6552dd commit bc65aaf
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 55 deletions.
143 changes: 92 additions & 51 deletions tests/table/table_io/test_parquet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import contextlib
import datetime
import io
import pathlib
import tempfile

import pyarrow as pa
import pyarrow.parquet as papq
import pytest

import daft
from daft.datatype import DataType
from daft.datatype import DataType, TimeUnit
from daft.logical.schema import Schema
from daft.runners.partitioning import TableReadOptions
from daft.table import Table, schema_inference, table_io
Expand All @@ -31,11 +32,11 @@ def test_read_input(tmpdir):
assert table_io.read_parquet(f, schema=schema).to_arrow() == data


def _parquet_write_helper(data: pa.Table, row_group_size: int = None):
f = io.BytesIO()
papq.write_table(data, f, row_group_size=row_group_size)
f.seek(0)
return f
@contextlib.contextmanager
def _parquet_write_helper(data: pa.Table, row_group_size: int = None, papq_write_table_kwargs: dict = {}):
with tempfile.NamedTemporaryFile() as tmpfile:
papq.write_table(data, tmpfile.name, row_group_size=row_group_size, **papq_write_table_kwargs)
yield tmpfile.name


@pytest.mark.parametrize(
Expand All @@ -54,18 +55,18 @@ def _parquet_write_helper(data: pa.Table, row_group_size: int = None):
([1, None, 2], DataType.list("item", DataType.int64())),
],
)
def test_parquet_infer_schema(data, expected_dtype):
f = _parquet_write_helper(
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_parquet_infer_schema(data, expected_dtype, use_native_downloader):
with _parquet_write_helper(
pa.Table.from_pydict(
{
"id": [1, 2, 3],
"data": [data, data, None],
}
)
)

schema = schema_inference.from_parquet(f)
assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)])
) as f:
schema = schema_inference.from_parquet(f, use_native_downloader=use_native_downloader)
assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)])


@pytest.mark.parametrize(
Expand All @@ -85,65 +86,105 @@ def test_parquet_infer_schema(data, expected_dtype):
({"foo": 1}, daft.Series.from_pylist([{"foo": 1}, {"foo": 1}, None])),
],
)
def test_parquet_read_data(data, expected_data_series):
f = _parquet_write_helper(
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_parquet_read_data(data, expected_data_series, use_native_downloader):
with _parquet_write_helper(
pa.Table.from_pydict(
{
"id": [1, 2, 3],
"data": [data, data, None],
}
)
)

schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_data_series.datatype())])
expected = Table.from_pydict(
{
"id": [1, 2, 3],
"data": expected_data_series,
}
)
table = table_io.read_parquet(f, schema)
assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"
) as f:
schema = Schema._from_field_name_and_types(
[("id", DataType.int64()), ("data", expected_data_series.datatype())]
)
expected = Table.from_pydict(
{
"id": [1, 2, 3],
"data": expected_data_series,
}
)
table = table_io.read_parquet(f, schema, use_native_downloader=use_native_downloader)
assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"


@pytest.mark.parametrize("row_group_size", [None, 1, 3])
def test_parquet_read_data_limit_rows(row_group_size):
f = _parquet_write_helper(
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_parquet_read_data_limit_rows(row_group_size, use_native_downloader):
with _parquet_write_helper(
pa.Table.from_pydict(
{
"id": [1, 2, 3],
"data": [1, 2, None],
}
),
row_group_size=row_group_size,
)

schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())])
expected = Table.from_pydict(
{
"id": [1, 2],
"data": [1, 2],
}
)
table = table_io.read_parquet(f, schema, read_options=TableReadOptions(num_rows=2))
assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"
) as f:
schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())])
expected = Table.from_pydict(
{
"id": [1, 2],
"data": [1, 2],
}
)
table = table_io.read_parquet(
f, schema, read_options=TableReadOptions(num_rows=2), use_native_downloader=use_native_downloader
)
assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"


def test_parquet_read_data_select_columns():
f = _parquet_write_helper(
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_parquet_read_data_select_columns(use_native_downloader):
with _parquet_write_helper(
pa.Table.from_pydict(
{
"id": [1, 2, 3],
"data": [1, 2, None],
}
)
)

schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())])
expected = Table.from_pydict(
{
"data": [1, 2, None],
}
)
table = table_io.read_parquet(f, schema, read_options=TableReadOptions(column_names=["data"]))
assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"
) as f:
schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())])
expected = Table.from_pydict(
{
"data": [1, 2, None],
}
)
table = table_io.read_parquet(
f, schema, read_options=TableReadOptions(column_names=["data"]), use_native_downloader=use_native_downloader
)
assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"


@pytest.mark.parametrize("use_native_downloader", [True, False])
@pytest.mark.parametrize("use_deprecated_int96_timestamps", [True, False])
def test_parquet_read_timestamps(use_native_downloader, use_deprecated_int96_timestamps):
data = {
"timestamp_ms": pa.array([1, 2, 3], pa.timestamp("ms")),
"timestamp_us": pa.array([1, 2, 3], pa.timestamp("us")),
}
schema = [
("timestamp_ms", DataType.timestamp(TimeUnit.ms())),
("timestamp_us", DataType.timestamp(TimeUnit.us())),
]
# int64 timestamps cannot support nanosecond resolutions
if use_deprecated_int96_timestamps:
data["timestamp_ns"] = pa.array([1, 2, 3], pa.timestamp("ns"))
schema.append(("timestamp_ns", DataType.timestamp(TimeUnit.ns())))

with _parquet_write_helper(
pa.Table.from_pydict(data),
papq_write_table_kwargs={
"use_deprecated_int96_timestamps": use_deprecated_int96_timestamps,
"coerce_timestamps": "us" if not use_deprecated_int96_timestamps else None,
},
) as f:
schema = Schema._from_field_name_and_types(schema)
expected = Table.from_pydict(data)
table = table_io.read_parquet(
f,
schema,
read_options=TableReadOptions(column_names=schema.column_names()),
use_native_downloader=use_native_downloader,
)
assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"
8 changes: 4 additions & 4 deletions tests/table/table_io/test_read_time_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
],
)
def test_parquet_cast_at_read_time(data, schema, expected):
f = _parquet_write_helper(data)
table = table_io.read_parquet(f, schema)
assert table.schema() == schema
assert table.to_arrow() == expected.to_arrow()
with _parquet_write_helper(data) as f:
table = table_io.read_parquet(f, schema)
assert table.schema() == schema
assert table.to_arrow() == expected.to_arrow()

0 comments on commit bc65aaf

Please sign in to comment.