From 5aa19d1c7ecc6567c200aff6c3460e17a3e7e379 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 15 Feb 2024 12:10:32 +0800 Subject: [PATCH 1/7] populate preview method --- daft/dataframe/dataframe.py | 51 ++++++++++++++++++------------- tests/cookbook/test_write.py | 12 ++++++++ tests/dataframe/test_decimals.py | 2 +- tests/dataframe/test_show.py | 45 --------------------------- tests/dataframe/test_temporals.py | 2 +- 5 files changed, 43 insertions(+), 69 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index a370d998ed..87a4fa5164 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -85,6 +85,7 @@ def __init__(self, builder: LogicalPlanBuilder) -> None: self.__builder = builder self._result_cache: Optional[PartitionCacheEntry] = None self._preview = DataFramePreview(preview_partition=None, dataframe_num_rows=None) + self._num_preview_rows = get_context().daft_execution_config.num_preview_rows @property def _builder(self) -> LogicalPlanBuilder: @@ -225,13 +226,37 @@ def iter_partitions(self) -> Iterator[Union[MicroPartition, "RayObjectRef"]]: for result in results_iter: yield result.partition() + def _populate_preview(self) -> None: + """Populates the preview of the DataFrame, if it is not already populated.""" + if self._result is None: + return + + preview_partition_invalid = ( + self._preview.preview_partition is None or len(self._preview.preview_partition) < self._num_preview_rows + ) + if preview_partition_invalid: + preview_df = self + if self._num_preview_rows < len(self): + preview_df = preview_df.limit(self._num_preview_rows) + preview_df._materialize_results() + preview_results = preview_df._result + assert preview_results is not None + + preview_partition = preview_results._get_merged_vpartition() + self._preview = DataFramePreview( + preview_partition=preview_partition, + dataframe_num_rows=len(self), + ) + @DataframePublicAPI def __repr__(self) -> str: + self._populate_preview() display = DataFrameDisplay(self._preview, self.schema()) return display.__repr__() @DataframePublicAPI def _repr_html_(self) -> str: + self._populate_preview() display = DataFrameDisplay(self._preview, self.schema()) return display._repr_html_() @@ -1113,7 +1138,7 @@ def _materialize_results(self) -> None: result.wait() @DataframePublicAPI - def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame": + def collect(self, num_preview_rows: Optional[int] = None) -> "DataFrame": """Executes the entire DataFrame and materializes the results .. NOTE:: @@ -1128,31 +1153,13 @@ def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame": self._materialize_results() assert self._result is not None - dataframe_len = len(self._result) - requested_rows = dataframe_len if num_preview_rows is None else num_preview_rows - - # Build a DataFramePreview and cache it if necessary - preview_partition_invalid = ( - self._preview.preview_partition is None or len(self._preview.preview_partition) < requested_rows - ) - if preview_partition_invalid: - preview_df = self - if num_preview_rows is not None: - preview_df = preview_df.limit(num_preview_rows) - preview_df._materialize_results() - preview_results = preview_df._result - assert preview_results is not None - - preview_partition = preview_results._get_merged_vpartition() - self._preview = DataFramePreview( - preview_partition=preview_partition, - dataframe_num_rows=dataframe_len, - ) - + if num_preview_rows is not None: + self._num_preview_rows = num_preview_rows return self def _construct_show_display(self, n: int) -> "DataFrameDisplay": """Helper for .show() which will construct the underlying DataFrameDisplay object""" + self._populate_preview() preview_partition = self._preview.preview_partition total_rows = self._preview.dataframe_num_rows diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 2f611c4059..2df9833964 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -21,6 +21,8 @@ def test_parquet_write(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(pd_df) == 1 + assert pd_df._preview.preview_partition is None + pd_df._populate_preview() assert len(pd_df._preview.preview_partition) == 1 @@ -33,6 +35,8 @@ def test_parquet_write_with_partitioning(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(pd_df) == 5 + assert pd_df._preview.preview_partition is None + pd_df._populate_preview() assert len(pd_df._preview.preview_partition) == 5 @@ -41,6 +45,8 @@ def test_empty_parquet_write_without_partitioning(tmp_path): df = df.where(daft.lit(False)) output_files = df.write_parquet(tmp_path) assert len(output_files) == 0 + assert output_files._preview.preview_partition is None + output_files._populate_preview() assert len(output_files._preview.preview_partition) == 0 @@ -49,6 +55,8 @@ def test_empty_parquet_write_with_partitioning(tmp_path): df = df.where(daft.lit(False)) output_files = df.write_parquet(tmp_path, partition_cols=["Borough"]) assert len(output_files) == 0 + assert output_files._preview.preview_partition is None + output_files._populate_preview() assert len(output_files._preview.preview_partition) == 0 @@ -69,6 +77,8 @@ def test_parquet_write_with_partitioning_readback_values(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(output_files) == 5 + assert output_files._preview.preview_partition is None + output_files._populate_preview() assert len(output_files._preview.preview_partition) == 5 @@ -193,6 +203,8 @@ def test_csv_write(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(pd_df) == 1 + assert pd_df._preview.preview_partition is None + pd_df._populate_preview() assert len(pd_df._preview.preview_partition) == 1 diff --git a/tests/dataframe/test_decimals.py b/tests/dataframe/test_decimals.py index 530146a098..2005790c9d 100644 --- a/tests/dataframe/test_decimals.py +++ b/tests/dataframe/test_decimals.py @@ -24,7 +24,7 @@ def test_decimal_parquet_roundtrip() -> None: df.write_parquet(dirname) df_readback = daft.read_parquet(dirname).collect() - assert str(df.to_pydict()["decimal128"]) == str(df_readback.to_pydict()["decimal128"]) + assert str(df.to_pydict()["decimal128"]) == str(df_readback.to_pydict()["decimal128"]) def test_arrow_decimal() -> None: diff --git a/tests/dataframe/test_show.py b/tests/dataframe/test_show.py index df32865551..f0933da5b2 100644 --- a/tests/dataframe/test_show.py +++ b/tests/dataframe/test_show.py @@ -24,48 +24,3 @@ def test_show_some(make_df, valid_data, data_source): elif variant == "arrow": assert df_display.preview.dataframe_num_rows == len(valid_data) assert df_display.num_rows == 1 - - -def test_show_from_cached_collect(make_df, valid_data): - df = make_df(valid_data) - df = df.collect() - collected_preview = df._preview - df_display = df._construct_show_display(8) - - # Check that cached preview from df.collect() was used. - assert df_display.preview is collected_preview - assert df_display.schema == df.schema() - assert len(df_display.preview.preview_partition) == len(valid_data) - assert df_display.preview.dataframe_num_rows == 3 - assert df_display.num_rows == 3 - - -def test_show_from_cached_collect_prefix(make_df, valid_data): - df = make_df(valid_data) - df = df.collect(3) - df_display = df._construct_show_display(2) - - assert df_display.schema == df.schema() - assert len(df_display.preview.preview_partition) == 2 - # Check that a prefix of the cached preview from df.collect() was used, so dataframe_num_rows should be set. - assert df_display.preview.dataframe_num_rows == 3 - assert df_display.num_rows == 2 - - -def test_show_not_from_cached_collect(make_df, valid_data, data_source): - df = make_df(valid_data) - df = df.collect(2) - collected_preview = df._preview - df_display = df._construct_show_display(8) - - variant = data_source - if variant == "parquet": - # Cached preview from df.collect() is NOT USED because data was not materialized from parquet. - assert df_display.preview != collected_preview - elif variant == "arrow": - # Cached preview from df.collect() is USED because data was materialized from arrow. - assert df_display.preview == collected_preview - assert df_display.schema == df.schema() - assert len(df_display.preview.preview_partition) == len(valid_data) - assert df_display.preview.dataframe_num_rows == 3 - assert df_display.num_rows == 3 diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 1c0cbcda9c..3973d3c7ce 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -90,7 +90,7 @@ def test_temporal_file_roundtrip(format, use_native_downloader) -> None: df.write_parquet(dirname) df_readback = daft.read_parquet(dirname, use_native_downloader=use_native_downloader).collect() - assert df.to_pydict() == df_readback.to_pydict() + assert df.to_pydict() == df_readback.to_pydict() @pytest.mark.parametrize( From 11f8fc295080f862ff40412a594f0fc6b87f7ea6 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 15 Feb 2024 12:33:01 +0800 Subject: [PATCH 2/7] add some tests --- daft/dataframe/dataframe.py | 4 +++- tests/dataframe/test_repr.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 87a4fa5164..0f6ecef9d2 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1138,7 +1138,7 @@ def _materialize_results(self) -> None: result.wait() @DataframePublicAPI - def collect(self, num_preview_rows: Optional[int] = None) -> "DataFrame": + def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame": """Executes the entire DataFrame and materializes the results .. NOTE:: @@ -1155,6 +1155,8 @@ def collect(self, num_preview_rows: Optional[int] = None) -> "DataFrame": assert self._result is not None if num_preview_rows is not None: self._num_preview_rows = num_preview_rows + else: + self._num_preview_rows = len(self._result) return self def _construct_show_display(self, n: int) -> "DataFrameDisplay": diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index f84e13c0e7..1c02e81dd9 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +import pytest from PIL import Image import daft @@ -86,6 +87,16 @@ def test_empty_repr(make_df): assert df._repr_html_() == "(No data to display: Dataframe has no columns)" +@pytest.mark.parametrize("num_preview_rows", [None]) +def test_repr_with_non_default_preview_rows(make_df, num_preview_rows): + df = make_df({"A": [i for i in range(10)], "B": [i for i in range(10)]}) + df.collect(num_preview_rows=num_preview_rows) + df.__repr__() + + assert df._preview.dataframe_num_rows == 10 + assert len(df._preview.preview_partition) == (num_preview_rows if num_preview_rows is not None else 10) + + def test_empty_df_repr(make_df): df = make_df({"A": [1, 2, 3], "B": ["a", "b", "c"]}) df = df.where(df["A"] > 10) From 10c5e568b78f62f0358c38b339aebdb637c674c2 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 15 Feb 2024 12:44:19 +0800 Subject: [PATCH 3/7] fix tests --- tests/integration/io/test_url_download_http.py | 2 ++ tests/udf_library/test_url_udfs.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/integration/io/test_url_download_http.py b/tests/integration/io/test_url_download_http.py index 531b89907a..3b0dc3a499 100644 --- a/tests/integration/io/test_url_download_http.py +++ b/tests/integration/io/test_url_download_http.py @@ -30,6 +30,7 @@ def test_url_download_http_error_codes(nginx_config, use_native_downloader, stat if status_code == 404: with pytest.raises(FileNotFoundError): df.collect() + df.__repr__() # When using fsspec, other error codes are bubbled up to the user as aiohttp.client_exceptions.ClientResponseError elif not use_native_downloader: # Ray runner has a pretty catastrophic failure when raising non-pickleable exceptions (ClientResponseError is not pickleable) @@ -45,3 +46,4 @@ def test_url_download_http_error_codes(nginx_config, use_native_downloader, stat # user-facing I/O error with the error code with pytest.raises(ValueError, match=f"{status_code}") as e: df.collect() + df.__repr__() diff --git a/tests/udf_library/test_url_udfs.py b/tests/udf_library/test_url_udfs.py index a8daee4596..aee414c674 100644 --- a/tests/udf_library/test_url_udfs.py +++ b/tests/udf_library/test_url_udfs.py @@ -87,6 +87,7 @@ def test_download_with_missing_urls_reraise_errors(files, use_native_downloader) # TODO: Change to a FileNotFound Error with pytest.raises(FileNotFoundError): df.collect() + df.__repr__() @pytest.mark.parametrize("use_native_downloader", [False, True]) From 5dfa3d29fdef3755e4f0c3b2387f8617dbfd64ec Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 15 Feb 2024 12:53:54 +0800 Subject: [PATCH 4/7] dont change tests --- daft/dataframe/dataframe.py | 5 ++++- tests/integration/io/test_url_download_http.py | 2 -- tests/udf_library/test_url_udfs.py | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 0f6ecef9d2..236daca198 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -242,7 +242,9 @@ def _populate_preview(self) -> None: preview_results = preview_df._result assert preview_results is not None + print("here") preview_partition = preview_results._get_merged_vpartition() + print("not here") self._preview = DataFramePreview( preview_partition=preview_partition, dataframe_num_rows=len(self), @@ -1153,10 +1155,11 @@ def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame": self._materialize_results() assert self._result is not None + dataframe_len = len(self._result) if num_preview_rows is not None: self._num_preview_rows = num_preview_rows else: - self._num_preview_rows = len(self._result) + self._num_preview_rows = dataframe_len return self def _construct_show_display(self, n: int) -> "DataFrameDisplay": diff --git a/tests/integration/io/test_url_download_http.py b/tests/integration/io/test_url_download_http.py index 3b0dc3a499..531b89907a 100644 --- a/tests/integration/io/test_url_download_http.py +++ b/tests/integration/io/test_url_download_http.py @@ -30,7 +30,6 @@ def test_url_download_http_error_codes(nginx_config, use_native_downloader, stat if status_code == 404: with pytest.raises(FileNotFoundError): df.collect() - df.__repr__() # When using fsspec, other error codes are bubbled up to the user as aiohttp.client_exceptions.ClientResponseError elif not use_native_downloader: # Ray runner has a pretty catastrophic failure when raising non-pickleable exceptions (ClientResponseError is not pickleable) @@ -46,4 +45,3 @@ def test_url_download_http_error_codes(nginx_config, use_native_downloader, stat # user-facing I/O error with the error code with pytest.raises(ValueError, match=f"{status_code}") as e: df.collect() - df.__repr__() diff --git a/tests/udf_library/test_url_udfs.py b/tests/udf_library/test_url_udfs.py index aee414c674..a8daee4596 100644 --- a/tests/udf_library/test_url_udfs.py +++ b/tests/udf_library/test_url_udfs.py @@ -87,7 +87,6 @@ def test_download_with_missing_urls_reraise_errors(files, use_native_downloader) # TODO: Change to a FileNotFound Error with pytest.raises(FileNotFoundError): df.collect() - df.__repr__() @pytest.mark.parametrize("use_native_downloader", [False, True]) From e8ab2524f1e7c236d784f69862b5876c26454a01 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 15 Feb 2024 13:35:13 +0800 Subject: [PATCH 5/7] cleanup --- daft/dataframe/dataframe.py | 2 -- tests/cookbook/test_write.py | 12 ++++++------ tests/dataframe/test_repr.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 236daca198..fd462bd9d5 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -242,9 +242,7 @@ def _populate_preview(self) -> None: preview_results = preview_df._result assert preview_results is not None - print("here") preview_partition = preview_results._get_merged_vpartition() - print("not here") self._preview = DataFramePreview( preview_partition=preview_partition, dataframe_num_rows=len(self), diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 2df9833964..bb1dba9668 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -22,7 +22,7 @@ def test_parquet_write(tmp_path): assert len(pd_df) == 1 assert pd_df._preview.preview_partition is None - pd_df._populate_preview() + pd_df.__repr__() assert len(pd_df._preview.preview_partition) == 1 @@ -36,7 +36,7 @@ def test_parquet_write_with_partitioning(tmp_path): assert len(pd_df) == 5 assert pd_df._preview.preview_partition is None - pd_df._populate_preview() + pd_df.__repr__() assert len(pd_df._preview.preview_partition) == 5 @@ -46,7 +46,7 @@ def test_empty_parquet_write_without_partitioning(tmp_path): output_files = df.write_parquet(tmp_path) assert len(output_files) == 0 assert output_files._preview.preview_partition is None - output_files._populate_preview() + output_files.__repr__() assert len(output_files._preview.preview_partition) == 0 @@ -56,7 +56,7 @@ def test_empty_parquet_write_with_partitioning(tmp_path): output_files = df.write_parquet(tmp_path, partition_cols=["Borough"]) assert len(output_files) == 0 assert output_files._preview.preview_partition is None - output_files._populate_preview() + output_files.__repr__() assert len(output_files._preview.preview_partition) == 0 @@ -78,7 +78,7 @@ def test_parquet_write_with_partitioning_readback_values(tmp_path): assert len(output_files) == 5 assert output_files._preview.preview_partition is None - output_files._populate_preview() + output_files.__repr__() assert len(output_files._preview.preview_partition) == 5 @@ -204,7 +204,7 @@ def test_csv_write(tmp_path): assert len(pd_df) == 1 assert pd_df._preview.preview_partition is None - pd_df._populate_preview() + pd_df.__repr__() assert len(pd_df._preview.preview_partition) == 1 diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index 1c02e81dd9..636552978a 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -87,7 +87,7 @@ def test_empty_repr(make_df): assert df._repr_html_() == "(No data to display: Dataframe has no columns)" -@pytest.mark.parametrize("num_preview_rows", [None]) +@pytest.mark.parametrize("num_preview_rows", [9, 10, None]) def test_repr_with_non_default_preview_rows(make_df, num_preview_rows): df = make_df({"A": [i for i in range(10)], "B": [i for i in range(10)]}) df.collect(num_preview_rows=num_preview_rows) From d2093f6fdd0ec6b8d581d445a2b5a9d7d7977b19 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 16 Feb 2024 07:35:32 +0800 Subject: [PATCH 6/7] efficiency --- daft/dataframe/dataframe.py | 43 +++++++++++-------------------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index fd462bd9d5..9827481ac0 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -235,12 +235,18 @@ def _populate_preview(self) -> None: self._preview.preview_partition is None or len(self._preview.preview_partition) < self._num_preview_rows ) if preview_partition_invalid: - preview_df = self - if self._num_preview_rows < len(self): - preview_df = preview_df.limit(self._num_preview_rows) - preview_df._materialize_results() - preview_results = preview_df._result - assert preview_results is not None + need = self._num_preview_rows + preview_parts = [] + for part in self._result.values(): + part_len = len(part) + if part_len >= need: # if this part has enough rows, take what we need and break + preview_parts.append(part.slice(0, need)) + break + else: # otherwise, take the whole part and keep going + need -= part_len + preview_parts.append(part) + + preview_results = LocalPartitionSet({i: part for i, part in enumerate(preview_parts)}) preview_partition = preview_results._get_merged_vpartition() self._preview = DataFramePreview( @@ -330,30 +336,7 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame": df._result_cache = cache_entry # build preview - num_preview_rows = context.daft_execution_config.num_preview_rows - dataframe_num_rows = len(df) - if dataframe_num_rows > num_preview_rows: - need = num_preview_rows - preview_parts = [] - for part in parts: - part_len = len(part) - if part_len >= need: # if this part has enough rows, take what we need and break - preview_parts.append(part.slice(0, need)) - break - else: # otherwise, take the whole part and keep going - need -= part_len - preview_parts.append(part) - - preview_results = LocalPartitionSet({i: part for i, part in enumerate(preview_parts)}) - else: - preview_results = result_pset - - # set preview - preview_partition = preview_results._get_merged_vpartition() - df._preview = DataFramePreview( - preview_partition=preview_partition, - dataframe_num_rows=dataframe_num_rows, - ) + df._populate_preview() return df ### From 38fa51b3013e84fdb0b89f9e02ebe78bbbf835a8 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 20 Feb 2024 13:58:55 -0800 Subject: [PATCH 7/7] preview implementation for py and ray runner --- daft/dataframe/dataframe.py | 13 +--------- daft/runners/partitioning.py | 3 +++ daft/runners/pyrunner.py | 13 ++++++++++ daft/runners/ray_runner.py | 14 +++++++++++ tests/dataframe/test_show.py | 48 ++++++++++++++++++++++++++++++++++++ 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 9827481ac0..32f8bb726e 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -235,17 +235,7 @@ def _populate_preview(self) -> None: self._preview.preview_partition is None or len(self._preview.preview_partition) < self._num_preview_rows ) if preview_partition_invalid: - need = self._num_preview_rows - preview_parts = [] - for part in self._result.values(): - part_len = len(part) - if part_len >= need: # if this part has enough rows, take what we need and break - preview_parts.append(part.slice(0, need)) - break - else: # otherwise, take the whole part and keep going - need -= part_len - preview_parts.append(part) - + preview_parts = self._result._get_preview_vpartition(self._num_preview_rows) preview_results = LocalPartitionSet({i: part for i, part in enumerate(preview_parts)}) preview_partition = preview_results._get_merged_vpartition() @@ -1145,7 +1135,6 @@ def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame": def _construct_show_display(self, n: int) -> "DataFrameDisplay": """Helper for .show() which will construct the underlying DataFrameDisplay object""" - self._populate_preview() preview_partition = self._preview.preview_partition total_rows = self._preview.dataframe_num_rows diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 8836a3bc5c..56fda08a3a 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -209,6 +209,9 @@ class PartitionSet(Generic[PartitionT]): def _get_merged_vpartition(self) -> MicroPartition: raise NotImplementedError() + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + raise NotImplementedError() + def to_pydict(self) -> dict[str, list[Any]]: """Retrieves all the data in a PartitionSet as a Python dictionary. Values are the raw data from each Block.""" merged_partition = self._get_merged_vpartition() diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 2be28e4c54..41b435ff24 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -51,6 +51,19 @@ def _get_merged_vpartition(self) -> MicroPartition: assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions) return MicroPartition.concat([part for id, part in ids_and_partitions]) + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + ids_and_partitions = self.items() + preview_parts = [] + for _, part in ids_and_partitions: + part_len = len(part) + if part_len >= num_rows: # if this part has enough rows, take what we need and break + preview_parts.append(part.slice(0, num_rows)) + break + else: # otherwise, take the whole part and keep going + num_rows -= part_len + preview_parts.append(part) + return preview_parts + def get_partition(self, idx: PartID) -> MicroPartition: return self._partitions[idx] diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index f8803995c0..d206aea5b5 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -151,6 +151,20 @@ def _get_merged_vpartition(self) -> MicroPartition: all_partitions = ray.get([part for id, part in ids_and_partitions]) return MicroPartition.concat(all_partitions) + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + ids_and_partitions = self.items() + preview_parts = [] + for _, part in ids_and_partitions: + part = ray.get(part) + part_len = len(part) + if part_len >= num_rows: # if this part has enough rows, take what we need and break + preview_parts.append(part.slice(0, num_rows)) + break + else: # otherwise, take the whole part and keep going + num_rows -= part_len + preview_parts.append(part) + return preview_parts + def to_ray_dataset(self) -> RayDataset: if not _RAY_FROM_ARROW_REFS_AVAILABLE: raise ImportError( diff --git a/tests/dataframe/test_show.py b/tests/dataframe/test_show.py index f0933da5b2..dd5ced328f 100644 --- a/tests/dataframe/test_show.py +++ b/tests/dataframe/test_show.py @@ -24,3 +24,51 @@ def test_show_some(make_df, valid_data, data_source): elif variant == "arrow": assert df_display.preview.dataframe_num_rows == len(valid_data) assert df_display.num_rows == 1 + + +def test_show_from_cached_repr(make_df, valid_data): + df = make_df(valid_data) + df = df.collect() + df.__repr__() + collected_preview = df._preview + df_display = df._construct_show_display(8) + + # Check that cached preview from df.__repr__() was used. + assert df_display.preview is collected_preview + assert df_display.schema == df.schema() + assert len(df_display.preview.preview_partition) == len(valid_data) + assert df_display.preview.dataframe_num_rows == 3 + assert df_display.num_rows == 3 + + +def test_show_from_cached_repr_prefix(make_df, valid_data): + df = make_df(valid_data) + df = df.collect(3) + df.__repr__() + df_display = df._construct_show_display(2) + + assert df_display.schema == df.schema() + assert len(df_display.preview.preview_partition) == 2 + # Check that a prefix of the cached preview from df.__repr__() was used, so dataframe_num_rows should be set. + assert df_display.preview.dataframe_num_rows == 3 + assert df_display.num_rows == 2 + + +def test_show_not_from_cached_repr(make_df, valid_data, data_source): + df = make_df(valid_data) + df = df.collect(2) + df.__repr__() + collected_preview = df._preview + df_display = df._construct_show_display(8) + + variant = data_source + if variant == "parquet": + # Cached preview from df.__repr__() is NOT USED because data was not materialized from parquet. + assert df_display.preview != collected_preview + elif variant == "arrow": + # Cached preview from df.__repr__() is USED because data was materialized from arrow. + assert df_display.preview == collected_preview + assert df_display.schema == df.schema() + assert len(df_display.preview.preview_partition) == len(valid_data) + assert df_display.preview.dataframe_num_rows == 3 + assert df_display.num_rows == 3