Skip to content

Commit

Permalink
Improved test infrastructure.
Browse files Browse the repository at this point in the history
- Provided layer which uses jax to connect to various devices
- Provided layer which uses jax to run payload on various devices
- Provided infra which makes it easy to unit test either simple ops, or graphs or even full models

Fixes #80, #93.
  • Loading branch information
kmitrovicTT committed Dec 18, 2024
1 parent 8af0c4f commit 4d23305
Show file tree
Hide file tree
Showing 25 changed files with 1,315 additions and 18 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ pre-commit
lit
pybind11
pytest
transformers
3 changes: 1 addition & 2 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,7 @@ void DeviceDescription::BindApi(PJRT_Api *api) {
api->PJRT_DeviceDescription_ToString =
+[](PJRT_DeviceDescription_ToString_Args *args) -> PJRT_Error * {
DLOG_F(LOG_DEBUG, "DeviceDescription::PJRT_DeviceDescription_ToString");
auto sv =
DeviceDescription::Unwrap(args->device_description)->user_string();
auto sv = DeviceDescription::Unwrap(args->device_description)->to_string();
args->to_string = sv.data();
args->to_string_size = sv.size();
return nullptr;
Expand Down
12 changes: 7 additions & 5 deletions src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,11 @@ class DeviceDescription {
}

std::string_view kind_string() { return kind_string_; }
std::string_view debug_string() { return debug_string_; }
std::string_view user_string() {
std::string_view debug_string() { return to_string(); }
std::string_view to_string() {
std::stringstream ss;
ss << "TTDevice(id=" << device_id() << ")";
ss << kind_string_ << "(id=" << device_id() << ", arch=" << arch_string_
<< ")";
user_string_ = ss.str();
return user_string_;
}
Expand All @@ -160,8 +161,9 @@ class DeviceDescription {

private:
int client_id_;
std::string kind_string_ = "wormhole";
std::string debug_string_ = "debug_string";
std::string kind_string_ = "TTDevice";
// TODO should not be hardcoded
std::string arch_string_ = "Wormhole";
std::string user_string_ = "";
};

Expand Down
11 changes: 9 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import os

import jax
import jax._src.xla_bridge as xb
import pytest
from infra.device_connector import DeviceConnector


def initialize():
Expand All @@ -21,4 +23,9 @@ def initialize():

@pytest.fixture(scope="session", autouse=True)
def setup_session():
initialize()
# Added to prevent `PJRT_Api already exists for device type tt` error.
# Will be removed completely soon.
connector = DeviceConnector.get_instance()

if not connector.is_initialized():
initialize()
9 changes: 9 additions & 0 deletions tests/infra/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# Exposes only what is really needed to write tests, nothing else.
from .comparison import ComparisonConfig
from .graph_tester import run_graph_test, run_graph_test_with_random_inputs
from .model_tester import ModelTester, RunMode
from .op_tester import run_op_test, run_op_test_with_random_inputs
72 changes: 72 additions & 0 deletions tests/infra/base_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC
from typing import Callable, Sequence

import jax

from .comparison import (
ComparisonConfig,
compare_allclose,
compare_atol,
compare_equal,
compare_pcc,
)
from .device_runner import DeviceRunner
from .utils import Tensor


class BaseTester(ABC):
"""
Abstract base class all testers must inherit.
Provides just a couple of common methods.
"""

def __init__(
self, comparison_config: ComparisonConfig = ComparisonConfig()
) -> None:
self._comparison_config = comparison_config

@staticmethod
def _compile(executable: Callable) -> Callable:
"""Sets up `executable` for just-in-time compile."""
return jax.jit(executable)

def _compare(
self,
device_out: Tensor,
golden_out: Tensor,
) -> None:
device_output, golden_output = DeviceRunner.put_tensors_on_cpu(
device_out, golden_out
)
device_output, golden_output = self._match_data_types(
device_output, golden_output
)

if self._comparison_config.equal.enabled:
compare_equal(device_output, golden_output)
if self._comparison_config.atol.enabled:
compare_atol(device_output, golden_output, self._comparison_config.atol)
if self._comparison_config.pcc.enabled:
compare_pcc(device_output, golden_output, self._comparison_config.pcc)
if self._comparison_config.allclose.enabled:
compare_allclose(
device_output, golden_output, self._comparison_config.allclose
)

def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]:
"""
Casts all tensors to float32 if not already in that format.
Tensors need to be in same data format in order to compare them.
"""
return [
tensor.astype("float32") if tensor.dtype.str != "float32" else tensor
for tensor in tensors
]
130 changes: 130 additions & 0 deletions tests/infra/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from .utils import Tensor


@dataclass
class ConfigBase:
enabled: bool = True

def enable(self) -> None:
self.enabled = True

def disable(self) -> None:
self.enabled = False


@dataclass
class EqualConfig(ConfigBase):
pass


@dataclass
class AtolConfig(ConfigBase):
required_atol: float = 1e-1


@dataclass
class PccConfig(ConfigBase):
required_pcc: float = 0.99


@dataclass
class AllcloseConfig(ConfigBase):
rtol: float = 1e-2
atol: float = 1e-2


@dataclass
class ComparisonConfig:
equal: EqualConfig = EqualConfig(False)
atol: AtolConfig = AtolConfig()
pcc: PccConfig = PccConfig()
allclose: AllcloseConfig = AllcloseConfig()

def enable_all(self) -> None:
self.equal.enable()
self.atol.enable()
self.allclose.enable()
self.pcc.enable()

def disable_all(self) -> None:
self.equal.disable()
self.atol.disable()
self.allclose.disable()
self.pcc.disable()


# TODO functions below rely on jax functions, should be generalized for all supported
# frameworks in the future.


def compare_equal(device_output: Tensor, golden_output: Tensor) -> None:
assert isinstance(device_output, jax.Array) and isinstance(
golden_output, jax.Array
), f"Currently only jax.Array is supported"

eq = (device_output == golden_output).all()

assert eq, f"Equal comparison failed"


def compare_atol(
device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig
) -> None:
assert isinstance(device_output, jax.Array) and isinstance(
golden_output, jax.Array
), f"Currently only jax.Array is supported {type(device_output)}, {type(golden_output)}"

atol = jnp.max(jnp.abs(device_output - golden_output))

assert (
atol <= atol_config.required_atol
), f"Atol comparison failed. Calculated atol={atol}"


def compare_pcc(
device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig
) -> None:
assert isinstance(device_output, jax.Array) and isinstance(
golden_output, jax.Array
), f"Currently only jax.Array is supported"

# If tensors are really close, pcc will be nan. Handle that before calculating pcc.
try:
compare_allclose(
device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2)
)
except AssertionError:
pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten())
pcc = jnp.min(pcc)

assert (
pcc >= pcc_config.required_pcc
), f"PCC comparison failed. Calculated pcc={pcc}"


def compare_allclose(
device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig
) -> None:
assert isinstance(device_output, jax.Array) and isinstance(
golden_output, jax.Array
), f"Currently only jax.Array is supported"

allclose = jnp.allclose(
device_output,
golden_output,
rtol=allclose_config.rtol,
atol=allclose_config.atol,
)

assert allclose, f"Allclose comparison failed."
Loading

0 comments on commit 4d23305

Please sign in to comment.