Skip to content

Commit

Permalink
add pushdown tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 11, 2024
1 parent 115d342 commit d73f6bb
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/dataframe/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,25 @@ def test_create_dataframe_multiple_csvs_with_file_path_column(valid_data: list[d
assert pd_df["file_path"].to_list() == [f1name] * len(valid_data) + [f2name] * len(valid_data)


def test_create_dataframe_csv_with_file_path_column_and_pushdowns(valid_data: list[dict[str, float]]) -> None:
with create_temp_filename() as f1name, create_temp_filename() as f2name:
with open(f1name, "w") as f1, open(f2name, "w") as f2:
for f in (f1, f2):
header = list(valid_data[0].keys())
writer = csv.writer(f)
writer.writerow(header)
writer.writerows([[item[col] for col in header] for item in valid_data])
f.flush()

df = daft.read_csv([f1name, f2name], file_path_column="file_path").where(daft.col("file_path") == f1name)
assert df.column_names == COL_NAMES + ["file_path"]

pd_df = df.to_pandas()
assert list(pd_df.columns) == COL_NAMES + ["file_path"]
assert len(pd_df) == len(valid_data)
assert pd_df["file_path"].to_list() == [f1name] * len(valid_data)


def test_create_dataframe_csv_with_file_path_column_duplicate_field_names() -> None:
with create_temp_filename() as fname:
with open(fname, "w") as f:
Expand Down Expand Up @@ -764,6 +783,24 @@ def test_create_dataframe_multiple_jsons_with_file_path_column(valid_data: list[
assert pd_df["file_path"].to_list() == [f1name] * len(valid_data) + [f2name] * len(valid_data)


def test_create_dataframe_json_with_file_path_column_and_pushdowns(valid_data: list[dict[str, float]]) -> None:
with create_temp_filename() as f1name, create_temp_filename() as f2name:
with open(f1name, "w") as f1, open(f2name, "w") as f2:
for f in (f1, f2):
for data in valid_data:
f.write(json.dumps(data))
f.write("\n")
f.flush()

df = daft.read_json([f1name, f2name], file_path_column="file_path").where(daft.col("file_path") == f1name)
assert df.column_names == COL_NAMES + ["file_path"]

pd_df = df.to_pandas()
assert list(pd_df.columns) == COL_NAMES + ["file_path"]
assert len(pd_df) == len(valid_data)
assert pd_df["file_path"].to_list() == [f1name] * len(valid_data)


def test_create_dataframe_json_with_file_path_column_duplicate_field_names() -> None:
with create_temp_filename() as fname:
with open(fname, "w") as f:
Expand Down Expand Up @@ -1033,6 +1070,23 @@ def test_create_dataframe_multiple_parquets_with_file_path_column(valid_data: li
assert pd_df["file_path"].to_list() == [f1name] * len(valid_data) + [f2name] * len(valid_data)


def test_create_dataframe_parquet_with_file_path_column_and_pushdowns(valid_data: list[dict[str, float]]) -> None:
with create_temp_filename() as f1name, create_temp_filename() as f2name:
with open(f1name, "w") as f1, open(f2name, "w") as f2:
for f in (f1, f2):
table = pa.Table.from_pydict({col: [d[col] for d in valid_data] for col in COL_NAMES})
papq.write_table(table, f.name)
f.flush()

df = daft.read_parquet([f1name, f2name], file_path_column="file_path").where(daft.col("file_path") == f1name)
assert df.column_names == COL_NAMES + ["file_path"]

pd_df = df.to_pandas()
assert list(pd_df.columns) == COL_NAMES + ["file_path"]
assert len(pd_df) == len(valid_data)
assert pd_df["file_path"].to_list() == [f1name] * len(valid_data)


def test_create_dataframe_parquet_with_file_path_column_duplicate_field_names() -> None:
with create_temp_filename() as fname:
with open(fname, "w") as f:
Expand Down

0 comments on commit d73f6bb

Please sign in to comment.