-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Provided layer which uses jax to connect to various devices - Provided layer which uses jax to run payload on various devices - Provided infra which makes it easy to unit test either simple ops, or graphs or even full models Fixes #80, #93.
- Loading branch information
1 parent
4eb910d
commit 16d97ce
Showing
26 changed files
with
1,347 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,4 @@ pre-commit | |
lit | ||
pybind11 | ||
pytest | ||
transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Exposes only what is really needed to write tests, nothing else. | ||
from .comparison import ComparisonConfig | ||
from .graph_tester import run_graph_test, run_graph_test_with_random_inputs | ||
from .model_tester import ModelTester, RunMode | ||
from .op_tester import run_op_test, run_op_test_with_random_inputs | ||
from .utils import random_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
from typing import Callable, Sequence | ||
|
||
import jax | ||
|
||
from .comparison import ( | ||
ComparisonConfig, | ||
compare_allclose, | ||
compare_atol, | ||
compare_equal, | ||
compare_pcc, | ||
) | ||
from .device_runner import DeviceRunner | ||
from .utils import Tensor | ||
|
||
|
||
class BaseTester(ABC): | ||
""" | ||
Abstract base class all testers must inherit. | ||
Provides just a couple of common methods. | ||
""" | ||
|
||
def __init__( | ||
self, comparison_config: ComparisonConfig = ComparisonConfig() | ||
) -> None: | ||
self._comparison_config = comparison_config | ||
|
||
@staticmethod | ||
def _compile(executable: Callable) -> Callable: | ||
"""Sets up `executable` for just-in-time compile.""" | ||
return jax.jit(executable) | ||
|
||
def _compare( | ||
self, | ||
device_out: Tensor, | ||
golden_out: Tensor, | ||
) -> None: | ||
device_output, golden_output = DeviceRunner.put_tensors_on_cpu( | ||
device_out, golden_out | ||
) | ||
device_output, golden_output = self._match_data_types( | ||
device_output, golden_output | ||
) | ||
|
||
if self._comparison_config.equal.enabled: | ||
compare_equal(device_output, golden_output) | ||
if self._comparison_config.atol.enabled: | ||
compare_atol(device_output, golden_output, self._comparison_config.atol) | ||
if self._comparison_config.pcc.enabled: | ||
compare_pcc(device_output, golden_output, self._comparison_config.pcc) | ||
if self._comparison_config.allclose.enabled: | ||
compare_allclose( | ||
device_output, golden_output, self._comparison_config.allclose | ||
) | ||
|
||
def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]: | ||
""" | ||
Casts all tensors to float32 if not already in that format. | ||
Tensors need to be in same data format in order to compare them. | ||
""" | ||
return [ | ||
tensor.astype("float32") if tensor.dtype.str != "float32" else tensor | ||
for tensor in tensors | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from .device_runner import run_on_cpu | ||
from .utils import Tensor | ||
|
||
|
||
@dataclass | ||
class ConfigBase: | ||
enabled: bool = True | ||
|
||
def enable(self) -> None: | ||
self.enabled = True | ||
|
||
def disable(self) -> None: | ||
self.enabled = False | ||
|
||
|
||
@dataclass | ||
class EqualConfig(ConfigBase): | ||
pass | ||
|
||
|
||
@dataclass | ||
class AtolConfig(ConfigBase): | ||
required_atol: float = 1e-1 | ||
|
||
|
||
@dataclass | ||
class PccConfig(ConfigBase): | ||
required_pcc: float = 0.99 | ||
|
||
|
||
@dataclass | ||
class AllcloseConfig(ConfigBase): | ||
rtol: float = 1e-2 | ||
atol: float = 1e-2 | ||
|
||
|
||
@dataclass | ||
class ComparisonConfig: | ||
equal: EqualConfig = EqualConfig(False) | ||
atol: AtolConfig = AtolConfig() | ||
pcc: PccConfig = PccConfig() | ||
allclose: AllcloseConfig = AllcloseConfig() | ||
|
||
def enable_all(self) -> None: | ||
self.equal.enable() | ||
self.atol.enable() | ||
self.allclose.enable() | ||
self.pcc.enable() | ||
|
||
def disable_all(self) -> None: | ||
self.equal.disable() | ||
self.atol.disable() | ||
self.allclose.disable() | ||
self.pcc.disable() | ||
|
||
|
||
# TODO functions below rely on jax functions, should be generalized for all supported | ||
# frameworks in the future. | ||
|
||
|
||
@run_on_cpu | ||
def compare_equal(device_output: Tensor, golden_output: Tensor) -> None: | ||
assert isinstance(device_output, jax.Array) and isinstance( | ||
golden_output, jax.Array | ||
), f"Currently only jax.Array is supported" | ||
|
||
eq = (device_output == golden_output).all() | ||
|
||
assert eq, f"Equal comparison failed" | ||
|
||
|
||
@run_on_cpu | ||
def compare_atol( | ||
device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig | ||
) -> None: | ||
assert isinstance(device_output, jax.Array) and isinstance( | ||
golden_output, jax.Array | ||
), f"Currently only jax.Array is supported {type(device_output)}, {type(golden_output)}" | ||
|
||
atol = jnp.max(jnp.abs(device_output - golden_output)) | ||
|
||
assert ( | ||
atol <= atol_config.required_atol | ||
), f"Atol comparison failed. Calculated atol={atol}" | ||
|
||
|
||
@run_on_cpu | ||
def compare_pcc( | ||
device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig | ||
) -> None: | ||
assert isinstance(device_output, jax.Array) and isinstance( | ||
golden_output, jax.Array | ||
), f"Currently only jax.Array is supported" | ||
|
||
# If tensors are really close, pcc will be nan. Handle that before calculating pcc. | ||
try: | ||
compare_allclose( | ||
device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2) | ||
) | ||
except AssertionError: | ||
pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten()) | ||
pcc = jnp.min(pcc) | ||
|
||
assert ( | ||
pcc >= pcc_config.required_pcc | ||
), f"PCC comparison failed. Calculated pcc={pcc}" | ||
|
||
|
||
@run_on_cpu | ||
def compare_allclose( | ||
device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig | ||
) -> None: | ||
assert isinstance(device_output, jax.Array) and isinstance( | ||
golden_output, jax.Array | ||
), f"Currently only jax.Array is supported" | ||
|
||
allclose = jnp.allclose( | ||
device_output, | ||
golden_output, | ||
rtol=allclose_config.rtol, | ||
atol=allclose_config.atol, | ||
) | ||
|
||
assert allclose, f"Allclose comparison failed." |
Oops, something went wrong.