diff --git a/tests/infra/base_tester.py b/tests/infra/base_tester.py index cd4f39e..97024c7 100644 --- a/tests/infra/base_tester.py +++ b/tests/infra/base_tester.py @@ -17,7 +17,7 @@ compare_pcc, ) from .device_runner import DeviceRunner -from .utils import Tensor +from .types import Tensor class BaseTester(ABC): diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py index cb3c88f..ef85a00 100644 --- a/tests/infra/comparison.py +++ b/tests/infra/comparison.py @@ -10,7 +10,7 @@ import jax.numpy as jnp from .device_runner import run_on_cpu -from .utils import Tensor +from .types import Tensor @dataclass diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index 8dd581b..3575094 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -62,15 +62,6 @@ def __init__(self) -> None: self._plugin_path = plugin_path self._initialize_backend() - @staticmethod - def get_instance() -> DeviceConnector: - """ - Factory method returning singleton connector instance. - - Use this method instead of constructor. - """ - return DeviceConnector() - def is_initialized(self) -> bool: """Checks if connector is already initialized.""" if hasattr(self, "_initialized") and self._initialized == True: @@ -119,3 +110,8 @@ def _initialize_backend(self) -> None: jax.config.update("jax_platforms", self._supported_devices_str()) self._initialized = True + + +# `DeviceConnector._initialize_backend` must be executed before anything jax related is +# called. By providing this global instance, that is secured. +device_connector = DeviceConnector() diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 031bdd3..a246f11 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -6,8 +6,8 @@ import jax -from .device_connector import DeviceConnector, DeviceType -from .utils import Tensor, Workload +from .device_connector import DeviceType, device_connector +from .types import Tensor, Workload class DeviceRunner: @@ -64,9 +64,7 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor: """Runs `workload` on device identified by `device_type`.""" device_workload = DeviceRunner._put_on_device(device_type, workload) - - connector = DeviceConnector().get_instance() - device = connector.connect_device(device_type) + device = device_connector.connect_device(device_type) with jax.default_device(device): return device_workload.execute() @@ -74,8 +72,7 @@ def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor: @staticmethod def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload: """Puts `workload` on device and returns it.""" - connector = DeviceConnector().get_instance() - device = connector.connect_device(device_type) + device = device_connector.connect_device(device_type) return DeviceRunner._safely_put_workload_on_device(workload, device) @staticmethod @@ -83,8 +80,7 @@ def _put_tensors_on_device( device_type: DeviceType, tensors: Sequence[Tensor] ) -> Sequence[Tensor]: """Puts `tensors` on device identified by `device_type`.""" - connector = DeviceConnector().get_instance() - device = connector.connect_device(device_type) + device = device_connector.connect_device(device_type) return [jax.device_put(t, device) for t in tensors] @staticmethod diff --git a/tests/infra/graph_tester.py b/tests/infra/graph_tester.py index 4f63700..688de49 100644 --- a/tests/infra/graph_tester.py +++ b/tests/infra/graph_tester.py @@ -8,7 +8,7 @@ from .comparison import ComparisonConfig from .op_tester import OpTester -from .utils import Tensor, Workload +from .types import Tensor, Workload class GraphTester(OpTester): diff --git a/tests/infra/model_tester.py b/tests/infra/model_tester.py index 7faf601..07407e9 100644 --- a/tests/infra/model_tester.py +++ b/tests/infra/model_tester.py @@ -14,7 +14,7 @@ from .base_tester import BaseTester from .comparison import ComparisonConfig from .device_runner import DeviceRunner -from .utils import Model, Workload +from .types import Model, Workload class RunMode(Enum): diff --git a/tests/infra/op_tester.py b/tests/infra/op_tester.py index 7122d17..46f100c 100644 --- a/tests/infra/op_tester.py +++ b/tests/infra/op_tester.py @@ -9,7 +9,8 @@ from .base_tester import BaseTester from .comparison import ComparisonConfig from .device_runner import DeviceRunner -from .utils import Tensor, Workload, random_tensor +from .types import Tensor, Workload +from .utils import random_tensor class OpTester(BaseTester): diff --git a/tests/infra/types.py b/tests/infra/types.py new file mode 100644 index 0000000..b8f1e2e --- /dev/null +++ b/tests/infra/types.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Mapping, Optional, Sequence, Union + +import jax +from flax import linen, nnx + +# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc. +Tensor = Union[jax.Array] + +# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc. +# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some +# huggingface models. +Model = Union[nnx.Module, linen.Module] + + +@dataclass +class Workload: + executable: Callable + args: Sequence[Any] + kwargs: Optional[Mapping[str, Any]] = None + + def __post_init__(self): + # If kwargs is None, initialize it to an empty dictionary. + if self.kwargs is None: + self.kwargs = {} + + def execute(self) -> Any: + return self.executable(*self.args, **self.kwargs) + + +class Framework(Enum): + JAX = "jax" + TORCH = "torch" + NUMPY = "numpy" diff --git a/tests/infra/utils.py b/tests/infra/utils.py index 0cfb44e..2738794 100644 --- a/tests/infra/utils.py +++ b/tests/infra/utils.py @@ -2,44 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Mapping, Optional, Sequence, Union - import jax import jax.numpy as jnp -from flax import linen, nnx from jax import export - -@dataclass -class Workload: - executable: Callable - args: Sequence[Any] - kwargs: Optional[Mapping[str, Any]] = None - - def __post_init__(self): - # If kwargs is None, initialize it to an empty dictionary. - if self.kwargs is None: - self.kwargs = {} - - def execute(self) -> Any: - return self.executable(*self.args, **self.kwargs) - - -class Framework(Enum): - JAX = "jax" - TORCH = "torch" - NUMPY = "numpy" - - -# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc. -Tensor = Union[jax.Array] - -# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc. -# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some -# huggingface models. -Model = Union[nnx.Module, linen.Module] +from .device_runner import run_on_cpu +from .types import Framework, Tensor, Workload def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX): @@ -50,6 +18,7 @@ def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX): raise ValueError(f"Unsupported framework: {framework.value}.") +@run_on_cpu def random_tensor( shape: tuple, dtype: str = "float32",