From 464a7a1efe5f308edea29086bda5d992ea5d7176 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 2 Oct 2024 16:20:42 -0700 Subject: [PATCH] [BUG] Add proper resources to Ray stateful UDF actor --- daft/execution/physical_plan.py | 1 + daft/runners/pyrunner.py | 15 +++++++++++---- daft/runners/ray_runner.py | 16 ++++++++++++++-- daft/runners/runner.py | 3 ++- tests/actor_pool/test_pyactor_pool.py | 4 +++- 5 files changed, 31 insertions(+), 8 deletions(-) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 91be85e02a..03c88b4779 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -241,6 +241,7 @@ def actor_pool_project( with get_context().runner().actor_pool_context( actor_pool_name, actor_resource_request, + task_resource_request, num_actors, projection, ) as actor_pool_id: diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 31c56c3ad4..16e939fbca 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -337,20 +337,27 @@ def run_iter_tables( @contextlib.contextmanager def actor_pool_context( - self, name: str, resource_request: ResourceRequest, num_actors: int, projection: ExpressionsProjection + self, + name: str, + actor_resource_request: ResourceRequest, + _: ResourceRequest, + num_actors: int, + projection: ExpressionsProjection, ) -> Iterator[str]: actor_pool_id = f"py_actor_pool-{name}" - total_resource_request = resource_request * num_actors + total_resource_request = actor_resource_request * num_actors admitted = self._attempt_admit_task(total_resource_request) if not admitted: raise RuntimeError( - f"Not enough resources available to admit {num_actors} actors, each with resource request: {resource_request}" + f"Not enough resources available to admit {num_actors} actors, each with resource request: {actor_resource_request}" ) try: - self._actor_pools[actor_pool_id] = PyActorPool(actor_pool_id, num_actors, resource_request, projection) + self._actor_pools[actor_pool_id] = PyActorPool( + actor_pool_id, num_actors, actor_resource_request, projection + ) self._actor_pools[actor_pool_id].setup() logger.debug("Created actor pool %s with resources: %s", actor_pool_id, total_resource_request) yield actor_pool_id diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index d29a15c9f2..c0579e11e5 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -986,8 +986,12 @@ def __init__( self._projection = projection def setup(self) -> None: + ray_options = _get_ray_task_options(self._resource_request_per_actor) + self._actors = [ - DaftRayActor.options(name=f"rank={rank}-{self._id}").remote(self._execution_config, self._projection) # type: ignore + DaftRayActor.options(name=f"rank={rank}-{self._id}", **ray_options).remote( # type: ignore + self._execution_config, self._projection + ) for rank in range(self._num_actors) ] @@ -1155,8 +1159,16 @@ def run_iter_tables( @contextlib.contextmanager def actor_pool_context( - self, name: str, resource_request: ResourceRequest, num_actors: PartID, projection: ExpressionsProjection + self, + name: str, + actor_resource_request: ResourceRequest, + task_resource_request: ResourceRequest, + num_actors: PartID, + projection: ExpressionsProjection, ) -> Iterator[str]: + # Ray runs actor methods serially, so the resource request for an actor should be both the actor's resources and the task's resources + resource_request = actor_resource_request + task_resource_request + execution_config = get_context().daft_execution_config if self.ray_client_mode: try: diff --git a/daft/runners/runner.py b/daft/runners/runner.py index c1dd30f64e..730f5e1a4a 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -67,7 +67,8 @@ def run_iter_tables( def actor_pool_context( self, name: str, - resource_request: ResourceRequest, + actor_resource_request: ResourceRequest, + task_resource_request: ResourceRequest, num_actors: int, projection: ExpressionsProjection, ) -> Iterator[str]: diff --git a/tests/actor_pool/test_pyactor_pool.py b/tests/actor_pool/test_pyactor_pool.py index e95feec9ed..f34d91bd7b 100644 --- a/tests/actor_pool/test_pyactor_pool.py +++ b/tests/actor_pool/test_pyactor_pool.py @@ -69,5 +69,7 @@ def test_pyactor_pool_not_enough_resources(): assert isinstance(runner, PyRunner) with pytest.raises(RuntimeError, match=f"Requested {float(cpu_count + 1)} CPUs but found only"): - with runner.actor_pool_context("my-pool", ResourceRequest(num_cpus=1), cpu_count + 1, projection) as _: + with runner.actor_pool_context( + "my-pool", ResourceRequest(num_cpus=1), ResourceRequest(), cpu_count + 1, projection + ) as _: pass