Skip to content

Commit

Permalink
[FEAT] Use cached preview from df.collect() in df.show(). (#1651)
Browse files Browse the repository at this point in the history
This PR ensures that a preview cached via a `df.collect()` is used in
`df.show()`, rather than recomputing output tables from scratch.
  • Loading branch information
clarkzinzow authored Nov 21, 2023
1 parent aaf279e commit c25f644
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 35 deletions.
87 changes: 54 additions & 33 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,39 +154,6 @@ def columns(self) -> List[Expression]:
"""
return [col(field.name) for field in self.__builder.schema()]

@DataframePublicAPI
def show(self, n: int = 8) -> "DataFrameDisplay":
"""Executes enough of the DataFrame in order to display the first ``n`` rows
.. NOTE::
This call is **blocking** and will execute the DataFrame when called
Args:
n: number of rows to show. Defaults to 8.
Returns:
DataFrameDisplay: object that has a rich tabular display
"""
builder = self._builder.limit(n, eager=True)

# Iteratively retrieve partitions until enough data has been materialized
tables = []
seen = 0
for table in get_context().runner().run_iter_tables(builder, results_buffer_size=1):
tables.append(table)
seen += len(table)
if seen >= n:
break

preview_partition = Table.concat(tables)
preview_partition = preview_partition if len(preview_partition) <= n else preview_partition.slice(0, n)
preview = DataFramePreview(
preview_partition=preview_partition,
# We do not know the size of the entire (un-limited) dataframe when showing
dataframe_num_rows=None,
)
return DataFrameDisplay(preview, self.schema(), num_rows=n)

@DataframePublicAPI
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Return an iterator of rows for this dataframe.
Expand Down Expand Up @@ -1044,6 +1011,60 @@ def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame":

return self

@DataframePublicAPI
def show(self, n: int = 8) -> "DataFrameDisplay":
"""Executes enough of the DataFrame in order to display the first ``n`` rows
.. NOTE::
This call is **blocking** and will execute the DataFrame when called
Args:
n: number of rows to show. Defaults to 8.
Returns:
DataFrameDisplay: object that has a rich tabular display
"""
preview_partition = self._preview.preview_partition
total_rows = self._preview.dataframe_num_rows

# Truncate n to the length of the DataFrame, if we have it.
if total_rows is not None and n > total_rows:
n = total_rows

if preview_partition is None or len(preview_partition) < n:
# Preview partition doesn't exist or doesn't contain enough rows, so we need to compute a
# new one from scratch.
builder = self._builder.limit(n, eager=True)

# Iteratively retrieve partitions until enough data has been materialized
tables = []
seen = 0
for table in get_context().runner().run_iter_tables(builder, results_buffer_size=1):
tables.append(table)
seen += len(table)
if seen >= n:
break

preview_partition = Table.concat(tables)
if len(preview_partition) > n:
preview_partition = preview_partition.slice(0, n)
elif len(preview_partition) < n:
# Iterator short-circuited before reaching n, so we know that we have the full DataFrame.
total_rows = n = len(preview_partition)
preview = DataFramePreview(
preview_partition=preview_partition,
dataframe_num_rows=total_rows,
)
elif len(preview_partition) > n:
# Preview partition is cached but has more rows that we need, so use the appropriate slice.
truncated_preview_partition = preview_partition.slice(0, n)
preview = DataFramePreview(preview_partition=truncated_preview_partition, dataframe_num_rows=total_rows)
else:
assert len(preview_partition) == n
# Preview partition is cached and has exactly the number of rows that we need, so use it directly.
preview = self._preview
return DataFrameDisplay(preview, self.schema(), num_rows=n)

def __len__(self):
"""Returns the count of rows when dataframe is materialized.
If dataframe is not materialized yet, raises a runtime error.
Expand Down
45 changes: 43 additions & 2 deletions tests/dataframe/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ def test_show_default(valid_data):

assert df_display.schema == df.schema()
assert len(df_display.preview.preview_partition) == len(valid_data)
assert df_display.preview.dataframe_num_rows is None
assert df_display.num_rows == 8
assert df_display.preview.dataframe_num_rows == 3
assert df_display.num_rows == 3


def test_show_some(valid_data):
Expand All @@ -19,5 +19,46 @@ def test_show_some(valid_data):

assert df_display.schema == df.schema()
assert len(df_display.preview.preview_partition) == 1
# Limit is less than DataFrame length, so we don't know the full DataFrame length.
assert df_display.preview.dataframe_num_rows is None
assert df_display.num_rows == 1


def test_show_from_cached_collect(valid_data):
df = daft.from_pylist(valid_data)
df = df.collect()
collected_preview = df._preview
df_display = df.show()

# Check that cached preview from df.collect() was used.
assert df_display.preview is collected_preview
assert df_display.schema == df.schema()
assert len(df_display.preview.preview_partition) == len(valid_data)
assert df_display.preview.dataframe_num_rows == 3
assert df_display.num_rows == 3


def test_show_from_cached_collect_prefix(valid_data):
df = daft.from_pylist(valid_data)
df = df.collect(3)
df_display = df.show(2)

assert df_display.schema == df.schema()
assert len(df_display.preview.preview_partition) == 2
# Check that a prefix of the cached preview from df.collect() was used, so dataframe_num_rows should be set.
assert df_display.preview.dataframe_num_rows == 3
assert df_display.num_rows == 2


def test_show_not_from_cached_collect(valid_data):
df = daft.from_pylist(valid_data)
df = df.collect(2)
collected_preview = df._preview
df_display = df.show()

# Check that cached preview from df.collect() was NOT used, since it didn't have enough rows.
assert df_display.preview != collected_preview
assert df_display.schema == df.schema()
assert len(df_display.preview.preview_partition) == len(valid_data)
assert df_display.preview.dataframe_num_rows == 3
assert df_display.num_rows == 3

0 comments on commit c25f644

Please sign in to comment.