Skip to content

Commit

Permalink
[BUG] Fix actor pool initialization in ray client mode (#3028)
Browse files Browse the repository at this point in the history
This PR moves the `actor_pool_context` method from the ray runner to the
scheduler, and routes the relevant `ActorPoolManager` implementation to
`actor_pool_project`. That way, it does not accidentally get the wrong
actor pool context when the scheduler is running on a ray actor, which
we do when in ray client mode.

With this change, in addition to separating the `actor_pool_context`
method out of `Runner` into `ActorPoolManager`, I also move some other
things around to clean things up, especially so that `ray_runner.py` no
longer depends on the `pyrunner.py`.

Note:
- One of the tests was cut down because it needed 3 CPUs to run, which
combined with the 1 CPU for the scheduler actor, meant no more tasks
could even be scheduled. I plan on adding some more informative
errors/warnings when this happens in a future PR
  • Loading branch information
kevinzwang authored Oct 21, 2024
1 parent 6173006 commit 23d4a1f
Show file tree
Hide file tree
Showing 13 changed files with 249 additions and 197 deletions.
18 changes: 16 additions & 2 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _RayRunnerConfig(_RunnerConfig):
name = "ray"
address: str | None
max_task_backlog: int | None
force_client_mode: bool


def _get_runner_config_from_env() -> _RunnerConfig:
Expand All @@ -43,10 +44,18 @@ def _get_runner_config_from_env() -> _RunnerConfig:
2. RayRunner: set DAFT_RUNNER=ray and optionally RAY_ADDRESS=ray://...
"""
runner_from_envvar = os.getenv("DAFT_RUNNER")

task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG")
task_backlog = int(task_backlog_env) if task_backlog_env is not None else None

use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL")
use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None

ray_force_client_mode_env = os.getenv("DAFT_RAY_FORCE_CLIENT_MODE")
ray_force_client_mode = (
ray_force_client_mode_env.strip().lower() in ["1", "true"] if ray_force_client_mode_env else False
)

ray_is_initialized = False
in_ray_worker = False
try:
Expand All @@ -71,7 +80,8 @@ def _get_runner_config_from_env() -> _RunnerConfig:
ray_address = os.getenv("RAY_ADDRESS")
return _RayRunnerConfig(
address=ray_address,
max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
max_task_backlog=task_backlog,
force_client_mode=ray_force_client_mode,
)
elif runner_from_envvar and runner_from_envvar.upper() == "PY":
return _PyRunnerConfig(use_thread_pool=use_thread_pool)
Expand All @@ -82,7 +92,8 @@ def _get_runner_config_from_env() -> _RunnerConfig:
elif ray_is_initialized and not in_ray_worker:
return _RayRunnerConfig(
address=None, # No address supplied, use the existing connection
max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
max_task_backlog=task_backlog,
force_client_mode=ray_force_client_mode,
)

# Fall back on PyRunner
Expand Down Expand Up @@ -155,6 +166,7 @@ def _get_runner(self) -> Runner:
self._runner = RayRunner(
address=runner_config.address,
max_task_backlog=runner_config.max_task_backlog,
force_client_mode=runner_config.force_client_mode,
)
elif runner_config.name == "py":
from daft.runners.pyrunner import PyRunner
Expand Down Expand Up @@ -189,6 +201,7 @@ def set_runner_ray(
address: str | None = None,
noop_if_initialized: bool = False,
max_task_backlog: int | None = None,
force_client_mode: bool = False,
) -> DaftContext:
"""Set the runner for executing Daft dataframes to a Ray cluster
Expand Down Expand Up @@ -222,6 +235,7 @@ def set_runner_ray(
ctx._runner_config = _RayRunnerConfig(
address=address,
max_task_backlog=max_task_backlog,
force_client_mode=force_client_mode,
)
ctx._disallow_set_runner = True
return ctx
Expand Down
4 changes: 3 additions & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1639,7 +1639,9 @@ class PhysicalPlanScheduler:
def repr_ascii(self, simple: bool) -> str: ...
def repr_mermaid(self, options: MermaidOptions) -> str: ...
def to_json_string(self) -> str: ...
def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.InProgressPhysicalPlan: ...
def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], actor_pool_manager: Any
) -> physical_plan.InProgressPhysicalPlan: ...
def run(self, psets: dict[str, list[PartitionT]]) -> Iterator[PyMicroPartition]: ...

class AdaptivePhysicalPlanScheduler:
Expand Down
3 changes: 1 addition & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
from daft.errors import ExpressionTypeError
from daft.expressions import Expression, ExpressionsProjection, col, lit
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import PartitionCacheEntry, PartitionSet
from daft.runners.pyrunner import LocalPartitionSet
from daft.runners.partitioning import LocalPartitionSet, PartitionCacheEntry, PartitionSet
from daft.table import MicroPartition
from daft.viz import DataFrameDisplay

Expand Down
8 changes: 4 additions & 4 deletions daft/execution/native_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
if TYPE_CHECKING:
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
LocalMaterializedResult,
MaterializedResult,
PartitionT,
)
from daft.runners.pyrunner import PyMaterializedResult


class NativeExecutor:
Expand All @@ -31,13 +31,13 @@ def run(
psets: dict[str, list[MaterializedResult[PartitionT]]],
daft_execution_config: PyDaftExecutionConfig,
results_buffer_size: int | None,
) -> Iterator[PyMaterializedResult]:
from daft.runners.pyrunner import PyMaterializedResult
) -> Iterator[LocalMaterializedResult]:
from daft.runners.partitioning import LocalMaterializedResult

psets_mp = {
part_id: [part.micropartition()._micropartition for part in parts] for part_id, parts in psets.items()
}
return (
PyMaterializedResult(MicroPartition._from_pymicropartition(part))
LocalMaterializedResult(MicroPartition._from_pymicropartition(part))
for part in self._executor.run(psets_mp, daft_execution_config, results_buffer_size)
)
31 changes: 30 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from __future__ import annotations

import collections
import contextlib
import itertools
import logging
import math
from abc import abstractmethod
from collections import deque
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -205,9 +207,36 @@ def pipeline_instruction(
)


class ActorPoolManager:
@abstractmethod
@contextlib.contextmanager
def actor_pool_context(
self,
name: str,
actor_resource_request: ResourceRequest,
task_resource_request: ResourceRequest,
num_actors: int,
projection: ExpressionsProjection,
) -> Iterator[str]:
"""Creates a pool of actors which can execute work, and yield a context in which the pool can be used.
Also yields a `str` ID which clients can use to refer to the actor pool when submitting tasks.
Note that attempting to do work outside this context will result in errors!
Args:
name: Name of the actor pool for debugging/observability
resource_request: Requested amount of resources for each actor
num_actors: Number of actors to spin up
projection: Projection to be run on the incoming data (contains Stateful UDFs as well as other stateless expressions such as aliases)
"""
...


def actor_pool_project(
child_plan: InProgressPhysicalPlan[PartitionT],
projection: ExpressionsProjection,
actor_pool_manager: ActorPoolManager,
resource_request: execution_step.ResourceRequest,
num_actors: int,
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down Expand Up @@ -238,7 +267,7 @@ def actor_pool_project(
num_gpus=resource_request.num_gpus, memory_bytes=resource_request.memory_bytes
)

with get_context().runner().actor_pool_context(
with actor_pool_manager.actor_pool_context(
actor_pool_name,
actor_resource_request,
task_resource_request,
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def project(
def actor_pool_project(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
projection: list[PyExpr],
actor_pool_manager: physical_plan.ActorPoolManager,
resource_request: ResourceRequest | None,
num_actors: int,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
Expand All @@ -93,6 +94,7 @@ def actor_pool_project(
return physical_plan.actor_pool_project(
child_plan=input,
projection=expr_projection,
actor_pool_manager=actor_pool_manager,
resource_request=resource_request,
num_actors=num_actors,
)
Expand Down
2 changes: 1 addition & 1 deletion daft/io/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from daft.daft import IOConfig
from daft.dataframe import DataFrame
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.pyrunner import LocalPartitionSet
from daft.runners.partitioning import LocalPartitionSet
from daft.table import MicroPartition


Expand Down
11 changes: 9 additions & 2 deletions daft/plan_scheduler/physical_plan_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,16 @@ def to_json_string(self) -> str:
return self._scheduler.to_json_string()

def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], results_buffer_size: int | None
self,
psets: dict[str, list[PartitionT]],
actor_pool_manager: physical_plan.ActorPoolManager,
results_buffer_size: int | None,
) -> physical_plan.MaterializedPhysicalPlan:
return iter(physical_plan.Materialize(self._scheduler.to_partition_tasks(psets), results_buffer_size))
return iter(
physical_plan.Materialize(
self._scheduler.to_partition_tasks(psets, actor_pool_manager), results_buffer_size
)
)


class AdaptivePhysicalPlanScheduler:
Expand Down
88 changes: 87 additions & 1 deletion daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from uuid import uuid4

from daft.datatype import TimeUnit
from daft.table import MicroPartition

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

from daft.expressions.expressions import Expression
from daft.logical.schema import Schema
from daft.table import MicroPartition

PartID = int

Expand Down Expand Up @@ -271,6 +271,92 @@ def wait(self) -> None:
raise NotImplementedError()


class LocalPartitionSet(PartitionSet[MicroPartition]):
_partitions: dict[PartID, MaterializedResult[MicroPartition]]

def __init__(self) -> None:
super().__init__()
self._partitions = {}

def items(self) -> list[tuple[PartID, MaterializedResult[MicroPartition]]]:
return sorted(self._partitions.items())

def _get_merged_micropartition(self) -> MicroPartition:
ids_and_partitions = self.items()
assert ids_and_partitions[0][0] == 0
assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions)
return MicroPartition.concat([part.partition() for id, part in ids_and_partitions])

def _get_preview_micropartitions(self, num_rows: int) -> list[MicroPartition]:
ids_and_partitions = self.items()
preview_parts = []
for _, mat_result in ids_and_partitions:
part: MicroPartition = mat_result.partition()
part_len = len(part)
if part_len >= num_rows: # if this part has enough rows, take what we need and break
preview_parts.append(part.slice(0, num_rows))
break
else: # otherwise, take the whole part and keep going
num_rows -= part_len
preview_parts.append(part)
return preview_parts

def get_partition(self, idx: PartID) -> MaterializedResult[MicroPartition]:
return self._partitions[idx]

def set_partition(self, idx: PartID, part: MaterializedResult[MicroPartition]) -> None:
self._partitions[idx] = part

def set_partition_from_table(self, idx: PartID, part: MicroPartition) -> None:
self._partitions[idx] = LocalMaterializedResult(part, PartitionMetadata.from_table(part))

def delete_partition(self, idx: PartID) -> None:
del self._partitions[idx]

def has_partition(self, idx: PartID) -> bool:
return idx in self._partitions

def __len__(self) -> int:
return sum(len(partition.partition()) for partition in self._partitions.values())

def size_bytes(self) -> int | None:
size_bytes_ = [partition.partition().size_bytes() for partition in self._partitions.values()]
size_bytes: list[int] = [size for size in size_bytes_ if size is not None]
if len(size_bytes) != len(size_bytes_):
return None
else:
return sum(size_bytes)

def num_partitions(self) -> int:
return len(self._partitions)

def wait(self) -> None:
pass


@dataclass
class LocalMaterializedResult(MaterializedResult[MicroPartition]):
_partition: MicroPartition
_metadata: PartitionMetadata | None = None

def partition(self) -> MicroPartition:
return self._partition

def micropartition(self) -> MicroPartition:
return self._partition

def metadata(self) -> PartitionMetadata:
if self._metadata is None:
self._metadata = PartitionMetadata.from_table(self._partition)
return self._metadata

def cancel(self) -> None:
return None

def _noop(self, _: MicroPartition) -> None:
return None


@dataclass(eq=False, repr=False)
class PartitionCacheEntry:
key: str
Expand Down
Loading

0 comments on commit 23d4a1f

Please sign in to comment.