Skip to content

Commit

Permalink
[PERF] Spread scan tasks over Ray cluster. (#1950)
Browse files Browse the repository at this point in the history
This PR forces a `SPREAD` scheduling strategy for scan tasks when using
the Ray runner. This should result in better load balancing of read
tasks across the Ray cluster, yielding:
- better utilization of the aggregate network bandwidth of the cluster,
- better memory stability due to a more even post-read object
distribution,
- better performance of downstream parallel compute operations due to a
more even distribution of data over the compute bandwidth of the
cluster.

Closes #1940
  • Loading branch information
clarkzinzow authored Feb 26, 2024
1 parent e8697b2 commit 1a94752
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 49 deletions.
43 changes: 42 additions & 1 deletion daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
else:
from typing import Protocol

from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest
from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask
from daft.expressions import Expression, ExpressionsProjection, col
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
Expand Down Expand Up @@ -290,6 +290,47 @@ def num_outputs(self) -> int:
return 1


@dataclass(frozen=True)
class ScanWithTask(SingleOutputInstruction):
scan_task: ScanTask

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._scan(inputs)

def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
assert len(inputs) == 0
table = MicroPartition._from_scan_task(self.scan_task)
return [table]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=self.scan_task.num_rows(),
size_bytes=self.scan_task.size_bytes(),
)
]


@dataclass(frozen=True)
class EmptyScan(SingleOutputInstruction):
schema: Schema

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return [MicroPartition.empty(self.schema)]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=0,
size_bytes=0,
)
]


@dataclass(frozen=True)
class WriteFile(SingleOutputInstruction):
file_format: FileFormat
Expand Down
49 changes: 3 additions & 46 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass

from daft.daft import (
FileFormat,
IOConfig,
Expand All @@ -15,7 +13,7 @@
from daft.expressions import Expression, ExpressionsProjection
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import PartialPartitionMetadata, PartitionT
from daft.runners.partitioning import PartitionT
from daft.table import MicroPartition


Expand All @@ -30,7 +28,7 @@ def scan_with_tasks(
# We can instead right-size and bundle the ScanTask into single-instruction bulk reads.
for scan_task in scan_tasks:
scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction(
instruction=ScanWithTask(scan_task),
instruction=execution_step.ScanWithTask(scan_task),
# Set the filesize as the memory request.
# (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.)
resource_request=ResourceRequest(memory_bytes=scan_task.size_bytes()),
Expand All @@ -43,53 +41,12 @@ def empty_scan(
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
"""yield a plan to create an empty Partition"""
scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction(
instruction=EmptyScan(schema=schema),
instruction=execution_step.EmptyScan(schema=schema),
resource_request=ResourceRequest(memory_bytes=0),
)
yield scan_step


@dataclass(frozen=True)
class ScanWithTask(execution_step.SingleOutputInstruction):
scan_task: ScanTask

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._scan(inputs)

def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
assert len(inputs) == 0
table = MicroPartition._from_scan_task(self.scan_task)
return [table]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=self.scan_task.num_rows(),
size_bytes=self.scan_task.size_bytes(),
)
]


@dataclass(frozen=True)
class EmptyScan(execution_step.SingleOutputInstruction):
schema: Schema

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return [MicroPartition.empty(self.schema)]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=0,
size_bytes=0,
)
]


def project(
input: physical_plan.InProgressPhysicalPlan[PartitionT], projection: list[PyExpr], resource_request: ResourceRequest
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
Expand Down
13 changes: 11 additions & 2 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
MultiOutputPartitionTask,
PartitionTask,
ReduceInstruction,
ScanWithTask,
SingleOutputPartitionTask,
)
from daft.filesystem import glob_path_with_stats
Expand Down Expand Up @@ -644,14 +645,22 @@ def _build_partitions(
ray_options = {**ray_options, **_get_ray_task_options(task.resource_request)}

if isinstance(task.instructions[0], ReduceInstruction):
build_remote = reduce_and_fanout if isinstance(task.instructions[-1], FanoutInstruction) else reduce_pipeline
build_remote = (
reduce_and_fanout
if task.instructions and isinstance(task.instructions[-1], FanoutInstruction)
else reduce_pipeline
)
build_remote = build_remote.options(**ray_options)
[metadatas_ref, *partitions] = build_remote.remote(daft_execution_config_objref, task.instructions, task.inputs)

else:
build_remote = (
fanout_pipeline if isinstance(task.instructions[-1], FanoutInstruction) else single_partition_pipeline
fanout_pipeline
if task.instructions and isinstance(task.instructions[-1], FanoutInstruction)
else single_partition_pipeline
)
if task.instructions and isinstance(task.instructions[0], ScanWithTask):
ray_options["scheduling_strategy"] = "SPREAD"
build_remote = build_remote.options(**ray_options)
[metadatas_ref, *partitions] = build_remote.remote(
daft_execution_config_objref, task.instructions, *task.inputs
Expand Down

0 comments on commit 1a94752

Please sign in to comment.