diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index ea490073ea..b0763b2c25 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -52,6 +52,9 @@ class PartitionTask(Generic[PartitionT]): # This is used when a specific executor (e.g. an Actor pool) must be provisioned and used for the task actor_pool_id: str | None + # Indicates if the PartitionTask is "done" or not + is_done: bool = False + _id: int = field(default_factory=lambda: next(ID_GEN)) def id(self) -> str: @@ -59,14 +62,23 @@ def id(self) -> str: def done(self) -> bool: """Whether the PartitionT result of this task is available.""" - raise NotImplementedError() + return self.is_done + + def set_done(self): + """Sets the PartitionTask as done.""" + assert not self.is_done, "Cannot set PartitionTask as done more than once" + self.is_done = True def cancel(self) -> None: """If possible, cancel the execution of this PartitionTask.""" raise NotImplementedError() def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: - """Set the result of this Task. For use by the Task executor.""" + """Set the result of this Task. For use by the Task executor. + + NOTE: A PartitionTask may contain a `result` without being `.done()`. This is because + results can potentially contain futures which are yet to be completed. + """ raise NotImplementedError def is_empty(self) -> bool: @@ -189,9 +201,6 @@ def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: [partition] = result self._result = partition - def done(self) -> bool: - return self._result is not None - def result(self) -> MaterializedResult[PartitionT]: assert self._result is not None, "Cannot call .result() on a PartitionTask that is not done" return self._result @@ -237,9 +246,6 @@ def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: assert self._results is None, f"Cannot set result twice. Result is already {self._results}" self._results = result - def done(self) -> bool: - return self._results is not None - def cancel(self) -> None: if self._results is not None: for result in self._results: diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 698ef55b1d..4934d481e7 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -510,6 +510,7 @@ def _physical_plan_to_partitions( self._resources.release(resources) next_step.set_result(materialized_results) + next_step.set_done() else: # Submit the task for execution. @@ -572,6 +573,7 @@ def _physical_plan_to_partitions( ) done_task.set_result(materialized_results) + done_task.set_done() if next_step is None: next_step = next(plan) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 55c065735d..364c0c08b5 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -750,6 +750,7 @@ def place_in_queue(item): next_step.set_result( [RayMaterializedResult(partition, accessor, 0) for partition in next_step.inputs] ) + next_step.set_done() next_step = next(tasks) else: @@ -816,6 +817,8 @@ def place_in_queue(item): completed_task_ids.append(task_id) # Mark the entire task associated with the result as done. task = inflight_tasks[task_id] + task.set_done() + if isinstance(task, SingleOutputPartitionTask): del inflight_ref_to_task[ready] elif isinstance(task, MultiOutputPartitionTask): diff --git a/tests/physical_plan/test_physical_plan_buffering.py b/tests/physical_plan/test_physical_plan_buffering.py index ca0e71817c..a5fc2d232c 100644 --- a/tests/physical_plan/test_physical_plan_buffering.py +++ b/tests/physical_plan/test_physical_plan_buffering.py @@ -29,8 +29,11 @@ def test_single_non_buffered_plan(): # Manually "complete" the tasks task1.set_result([result1]) + task1.set_done() task2.set_result([result2]) + task2.set_done() task3.set_result([result3]) + task3.set_done() # Results should be as we expect assert next(plan) == result1 @@ -57,7 +60,9 @@ def test_single_non_buffered_plan_done_while_planning(): # Manually "complete" the tasks task1.set_result([result1]) + task1.set_done() task2.set_result([result2]) + task2.set_done() # On the next iteration, we should receive the result assert next(plan) == result1 @@ -70,6 +75,7 @@ def test_single_non_buffered_plan_done_while_planning(): # Manually "complete" the last task task3.set_result([result3]) + task3.set_done() # Results should be as we expect assert next(plan) == result3 @@ -95,8 +101,10 @@ def test_single_plan_with_buffer_slow_tasks(): # Plan cannot make forward progress until task1 finishes assert next(plan) is None task2.set_result([result2]) + task2.set_done() assert next(plan) is None task1.set_result([result1]) + task1.set_done() # Plan should fill its buffer with new tasks before starting to yield results again task3 = next(plan) @@ -106,6 +114,7 @@ def test_single_plan_with_buffer_slow_tasks(): # Finish the last task task3.set_result([result3]) + task3.set_done() assert next(plan) == result3 with pytest.raises(StopIteration): @@ -126,6 +135,7 @@ def test_single_plan_with_buffer_saturation_fast_tasks(): # Finish up on task 1 (task is "fast" and completes so quickly even before the next plan loop call) task1.set_result([result1]) + task1.set_done() # Plan should fill its buffer completely with new tasks before starting to yield results again task2 = next(plan) @@ -139,7 +149,9 @@ def test_single_plan_with_buffer_saturation_fast_tasks(): # Finish the last task(s) task2.set_result([result2]) + task2.set_done() task3.set_result([result3]) + task3.set_done() assert next(plan) == result2 assert next(plan) == result3