diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index fae345eb2e..6401af3dc7 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -6,7 +6,7 @@ import uuid from dataclasses import dataclass from datetime import datetime -from queue import Queue +from queue import Full, Queue from typing import TYPE_CHECKING, Any, Generator, Iterable, Iterator import pyarrow as pa @@ -466,12 +466,24 @@ def _run_plan( f"profile_RayRunner.run()_" f"{datetime.replace(datetime.now(), second=0, microsecond=0).isoformat()[:-3]}.json" ) + + def is_active(): + return self.active_by_df.get(result_uuid, False) + + def place_in_queue(item): + while is_active(): + try: + self.results_by_df[result_uuid].put(item, timeout=0.1) + break + except Full: + pass + with profiler(profile_filename): try: next_step = next(tasks) - while self.active_by_df.get(result_uuid, False): # Loop: Dispatch -> await. - while self.active_by_df.get(result_uuid, False): # Loop: Dispatch (get tasks -> batch dispatch). + while is_active(): # Loop: Dispatch -> await. + while is_active(): # Loop: Dispatch (get tasks -> batch dispatch). tasks_to_dispatch: list[PartitionTask] = [] cores: int = next(num_cpus_provider) - self.reserved_cores @@ -480,14 +492,14 @@ def _run_plan( dispatches_allowed = min(cores, dispatches_allowed) # Loop: Get a batch of tasks. - while len(tasks_to_dispatch) < dispatches_allowed: + while len(tasks_to_dispatch) < dispatches_allowed and is_active(): if next_step is None: # Blocked on already dispatched tasks; await some tasks. break elif isinstance(next_step, MaterializedResult): # A final result. - self.results_by_df[result_uuid].put(next_step) + place_in_queue(next_step) next_step = next(tasks) # next_step is a task. @@ -512,6 +524,10 @@ def _run_plan( (datetime.now() - start).total_seconds(), len(tasks_to_dispatch), ) + + if not is_active(): + break + for task in tasks_to_dispatch: results = _build_partitions(task) logger.debug("%s -> %s", task, results) @@ -530,6 +546,10 @@ def _run_plan( dispatch = datetime.now() completed_task_ids = [] for wait_for in ("next_one", "next_batch"): + + if not is_active(): + break + if wait_for == "next_one": num_returns = 1 timeout = None @@ -570,11 +590,11 @@ def _run_plan( next_step = next(tasks) except StopIteration as e: - self.results_by_df[result_uuid].put(e) + place_in_queue(e) # Ensure that all Exceptions are correctly propagated to the consumer before reraising to kill thread except Exception as e: - self.results_by_df[result_uuid].put(e) + place_in_queue(e) pbar.close() raise diff --git a/tests/ray/runner.py b/tests/ray/runner.py index 16026db235..9b1e306762 100644 --- a/tests/ray/runner.py +++ b/tests/ray/runner.py @@ -46,3 +46,15 @@ def test_active_plan_multiple_iter_partitions(): del iter2 assert len(runner.active_plans()) == 0 + + +@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") +def test_active_plan_with_show_and_write_parquet(tmpdir): + df = daft.read_parquet("tests/assets/parquet-data/mvp.parquet") + df = df.into_partitions(8) + df = df.join(df, on="a") + df.show() + runner = get_context().runner() + assert len(runner.active_plans()) == 0 + df.write_parquet(tmpdir.dirname) + assert len(runner.active_plans()) == 0