Skip to content

Commit

Permalink
[CHORE] Refactor RayRunner so that we can add tracing (#3163)
Browse files Browse the repository at this point in the history
This PR refactors the RayRunner so that it is easier to add tracing.

I also added more docstrings to make it clearer about what is happening.

## Code Changes

I highlight changes that were made in the code for easier review.

1. I removed the `next_step` state, and instead expose a new `has_next`
return variable from `self._construct_dispatch_batch`. This cleans up
the code because iteration on the physical plan now ONLY happens inside
of `self._construct_dispatch_batch` instead of being scattered across
the scheduling loop.
2. I cleaned up the logic in `self._await_tasks` by explicitly waiting
on one task (with `timeout=None`) to first wait for any task to
complete, and then perform an actual wait on all tasks (with
`timeout=0.01`) to actually retrieve tasks that are ready.
3. I pulled out `self._is_active` and `self._place_in_queue` into
methods, instead of them being locally-defined functions

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Nov 5, 2024
1 parent 944c6da commit 64e35f8
Showing 1 changed file with 204 additions and 135 deletions.
339 changes: 204 additions & 135 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from ray.data.block import Block as RayDatasetBlock
from ray.data.dataset import Dataset as RayDataset

from daft.execution.physical_plan import MaterializedPhysicalPlan
from daft.logical.builder import LogicalPlanBuilder
from daft.plan_scheduler import PhysicalPlanScheduler

Expand Down Expand Up @@ -655,6 +656,141 @@ def teardown_actor_pool(self, name: str) -> None:
self._actor_pools[name].teardown()
del self._actor_pools[name]

def _construct_dispatch_batch(
self,
execution_id: str,
tasks: MaterializedPhysicalPlan,
dispatches_allowed: int,
) -> tuple[list[PartitionTask], bool]:
"""Constructs a batch of PartitionTasks that should be dispatched
Args:
execution_id: The ID of the current execution.
tasks: The iterator over the physical plan.
dispatches_allowed (int): The maximum number of tasks that can be dispatched in this batch.
Returns:
tuple[list[PartitionTask], bool]: A tuple containing:
- A list of PartitionTasks to be dispatched.
- A pagination boolean indicating whether or not there are more tasks to be had by calling _construct_dispatch_batch again
"""
tasks_to_dispatch: list[PartitionTask] = []

# Loop until:
# - Reached the limit of the number of tasks we are allowed to dispatch
# - Encounter a `None` as the next step (short-circuit and return has_next=False)
while len(tasks_to_dispatch) < dispatches_allowed and self._is_active(execution_id):
next_step = next(tasks)

# CASE: Blocked on already dispatched tasks
# Early terminate and mark has_next=False
if next_step is None:
return tasks_to_dispatch, False

# CASE: A final result
# Place it in the result queue (potentially block on space to be available)
elif isinstance(next_step, MaterializedResult):
self._place_in_queue(execution_id, next_step)

# CASE: No-op task
# Just run it locally immediately.
elif len(next_step.instructions) == 0:
logger.debug("Running task synchronously in main thread: %s", next_step)
assert (
len(next_step.partial_metadatas) == 1
), "No-op tasks must have one output by definition, since there are no instructions to run"
[single_partial] = next_step.partial_metadatas
if single_partial.num_rows is None:
[single_meta] = ray.get(get_metas.remote(next_step.inputs))
accessor = PartitionMetadataAccessor.from_metadata_list(
[single_meta.merge_with_partial(single_partial)]
)
else:
accessor = PartitionMetadataAccessor.from_metadata_list(
[
PartitionMetadata(
num_rows=single_partial.num_rows,
size_bytes=single_partial.size_bytes,
boundaries=single_partial.boundaries,
)
]
)

next_step.set_result([RayMaterializedResult(partition, accessor, 0) for partition in next_step.inputs])
next_step.set_done()

# CASE: Actual task that needs to be dispatched
else:
tasks_to_dispatch.append(next_step)

return tasks_to_dispatch, True

def _dispatch_tasks(
self,
tasks_to_dispatch: list[PartitionTask],
daft_execution_config: PyDaftExecutionConfig,
) -> Iterator[tuple[PartitionTask, list[ray.ObjectRef]]]:
"""Iteratively Dispatches a batch of tasks to the Ray backend"""

for task in tasks_to_dispatch:
if task.actor_pool_id is None:
results = _build_partitions(daft_execution_config, task)
else:
actor_pool = self._actor_pools.get(task.actor_pool_id)
assert actor_pool is not None, "Ray actor pool must live for as long as the tasks."
results = _build_partitions_on_actor_pool(task, actor_pool)
logger.debug("%s -> %s", task, results)

yield task, results

def _await_tasks(
self,
inflight_ref_to_task_id: dict[ray.ObjectRef, str],
) -> list[ray.ObjectRef]:
"""Awaits for tasks to be completed. Returns tasks that are ready.
NOTE: This method blocks until at least 1 task is ready. Then it will return as many ready tasks as it can.
"""
if len(inflight_ref_to_task_id) == 0:
return []

# Await on (any) task to be ready with an unlimited timeout
ray.wait(
list(inflight_ref_to_task_id.keys()),
num_returns=1,
timeout=None,
fetch_local=False,
)

# Now, grab as many ready tasks as possible with a 0.01s timeout
timeout = 0.01
num_returns = len(inflight_ref_to_task_id)
readies, _ = ray.wait(
list(inflight_ref_to_task_id.keys()),
num_returns=num_returns,
timeout=timeout,
fetch_local=False,
)

return readies

def _is_active(self, execution_id: str):
"""Checks if the execution for the provided `execution_id` is still active"""
return self.active_by_df.get(execution_id, False)

def _place_in_queue(self, execution_id: str, item: ray.ObjectRef):
"""Places a result into the queue for the provided `execution_id
NOTE: This will block and poll busily until space is available on the queue
`"""
while self._is_active(execution_id):
try:
self.results_by_df[execution_id].put(item, timeout=0.1)
break
except Full:
pass

def _run_plan(
self,
plan_scheduler: PhysicalPlanScheduler,
Expand All @@ -663,14 +799,6 @@ def _run_plan(
) -> None:
# Get executable tasks from plan scheduler.
results_buffer_size = self.results_buffer_size_by_df[result_uuid]
tasks = plan_scheduler.to_partition_tasks(
psets,
self,
# Attempt to subtract 1 from results_buffer_size because the return Queue size is already 1
# If results_buffer_size=1 though, we can't do much and the total buffer size actually has to be >= 2
# because we have two buffers (the Queue and the buffer inside the `materialize` generator)
None if results_buffer_size is None else max(results_buffer_size - 1, 1),
)

daft_execution_config = self.execution_configs_objref_by_df[result_uuid]
inflight_tasks: dict[str, PartitionTask[ray.ObjectRef]] = dict()
Expand All @@ -684,25 +812,35 @@ def _run_plan(
f"{datetime.replace(datetime.now(), second=0, microsecond=0).isoformat()[:-3]}.json"
)

def is_active():
return self.active_by_df.get(result_uuid, False)

def place_in_queue(item):
while is_active():
try:
self.results_by_df[result_uuid].put(item, timeout=0.1)
break
except Full:
pass

with profiler(profile_filename):
tasks = plan_scheduler.to_partition_tasks(
psets,
self,
# Attempt to subtract 1 from results_buffer_size because the return Queue size is already 1
# If results_buffer_size=1 though, we can't do much and the total buffer size actually has to be >= 2
# because we have two buffers (the Queue and the buffer inside the `materialize` generator)
None if results_buffer_size is None else max(results_buffer_size - 1, 1),
)
try:
next_step = next(tasks)

while is_active(): # Loop: Dispatch -> await.
while is_active(): # Loop: Dispatch (get tasks -> batch dispatch).
tasks_to_dispatch: list[PartitionTask] = []

###
# Scheduling Loop:
#
# DispatchBatching ─► Dispatch
# ▲ │ ───────► Await
# └────────────────────────┘ │
# ▲ │
# └───────────────────────────────┘
###
while self._is_active(result_uuid):
###
# Dispatch Loop:
#
# DispatchBatching ─► Dispatch
# ▲ │
# └────────────────────────┘
###
while self._is_active(result_uuid):
# Update available cluster resources
# TODO: improve control loop code to be more understandable and dynamically adjust backlog
cores: int = max(
next(num_cpus_provider) - self.reserved_cores, 1
Expand All @@ -711,136 +849,67 @@ def place_in_queue(item):
dispatches_allowed = max_inflight_tasks - len(inflight_tasks)
dispatches_allowed = min(cores, dispatches_allowed)

# Loop: Get a batch of tasks.
while len(tasks_to_dispatch) < dispatches_allowed and is_active():
if next_step is None:
# Blocked on already dispatched tasks; await some tasks.
break

elif isinstance(next_step, MaterializedResult):
# A final result.
place_in_queue(next_step)
next_step = next(tasks)

# next_step is a task.

# If it is a no-op task, just run it locally immediately.
elif len(next_step.instructions) == 0:
logger.debug("Running task synchronously in main thread: %s", next_step)
assert (
len(next_step.partial_metadatas) == 1
), "No-op tasks must have one output by definition, since there are no instructions to run"
[single_partial] = next_step.partial_metadatas
if single_partial.num_rows is None:
[single_meta] = ray.get(get_metas.remote(next_step.inputs))
accessor = PartitionMetadataAccessor.from_metadata_list(
[single_meta.merge_with_partial(single_partial)]
)
else:
accessor = PartitionMetadataAccessor.from_metadata_list(
[
PartitionMetadata(
num_rows=single_partial.num_rows,
size_bytes=single_partial.size_bytes,
boundaries=single_partial.boundaries,
)
]
)

next_step.set_result(
[RayMaterializedResult(partition, accessor, 0) for partition in next_step.inputs]
)
next_step.set_done()
next_step = next(tasks)

else:
# Add the task to the batch.
tasks_to_dispatch.append(next_step)
next_step = next(tasks)

# Dispatch the batch of tasks.
# Dispatch Batching
tasks_to_dispatch, has_next = self._construct_dispatch_batch(
result_uuid,
tasks,
dispatches_allowed,
)

logger.debug(
"%ss: RayRunner dispatching %s tasks",
(datetime.now() - start).total_seconds(),
len(tasks_to_dispatch),
)

if not is_active():
if not self._is_active(result_uuid):
break

for task in tasks_to_dispatch:
if task.actor_pool_id is None:
results = _build_partitions(daft_execution_config, task)
else:
actor_pool = self._actor_pools.get(task.actor_pool_id)
assert actor_pool is not None, "Ray actor pool must live for as long as the tasks."
results = _build_partitions_on_actor_pool(task, actor_pool)
logger.debug("%s -> %s", task, results)
# Dispatch
for task, result_obj_refs in self._dispatch_tasks(
tasks_to_dispatch,
daft_execution_config,
):
inflight_tasks[task.id()] = task
for result in results:
for result in result_obj_refs:
inflight_ref_to_task[result] = task.id()

pbar.mark_task_start(task)

if dispatches_allowed == 0 or next_step is None:
break

# Await a batch of tasks.
# (Awaits the next task, and then the next batch of tasks within 10ms.)

dispatch = datetime.now()
completed_task_ids = []
for wait_for in ("next_one", "next_batch"):
if not is_active():
break

if wait_for == "next_one":
num_returns = 1
timeout = None
elif wait_for == "next_batch":
num_returns = len(inflight_ref_to_task)
timeout = 0.01 # 10ms

if num_returns == 0:
# Break the dispatch batching/dispatch loop if no more dispatches allowed, or physical plan
# needs work for forward progress
if dispatches_allowed == 0 or not has_next:
break

readies, _ = ray.wait(
list(inflight_ref_to_task.keys()),
num_returns=num_returns,
timeout=timeout,
fetch_local=False,
)

for ready in readies:
if ready in inflight_ref_to_task:
task_id = inflight_ref_to_task[ready]
completed_task_ids.append(task_id)
# Mark the entire task associated with the result as done.
task = inflight_tasks[task_id]
task.set_done()

if isinstance(task, SingleOutputPartitionTask):
del inflight_ref_to_task[ready]
elif isinstance(task, MultiOutputPartitionTask):
for partition in task.partitions():
del inflight_ref_to_task[partition]

pbar.mark_task_done(task)
del inflight_tasks[task_id]

logger.debug(
"%ss to await results from %s", (datetime.now() - dispatch).total_seconds(), completed_task_ids
)

if next_step is None:
next_step = next(tasks)
###
# Await:
# Wait for some work to be completed from the current wave's dispatch
# Then we perform the necessary record-keeping on tasks that were retrieved as ready.
###
readies = self._await_tasks(inflight_ref_to_task)
for ready in readies:
if ready in inflight_ref_to_task:
task_id = inflight_ref_to_task[ready]

# Mark the entire task associated with the result as done.
task = inflight_tasks[task_id]
task.set_done()

if isinstance(task, SingleOutputPartitionTask):
del inflight_ref_to_task[ready]
elif isinstance(task, MultiOutputPartitionTask):
for partition in task.partitions():
del inflight_ref_to_task[partition]

pbar.mark_task_done(task)
del inflight_tasks[task_id]

except StopIteration as e:
place_in_queue(e)
self._place_in_queue(result_uuid, e)

# Ensure that all Exceptions are correctly propagated to the consumer before reraising to kill thread
except Exception as e:
place_in_queue(e)
self._place_in_queue(result_uuid, e)
pbar.close()
raise

Expand Down

0 comments on commit 64e35f8

Please sign in to comment.