-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e27a765
commit 00918dc
Showing
17 changed files
with
618 additions
and
415 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from typing import Callable, Sequence | ||
|
||
from flax import linen, nnx | ||
from transformers.modeling_flax_utils import FlaxPreTrainedModel | ||
|
||
from .base_tester import BaseTester | ||
from .comparison import ComparisonConfig | ||
from .device_runner import DeviceRunner | ||
from .utils import Model, Tensor | ||
|
||
|
||
class TestType(Enum): | ||
INFERENCE = "inference" | ||
TRAINING = "training" | ||
|
||
|
||
class BaseModelTester(BaseTester, ABC): | ||
""" | ||
Abstract base class all model testers must inherit. | ||
Derived classes must provide implementations of: | ||
``` | ||
_get_model() -> Model | ||
_get_model_inputs() -> Sequence[Tensor] | ||
_get_model_forward_pass_method_name() -> str | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
comparison_config: ComparisonConfig, | ||
test_type: TestType = TestType.INFERENCE, | ||
) -> None: | ||
super().__init__(comparison_config) | ||
self._test_type = test_type | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def _get_model() -> Model: | ||
"""Returns model instance.""" | ||
raise NotImplementedError("Subclasses should implement this method.") | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def _get_model_inputs() -> Sequence[Tensor]: | ||
"""Returns inputs to the model's forward pass.""" | ||
raise NotImplementedError("Subclasses should implement this method.") | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def _get_model_forward_pass_method_name() -> str: | ||
""" | ||
Returns string name of a forward pass method. | ||
By default it is `Model.__call__` method which is the most convenient one. | ||
""" | ||
return "__call__" | ||
|
||
def test(self) -> None: | ||
"""Tests the model depending on test type with which tester was configured.""" | ||
model = self._get_model() | ||
inputs = self._get_model_inputs() | ||
|
||
if self._test_type == TestType.INFERENCE: | ||
self._test_inference(model, inputs) | ||
else: | ||
self._test_training(model, inputs) | ||
|
||
def _test_inference(self, model: Model, inputs: Sequence[Tensor]) -> None: | ||
""" | ||
Tests the model by running inference on TT device and on CPU and comparing the | ||
results. | ||
""" | ||
self._configure_model_for_inference(model) | ||
compiled_model = self._compile(model) | ||
|
||
tt_res = DeviceRunner.run_on_tt_device(compiled_model, inputs) | ||
cpu_res = DeviceRunner.run_on_cpu(compiled_model, inputs) | ||
|
||
self._compare(tt_res, cpu_res) | ||
|
||
def _test_training(self, model: Model, inputs: Sequence[Tensor]): | ||
"""TODO""" | ||
# self._configure_model_for_training(model) | ||
raise NotImplementedError("Support for training not implemented") | ||
|
||
def _configure_model_for_inference(self, model: Model) -> None: | ||
"""Configures model for inference.""" | ||
if isinstance(model, nnx.Module): | ||
model.eval() | ||
elif isinstance(model, linen.Module) or isinstance(model, FlaxPreTrainedModel): | ||
# TODO does linen have something alike nnx.Module.eval()? | ||
pass | ||
else: | ||
raise TypeError(f"Uknown model type: {type(model)}") | ||
|
||
def _configure_model_for_training(self, model: Model) -> None: | ||
"""Configures model for training.""" | ||
if isinstance(model, nnx.Module): | ||
model.train() | ||
elif isinstance(model, linen.Module) or isinstance(model, FlaxPreTrainedModel): | ||
# TODO does linen have something alike nnx.Module.train()? | ||
pass | ||
else: | ||
raise TypeError(f"Uknown model type: {type(model)}") | ||
|
||
def _compile(self, model: Model) -> Callable: | ||
"""JIT-compiles model into optimized kernels.""" | ||
forward_pass_method_name = self._get_model_forward_pass_method_name() | ||
assert hasattr( | ||
model, forward_pass_method_name | ||
), f"Model {model} does not have {forward_pass_method_name} method." | ||
|
||
forward_pass_method = getattr(model, forward_pass_method_name) | ||
return super()._compile(forward_pass_method) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
from typing import Callable, Sequence | ||
|
||
import jax | ||
|
||
from .comparison import ( | ||
ComparisonConfig, | ||
compare_allclose, | ||
compare_atol, | ||
compare_equal, | ||
compare_pcc, | ||
) | ||
from .device_runner import DeviceRunner | ||
from .utils import Tensor | ||
|
||
|
||
class BaseTester(ABC): | ||
""" | ||
Abstract base class all testers must inherit. | ||
Provides just a couple of common methods. | ||
""" | ||
|
||
def __init__( | ||
self, comparison_config: ComparisonConfig = ComparisonConfig() | ||
) -> None: | ||
self._comparison_config = comparison_config | ||
|
||
@staticmethod | ||
def _compile(f: Callable) -> Callable: | ||
"""Sets up `f` for just-in-time compile.""" | ||
return jax.jit(f) | ||
|
||
def _compare( | ||
self, | ||
device_out: Tensor, | ||
golden_out: Tensor, | ||
) -> None: | ||
device_output, golden_output = DeviceRunner.put_on_cpu(device_out, golden_out) | ||
device_output, golden_output = self._match_data_types( | ||
device_output, golden_output | ||
) | ||
|
||
if self._comparison_config.equal.enabled: | ||
compare_equal(device_output, golden_output) | ||
if self._comparison_config.atol.enabled: | ||
compare_atol(device_output, golden_output, self._comparison_config.atol) | ||
if self._comparison_config.pcc.enabled: | ||
compare_pcc(device_output, golden_output, self._comparison_config.pcc) | ||
if self._comparison_config.allclose.enabled: | ||
compare_allclose( | ||
device_output, golden_output, self._comparison_config.allclose | ||
) | ||
|
||
def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]: | ||
""" | ||
Casts all tensors to float32 if not already in that format. | ||
Tensors need to be in same data format in order to compare them. | ||
""" | ||
return [ | ||
(tensor.astype("float32") if tensor.dtype.str != "float32" else tensor) | ||
for tensor in tensors | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
|
||
import jax.numpy as jnp | ||
|
||
from .utils import Tensor | ||
|
||
|
||
@dataclass | ||
class ConfigBase: | ||
enabled: bool = True | ||
|
||
def enable(self) -> None: | ||
self.enabled = True | ||
|
||
def disable(self) -> None: | ||
self.enabled = False | ||
|
||
|
||
@dataclass | ||
class EqualConfig(ConfigBase): | ||
pass | ||
|
||
|
||
@dataclass | ||
class AtolConfig(ConfigBase): | ||
required_atol: float = 1e-1 | ||
|
||
|
||
@dataclass | ||
class PccConfig(ConfigBase): | ||
required_pcc: float = 0.99 | ||
|
||
|
||
@dataclass | ||
class AllcloseConfig(ConfigBase): | ||
rtol: float = 1e-2 | ||
atol: float = 1e-2 | ||
|
||
|
||
@dataclass | ||
class ComparisonConfig: | ||
equal: EqualConfig = EqualConfig(False) | ||
atol: AtolConfig = AtolConfig() | ||
pcc: PccConfig = PccConfig() | ||
allclose: AllcloseConfig = AllcloseConfig() | ||
|
||
def enable_all(self) -> None: | ||
self.equal.enable() | ||
self.atol.enable() | ||
self.allclose.enable() | ||
self.pcc.enable() | ||
|
||
def disable_all(self) -> None: | ||
self.equal.disable() | ||
self.atol.disable() | ||
self.allclose.disable() | ||
self.pcc.disable() | ||
|
||
|
||
def compare_equal(device_output: Tensor, golden_output: Tensor) -> None: | ||
eq = (device_output == golden_output).all() | ||
|
||
assert eq, f"Equal comparison failed" | ||
|
||
|
||
def compare_atol( | ||
device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig | ||
) -> None: | ||
atol = jnp.max(jnp.abs(device_output - golden_output)) | ||
|
||
assert ( | ||
atol <= atol_config.required_atol | ||
), f"Atol comparison failed. Calculated atol={atol}" | ||
|
||
|
||
def compare_pcc( | ||
device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig | ||
) -> None: | ||
# If tensors are really close, pcc will be nan. Handle that before calculating pcc. | ||
try: | ||
compare_allclose( | ||
device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2) | ||
) | ||
except AssertionError: | ||
pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten()) | ||
pcc = jnp.min(pcc) | ||
|
||
assert ( | ||
pcc >= pcc_config.required_pcc | ||
), f"PCC comparison failed. Calculated pcc={pcc}" | ||
|
||
|
||
def compare_allclose( | ||
device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig | ||
) -> None: | ||
allclose = jnp.allclose( | ||
device_output, | ||
golden_output, | ||
rtol=allclose_config.rtol, | ||
atol=allclose_config.atol, | ||
) | ||
|
||
assert allclose, f"Allclose comparison failed." |
Oops, something went wrong.