Skip to content

Commit

Permalink
Provided one global DeviceConnector instance. Moved typedefs from uti…
Browse files Browse the repository at this point in the history
…ls to a new file.
  • Loading branch information
kmitrovicTT committed Dec 24, 2024
1 parent 49bba2a commit 7d40787
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 57 deletions.
2 changes: 1 addition & 1 deletion tests/infra/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
compare_pcc,
)
from .device_runner import DeviceRunner
from .utils import Tensor
from .types import Tensor


class BaseTester(ABC):
Expand Down
2 changes: 1 addition & 1 deletion tests/infra/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.numpy as jnp

from .device_runner import run_on_cpu
from .utils import Tensor
from .types import Tensor


@dataclass
Expand Down
14 changes: 5 additions & 9 deletions tests/infra/device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,6 @@ def __init__(self) -> None:
self._plugin_path = plugin_path
self._initialize_backend()

@staticmethod
def get_instance() -> DeviceConnector:
"""
Factory method returning singleton connector instance.
Use this method instead of constructor.
"""
return DeviceConnector()

def is_initialized(self) -> bool:
"""Checks if connector is already initialized."""
if hasattr(self, "_initialized") and self._initialized == True:
Expand Down Expand Up @@ -119,3 +110,8 @@ def _initialize_backend(self) -> None:
jax.config.update("jax_platforms", self._supported_devices_str())

self._initialized = True


# `DeviceConnector._initialize_backend` must be executed before anything jax related is
# called. By providing this global instance, that is secured.
device_connector = DeviceConnector()
14 changes: 5 additions & 9 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import jax

from .device_connector import DeviceConnector, DeviceType
from .utils import Tensor, Workload
from .device_connector import DeviceType, device_connector
from .types import Tensor, Workload


class DeviceRunner:
Expand Down Expand Up @@ -64,27 +64,23 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(device_type, workload)

connector = DeviceConnector().get_instance()
device = connector.connect_device(device_type)
device = device_connector.connect_device(device_type)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload:
"""Puts `workload` on device and returns it."""
connector = DeviceConnector().get_instance()
device = connector.connect_device(device_type)
device = device_connector.connect_device(device_type)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
def _put_tensors_on_device(
device_type: DeviceType, tensors: Sequence[Tensor]
) -> Sequence[Tensor]:
"""Puts `tensors` on device identified by `device_type`."""
connector = DeviceConnector().get_instance()
device = connector.connect_device(device_type)
device = device_connector.connect_device(device_type)
return [jax.device_put(t, device) for t in tensors]

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/infra/graph_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .comparison import ComparisonConfig
from .op_tester import OpTester
from .utils import Tensor, Workload
from .types import Tensor, Workload


class GraphTester(OpTester):
Expand Down
2 changes: 1 addition & 1 deletion tests/infra/model_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .utils import Model, Workload
from .types import Model, Workload


class RunMode(Enum):
Expand Down
3 changes: 2 additions & 1 deletion tests/infra/op_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .utils import Tensor, Workload, random_tensor
from .types import Tensor, Workload
from .utils import random_tensor


class OpTester(BaseTester):
Expand Down
39 changes: 39 additions & 0 deletions tests/infra/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Mapping, Optional, Sequence, Union

import jax
from flax import linen, nnx

# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc.
Tensor = Union[jax.Array]

# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc.
# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some
# huggingface models.
Model = Union[nnx.Module, linen.Module]


@dataclass
class Workload:
executable: Callable
args: Sequence[Any]
kwargs: Optional[Mapping[str, Any]] = None

def __post_init__(self):
# If kwargs is None, initialize it to an empty dictionary.
if self.kwargs is None:
self.kwargs = {}

def execute(self) -> Any:
return self.executable(*self.args, **self.kwargs)


class Framework(Enum):
JAX = "jax"
TORCH = "torch"
NUMPY = "numpy"
37 changes: 3 additions & 34 deletions tests/infra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Mapping, Optional, Sequence, Union

import jax
import jax.numpy as jnp
from flax import linen, nnx
from jax import export


@dataclass
class Workload:
executable: Callable
args: Sequence[Any]
kwargs: Optional[Mapping[str, Any]] = None

def __post_init__(self):
# If kwargs is None, initialize it to an empty dictionary.
if self.kwargs is None:
self.kwargs = {}

def execute(self) -> Any:
return self.executable(*self.args, **self.kwargs)


class Framework(Enum):
JAX = "jax"
TORCH = "torch"
NUMPY = "numpy"


# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc.
Tensor = Union[jax.Array]

# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc.
# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some
# huggingface models.
Model = Union[nnx.Module, linen.Module]
from .device_runner import run_on_cpu
from .types import Framework, Tensor, Workload


def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
Expand All @@ -50,6 +18,7 @@ def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
raise ValueError(f"Unsupported framework: {framework.value}.")


@run_on_cpu
def random_tensor(
shape: tuple,
dtype: str = "float32",
Expand Down

0 comments on commit 7d40787

Please sign in to comment.