-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] Adds a
read_generator
method that reads tables from a genera…
…tor (#3258) `read_generator` takes in a generator function that yields `Table`s, with an optional parameter of `num_partitions` which will be the number of scan tasks that call this function. The function will be provided the partition number as the first argument, and whatever user args after that. Useful for testing shuffles. --------- Co-authored-by: Colin Ho <[email protected]> Co-authored-by: Colin Ho <[email protected]>
- Loading branch information
1 parent
7e89850
commit d1213a4
Showing
1 changed file
with
118 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# isort: dont-add-import: from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Callable, Iterator, List | ||
|
||
from daft.daft import Pushdowns, PyTable, ScanOperatorHandle, ScanTask | ||
from daft.dataframe import DataFrame | ||
from daft.io.scan import PartitionField, ScanOperator | ||
from daft.logical.builder import LogicalPlanBuilder | ||
from daft.logical.schema import Schema | ||
|
||
if TYPE_CHECKING: | ||
from daft.table.table import Table | ||
|
||
|
||
def _generator_factory_function(func: Callable[[], Iterator["Table"]]) -> Iterator["PyTable"]: | ||
for table in func(): | ||
yield table._table | ||
|
||
|
||
def read_generator( | ||
generators: Iterator[Callable[[], Iterator["Table"]]], | ||
schema: Schema, | ||
) -> DataFrame: | ||
"""Create a DataFrame from a generator function. | ||
Example: | ||
>>> import daft | ||
>>> from daft.io._generator import read_generator | ||
>>> from daft.table.table import Table | ||
>>> from functools import partial | ||
>>> | ||
>>> # Set runner to Ray for distributed processing | ||
>>> daft.context.set_runner_ray() | ||
>>> | ||
>>> # Helper function to generate data for each partition | ||
>>> def generate(num_rows: int): | ||
... data = {"ints": [i for i in range(num_rows)]} | ||
... yield Table.from_pydict(data) | ||
>>> | ||
>>> # Generator function that yields partial functions for each partition | ||
>>> def generator(num_partitions: int): | ||
... for i in range(num_partitions): | ||
... yield partial(generate, 100) | ||
>>> | ||
>>> # Create DataFrame using read_generator and repartition the data | ||
>>> df = ( | ||
... read_generator( | ||
... generator(num_partitions=100), | ||
... daft.Schema._from_field_name_and_types([("ints", daft.DataType.uint64())]), | ||
... ) | ||
... .repartition(100, "ints") | ||
... .collect() | ||
... ) | ||
Args: | ||
generator (Callable[[int, Any], Iterator[Table]]): a generator function that generates data | ||
num_partitions (int): the number of partitions to generate | ||
schema (Schema): the schema of the generated data | ||
generator_args (Any): additional arguments to pass to the generator | ||
Returns: | ||
DataFrame: a DataFrame containing the generated data | ||
""" | ||
|
||
generator_scan_operator = GeneratorScanOperator( | ||
generators=generators, | ||
schema=schema, | ||
) | ||
handle = ScanOperatorHandle.from_python_scan_operator(generator_scan_operator) | ||
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) | ||
return DataFrame(builder) | ||
|
||
|
||
class GeneratorScanOperator(ScanOperator): | ||
def __init__( | ||
self, | ||
generators: Iterator[Callable[[], Iterator["Table"]]], | ||
schema: Schema, | ||
): | ||
self._generators = generators | ||
self._schema = schema | ||
|
||
def display_name(self) -> str: | ||
return "GeneratorScanOperator" | ||
|
||
def schema(self) -> Schema: | ||
return self._schema | ||
|
||
def partitioning_keys(self) -> List[PartitionField]: | ||
return [] | ||
|
||
def can_absorb_filter(self) -> bool: | ||
return False | ||
|
||
def can_absorb_limit(self) -> bool: | ||
return False | ||
|
||
def can_absorb_select(self) -> bool: | ||
return False | ||
|
||
def multiline_display(self) -> List[str]: | ||
return [ | ||
self.display_name(), | ||
f"Schema = {self.schema()}", | ||
] | ||
|
||
def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: | ||
for generator in self._generators: | ||
yield ScanTask.python_factory_func_scan_task( | ||
module=_generator_factory_function.__module__, | ||
func_name=_generator_factory_function.__name__, | ||
func_args=(generator,), | ||
schema=self.schema()._schema, | ||
num_rows=None, | ||
size_bytes=None, | ||
pushdowns=pushdowns, | ||
stats=None, | ||
) |