Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Oct 5, 2024
1 parent cf871d9 commit 2e7f0b2
Showing 1 changed file with 13 additions and 77 deletions.
90 changes: 13 additions & 77 deletions tests/expressions/test_stateful_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __call__(self, data):

assert actor_context.rank == self._rank

import time

time.sleep(0.1)

return [self._rank] * len(data)

GetRank = GetRank.with_concurrency(concurrency)
Expand Down Expand Up @@ -69,6 +73,10 @@ def __call__(self, data):
actor_context = get_actor_context()
assert actor_context.resource_request == self.resource_request

import time

time.sleep(0.1)

return data

TestResourceRequest = TestResourceRequest.with_concurrency(concurrency)
Expand Down Expand Up @@ -97,6 +105,10 @@ def __init__(self):
def __call__(self, data):
assert os.environ["CUDA_VISIBLE_DEVICES"] == self.cuda_visible_devices

import time

time.sleep(0.1)

return [self.cuda_visible_devices] * len(data)

GetCudaVisibleDevices = GetCudaVisibleDevices.with_concurrency(concurrency)
Expand All @@ -110,81 +122,5 @@ def __call__(self, data):
unique_visible_devices = set(result["x"])
assert len(unique_visible_devices) == concurrency

all_devices = (",".join(cuda_visible_devices())).split(",")
all_devices = (",".join(unique_visible_devices)).split(",")
assert len(all_devices) == concurrency * num_gpus


@pytest.mark.skipif(len(cuda_visible_devices()) < 2, reason="Not enough GPUs available")
def test_stateful_udf_cuda_env_var_filtered():
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

@udf(return_dtype=DataType.string(), num_gpus=1)
class GetCudaVisibleDevices:
def __init__(self):
self.cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]

def __call__(self, data):
assert os.environ["CUDA_VISIBLE_DEVICES"] == self.cuda_visible_devices

return [self.cuda_visible_devices] * len(data)

GetCudaVisibleDevices = GetCudaVisibleDevices.with_concurrency(1)

df = daft.from_pydict({"x": [1]})
df = df.select(GetCudaVisibleDevices(df["x"]))

result = df.to_pydict()
assert result == {"x": ["1"]}


@pytest.mark.skipif(len(cuda_visible_devices()) < 3, reason="Not enough GPUs available")
def test_stateful_udf_cuda_env_var_filtered_multiple_gpus():
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

@udf(return_dtype=DataType.string(), num_gpus=1)
class GetCudaVisibleDevices:
def __init__(self):
self.cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]

def __call__(self, data):
assert os.environ["CUDA_VISIBLE_DEVICES"] == self.cuda_visible_devices

return [self.cuda_visible_devices] * len(data)

GetCudaVisibleDevices = GetCudaVisibleDevices.with_concurrency(1)

df = daft.from_pydict({"x": [1]})
df = df.select(GetCudaVisibleDevices(df["x"]))

result = df.to_pydict()
assert result == {"x": ["1,2"]} or result == {"x": ["2,1"]}


@pytest.mark.skipif(len(cuda_visible_devices()) < 3, reason="Not enough GPUs available")
def test_stateful_udf_cuda_env_var_filtered_multiple_concurrency():
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

@udf(return_dtype=DataType.string(), num_gpus=1)
class GetCudaVisibleDevices:
def __init__(self):
self.cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]

def __call__(self, data):
assert os.environ["CUDA_VISIBLE_DEVICES"] == self.cuda_visible_devices

return [self.cuda_visible_devices] * len(data)

GetCudaVisibleDevices = GetCudaVisibleDevices.with_concurrency(2)

df = daft.from_pydict({"x": [1, 2]})
df = df.into_partitions(2)
df = df.select(GetCudaVisibleDevices(df["x"])).sort("x")

result = df.to_pydict()
assert result == {"x": ["1", "2"]}

0 comments on commit 2e7f0b2

Please sign in to comment.