From 00918dc6655e21f5c6ea580b6529543094077e96 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Thu, 12 Dec 2024 12:57:41 +0000 Subject: [PATCH] Code review changes --- tests/conftest.py | 11 +- tests/infra/__init__.py | 7 +- tests/infra/base_model_tester.py | 123 +++++++++++++++ tests/infra/base_tester.py | 70 +++++++++ tests/infra/comparison.py | 109 ++++++++++++++ tests/infra/device_connector.py | 45 +++--- tests/infra/device_runner.py | 106 ++++--------- tests/infra/graph_tester.py | 47 ++++++ tests/infra/module_tester.py | 141 ------------------ tests/infra/op_tester.py | 69 +++++++++ tests/infra/test_model.py | 44 ------ tests/infra/test_module.py | 54 ------- tests/infra/utils.py | 86 ++++++----- .../graphs/test_arbitrary_op_chain.py | 13 +- .../test_flax_distil_bert_for_masked_lm.py | 51 +++++++ tests/{ => jax}/models/test_simple_nn.py | 48 +++--- tests/{ => jax}/ops/test_add.py | 9 +- 17 files changed, 618 insertions(+), 415 deletions(-) create mode 100644 tests/infra/base_model_tester.py create mode 100644 tests/infra/base_tester.py create mode 100644 tests/infra/comparison.py create mode 100644 tests/infra/graph_tester.py delete mode 100644 tests/infra/module_tester.py create mode 100644 tests/infra/op_tester.py delete mode 100644 tests/infra/test_model.py delete mode 100644 tests/infra/test_module.py rename tests/{ => jax}/graphs/test_arbitrary_op_chain.py (59%) create mode 100644 tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py rename tests/{ => jax}/models/test_simple_nn.py (50%) rename tests/{ => jax}/ops/test_add.py (55%) diff --git a/tests/conftest.py b/tests/conftest.py index ae2fb1a..06c5e39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -import pytest import os + import jax import jax._src.xla_bridge as xb +import pytest +from infra.device_connector import DeviceConnector def initialize(): @@ -21,4 +23,9 @@ def initialize(): @pytest.fixture(scope="session", autouse=True) def setup_session(): - initialize() + # Added to prevent `PJRT_Api already exists for device type tt` error. + # Will be removed completely soon. + connector = DeviceConnector.get_instance() + + if not connector.is_initialized(): + initialize() diff --git a/tests/infra/__init__.py b/tests/infra/__init__.py index 8bd0773..32030b2 100644 --- a/tests/infra/__init__.py +++ b/tests/infra/__init__.py @@ -2,5 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from .device_runner import run_on_cpu, run_on_tt_device -from .module_tester import ComparisonMetric, TestType, test, test_with_random_inputs +# Exposes only what is really needed to write tests, nothing else. +from .base_model_tester import BaseModelTester +from .comparison import ComparisonConfig +from .graph_tester import run_graph_test, run_graph_test_with_random_inputs +from .op_tester import run_op_test, run_op_test_with_random_inputs diff --git a/tests/infra/base_model_tester.py b/tests/infra/base_model_tester.py new file mode 100644 index 0000000..87c4b6f --- /dev/null +++ b/tests/infra/base_model_tester.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Callable, Sequence + +from flax import linen, nnx +from transformers.modeling_flax_utils import FlaxPreTrainedModel + +from .base_tester import BaseTester +from .comparison import ComparisonConfig +from .device_runner import DeviceRunner +from .utils import Model, Tensor + + +class TestType(Enum): + INFERENCE = "inference" + TRAINING = "training" + + +class BaseModelTester(BaseTester, ABC): + """ + Abstract base class all model testers must inherit. + + Derived classes must provide implementations of: + ``` + _get_model() -> Model + _get_model_inputs() -> Sequence[Tensor] + _get_model_forward_pass_method_name() -> str + ``` + """ + + def __init__( + self, + comparison_config: ComparisonConfig, + test_type: TestType = TestType.INFERENCE, + ) -> None: + super().__init__(comparison_config) + self._test_type = test_type + + @staticmethod + @abstractmethod + def _get_model() -> Model: + """Returns model instance.""" + raise NotImplementedError("Subclasses should implement this method.") + + @staticmethod + @abstractmethod + def _get_model_inputs() -> Sequence[Tensor]: + """Returns inputs to the model's forward pass.""" + raise NotImplementedError("Subclasses should implement this method.") + + @staticmethod + @abstractmethod + def _get_model_forward_pass_method_name() -> str: + """ + Returns string name of a forward pass method. + + By default it is `Model.__call__` method which is the most convenient one. + """ + return "__call__" + + def test(self) -> None: + """Tests the model depending on test type with which tester was configured.""" + model = self._get_model() + inputs = self._get_model_inputs() + + if self._test_type == TestType.INFERENCE: + self._test_inference(model, inputs) + else: + self._test_training(model, inputs) + + def _test_inference(self, model: Model, inputs: Sequence[Tensor]) -> None: + """ + Tests the model by running inference on TT device and on CPU and comparing the + results. + """ + self._configure_model_for_inference(model) + compiled_model = self._compile(model) + + tt_res = DeviceRunner.run_on_tt_device(compiled_model, inputs) + cpu_res = DeviceRunner.run_on_cpu(compiled_model, inputs) + + self._compare(tt_res, cpu_res) + + def _test_training(self, model: Model, inputs: Sequence[Tensor]): + """TODO""" + # self._configure_model_for_training(model) + raise NotImplementedError("Support for training not implemented") + + def _configure_model_for_inference(self, model: Model) -> None: + """Configures model for inference.""" + if isinstance(model, nnx.Module): + model.eval() + elif isinstance(model, linen.Module) or isinstance(model, FlaxPreTrainedModel): + # TODO does linen have something alike nnx.Module.eval()? + pass + else: + raise TypeError(f"Uknown model type: {type(model)}") + + def _configure_model_for_training(self, model: Model) -> None: + """Configures model for training.""" + if isinstance(model, nnx.Module): + model.train() + elif isinstance(model, linen.Module) or isinstance(model, FlaxPreTrainedModel): + # TODO does linen have something alike nnx.Module.train()? + pass + else: + raise TypeError(f"Uknown model type: {type(model)}") + + def _compile(self, model: Model) -> Callable: + """JIT-compiles model into optimized kernels.""" + forward_pass_method_name = self._get_model_forward_pass_method_name() + assert hasattr( + model, forward_pass_method_name + ), f"Model {model} does not have {forward_pass_method_name} method." + + forward_pass_method = getattr(model, forward_pass_method_name) + return super()._compile(forward_pass_method) diff --git a/tests/infra/base_tester.py b/tests/infra/base_tester.py new file mode 100644 index 0000000..fcaeb65 --- /dev/null +++ b/tests/infra/base_tester.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC +from typing import Callable, Sequence + +import jax + +from .comparison import ( + ComparisonConfig, + compare_allclose, + compare_atol, + compare_equal, + compare_pcc, +) +from .device_runner import DeviceRunner +from .utils import Tensor + + +class BaseTester(ABC): + """ + Abstract base class all testers must inherit. + + Provides just a couple of common methods. + """ + + def __init__( + self, comparison_config: ComparisonConfig = ComparisonConfig() + ) -> None: + self._comparison_config = comparison_config + + @staticmethod + def _compile(f: Callable) -> Callable: + """Sets up `f` for just-in-time compile.""" + return jax.jit(f) + + def _compare( + self, + device_out: Tensor, + golden_out: Tensor, + ) -> None: + device_output, golden_output = DeviceRunner.put_on_cpu(device_out, golden_out) + device_output, golden_output = self._match_data_types( + device_output, golden_output + ) + + if self._comparison_config.equal.enabled: + compare_equal(device_output, golden_output) + if self._comparison_config.atol.enabled: + compare_atol(device_output, golden_output, self._comparison_config.atol) + if self._comparison_config.pcc.enabled: + compare_pcc(device_output, golden_output, self._comparison_config.pcc) + if self._comparison_config.allclose.enabled: + compare_allclose( + device_output, golden_output, self._comparison_config.allclose + ) + + def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]: + """ + Casts all tensors to float32 if not already in that format. + + Tensors need to be in same data format in order to compare them. + """ + return [ + (tensor.astype("float32") if tensor.dtype.str != "float32" else tensor) + for tensor in tensors + ] diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py new file mode 100644 index 0000000..bb57fd4 --- /dev/null +++ b/tests/infra/comparison.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +import jax.numpy as jnp + +from .utils import Tensor + + +@dataclass +class ConfigBase: + enabled: bool = True + + def enable(self) -> None: + self.enabled = True + + def disable(self) -> None: + self.enabled = False + + +@dataclass +class EqualConfig(ConfigBase): + pass + + +@dataclass +class AtolConfig(ConfigBase): + required_atol: float = 1e-1 + + +@dataclass +class PccConfig(ConfigBase): + required_pcc: float = 0.99 + + +@dataclass +class AllcloseConfig(ConfigBase): + rtol: float = 1e-2 + atol: float = 1e-2 + + +@dataclass +class ComparisonConfig: + equal: EqualConfig = EqualConfig(False) + atol: AtolConfig = AtolConfig() + pcc: PccConfig = PccConfig() + allclose: AllcloseConfig = AllcloseConfig() + + def enable_all(self) -> None: + self.equal.enable() + self.atol.enable() + self.allclose.enable() + self.pcc.enable() + + def disable_all(self) -> None: + self.equal.disable() + self.atol.disable() + self.allclose.disable() + self.pcc.disable() + + +def compare_equal(device_output: Tensor, golden_output: Tensor) -> None: + eq = (device_output == golden_output).all() + + assert eq, f"Equal comparison failed" + + +def compare_atol( + device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig +) -> None: + atol = jnp.max(jnp.abs(device_output - golden_output)) + + assert ( + atol <= atol_config.required_atol + ), f"Atol comparison failed. Calculated atol={atol}" + + +def compare_pcc( + device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig +) -> None: + # If tensors are really close, pcc will be nan. Handle that before calculating pcc. + try: + compare_allclose( + device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2) + ) + except AssertionError: + pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten()) + pcc = jnp.min(pcc) + + assert ( + pcc >= pcc_config.required_pcc + ), f"PCC comparison failed. Calculated pcc={pcc}" + + +def compare_allclose( + device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig +) -> None: + allclose = jnp.allclose( + device_output, + golden_output, + rtol=allclose_config.rtol, + atol=allclose_config.atol, + ) + + assert allclose, f"Allclose comparison failed." diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index 5bc3295..8dd581b 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import os from enum import Enum from typing import Sequence @@ -9,6 +11,9 @@ import jax import jax._src.xla_bridge as xb +# Relative path to PJRT plugin for TT devices. +TT_PJRT_PLUGIN_RELPATH = "build/src/tt/pjrt_plugin_tt.so" + class DeviceType(Enum): CPU = "cpu" @@ -25,7 +30,9 @@ class DeviceConnector: registered exactly once. Registration needs to happen before any other jax commands are executed. Registering it multiple times would cause error. - Do not instantiate this class. Use provided global instance. + Do not instantiate this class directly. Use provided factory method instead. + + TODO (kmitrovic) see how to make this class a thread safe singleton if needed. """ _instance = None @@ -37,17 +44,16 @@ def __new__(cls, *args, **kwargs): return cls._instance - def __init__( - self, tt_pjrt_plugin_relpath: str = "build/src/tt/pjrt_plugin_tt.so" - ) -> None: + def __init__(self) -> None: + """Don't use directly, use provided factory method instead.""" # We need to ensure __init__ body is executed once. It will be called each time # `DeviceConnector()` is called. - if self._is_initialized(): + if self.is_initialized(): return self._initialized = False - plugin_path = os.path.join(os.getcwd(), tt_pjrt_plugin_relpath) + plugin_path = os.path.join(os.getcwd(), TT_PJRT_PLUGIN_RELPATH) if not os.path.exists(plugin_path): raise FileNotFoundError( f"Could not find tt_pjrt C API plugin at {plugin_path}" @@ -56,6 +62,22 @@ def __init__( self._plugin_path = plugin_path self._initialize_backend() + @staticmethod + def get_instance() -> DeviceConnector: + """ + Factory method returning singleton connector instance. + + Use this method instead of constructor. + """ + return DeviceConnector() + + def is_initialized(self) -> bool: + """Checks if connector is already initialized.""" + if hasattr(self, "_initialized") and self._initialized == True: + return True + + return False + def connect_tt_device(self) -> jax.Device: """Returns TTDevice handle.""" return self.connect_device(DeviceType.TT) @@ -97,14 +119,3 @@ def _initialize_backend(self) -> None: jax.config.update("jax_platforms", self._supported_devices_str()) self._initialized = True - - def _is_initialized(self) -> bool: - """Checks if connector is already initialized.""" - if hasattr(self, "_initialized") and self._initialized == True: - return True - - return False - - -# Global instance of DeviceConnector to be used from other modules. -connector = DeviceConnector() diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index c4bae8c..7643db8 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -6,107 +6,65 @@ import jax -from .device_connector import DeviceType, connector -from .test_module import TestModule +from .device_connector import DeviceConnector, DeviceType +from .utils import Tensor class DeviceRunner: + """ + Class providing methods to run workload on any supported device. + + TODO (kmitrovic) can we make this more general so inputs can be anything and provide + decorators for any function to be run on any device? + """ + @staticmethod - def run_on_tt_device(module: TestModule) -> jax.Array: - """Runs test module on TT device.""" - return DeviceRunner._run_on_device(module, DeviceType.TT) + def run_on_tt_device(f: Callable, inputs: Sequence[Tensor]) -> Tensor: + """Runs `f(inputs)` on TT device.""" + return DeviceRunner._run_on_device(DeviceType.TT, f, inputs) @staticmethod - def run_on_cpu(module: TestModule) -> jax.Array: - """Runs test module on CPU.""" - return DeviceRunner._run_on_device(module, DeviceType.CPU) + def run_on_cpu(f: Callable, inputs: Sequence[Tensor]) -> Tensor: + """Runs `f(inputs)` on CPU.""" + return DeviceRunner._run_on_device(DeviceType.CPU, f, inputs) @staticmethod - def run_on_gpu(module: TestModule) -> jax.Array: - """Runs test module on GPU.""" + def run_on_gpu(f: Callable, inputs: Sequence[Tensor]) -> Tensor: + """Runs `f(inputs)` on GPU.""" raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def put_on_tt_device(*tensors: jax.Array) -> Sequence[jax.Array]: + def put_on_tt_device(*tensors: Tensor) -> Sequence[Tensor]: """Puts `tensors` on TT device.""" - return DeviceRunner._put_on_device(tensors, DeviceType.TT) + return DeviceRunner._put_on_device(DeviceType.TT, tensors) @staticmethod - def put_on_cpu(*tensors: jax.Array) -> Sequence[jax.Array]: + def put_on_cpu(*tensors: Tensor) -> Sequence[Tensor]: """Puts `tensors` on CPU.""" - return DeviceRunner._put_on_device(tensors, DeviceType.CPU) + return DeviceRunner._put_on_device(DeviceType.CPU, tensors) @staticmethod - def put_on_gpu(*tensors: jax.Array) -> Sequence[jax.Array]: + def put_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: """Puts `tensors` on GPU.""" raise NotImplementedError("Support for GPUs not implemented") @staticmethod def _run_on_device( - module: TestModule, - device_type: DeviceType, - raise_ex_if_jit_failed: bool = False, - ) -> jax.Array: - """ - Runs test module on device. - - If jitted graph fails to run, will try to run non-jitted graph if - `raise_ex_if_jit_failed` is False. - """ + device_type: DeviceType, f: Callable, inputs: Sequence[Tensor] + ) -> Tensor: + """Runs `f(inputs)` on device identified by `device_type`.""" + connector = DeviceConnector().get_instance() device = connector.connect_device(device_type) + inputs = DeviceRunner._put_on_device(device_type, inputs) - # TODO is there a better way to check if function can be jitted than runtime fail? - try: - graph = module.get_jit_graph() - inputs = DeviceRunner._put_on_device(module.get_inputs(), device_type) - - with jax.default_device(device): - return graph(*inputs) - - except Exception as e: - if raise_ex_if_jit_failed: - raise e - - with jax.default_device(device): - return module() + with jax.default_device(device): + return f(*inputs) @staticmethod def _put_on_device( - tensors: Sequence[jax.Array], device_type: DeviceType - ) -> Sequence[jax.Array]: + device_type: DeviceType, tensors: Sequence[Tensor] + ) -> Sequence[Tensor]: """Puts `tensors` on device identified by `device_type`.""" + connector = DeviceConnector().get_instance() device = connector.connect_device(device_type) return [jax.device_put(t, device) for t in tensors] - - -# ----- Convenience decorators ----- - - -def run_on_tt_device(f: Callable): - """Runs any decorated function on TT device.""" - - def wrapper(*args, **kwargs): - module = TestModule(f, args=args, kwargs=kwargs) - return DeviceRunner.run_on_tt_device(module) - - return wrapper - - -def run_on_cpu(f: Callable): - """Runs any decorated function on CPU.""" - - def wrapper(*args, **kwargs): - module = TestModule(f, args=args, kwargs=kwargs) - return DeviceRunner.run_on_cpu(module) - - return wrapper - - -def run_on_gpu(f: Callable): - """Runs any decorated function on GPU.""" - - def wrapper(*args, **kwargs): - module = TestModule(f, args=args, kwargs=kwargs) - return DeviceRunner.run_on_gpu(module) - - return wrapper diff --git a/tests/infra/graph_tester.py b/tests/infra/graph_tester.py new file mode 100644 index 0000000..5553651 --- /dev/null +++ b/tests/infra/graph_tester.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Callable, Sequence + +from .comparison import ComparisonConfig +from .op_tester import OpTester +from .utils import Tensor + + +class GraphTester(OpTester): + """ + Specific tester for graphs. + + Currently same as OpTester. + """ + + pass + + +def run_graph_test( + graph: Callable, + inputs: Sequence[Tensor], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `op` with `inputs` by running it on TT device and CPU and comparing the + results based on `comparison_config`. + """ + tester = GraphTester(comparison_config) + tester.test(graph, inputs) + + +def run_graph_test_with_random_inputs( + graph: Callable, + input_shapes: Sequence[tuple], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `graph` with random inputs by running it on TT device and CPU and comparing the + results based on `comparison_config`. + """ + tester = GraphTester(comparison_config) + tester.test_with_random_inputs(graph, input_shapes) diff --git a/tests/infra/module_tester.py b/tests/infra/module_tester.py deleted file mode 100644 index 9262a7a..0000000 --- a/tests/infra/module_tester.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -from enum import Enum -from typing import Callable, Sequence, Union - -import jax - -from .device_runner import DeviceRunner -from .test_model import TestModel -from .test_module import TestModule -from .utils import ( - compare_allclose, - compare_atol, - compare_equal, - compare_pcc, - random_tensor, -) - - -class TestType(Enum): - INFERENCE = "inference" - TRAINING = "training" - - -class ComparisonMetric(Enum): - EQUAL = "equal" - PCC = "pcc" - ATOL = "atol" - ALLCLOSE = "allclose" - - -class ModuleTester: - """ - Class providing infrastructure for testing test modules. - - Testing consists of comparing output results of running provided test module on - different devices. Frequently, run on CPU or GPU is taken as a ground truth - ("golden"), while custom device run is compared to it. - - It supports testing inference and training modes. - """ - - def __init__( - self, - test_type: TestType = TestType.INFERENCE, - comparison_metric: ComparisonMetric = ComparisonMetric.PCC, - ) -> None: - self._test_type = test_type - self._comparison_metric = comparison_metric - - def __call__(self, module: TestModule) -> bool: - """ - The only public method providing testing hook. - - Call tester with passed test module and it will run tests on it. - """ - return self._test(module) - - def _test(self, module: TestModule) -> bool: - if self._test_type == TestType.INFERENCE: - return self._test_inference(module) - else: - return self._test_training(module) - - def _test_inference(self, module: TestModule) -> bool: - tt_res = DeviceRunner.run_on_tt_device(module) - cpu_res = DeviceRunner.run_on_cpu(module) - return self._compare(tt_res, cpu_res) - - def _test_training(self, module: TestModule) -> bool: - raise NotImplementedError("Support for training not implemented") - - def _compare( - self, device_out: jax.Array, golden_out: jax.Array, assert_on_fail: bool = False - ) -> bool: - device_output, golden_output = DeviceRunner.put_on_cpu(device_out, golden_out) - device_output, golden_output = self._match_data_types( - device_output, golden_output - ) - - if self._comparison_metric == ComparisonMetric.EQUAL: - comp = compare_equal(device_output, golden_output) - elif self._comparison_metric == ComparisonMetric.PCC: - comp = compare_pcc(device_output, golden_output) - elif self._comparison_metric == ComparisonMetric.ATOL: - comp = compare_atol(device_output, golden_output) - elif self._comparison_metric == ComparisonMetric.ALLCLOSE: - comp = compare_allclose(device_output, golden_output) - - if assert_on_fail: - assert comp, f"{self._comparison_metric.value} comparison failed!" - - return comp - - def _match_data_types(self, *tensors: jax.Array) -> Sequence[jax.Array]: - """ - Casts all tensors to float32 if not already in that format. - - Tensors need to be in same data format in order to compare them. - """ - return [ - tensor.astype("float32") if tensor.dtype.str != "float32" else tensor - for tensor in tensors - ] - - -def _test( - f: Union[Callable, TestModel], - inputs: Sequence[jax.Array], - comparison_metric: ComparisonMetric, -) -> bool: - """Helper 'protected' method, don't use, use provided public methods below instead.""" - tester = ModuleTester(comparison_metric=comparison_metric) - module = ( - f.as_test_module(inputs) - if isinstance(f, TestModel) - else TestModule(f, args=inputs) - ) - return tester(module) - - -def test( - f: Union[Callable, TestModel], - inputs: Sequence[jax.Array], - comparison_metric: ComparisonMetric, -) -> bool: - return _test(f, inputs, comparison_metric) - - -def test_with_random_inputs( - f: Union[Callable, TestModel], - input_shapes: Sequence[tuple], - comparison_metric: ComparisonMetric, -) -> bool: - inputs = [random_tensor(shape) for shape in input_shapes] - return _test(f, inputs, comparison_metric) - - -# TODO expose multiple functions for each of the comparisons since their args may differ. diff --git a/tests/infra/op_tester.py b/tests/infra/op_tester.py new file mode 100644 index 0000000..3e1566b --- /dev/null +++ b/tests/infra/op_tester.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Callable, Sequence + +from .base_tester import BaseTester +from .comparison import ComparisonConfig +from .device_runner import DeviceRunner +from .utils import Tensor, random_tensor + + +class OpTester(BaseTester): + """Specific tester for ops.""" + + def __init__( + self, comparison_config: ComparisonConfig = ComparisonConfig() + ) -> None: + super().__init__(comparison_config) + + def test(self, f: Callable, inputs: Sequence[Tensor]) -> None: + """ + Tests `f` with `inputs` by running it on TT device and CPU and comparing the + results. + """ + compiled_f = self._compile(f) + + tt_res = DeviceRunner.run_on_tt_device(compiled_f, inputs) + cpu_res = DeviceRunner.run_on_cpu(compiled_f, inputs) + + self._compare(tt_res, cpu_res) + + def test_with_random_inputs( + self, f: Callable, input_shapes: Sequence[tuple] + ) -> None: + """ + Tests `f` by running it with random inputs on TT device and CPU and comparing + the results. + """ + inputs = [random_tensor(shape) for shape in input_shapes] + self.test(f, inputs) + + +def run_op_test( + op: Callable, + inputs: Sequence[Tensor], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `op` with `inputs` by running it on TT device and CPU and comparing the + results based on `comparison_config`. + """ + tester = OpTester(comparison_config) + tester.test(op, inputs) + + +def run_op_test_with_random_inputs( + op: Callable, + input_shapes: Sequence[tuple], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `op` with random inputs by running it on TT device and CPU and comparing the + results based on `comparison_config`. + """ + tester = OpTester(comparison_config) + tester.test_with_random_inputs(op, input_shapes) diff --git a/tests/infra/test_model.py b/tests/infra/test_model.py deleted file mode 100644 index 7f4a2f3..0000000 --- a/tests/infra/test_model.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from abc import abstractmethod -from typing import Any, Optional, Sequence - -import jax - -from .test_module import TestModule - - -class TestModel: - """ - Interface class for testing models. - - Provides methods which models to be tested must override. Provides a way to export - self to `TestModule` which is then used throughout test infra. - """ - - @abstractmethod - def __call__(self, *args: Any, **kwargs: Any) -> jax.Array: - raise NotImplementedError("Subclasses should implement this method") - - @staticmethod - @abstractmethod - def get_model() -> TestModel: - raise NotImplementedError("Subclasses should implement this method") - - @staticmethod - @abstractmethod - def get_model_inputs() -> Sequence[jax.Array]: - raise NotImplementedError("Subclasses should implement this method") - - def as_test_module( - self, inputs: Optional[Sequence[jax.Array]] = None - ) -> TestModule: - ins = inputs if inputs is not None else self.get_model_inputs() - return TestModule(self.__call__, ins) - - def __repr__(self) -> str: - return f"TestModel: {self.__class__.__qualname__}" diff --git a/tests/infra/test_module.py b/tests/infra/test_module.py deleted file mode 100644 index 4ed50ad..0000000 --- a/tests/infra/test_module.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from typing import Any, Callable, Mapping, Optional, Sequence - -import jax -from jax import export - - -class TestModule: - """ - Wrapper around a callable and its arguments. - - Single-op or multi-op graphs defined as a python function are wrapped in a - TestModule for convenience. TestModule is then used throughout test infra. - """ - - def __init__( - self, - f: Callable, - args: Sequence[Any], - kwargs: Optional[Mapping[str, Any]] = None, - ) -> None: - self._f = f - self._args = args - self._kwargs = kwargs if kwargs is not None else {} - - self._inputs = tuple(self._args) + tuple(self._kwargs.values()) - - def get_inputs(self) -> Sequence[Any]: - return self._inputs - - def get_jit_graph(self): - return jax.jit(self._f) - - def __call__(self): - """Calls underlying callable with passed underlying args.""" - return self._f(*self._inputs) - - def __repr__(self) -> str: - return f"TestModule: {self._f.__qualname__}" - - def as_mlir_module(self) -> str: - """ - Returns jitted graph as a mlir module string. - - Note that this only works if test module can be successfully run in jitted form. - """ - s = export.export(self.get_jit_graph())(*self.get_inputs()).mlir_module() - # Remove all #loc lines for cleaner output. - return "\n".join(line for line in s.splitlines() if not line.startswith("#loc")) diff --git a/tests/infra/utils.py b/tests/infra/utils.py index 210d888..2361731 100644 --- a/tests/infra/utils.py +++ b/tests/infra/utils.py @@ -2,43 +2,53 @@ # # SPDX-License-Identifier: Apache-2.0 +from enum import Enum +from typing import Union + import jax import jax.numpy as jnp - -from .device_runner import run_on_cpu - - -@run_on_cpu -def random_tensor(shape: tuple, dtype=jnp.float32, random_seed: int = 0) -> jax.Array: - """Generates random tensor of `shape` and `dtype` on CPU.""" - prng_key = jax.random.key(random_seed) - return jax.random.uniform(key=prng_key, shape=shape, dtype=dtype) - - -def compare_pcc( - device_output: jax.Array, golden_output: jax.Array, required_pcc: float = 0.99 -) -> bool: - # If tensors are really close, pcc will be nan. Handle that before calculating pcc. - if compare_allclose(device_output, golden_output, 1e-3, 1e-3): - return True - - pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten()) - return jnp.min(pcc) >= required_pcc - - -def compare_atol( - device_output: jax.Array, golden_output: jax.Array, required_atol: float = 1e-2 -) -> bool: - atol = jnp.max(jnp.abs(device_output - golden_output)) - return atol <= required_atol - - -def compare_equal(device_output: jax.Array, golden_output: jax.Array) -> bool: - return (device_output == golden_output).all() - - -def compare_allclose( - device_output: jax.Array, golden_output: jax.Array, rtol=1e-2, atol=1e-2 -) -> bool: - allclose = jnp.allclose(device_output, golden_output, rtol=rtol, atol=atol) - return allclose +from flax import linen, nnx + + +class Framework(Enum): + JAX = "jax" + TORCH = "torch" + NUMPY = "numpy" + + +# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc. +Tensor = Union[jax.Array] + +# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc. +# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some +# huggingface models. +Model = Union[nnx.Module, linen.Module] + + +def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX): + """Convert a string dtype to the corresponding framework-specific dtype.""" + if framework == Framework.JAX: + return jnp.dtype(dtype_str) + else: + raise ValueError(f"Unsupported framework: {framework.value}.") + + +def random_tensor( + shape: tuple, + dtype: str = "float32", + random_seed: int = 0, + framework: Framework = Framework.JAX, +) -> Tensor: + """ + Generates a random tensor of `shape`, `dtype`, and `random_seed` for the desired + `framework`. + """ + # Convert dtype string to actual dtype for the selected framework. + dtype_converted = _str_to_dtype(dtype, framework) + + # Generate random tensor based on framework type + if framework == Framework.JAX: + prng_key = jax.random.PRNGKey(random_seed) + return jax.random.uniform(key=prng_key, shape=shape, dtype=dtype_converted) + else: + raise ValueError(f"Unsupported framework: {framework.value}.") diff --git a/tests/graphs/test_arbitrary_op_chain.py b/tests/jax/graphs/test_arbitrary_op_chain.py similarity index 59% rename from tests/graphs/test_arbitrary_op_chain.py rename to tests/jax/graphs/test_arbitrary_op_chain.py index 0a89225..317244e 100644 --- a/tests/graphs/test_arbitrary_op_chain.py +++ b/tests/jax/graphs/test_arbitrary_op_chain.py @@ -4,7 +4,7 @@ import jax import pytest -from infra import ComparisonMetric, test_with_random_inputs +from infra import run_graph_test_with_random_inputs from jax import numpy as jnp @@ -24,13 +24,4 @@ def arbitrary_op_chain(x: jax.Array, y: jax.Array) -> jax.Array: ], ) def test_arbitrary_op_chain(x_shape: tuple, y_shape: tuple): - assert test_with_random_inputs( - arbitrary_op_chain, [x_shape, y_shape], ComparisonMetric.ALLCLOSE - ) - - -if __name__ == "__main__": - x_shape = y_shape = (32, 32) - assert test_with_random_inputs( - arbitrary_op_chain, [x_shape, y_shape], ComparisonMetric.ALLCLOSE - ) + run_graph_test_with_random_inputs(arbitrary_op_chain, [x_shape, y_shape]) diff --git a/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py b/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py new file mode 100644 index 0000000..e802e18 --- /dev/null +++ b/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Sequence + +import jax +import pytest +from flax import linen as nn +from infra import BaseModelTester, ComparisonConfig +from transformers import AutoTokenizer, FlaxDistilBertForMaskedLM + +MODEL = "distilbert/distilbert-base-uncased" + + +class FlaxDistilBertForMaskedLMTester(BaseModelTester): + """Tester for DistilBert model with a `language modeling` head on top.""" + + # @override + @staticmethod + def _get_model() -> nn.Module: + return FlaxDistilBertForMaskedLM.from_pretrained(MODEL) + + # @override + @staticmethod + def _get_model_inputs() -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(MODEL) + inputs = tokenizer("Hello [MASK].", return_tensors="np") + return [inputs["input_ids"]] + + # @override + @staticmethod + def _get_model_forward_pass_method_name() -> str: + return "__call__" + + +@pytest.fixture +def comparison_config() -> ComparisonConfig: + config = ComparisonConfig() + config.disable_all() + config.pcc.enable() + return config + + +@pytest.fixture +def tester(comparison_config: ComparisonConfig) -> FlaxDistilBertForMaskedLMTester: + return FlaxDistilBertForMaskedLMTester(comparison_config) + + +def test_simple_nn(tester: FlaxDistilBertForMaskedLMTester): + tester.test() diff --git a/tests/models/test_simple_nn.py b/tests/jax/models/test_simple_nn.py similarity index 50% rename from tests/models/test_simple_nn.py rename to tests/jax/models/test_simple_nn.py index 468ab3b..acab2be 100644 --- a/tests/models/test_simple_nn.py +++ b/tests/jax/models/test_simple_nn.py @@ -7,13 +7,11 @@ import jax import jax.numpy as jnp import pytest -from infra.module_tester import ComparisonMetric, test, test_with_random_inputs -from infra.test_model import TestModel +from flax import nnx +from infra import BaseModelTester, ComparisonConfig -class SimpleNN(TestModel): # TODO what's the benefit of inheriting nnx.Module? - # TODO upgrade to python 3.13 to enable this decorator - # @override +class SimpleNN(nnx.Module): def __call__( self, act: jax.Array, w0: jax.Array, b0: jax.Array, w1: jax.Array, b1: jax.Array ) -> jax.Array: @@ -21,14 +19,17 @@ def __call__( x = jnp.matmul(x, w1) + b1 return x + +class SimpleNNTester(BaseModelTester): + # TODO (kmitrovic) upgrade env to python 3.13 to enable this decorator # @override @staticmethod - def get_model() -> TestModel: + def _get_model() -> nnx.Module: return SimpleNN() # @override @staticmethod - def get_model_inputs() -> Sequence[jax.Array]: + def _get_model_inputs() -> Sequence[jax.Array]: act_shape, w0_shape, b0_shape, w1_shape, b1_shape = ( (32, 784), (784, 128), @@ -45,27 +46,24 @@ def get_model_inputs() -> Sequence[jax.Array]: return [act, w0, b0, w1, b1] + # @override + @staticmethod + def _get_model_forward_pass_method_name() -> str: + return "__call__" -def test_simple_nn(): - model = SimpleNN.get_model() - inputs = SimpleNN.get_model_inputs() - - assert test(model, inputs, ComparisonMetric.PCC) +@pytest.fixture +def comparison_config() -> ComparisonConfig: + config = ComparisonConfig() + config.disable_all() + config.pcc.enable() + return config -@pytest.mark.parametrize( - ["act", "w0", "b0", "w1", "b1"], - [ - [(32, 784), (784, 128), (1, 128), (128, 128), (1, 128)], - ], -) -def test_simple_nn_with_random_inputs( - act: tuple, w0: tuple, b0: tuple, w1: tuple, b1: tuple -): - model = SimpleNN.get_model() - assert test_with_random_inputs(model, [act, w0, b0, w1, b1], ComparisonMetric.PCC) +@pytest.fixture +def tester(comparison_config: ComparisonConfig) -> SimpleNNTester: + return SimpleNNTester(comparison_config) -if __name__ == "__main__": - test_simple_nn() +def test_simple_nn(tester: SimpleNNTester): + tester.test() diff --git a/tests/ops/test_add.py b/tests/jax/ops/test_add.py similarity index 55% rename from tests/ops/test_add.py rename to tests/jax/ops/test_add.py index 425f24c..0e41353 100644 --- a/tests/ops/test_add.py +++ b/tests/jax/ops/test_add.py @@ -4,7 +4,7 @@ import jax import pytest -from infra.module_tester import ComparisonMetric, test_with_random_inputs +from infra import run_op_test_with_random_inputs def add(x: jax.Array, y: jax.Array) -> jax.Array: @@ -19,9 +19,4 @@ def add(x: jax.Array, y: jax.Array) -> jax.Array: ], ) def test_add(x_shape: tuple, y_shape: tuple): - assert test_with_random_inputs(add, [x_shape, y_shape], ComparisonMetric.PCC) - - -if __name__ == "__main__": - x_shape = y_shape = (32, 32) - assert test_with_random_inputs(add, [x_shape, y_shape], ComparisonMetric.PCC) + run_op_test_with_random_inputs(add, [x_shape, y_shape])