Skip to content

Commit

Permalink
Code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Dec 13, 2024
1 parent e27a765 commit 00918dc
Show file tree
Hide file tree
Showing 17 changed files with 618 additions and 415 deletions.
11 changes: 9 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import os

import jax
import jax._src.xla_bridge as xb
import pytest
from infra.device_connector import DeviceConnector


def initialize():
Expand All @@ -21,4 +23,9 @@ def initialize():

@pytest.fixture(scope="session", autouse=True)
def setup_session():
initialize()
# Added to prevent `PJRT_Api already exists for device type tt` error.
# Will be removed completely soon.
connector = DeviceConnector.get_instance()

if not connector.is_initialized():
initialize()
7 changes: 5 additions & 2 deletions tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from .device_runner import run_on_cpu, run_on_tt_device
from .module_tester import ComparisonMetric, TestType, test, test_with_random_inputs
# Exposes only what is really needed to write tests, nothing else.
from .base_model_tester import BaseModelTester
from .comparison import ComparisonConfig
from .graph_tester import run_graph_test, run_graph_test_with_random_inputs
from .op_tester import run_op_test, run_op_test_with_random_inputs
123 changes: 123 additions & 0 deletions tests/infra/base_model_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from typing import Callable, Sequence

from flax import linen, nnx
from transformers.modeling_flax_utils import FlaxPreTrainedModel

from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .utils import Model, Tensor


class TestType(Enum):
INFERENCE = "inference"
TRAINING = "training"


class BaseModelTester(BaseTester, ABC):
"""
Abstract base class all model testers must inherit.
Derived classes must provide implementations of:
```
_get_model() -> Model
_get_model_inputs() -> Sequence[Tensor]
_get_model_forward_pass_method_name() -> str
```
"""

def __init__(
self,
comparison_config: ComparisonConfig,
test_type: TestType = TestType.INFERENCE,
) -> None:
super().__init__(comparison_config)
self._test_type = test_type

@staticmethod
@abstractmethod
def _get_model() -> Model:
"""Returns model instance."""
raise NotImplementedError("Subclasses should implement this method.")

@staticmethod
@abstractmethod
def _get_model_inputs() -> Sequence[Tensor]:
"""Returns inputs to the model's forward pass."""
raise NotImplementedError("Subclasses should implement this method.")

@staticmethod
@abstractmethod
def _get_model_forward_pass_method_name() -> str:
"""
Returns string name of a forward pass method.
By default it is `Model.__call__` method which is the most convenient one.
"""
return "__call__"

def test(self) -> None:
"""Tests the model depending on test type with which tester was configured."""
model = self._get_model()
inputs = self._get_model_inputs()

if self._test_type == TestType.INFERENCE:
self._test_inference(model, inputs)
else:
self._test_training(model, inputs)

def _test_inference(self, model: Model, inputs: Sequence[Tensor]) -> None:
"""
Tests the model by running inference on TT device and on CPU and comparing the
results.
"""
self._configure_model_for_inference(model)
compiled_model = self._compile(model)

tt_res = DeviceRunner.run_on_tt_device(compiled_model, inputs)
cpu_res = DeviceRunner.run_on_cpu(compiled_model, inputs)

self._compare(tt_res, cpu_res)

def _test_training(self, model: Model, inputs: Sequence[Tensor]):
"""TODO"""
# self._configure_model_for_training(model)
raise NotImplementedError("Support for training not implemented")

def _configure_model_for_inference(self, model: Model) -> None:
"""Configures model for inference."""
if isinstance(model, nnx.Module):
model.eval()
elif isinstance(model, linen.Module) or isinstance(model, FlaxPreTrainedModel):
# TODO does linen have something alike nnx.Module.eval()?
pass
else:
raise TypeError(f"Uknown model type: {type(model)}")

def _configure_model_for_training(self, model: Model) -> None:
"""Configures model for training."""
if isinstance(model, nnx.Module):
model.train()
elif isinstance(model, linen.Module) or isinstance(model, FlaxPreTrainedModel):
# TODO does linen have something alike nnx.Module.train()?
pass
else:
raise TypeError(f"Uknown model type: {type(model)}")

def _compile(self, model: Model) -> Callable:
"""JIT-compiles model into optimized kernels."""
forward_pass_method_name = self._get_model_forward_pass_method_name()
assert hasattr(
model, forward_pass_method_name
), f"Model {model} does not have {forward_pass_method_name} method."

forward_pass_method = getattr(model, forward_pass_method_name)
return super()._compile(forward_pass_method)
70 changes: 70 additions & 0 deletions tests/infra/base_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC
from typing import Callable, Sequence

import jax

from .comparison import (
ComparisonConfig,
compare_allclose,
compare_atol,
compare_equal,
compare_pcc,
)
from .device_runner import DeviceRunner
from .utils import Tensor


class BaseTester(ABC):
"""
Abstract base class all testers must inherit.
Provides just a couple of common methods.
"""

def __init__(
self, comparison_config: ComparisonConfig = ComparisonConfig()
) -> None:
self._comparison_config = comparison_config

@staticmethod
def _compile(f: Callable) -> Callable:
"""Sets up `f` for just-in-time compile."""
return jax.jit(f)

def _compare(
self,
device_out: Tensor,
golden_out: Tensor,
) -> None:
device_output, golden_output = DeviceRunner.put_on_cpu(device_out, golden_out)
device_output, golden_output = self._match_data_types(
device_output, golden_output
)

if self._comparison_config.equal.enabled:
compare_equal(device_output, golden_output)
if self._comparison_config.atol.enabled:
compare_atol(device_output, golden_output, self._comparison_config.atol)
if self._comparison_config.pcc.enabled:
compare_pcc(device_output, golden_output, self._comparison_config.pcc)
if self._comparison_config.allclose.enabled:
compare_allclose(
device_output, golden_output, self._comparison_config.allclose
)

def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]:
"""
Casts all tensors to float32 if not already in that format.
Tensors need to be in same data format in order to compare them.
"""
return [
(tensor.astype("float32") if tensor.dtype.str != "float32" else tensor)
for tensor in tensors
]
109 changes: 109 additions & 0 deletions tests/infra/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass

import jax.numpy as jnp

from .utils import Tensor


@dataclass
class ConfigBase:
enabled: bool = True

def enable(self) -> None:
self.enabled = True

def disable(self) -> None:
self.enabled = False


@dataclass
class EqualConfig(ConfigBase):
pass


@dataclass
class AtolConfig(ConfigBase):
required_atol: float = 1e-1


@dataclass
class PccConfig(ConfigBase):
required_pcc: float = 0.99


@dataclass
class AllcloseConfig(ConfigBase):
rtol: float = 1e-2
atol: float = 1e-2


@dataclass
class ComparisonConfig:
equal: EqualConfig = EqualConfig(False)
atol: AtolConfig = AtolConfig()
pcc: PccConfig = PccConfig()
allclose: AllcloseConfig = AllcloseConfig()

def enable_all(self) -> None:
self.equal.enable()
self.atol.enable()
self.allclose.enable()
self.pcc.enable()

def disable_all(self) -> None:
self.equal.disable()
self.atol.disable()
self.allclose.disable()
self.pcc.disable()


def compare_equal(device_output: Tensor, golden_output: Tensor) -> None:
eq = (device_output == golden_output).all()

assert eq, f"Equal comparison failed"


def compare_atol(
device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig
) -> None:
atol = jnp.max(jnp.abs(device_output - golden_output))

assert (
atol <= atol_config.required_atol
), f"Atol comparison failed. Calculated atol={atol}"


def compare_pcc(
device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig
) -> None:
# If tensors are really close, pcc will be nan. Handle that before calculating pcc.
try:
compare_allclose(
device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2)
)
except AssertionError:
pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten())
pcc = jnp.min(pcc)

assert (
pcc >= pcc_config.required_pcc
), f"PCC comparison failed. Calculated pcc={pcc}"


def compare_allclose(
device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig
) -> None:
allclose = jnp.allclose(
device_output,
golden_output,
rtol=allclose_config.rtol,
atol=allclose_config.atol,
)

assert allclose, f"Allclose comparison failed."
Loading

0 comments on commit 00918dc

Please sign in to comment.