Skip to content

Commit

Permalink
add tests for ray runner clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Nov 13, 2023
1 parent afb4611 commit d70c3dd
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
15 changes: 12 additions & 3 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def next(self, result_uuid: str) -> ray.ObjectRef | StopIteration:

return result

def run_plan(
def start_plan(
self,
plan_scheduler: PhysicalPlanScheduler,
psets: dict[str, ray.ObjectRef],
Expand All @@ -439,6 +439,9 @@ def run_plan(
t.start()
self.threads_by_df[result_uuid] = t

def active_plans(self) -> list[str]:
return [r_uuid for r_uuid, is_active in self.active_by_df.items() if is_active]

Check warning on line 443 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L443

Added line #L443 was not covered by tests

def stop_plan(self, result_uuid: str) -> None:
if result_uuid in self.active_by_df:
# Mark df as non-active
Expand Down Expand Up @@ -631,6 +634,12 @@ def __init__(
max_task_backlog=max_task_backlog,
)

def active_plans(self) -> list[str]:
if isinstance(self.ray_context, ray.client_builder.ClientContext):
return ray.get(self.scheduler_actor.active_plans.remote())

Check warning on line 639 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L638-L639

Added lines #L638 - L639 were not covered by tests
else:
return self.scheduler.active_plans()

Check warning on line 641 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L641

Added line #L641 was not covered by tests

def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[ray.ObjectRef]:
# Optimize the logical plan.
builder = builder.optimize()
Expand All @@ -646,7 +655,7 @@ def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None
result_uuid = str(uuid.uuid4())
if isinstance(self.ray_context, ray.client_builder.ClientContext):
ray.get(
self.scheduler_actor.run_plan.remote(
self.scheduler_actor.start_plan.remote(
plan_scheduler=plan_scheduler,
psets=psets,
result_uuid=result_uuid,
Expand All @@ -655,7 +664,7 @@ def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None
)

else:
self.scheduler.run_plan(
self.scheduler.start_plan(
plan_scheduler=plan_scheduler,
psets=psets,
result_uuid=result_uuid,
Expand Down
48 changes: 48 additions & 0 deletions tests/ray/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

import pytest

import daft
from daft.context import get_context


@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner")
def test_active_plan_clean_up_df_show():
path = "tests/assets/parquet-data/mvp.parquet"
df = daft.read_parquet([path, path])
df.show()
runner = get_context().runner()
assert len(runner.active_plans()) == 0


@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner")
def test_active_plan_single_iter_partitions():
path = "tests/assets/parquet-data/mvp.parquet"
df = daft.read_parquet([path, path])
iter = df.iter_partitions()
next(iter)
runner = get_context().runner()
assert len(runner.active_plans()) == 1
del iter
assert len(runner.active_plans()) == 0


@pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner")
def test_active_plan_multiple_iter_partitions():
path = "tests/assets/parquet-data/mvp.parquet"
df = daft.read_parquet([path, path])
iter = df.iter_partitions()
next(iter)
runner = get_context().runner()
assert len(runner.active_plans()) == 1

df2 = daft.read_parquet([path, path])
iter2 = df2.iter_partitions()
next(iter2)
assert len(runner.active_plans()) == 2

del iter
assert len(runner.active_plans()) == 1

del iter2
assert len(runner.active_plans()) == 0

0 comments on commit d70c3dd

Please sign in to comment.