Skip to content

Commit

Permalink
[BUG] Separate PartitionTask done from results (#3155)
Browse files Browse the repository at this point in the history
This PR marks `PartitionTasks` as done only after they have been
explicitly marked as done by the runner.

Previously, we used the existence of the `.results` on a PartitionTask
to determine whether or not it is done. However, this is not quite
correct in the case of the RayRunner, which will attach a result
containing a Ray ObjectRef, which is a future. This future may not (and
is likely not) be completed yet at the time of PartitionTask creation.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Nov 1, 2024
1 parent 8817a08 commit c435e92
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
22 changes: 14 additions & 8 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,33 @@ class PartitionTask(Generic[PartitionT]):
# This is used when a specific executor (e.g. an Actor pool) must be provisioned and used for the task
actor_pool_id: str | None

# Indicates if the PartitionTask is "done" or not
is_done: bool = False

_id: int = field(default_factory=lambda: next(ID_GEN))

def id(self) -> str:
return f"{self.__class__.__name__}_{self._id}"

def done(self) -> bool:
"""Whether the PartitionT result of this task is available."""
raise NotImplementedError()
return self.is_done

def set_done(self):
"""Sets the PartitionTask as done."""
assert not self.is_done, "Cannot set PartitionTask as done more than once"
self.is_done = True

def cancel(self) -> None:
"""If possible, cancel the execution of this PartitionTask."""
raise NotImplementedError()

def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None:
"""Set the result of this Task. For use by the Task executor."""
"""Set the result of this Task. For use by the Task executor.
NOTE: A PartitionTask may contain a `result` without being `.done()`. This is because
results can potentially contain futures which are yet to be completed.
"""
raise NotImplementedError

def is_empty(self) -> bool:
Expand Down Expand Up @@ -189,9 +201,6 @@ def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None:
[partition] = result
self._result = partition

def done(self) -> bool:
return self._result is not None

def result(self) -> MaterializedResult[PartitionT]:
assert self._result is not None, "Cannot call .result() on a PartitionTask that is not done"
return self._result
Expand Down Expand Up @@ -237,9 +246,6 @@ def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None:
assert self._results is None, f"Cannot set result twice. Result is already {self._results}"
self._results = result

def done(self) -> bool:
return self._results is not None

def cancel(self) -> None:
if self._results is not None:
for result in self._results:
Expand Down
2 changes: 2 additions & 0 deletions daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def _physical_plan_to_partitions(
self._resources.release(resources)

next_step.set_result(materialized_results)
next_step.set_done()

else:
# Submit the task for execution.
Expand Down Expand Up @@ -572,6 +573,7 @@ def _physical_plan_to_partitions(
)

done_task.set_result(materialized_results)
done_task.set_done()

if next_step is None:
next_step = next(plan)
Expand Down
3 changes: 3 additions & 0 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def place_in_queue(item):
next_step.set_result(
[RayMaterializedResult(partition, accessor, 0) for partition in next_step.inputs]
)
next_step.set_done()
next_step = next(tasks)

else:
Expand Down Expand Up @@ -816,6 +817,8 @@ def place_in_queue(item):
completed_task_ids.append(task_id)
# Mark the entire task associated with the result as done.
task = inflight_tasks[task_id]
task.set_done()

if isinstance(task, SingleOutputPartitionTask):
del inflight_ref_to_task[ready]
elif isinstance(task, MultiOutputPartitionTask):
Expand Down
12 changes: 12 additions & 0 deletions tests/physical_plan/test_physical_plan_buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ def test_single_non_buffered_plan():

# Manually "complete" the tasks
task1.set_result([result1])
task1.set_done()
task2.set_result([result2])
task2.set_done()
task3.set_result([result3])
task3.set_done()

# Results should be as we expect
assert next(plan) == result1
Expand All @@ -57,7 +60,9 @@ def test_single_non_buffered_plan_done_while_planning():

# Manually "complete" the tasks
task1.set_result([result1])
task1.set_done()
task2.set_result([result2])
task2.set_done()

# On the next iteration, we should receive the result
assert next(plan) == result1
Expand All @@ -70,6 +75,7 @@ def test_single_non_buffered_plan_done_while_planning():

# Manually "complete" the last task
task3.set_result([result3])
task3.set_done()

# Results should be as we expect
assert next(plan) == result3
Expand All @@ -95,8 +101,10 @@ def test_single_plan_with_buffer_slow_tasks():
# Plan cannot make forward progress until task1 finishes
assert next(plan) is None
task2.set_result([result2])
task2.set_done()
assert next(plan) is None
task1.set_result([result1])
task1.set_done()

# Plan should fill its buffer with new tasks before starting to yield results again
task3 = next(plan)
Expand All @@ -106,6 +114,7 @@ def test_single_plan_with_buffer_slow_tasks():

# Finish the last task
task3.set_result([result3])
task3.set_done()
assert next(plan) == result3

with pytest.raises(StopIteration):
Expand All @@ -126,6 +135,7 @@ def test_single_plan_with_buffer_saturation_fast_tasks():

# Finish up on task 1 (task is "fast" and completes so quickly even before the next plan loop call)
task1.set_result([result1])
task1.set_done()

# Plan should fill its buffer completely with new tasks before starting to yield results again
task2 = next(plan)
Expand All @@ -139,7 +149,9 @@ def test_single_plan_with_buffer_saturation_fast_tasks():

# Finish the last task(s)
task2.set_result([result2])
task2.set_done()
task3.set_result([result3])
task3.set_done()
assert next(plan) == result2
assert next(plan) == result3

Expand Down

0 comments on commit c435e92

Please sign in to comment.