diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 5bd5e483df..a70ab0a1da 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -11,7 +11,7 @@ else: from typing import Protocol -from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest +from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask from daft.expressions import Expression, ExpressionsProjection, col from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema @@ -290,6 +290,47 @@ def num_outputs(self) -> int: return 1 +@dataclass(frozen=True) +class ScanWithTask(SingleOutputInstruction): + scan_task: ScanTask + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return self._scan(inputs) + + def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + assert len(inputs) == 0 + table = MicroPartition._from_scan_task(self.scan_task) + return [table] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + assert len(input_metadatas) == 0 + + return [ + PartialPartitionMetadata( + num_rows=self.scan_task.num_rows(), + size_bytes=self.scan_task.size_bytes(), + ) + ] + + +@dataclass(frozen=True) +class EmptyScan(SingleOutputInstruction): + schema: Schema + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return [MicroPartition.empty(self.schema)] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + assert len(input_metadatas) == 0 + + return [ + PartialPartitionMetadata( + num_rows=0, + size_bytes=0, + ) + ] + + @dataclass(frozen=True) class WriteFile(SingleOutputInstruction): file_format: FileFormat diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 9b3c94b9f8..a0b1b8de8f 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -1,7 +1,5 @@ from __future__ import annotations -from dataclasses import dataclass - from daft.daft import ( FileFormat, IOConfig, @@ -15,7 +13,7 @@ from daft.expressions import Expression, ExpressionsProjection from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema -from daft.runners.partitioning import PartialPartitionMetadata, PartitionT +from daft.runners.partitioning import PartitionT from daft.table import MicroPartition @@ -30,7 +28,7 @@ def scan_with_tasks( # We can instead right-size and bundle the ScanTask into single-instruction bulk reads. for scan_task in scan_tasks: scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction( - instruction=ScanWithTask(scan_task), + instruction=execution_step.ScanWithTask(scan_task), # Set the filesize as the memory request. # (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.) resource_request=ResourceRequest(memory_bytes=scan_task.size_bytes()), @@ -43,53 +41,12 @@ def empty_scan( ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: """yield a plan to create an empty Partition""" scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction( - instruction=EmptyScan(schema=schema), + instruction=execution_step.EmptyScan(schema=schema), resource_request=ResourceRequest(memory_bytes=0), ) yield scan_step -@dataclass(frozen=True) -class ScanWithTask(execution_step.SingleOutputInstruction): - scan_task: ScanTask - - def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - return self._scan(inputs) - - def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - assert len(inputs) == 0 - table = MicroPartition._from_scan_task(self.scan_task) - return [table] - - def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: - assert len(input_metadatas) == 0 - - return [ - PartialPartitionMetadata( - num_rows=self.scan_task.num_rows(), - size_bytes=self.scan_task.size_bytes(), - ) - ] - - -@dataclass(frozen=True) -class EmptyScan(execution_step.SingleOutputInstruction): - schema: Schema - - def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - return [MicroPartition.empty(self.schema)] - - def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: - assert len(input_metadatas) == 0 - - return [ - PartialPartitionMetadata( - num_rows=0, - size_bytes=0, - ) - ] - - def project( input: physical_plan.InProgressPhysicalPlan[PartitionT], projection: list[PyExpr], resource_request: ResourceRequest ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 33cd2cf568..2ff03ea6cc 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -40,6 +40,7 @@ MultiOutputPartitionTask, PartitionTask, ReduceInstruction, + ScanWithTask, SingleOutputPartitionTask, ) from daft.filesystem import glob_path_with_stats @@ -652,6 +653,8 @@ def _build_partitions( build_remote = ( fanout_pipeline if isinstance(task.instructions[-1], FanoutInstruction) else single_partition_pipeline ) + if isinstance(task.instructions[0], ScanWithTask): + ray_options["scheduling_strategy"] = "SPREAD" build_remote = build_remote.options(**ray_options) [metadatas_ref, *partitions] = build_remote.remote( daft_execution_config_objref, task.instructions, *task.inputs