diff --git a/daft/io/_range.py b/daft/io/_range.py index 76e0e8a6b9..8a28bed8ab 100644 --- a/daft/io/_range.py +++ b/daft/io/_range.py @@ -11,19 +11,30 @@ from daft.table.table import Table -def _range_generators(start: int, end: int, step: int) -> Iterator[Callable[[], Iterator[Table]]]: - def generator_for_value(value: int) -> Callable[[], Iterator[Table]]: - def generator() -> Iterator[Table]: - yield Table.from_pydict({"id": [value]}) +def _range_generators(start: int, end: int, step: int, partitions: int) -> Iterator[Callable[[], Iterator[Table]]]: + # TODO: Partitioning with range scan is currently untested and unused. + # There may be issues with balanced partitions and step size. - return generator + # Calculate partition bounds upfront + partition_size = (end - start) // partitions + partition_bounds = [ + (start + (i * partition_size), start + ((i + 1) * partition_size) if i < partitions - 1 else end) + for i in range(partitions) + ] - for value in range(start, end, step): - yield generator_for_value(value) + def generator(partition_idx: int) -> Iterator[Table]: + partition_start, partition_end = partition_bounds[partition_idx] + values = list(range(partition_start, partition_end, step)) + yield Table.from_pydict({"id": values}) + + from functools import partial + + for partition_idx in range(partitions): + yield partial(generator, partition_idx) class RangeScanOperator(GeneratorScanOperator): - def __init__(self, start: int, end: int, step: int = 1) -> None: + def __init__(self, start: int, end: int, step: int = 1, partitions: int = 1) -> None: schema = Schema._from_field_name_and_types([("id", DataType.int64())]) - super().__init__(schema=schema, generators=_range_generators(start, end, step)) + super().__init__(schema=schema, generators=_range_generators(start, end, step, partitions)) diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index d25c7cecc6..58255e2ef9 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -33,9 +33,9 @@ fn range(range: Range) -> eyre::Result { num_partitions, } = range; - if let Some(partitions) = num_partitions { - warn!("{partitions} ignored"); - } + let partitions = num_partitions.unwrap_or(1); + + ensure!(partitions > 0, "num_partitions must be greater than 0"); let start = start.unwrap_or(0); @@ -51,7 +51,7 @@ fn range(range: Range) -> eyre::Result { .wrap_err("Failed to get range function")?; let range = range - .call1((start, end, step)) + .call1((start, end, step, partitions)) .wrap_err("Failed to create range scan operator")? .to_object(py); diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py index 35b69069a2..86f348470e 100644 --- a/tests/connect/test_range_simple.py +++ b/tests/connect/test_range_simple.py @@ -1,6 +1,5 @@ from __future__ import annotations -# import time import pytest from pyspark.sql import SparkSession diff --git a/xyz b/xyz new file mode 100644 index 0000000000..e69de29bb2