Skip to content

Commit

Permalink
Spread scan tasks over Ray cluster.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Feb 26, 2024
1 parent 0a51db1 commit 697c5f2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 47 deletions.
43 changes: 42 additions & 1 deletion daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
else:
from typing import Protocol

from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest
from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask
from daft.expressions import Expression, ExpressionsProjection, col
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
Expand Down Expand Up @@ -290,6 +290,47 @@ def num_outputs(self) -> int:
return 1


@dataclass(frozen=True)
class ScanWithTask(SingleOutputInstruction):
scan_task: ScanTask

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._scan(inputs)

def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
assert len(inputs) == 0
table = MicroPartition._from_scan_task(self.scan_task)
return [table]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=self.scan_task.num_rows(),
size_bytes=self.scan_task.size_bytes(),
)
]


@dataclass(frozen=True)
class EmptyScan(SingleOutputInstruction):
schema: Schema

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return [MicroPartition.empty(self.schema)]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=0,
size_bytes=0,
)
]


@dataclass(frozen=True)
class WriteFile(SingleOutputInstruction):
file_format: FileFormat
Expand Down
49 changes: 3 additions & 46 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass

from daft.daft import (
FileFormat,
IOConfig,
Expand All @@ -15,7 +13,7 @@
from daft.expressions import Expression, ExpressionsProjection
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import PartialPartitionMetadata, PartitionT
from daft.runners.partitioning import PartitionT
from daft.table import MicroPartition


Expand All @@ -30,7 +28,7 @@ def scan_with_tasks(
# We can instead right-size and bundle the ScanTask into single-instruction bulk reads.
for scan_task in scan_tasks:
scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction(
instruction=ScanWithTask(scan_task),
instruction=execution_step.ScanWithTask(scan_task),
# Set the filesize as the memory request.
# (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.)
resource_request=ResourceRequest(memory_bytes=scan_task.size_bytes()),
Expand All @@ -43,53 +41,12 @@ def empty_scan(
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
"""yield a plan to create an empty Partition"""
scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction(
instruction=EmptyScan(schema=schema),
instruction=execution_step.EmptyScan(schema=schema),
resource_request=ResourceRequest(memory_bytes=0),
)
yield scan_step


@dataclass(frozen=True)
class ScanWithTask(execution_step.SingleOutputInstruction):
scan_task: ScanTask

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._scan(inputs)

def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
assert len(inputs) == 0
table = MicroPartition._from_scan_task(self.scan_task)
return [table]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=self.scan_task.num_rows(),
size_bytes=self.scan_task.size_bytes(),
)
]


@dataclass(frozen=True)
class EmptyScan(execution_step.SingleOutputInstruction):
schema: Schema

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return [MicroPartition.empty(self.schema)]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=0,
size_bytes=0,
)
]


def project(
input: physical_plan.InProgressPhysicalPlan[PartitionT], projection: list[PyExpr], resource_request: ResourceRequest
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
Expand Down
3 changes: 3 additions & 0 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
MultiOutputPartitionTask,
PartitionTask,
ReduceInstruction,
ScanWithTask,
SingleOutputPartitionTask,
)
from daft.filesystem import glob_path_with_stats
Expand Down Expand Up @@ -652,6 +653,8 @@ def _build_partitions(
build_remote = (
fanout_pipeline if isinstance(task.instructions[-1], FanoutInstruction) else single_partition_pipeline
)
if isinstance(task.instructions[0], ScanWithTask):
ray_options["scheduling_strategy"] = "SPREAD"
build_remote = build_remote.options(**ray_options)
[metadatas_ref, *partitions] = build_remote.remote(
daft_execution_config_objref, task.instructions, *task.inputs
Expand Down

0 comments on commit 697c5f2

Please sign in to comment.