From 2ae44c1138f823ee2ccc4aafd9a7c36914bc9bf9 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 9 Oct 2024 14:24:33 -0700 Subject: [PATCH] make test cleaner --- tests/actor_pool/test_actor_context.py | 33 +++++++++++--------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/tests/actor_pool/test_actor_context.py b/tests/actor_pool/test_actor_context.py index eff59d3b7d..ff6495fed7 100644 --- a/tests/actor_pool/test_actor_context.py +++ b/tests/actor_pool/test_actor_context.py @@ -35,32 +35,27 @@ def enable_actor_pool(): @contextmanager def reset_runner_with_gpus(num_gpus, monkeypatch): """If current runner does not have enough GPUs, create a new runner with mocked GPU resources""" - - using_ray = get_context().runner_config.name == "ray" - insufficient_gpus = len(cuda_visible_devices()) < num_gpus - - original_runner = daft.context.get_context()._runner - - try: - if insufficient_gpus: - if using_ray: - if ray.is_initialized(): - ray.shutdown() + if len(cuda_visible_devices()) < num_gpus: + if get_context().runner_config.name == "ray": + try: + ray.shutdown() ray.init(num_gpus=num_gpus) - else: + yield + finally: + ray.shutdown() + ray.init() + else: + try: monkeypatch.setenv("CUDA_VISIBLE_DEVICES", ",".join(str(i) for i in range(num_gpus))) # Need to reset runner to recompute resources original_runner = daft.context.get_context()._runner daft.context.get_context()._runner = None - yield - finally: - if insufficient_gpus: - if using_ray: - ray.shutdown() - ray.init() - else: + yield + finally: daft.context.get_context()._runner = original_runner + else: + yield @pytest.mark.parametrize("concurrency", [1, 2, 3])