Skip to content

Commit

Permalink
Provided one global DeviceConnector instance. Moved typedefs from uti…
Browse files Browse the repository at this point in the history
…ls to a new file. (#131)

Also moved `Workload` to a new file.
  • Loading branch information
kmitrovicTT authored Dec 25, 2024
1 parent 3f849d6 commit afa672e
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 61 deletions.
6 changes: 2 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax
import jax._src.xla_bridge as xb
import pytest
from infra.device_connector import DeviceConnector
from infra.device_connector import device_connector


def initialize():
Expand All @@ -25,7 +25,5 @@ def initialize():
def setup_session():
# Added to prevent `PJRT_Api already exists for device type tt` error.
# Will be removed completely soon.
connector = DeviceConnector.get_instance()

if not connector.is_initialized():
if not device_connector.is_initialized():
initialize()
2 changes: 1 addition & 1 deletion tests/infra/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
compare_pcc,
)
from .device_runner import DeviceRunner
from .utils import Tensor
from .types import Tensor


class BaseTester(ABC):
Expand Down
2 changes: 1 addition & 1 deletion tests/infra/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.numpy as jnp

from .device_runner import run_on_cpu
from .utils import Tensor
from .types import Tensor


@dataclass
Expand Down
14 changes: 5 additions & 9 deletions tests/infra/device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,6 @@ def __init__(self) -> None:
self._plugin_path = plugin_path
self._initialize_backend()

@staticmethod
def get_instance() -> DeviceConnector:
"""
Factory method returning singleton connector instance.
Use this method instead of constructor.
"""
return DeviceConnector()

def is_initialized(self) -> bool:
"""Checks if connector is already initialized."""
if hasattr(self, "_initialized") and self._initialized == True:
Expand Down Expand Up @@ -119,3 +110,8 @@ def _initialize_backend(self) -> None:
jax.config.update("jax_platforms", self._supported_devices_str())

self._initialized = True


# `DeviceConnector._initialize_backend` must be executed before anything jax related is
# called. By providing this global instance, that is secured.
device_connector = DeviceConnector()
15 changes: 6 additions & 9 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import jax

from .device_connector import DeviceConnector, DeviceType
from .utils import Tensor, Workload
from .device_connector import DeviceType, device_connector
from .types import Tensor
from .workload import Workload


class DeviceRunner:
Expand Down Expand Up @@ -64,27 +65,23 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(device_type, workload)

connector = DeviceConnector().get_instance()
device = connector.connect_device(device_type)
device = device_connector.connect_device(device_type)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload:
"""Puts `workload` on device and returns it."""
connector = DeviceConnector().get_instance()
device = connector.connect_device(device_type)
device = device_connector.connect_device(device_type)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
def _put_tensors_on_device(
device_type: DeviceType, tensors: Sequence[Tensor]
) -> Sequence[Tensor]:
"""Puts `tensors` on device identified by `device_type`."""
connector = DeviceConnector().get_instance()
device = connector.connect_device(device_type)
device = device_connector.connect_device(device_type)
return [jax.device_put(t, device) for t in tensors]

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion tests/infra/graph_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from .comparison import ComparisonConfig
from .op_tester import OpTester
from .utils import Tensor, Workload
from .types import Tensor
from .workload import Workload


class GraphTester(OpTester):
Expand Down
3 changes: 2 additions & 1 deletion tests/infra/model_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .utils import Model, Workload
from .types import Model
from .workload import Workload


class RunMode(Enum):
Expand Down
4 changes: 3 additions & 1 deletion tests/infra/op_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .utils import Tensor, Workload, random_tensor
from .types import Tensor
from .utils import random_tensor
from .workload import Workload


class OpTester(BaseTester):
Expand Down
23 changes: 23 additions & 0 deletions tests/infra/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from typing import Union

import jax
from flax import linen, nnx

# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc.
Tensor = Union[jax.Array]

# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc.
# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some
# huggingface models.
Model = Union[nnx.Module, linen.Module]


class Framework(Enum):
JAX = "jax"
TORCH = "torch"
NUMPY = "numpy"
38 changes: 4 additions & 34 deletions tests/infra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Mapping, Optional, Sequence, Union

import jax
import jax.numpy as jnp
from flax import linen, nnx
from jax import export


@dataclass
class Workload:
executable: Callable
args: Sequence[Any]
kwargs: Optional[Mapping[str, Any]] = None

def __post_init__(self):
# If kwargs is None, initialize it to an empty dictionary.
if self.kwargs is None:
self.kwargs = {}

def execute(self) -> Any:
return self.executable(*self.args, **self.kwargs)


class Framework(Enum):
JAX = "jax"
TORCH = "torch"
NUMPY = "numpy"


# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc.
Tensor = Union[jax.Array]

# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc.
# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some
# huggingface models.
Model = Union[nnx.Module, linen.Module]
from .device_runner import run_on_cpu
from .types import Framework, Tensor
from .workload import Workload


def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
Expand All @@ -50,6 +19,7 @@ def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
raise ValueError(f"Unsupported framework: {framework.value}.")


@run_on_cpu
def random_tensor(
shape: tuple,
dtype: str = "float32",
Expand Down
26 changes: 26 additions & 0 deletions tests/infra/workload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Any, Callable, Mapping, Optional, Sequence


@dataclass
class Workload:
"""
Convenience dataclass storing a callable and its positional and keyword arguments.
"""

executable: Callable
args: Sequence[Any]
kwargs: Optional[Mapping[str, Any]] = None

def __post_init__(self):
# If kwargs is None, initialize it to an empty dictionary.
if self.kwargs is None:
self.kwargs = {}

def execute(self) -> Any:
"""Calls callable passing stored args and kwargs directly."""
return self.executable(*self.args, **self.kwargs)

0 comments on commit afa672e

Please sign in to comment.