diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 37dea4d822..56caaf8c9f 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -286,16 +286,16 @@ def iter_rows(self, results_buffer_size: Optional[int] = NUM_CPUS) -> Iterator[D yield row @DataframePublicAPI - def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pyarrow.Table"]: + def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pyarrow.RecordBatch"]: """ - Return an iterator of pyarrow tables for this dataframe. + Return an iterator of pyarrow recordbatches for this dataframe. """ if results_buffer_size is not None and not results_buffer_size > 0: raise ValueError(f"Provided `results_buffer_size` value must be > 0, received: {results_buffer_size}") if self._result is not None: # If the dataframe has already finished executing, # use the precomputed results. - yield self.to_arrow() + yield from self.to_arrow().to_batches() else: # Execute the dataframe in a streaming fashion. @@ -304,7 +304,7 @@ def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pya # Iterate through partitions. for partition in partitions_iter: - yield partition.to_arrow() + yield from partition.to_arrow().to_batches() @DataframePublicAPI def iter_partitions( diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index b8939d630d..52efe533f3 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -669,4 +669,4 @@ def __iter__(self): def test_to_arrow_iterator() -> None: df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) it = df.to_arrow_iter() - assert isinstance(next(it), pa.Table) + assert isinstance(next(it), pa.RecordBatch)