diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 9827481ac0..32f8bb726e 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -235,17 +235,7 @@ def _populate_preview(self) -> None: self._preview.preview_partition is None or len(self._preview.preview_partition) < self._num_preview_rows ) if preview_partition_invalid: - need = self._num_preview_rows - preview_parts = [] - for part in self._result.values(): - part_len = len(part) - if part_len >= need: # if this part has enough rows, take what we need and break - preview_parts.append(part.slice(0, need)) - break - else: # otherwise, take the whole part and keep going - need -= part_len - preview_parts.append(part) - + preview_parts = self._result._get_preview_vpartition(self._num_preview_rows) preview_results = LocalPartitionSet({i: part for i, part in enumerate(preview_parts)}) preview_partition = preview_results._get_merged_vpartition() @@ -1145,7 +1135,6 @@ def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame": def _construct_show_display(self, n: int) -> "DataFrameDisplay": """Helper for .show() which will construct the underlying DataFrameDisplay object""" - self._populate_preview() preview_partition = self._preview.preview_partition total_rows = self._preview.dataframe_num_rows diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 8836a3bc5c..56fda08a3a 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -209,6 +209,9 @@ class PartitionSet(Generic[PartitionT]): def _get_merged_vpartition(self) -> MicroPartition: raise NotImplementedError() + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + raise NotImplementedError() + def to_pydict(self) -> dict[str, list[Any]]: """Retrieves all the data in a PartitionSet as a Python dictionary. Values are the raw data from each Block.""" merged_partition = self._get_merged_vpartition() diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 2be28e4c54..41b435ff24 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -51,6 +51,19 @@ def _get_merged_vpartition(self) -> MicroPartition: assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions) return MicroPartition.concat([part for id, part in ids_and_partitions]) + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + ids_and_partitions = self.items() + preview_parts = [] + for _, part in ids_and_partitions: + part_len = len(part) + if part_len >= num_rows: # if this part has enough rows, take what we need and break + preview_parts.append(part.slice(0, num_rows)) + break + else: # otherwise, take the whole part and keep going + num_rows -= part_len + preview_parts.append(part) + return preview_parts + def get_partition(self, idx: PartID) -> MicroPartition: return self._partitions[idx] diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index f8803995c0..d206aea5b5 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -151,6 +151,20 @@ def _get_merged_vpartition(self) -> MicroPartition: all_partitions = ray.get([part for id, part in ids_and_partitions]) return MicroPartition.concat(all_partitions) + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + ids_and_partitions = self.items() + preview_parts = [] + for _, part in ids_and_partitions: + part = ray.get(part) + part_len = len(part) + if part_len >= num_rows: # if this part has enough rows, take what we need and break + preview_parts.append(part.slice(0, num_rows)) + break + else: # otherwise, take the whole part and keep going + num_rows -= part_len + preview_parts.append(part) + return preview_parts + def to_ray_dataset(self) -> RayDataset: if not _RAY_FROM_ARROW_REFS_AVAILABLE: raise ImportError( diff --git a/tests/dataframe/test_show.py b/tests/dataframe/test_show.py index f0933da5b2..dd5ced328f 100644 --- a/tests/dataframe/test_show.py +++ b/tests/dataframe/test_show.py @@ -24,3 +24,51 @@ def test_show_some(make_df, valid_data, data_source): elif variant == "arrow": assert df_display.preview.dataframe_num_rows == len(valid_data) assert df_display.num_rows == 1 + + +def test_show_from_cached_repr(make_df, valid_data): + df = make_df(valid_data) + df = df.collect() + df.__repr__() + collected_preview = df._preview + df_display = df._construct_show_display(8) + + # Check that cached preview from df.__repr__() was used. + assert df_display.preview is collected_preview + assert df_display.schema == df.schema() + assert len(df_display.preview.preview_partition) == len(valid_data) + assert df_display.preview.dataframe_num_rows == 3 + assert df_display.num_rows == 3 + + +def test_show_from_cached_repr_prefix(make_df, valid_data): + df = make_df(valid_data) + df = df.collect(3) + df.__repr__() + df_display = df._construct_show_display(2) + + assert df_display.schema == df.schema() + assert len(df_display.preview.preview_partition) == 2 + # Check that a prefix of the cached preview from df.__repr__() was used, so dataframe_num_rows should be set. + assert df_display.preview.dataframe_num_rows == 3 + assert df_display.num_rows == 2 + + +def test_show_not_from_cached_repr(make_df, valid_data, data_source): + df = make_df(valid_data) + df = df.collect(2) + df.__repr__() + collected_preview = df._preview + df_display = df._construct_show_display(8) + + variant = data_source + if variant == "parquet": + # Cached preview from df.__repr__() is NOT USED because data was not materialized from parquet. + assert df_display.preview != collected_preview + elif variant == "arrow": + # Cached preview from df.__repr__() is USED because data was materialized from arrow. + assert df_display.preview == collected_preview + assert df_display.schema == df.schema() + assert len(df_display.preview.preview_partition) == len(valid_data) + assert df_display.preview.dataframe_num_rows == 3 + assert df_display.num_rows == 3