Skip to content

Commit

Permalink
make test cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Oct 9, 2024
1 parent 9a89ba8 commit 2ae44c1
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions tests/actor_pool/test_actor_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,27 @@ def enable_actor_pool():
@contextmanager
def reset_runner_with_gpus(num_gpus, monkeypatch):
"""If current runner does not have enough GPUs, create a new runner with mocked GPU resources"""

using_ray = get_context().runner_config.name == "ray"
insufficient_gpus = len(cuda_visible_devices()) < num_gpus

original_runner = daft.context.get_context()._runner

try:
if insufficient_gpus:
if using_ray:
if ray.is_initialized():
ray.shutdown()
if len(cuda_visible_devices()) < num_gpus:
if get_context().runner_config.name == "ray":
try:
ray.shutdown()
ray.init(num_gpus=num_gpus)
else:
yield
finally:
ray.shutdown()
ray.init()
else:
try:
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", ",".join(str(i) for i in range(num_gpus)))

# Need to reset runner to recompute resources
original_runner = daft.context.get_context()._runner
daft.context.get_context()._runner = None
yield
finally:
if insufficient_gpus:
if using_ray:
ray.shutdown()
ray.init()
else:
yield
finally:
daft.context.get_context()._runner = original_runner
else:
yield


@pytest.mark.parametrize("concurrency", [1, 2, 3])
Expand Down

0 comments on commit 2ae44c1

Please sign in to comment.