From 697c5f2c98781ad5e49fb3e818ecb024abf8803b Mon Sep 17 00:00:00 2001
From: clarkzinzow <clarkzinzow@gmail.com>
Date: Mon, 26 Feb 2024 12:33:34 -0800
Subject: [PATCH] Spread scan tasks over Ray cluster.

---
 daft/execution/execution_step.py          | 43 +++++++++++++++++++-
 daft/execution/rust_physical_plan_shim.py | 49 ++---------------------
 daft/runners/ray_runner.py                |  3 ++
 3 files changed, 48 insertions(+), 47 deletions(-)

diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py
index 5bd5e483df..a70ab0a1da 100644
--- a/daft/execution/execution_step.py
+++ b/daft/execution/execution_step.py
@@ -11,7 +11,7 @@
 else:
     from typing import Protocol
 
-from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest
+from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask
 from daft.expressions import Expression, ExpressionsProjection, col
 from daft.logical.map_partition_ops import MapPartitionOp
 from daft.logical.schema import Schema
@@ -290,6 +290,47 @@ def num_outputs(self) -> int:
         return 1
 
 
+@dataclass(frozen=True)
+class ScanWithTask(SingleOutputInstruction):
+    scan_task: ScanTask
+
+    def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
+        return self._scan(inputs)
+
+    def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
+        assert len(inputs) == 0
+        table = MicroPartition._from_scan_task(self.scan_task)
+        return [table]
+
+    def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
+        assert len(input_metadatas) == 0
+
+        return [
+            PartialPartitionMetadata(
+                num_rows=self.scan_task.num_rows(),
+                size_bytes=self.scan_task.size_bytes(),
+            )
+        ]
+
+
+@dataclass(frozen=True)
+class EmptyScan(SingleOutputInstruction):
+    schema: Schema
+
+    def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
+        return [MicroPartition.empty(self.schema)]
+
+    def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
+        assert len(input_metadatas) == 0
+
+        return [
+            PartialPartitionMetadata(
+                num_rows=0,
+                size_bytes=0,
+            )
+        ]
+
+
 @dataclass(frozen=True)
 class WriteFile(SingleOutputInstruction):
     file_format: FileFormat
diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py
index 9b3c94b9f8..a0b1b8de8f 100644
--- a/daft/execution/rust_physical_plan_shim.py
+++ b/daft/execution/rust_physical_plan_shim.py
@@ -1,7 +1,5 @@
 from __future__ import annotations
 
-from dataclasses import dataclass
-
 from daft.daft import (
     FileFormat,
     IOConfig,
@@ -15,7 +13,7 @@
 from daft.expressions import Expression, ExpressionsProjection
 from daft.logical.map_partition_ops import MapPartitionOp
 from daft.logical.schema import Schema
-from daft.runners.partitioning import PartialPartitionMetadata, PartitionT
+from daft.runners.partitioning import PartitionT
 from daft.table import MicroPartition
 
 
@@ -30,7 +28,7 @@ def scan_with_tasks(
     # We can instead right-size and bundle the ScanTask into single-instruction bulk reads.
     for scan_task in scan_tasks:
         scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction(
-            instruction=ScanWithTask(scan_task),
+            instruction=execution_step.ScanWithTask(scan_task),
             # Set the filesize as the memory request.
             # (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.)
             resource_request=ResourceRequest(memory_bytes=scan_task.size_bytes()),
@@ -43,53 +41,12 @@ def empty_scan(
 ) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
     """yield a plan to create an empty Partition"""
     scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction(
-        instruction=EmptyScan(schema=schema),
+        instruction=execution_step.EmptyScan(schema=schema),
         resource_request=ResourceRequest(memory_bytes=0),
     )
     yield scan_step
 
 
-@dataclass(frozen=True)
-class ScanWithTask(execution_step.SingleOutputInstruction):
-    scan_task: ScanTask
-
-    def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
-        return self._scan(inputs)
-
-    def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
-        assert len(inputs) == 0
-        table = MicroPartition._from_scan_task(self.scan_task)
-        return [table]
-
-    def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
-        assert len(input_metadatas) == 0
-
-        return [
-            PartialPartitionMetadata(
-                num_rows=self.scan_task.num_rows(),
-                size_bytes=self.scan_task.size_bytes(),
-            )
-        ]
-
-
-@dataclass(frozen=True)
-class EmptyScan(execution_step.SingleOutputInstruction):
-    schema: Schema
-
-    def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
-        return [MicroPartition.empty(self.schema)]
-
-    def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
-        assert len(input_metadatas) == 0
-
-        return [
-            PartialPartitionMetadata(
-                num_rows=0,
-                size_bytes=0,
-            )
-        ]
-
-
 def project(
     input: physical_plan.InProgressPhysicalPlan[PartitionT], projection: list[PyExpr], resource_request: ResourceRequest
 ) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py
index 33cd2cf568..2ff03ea6cc 100644
--- a/daft/runners/ray_runner.py
+++ b/daft/runners/ray_runner.py
@@ -40,6 +40,7 @@
     MultiOutputPartitionTask,
     PartitionTask,
     ReduceInstruction,
+    ScanWithTask,
     SingleOutputPartitionTask,
 )
 from daft.filesystem import glob_path_with_stats
@@ -652,6 +653,8 @@ def _build_partitions(
         build_remote = (
             fanout_pipeline if isinstance(task.instructions[-1], FanoutInstruction) else single_partition_pipeline
         )
+        if isinstance(task.instructions[0], ScanWithTask):
+            ray_options["scheduling_strategy"] = "SPREAD"
         build_remote = build_remote.options(**ray_options)
         [metadatas_ref, *partitions] = build_remote.remote(
             daft_execution_config_objref, task.instructions, *task.inputs