From 16d97ce8e8d45f0d04dace4b019e90aff3a48e91 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Thu, 28 Nov 2024 20:30:38 +0100 Subject: [PATCH] Improved test infrastructure. - Provided layer which uses jax to connect to various devices - Provided layer which uses jax to run payload on various devices - Provided infra which makes it easy to unit test either simple ops, or graphs or even full models Fixes #80, #93. --- requirements.txt | 1 + src/common/api_impl.cc | 3 +- src/common/api_impl.h | 13 +- tests/TTIR/test_device.py | 6 +- tests/conftest.py | 11 +- tests/infra/__init__.py | 10 + tests/infra/base_tester.py | 72 +++++++ tests/infra/comparison.py | 135 +++++++++++++ tests/infra/device_connector.py | 121 ++++++++++++ tests/infra/device_runner.py | 159 ++++++++++++++++ tests/infra/graph_tester.py | 48 +++++ tests/infra/model_tester.py | 177 ++++++++++++++++++ tests/infra/op_tester.py | 72 +++++++ tests/infra/utils.py | 107 +++++++++++ tests/jax/graphs/test_example_graph.py | 27 +++ tests/jax/models/example_model/__init__.py | 0 .../mixed_args_and_kwargs/__init__.py | 0 ...est_example_model_mixed_args_and_kwargs.py | 92 +++++++++ tests/jax/models/example_model/model.py | 33 ++++ .../example_model/only_args/__init__.py | 0 .../only_args/test_example_model_only_args.py | 85 +++++++++ .../example_model/only_kwargs/__init__.py | 0 .../test_example_model_only_kwargs.py | 85 +++++++++ .../test_flax_distil_bert_for_masked_lm.py | 75 ++++++++ tests/jax/ops/test_add.py | 22 +++ venv/activate | 15 +- 26 files changed, 1347 insertions(+), 22 deletions(-) create mode 100644 tests/infra/__init__.py create mode 100644 tests/infra/base_tester.py create mode 100644 tests/infra/comparison.py create mode 100644 tests/infra/device_connector.py create mode 100644 tests/infra/device_runner.py create mode 100644 tests/infra/graph_tester.py create mode 100644 tests/infra/model_tester.py create mode 100644 tests/infra/op_tester.py create mode 100644 tests/infra/utils.py create mode 100644 tests/jax/graphs/test_example_graph.py create mode 100644 tests/jax/models/example_model/__init__.py create mode 100644 tests/jax/models/example_model/mixed_args_and_kwargs/__init__.py create mode 100644 tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py create mode 100644 tests/jax/models/example_model/model.py create mode 100644 tests/jax/models/example_model/only_args/__init__.py create mode 100644 tests/jax/models/example_model/only_args/test_example_model_only_args.py create mode 100644 tests/jax/models/example_model/only_kwargs/__init__.py create mode 100644 tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py create mode 100644 tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py create mode 100644 tests/jax/ops/test_add.py diff --git a/requirements.txt b/requirements.txt index 6a4f8fa..38d4414 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ pre-commit lit pybind11 pytest +transformers diff --git a/src/common/api_impl.cc b/src/common/api_impl.cc index c10edc3..276f258 100644 --- a/src/common/api_impl.cc +++ b/src/common/api_impl.cc @@ -428,8 +428,7 @@ void DeviceDescription::BindApi(PJRT_Api *api) { api->PJRT_DeviceDescription_ToString = +[](PJRT_DeviceDescription_ToString_Args *args) -> PJRT_Error * { DLOG_F(LOG_DEBUG, "DeviceDescription::PJRT_DeviceDescription_ToString"); - auto sv = - DeviceDescription::Unwrap(args->device_description)->user_string(); + auto sv = DeviceDescription::Unwrap(args->device_description)->to_string(); args->to_string = sv.data(); args->to_string_size = sv.size(); return nullptr; diff --git a/src/common/api_impl.h b/src/common/api_impl.h index 3b05460..d212487 100644 --- a/src/common/api_impl.h +++ b/src/common/api_impl.h @@ -150,10 +150,11 @@ class DeviceDescription { } std::string_view kind_string() { return kind_string_; } - std::string_view debug_string() { return debug_string_; } - std::string_view user_string() { + std::string_view debug_string() { return to_string(); } + std::string_view to_string() { std::stringstream ss; - ss << "TTDevice(id=" << device_id() << ")"; + ss << kind_string_ << "(id=" << device_id() << ", arch=" << arch_string_ + << ")"; user_string_ = ss.str(); return user_string_; } @@ -166,8 +167,10 @@ class DeviceDescription { private: int client_id_; - std::string kind_string_ = "wormhole"; - std::string debug_string_ = "debug_string"; + // TODO We should understand better how these are used. + // See https://github.com/tenstorrent/tt-xla/issues/125 + std::string kind_string_ = "TTDevice"; + std::string arch_string_ = "Wormhole"; std::string user_string_ = ""; }; diff --git a/tests/TTIR/test_device.py b/tests/TTIR/test_device.py index 5921d25..72da8a9 100644 --- a/tests/TTIR/test_device.py +++ b/tests/TTIR/test_device.py @@ -2,10 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -import pytest import jax import jax.numpy as jnp - from infrastructure import random_input_tensor @@ -16,9 +14,9 @@ def test_num_devices(): def test_to_device(): cpu_array = random_input_tensor((32, 32)) - device = jax.devices()[0] + device = jax.devices("tt")[0] tt_array = jax.device_put(cpu_array, device) - assert tt_array.device.device_kind == "wormhole" + assert tt_array.device.device_kind == "TTDevice" def test_input_on_device(): diff --git a/tests/conftest.py b/tests/conftest.py index ae2fb1a..06c5e39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -import pytest import os + import jax import jax._src.xla_bridge as xb +import pytest +from infra.device_connector import DeviceConnector def initialize(): @@ -21,4 +23,9 @@ def initialize(): @pytest.fixture(scope="session", autouse=True) def setup_session(): - initialize() + # Added to prevent `PJRT_Api already exists for device type tt` error. + # Will be removed completely soon. + connector = DeviceConnector.get_instance() + + if not connector.is_initialized(): + initialize() diff --git a/tests/infra/__init__.py b/tests/infra/__init__.py new file mode 100644 index 0000000..126f127 --- /dev/null +++ b/tests/infra/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +# Exposes only what is really needed to write tests, nothing else. +from .comparison import ComparisonConfig +from .graph_tester import run_graph_test, run_graph_test_with_random_inputs +from .model_tester import ModelTester, RunMode +from .op_tester import run_op_test, run_op_test_with_random_inputs +from .utils import random_tensor diff --git a/tests/infra/base_tester.py b/tests/infra/base_tester.py new file mode 100644 index 0000000..cd4f39e --- /dev/null +++ b/tests/infra/base_tester.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC +from typing import Callable, Sequence + +import jax + +from .comparison import ( + ComparisonConfig, + compare_allclose, + compare_atol, + compare_equal, + compare_pcc, +) +from .device_runner import DeviceRunner +from .utils import Tensor + + +class BaseTester(ABC): + """ + Abstract base class all testers must inherit. + + Provides just a couple of common methods. + """ + + def __init__( + self, comparison_config: ComparisonConfig = ComparisonConfig() + ) -> None: + self._comparison_config = comparison_config + + @staticmethod + def _compile(executable: Callable) -> Callable: + """Sets up `executable` for just-in-time compile.""" + return jax.jit(executable) + + def _compare( + self, + device_out: Tensor, + golden_out: Tensor, + ) -> None: + device_output, golden_output = DeviceRunner.put_tensors_on_cpu( + device_out, golden_out + ) + device_output, golden_output = self._match_data_types( + device_output, golden_output + ) + + if self._comparison_config.equal.enabled: + compare_equal(device_output, golden_output) + if self._comparison_config.atol.enabled: + compare_atol(device_output, golden_output, self._comparison_config.atol) + if self._comparison_config.pcc.enabled: + compare_pcc(device_output, golden_output, self._comparison_config.pcc) + if self._comparison_config.allclose.enabled: + compare_allclose( + device_output, golden_output, self._comparison_config.allclose + ) + + def _match_data_types(self, *tensors: Tensor) -> Sequence[Tensor]: + """ + Casts all tensors to float32 if not already in that format. + + Tensors need to be in same data format in order to compare them. + """ + return [ + tensor.astype("float32") if tensor.dtype.str != "float32" else tensor + for tensor in tensors + ] diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py new file mode 100644 index 0000000..cb3c88f --- /dev/null +++ b/tests/infra/comparison.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp + +from .device_runner import run_on_cpu +from .utils import Tensor + + +@dataclass +class ConfigBase: + enabled: bool = True + + def enable(self) -> None: + self.enabled = True + + def disable(self) -> None: + self.enabled = False + + +@dataclass +class EqualConfig(ConfigBase): + pass + + +@dataclass +class AtolConfig(ConfigBase): + required_atol: float = 1e-1 + + +@dataclass +class PccConfig(ConfigBase): + required_pcc: float = 0.99 + + +@dataclass +class AllcloseConfig(ConfigBase): + rtol: float = 1e-2 + atol: float = 1e-2 + + +@dataclass +class ComparisonConfig: + equal: EqualConfig = EqualConfig(False) + atol: AtolConfig = AtolConfig() + pcc: PccConfig = PccConfig() + allclose: AllcloseConfig = AllcloseConfig() + + def enable_all(self) -> None: + self.equal.enable() + self.atol.enable() + self.allclose.enable() + self.pcc.enable() + + def disable_all(self) -> None: + self.equal.disable() + self.atol.disable() + self.allclose.disable() + self.pcc.disable() + + +# TODO functions below rely on jax functions, should be generalized for all supported +# frameworks in the future. + + +@run_on_cpu +def compare_equal(device_output: Tensor, golden_output: Tensor) -> None: + assert isinstance(device_output, jax.Array) and isinstance( + golden_output, jax.Array + ), f"Currently only jax.Array is supported" + + eq = (device_output == golden_output).all() + + assert eq, f"Equal comparison failed" + + +@run_on_cpu +def compare_atol( + device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig +) -> None: + assert isinstance(device_output, jax.Array) and isinstance( + golden_output, jax.Array + ), f"Currently only jax.Array is supported {type(device_output)}, {type(golden_output)}" + + atol = jnp.max(jnp.abs(device_output - golden_output)) + + assert ( + atol <= atol_config.required_atol + ), f"Atol comparison failed. Calculated atol={atol}" + + +@run_on_cpu +def compare_pcc( + device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig +) -> None: + assert isinstance(device_output, jax.Array) and isinstance( + golden_output, jax.Array + ), f"Currently only jax.Array is supported" + + # If tensors are really close, pcc will be nan. Handle that before calculating pcc. + try: + compare_allclose( + device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2) + ) + except AssertionError: + pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten()) + pcc = jnp.min(pcc) + + assert ( + pcc >= pcc_config.required_pcc + ), f"PCC comparison failed. Calculated pcc={pcc}" + + +@run_on_cpu +def compare_allclose( + device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig +) -> None: + assert isinstance(device_output, jax.Array) and isinstance( + golden_output, jax.Array + ), f"Currently only jax.Array is supported" + + allclose = jnp.allclose( + device_output, + golden_output, + rtol=allclose_config.rtol, + atol=allclose_config.atol, + ) + + assert allclose, f"Allclose comparison failed." diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py new file mode 100644 index 0000000..8dd581b --- /dev/null +++ b/tests/infra/device_connector.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from enum import Enum +from typing import Sequence + +import jax +import jax._src.xla_bridge as xb + +# Relative path to PJRT plugin for TT devices. +TT_PJRT_PLUGIN_RELPATH = "build/src/tt/pjrt_plugin_tt.so" + + +class DeviceType(Enum): + CPU = "cpu" + TT = "tt" + GPU = "gpu" + + +class DeviceConnector: + """ + Singleton class providing connections to devices on which jax commands will be + executed. + + As a singleton it is instantiated only once, thus making sure that PJRT plugin is + registered exactly once. Registration needs to happen before any other jax commands + are executed. Registering it multiple times would cause error. + + Do not instantiate this class directly. Use provided factory method instead. + + TODO (kmitrovic) see how to make this class a thread safe singleton if needed. + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + # Ensure that only one instance of the class is created. + if cls._instance is None: + cls._instance = super().__new__(cls, *args, **kwargs) + + return cls._instance + + def __init__(self) -> None: + """Don't use directly, use provided factory method instead.""" + # We need to ensure __init__ body is executed once. It will be called each time + # `DeviceConnector()` is called. + if self.is_initialized(): + return + + self._initialized = False + + plugin_path = os.path.join(os.getcwd(), TT_PJRT_PLUGIN_RELPATH) + if not os.path.exists(plugin_path): + raise FileNotFoundError( + f"Could not find tt_pjrt C API plugin at {plugin_path}" + ) + + self._plugin_path = plugin_path + self._initialize_backend() + + @staticmethod + def get_instance() -> DeviceConnector: + """ + Factory method returning singleton connector instance. + + Use this method instead of constructor. + """ + return DeviceConnector() + + def is_initialized(self) -> bool: + """Checks if connector is already initialized.""" + if hasattr(self, "_initialized") and self._initialized == True: + return True + + return False + + def connect_tt_device(self) -> jax.Device: + """Returns TTDevice handle.""" + return self.connect_device(DeviceType.TT) + + def connect_cpu(self) -> jax.Device: + """Returns CPUDevice handle.""" + return self.connect_device(DeviceType.CPU) + + def connect_gpu(self) -> jax.Device: + """Returns GPUDevice handle.""" + return self.connect_device(DeviceType.GPU) + + def connect_device(self, device_type: DeviceType) -> jax.Device: + """Returns handle for device identified by `device_type`.""" + return jax.devices(device_type.value)[0] + + def _supported_devices(self) -> Sequence[DeviceType]: + """Returns list of supported device types.""" + # TODO support GPU + return [DeviceType.CPU, DeviceType.TT] + + def _supported_devices_str(self) -> str: + """Returns comma separated list of supported devices as a string.""" + # Note no space, only comma. + return ",".join([device.value for device in self._supported_devices()]) + + def _initialize_backend(self) -> None: + """ + Registers TT plugin which will make TTDevice available in jax. + + Needs to be called before any other jax command. + """ + xb.register_plugin( + DeviceType.TT.value, + priority=500, + library_path=self._plugin_path, + options=None, + ) + jax.config.update("jax_platforms", self._supported_devices_str()) + + self._initialized = True diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py new file mode 100644 index 0000000..031bdd3 --- /dev/null +++ b/tests/infra/device_runner.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Sequence + +import jax + +from .device_connector import DeviceConnector, DeviceType +from .utils import Tensor, Workload + + +class DeviceRunner: + """ + Class providing methods to put and run workload on any supported device. + """ + + @staticmethod + def run_on_tt_device(workload: Workload) -> Tensor: + """Runs `workload` on TT device.""" + return DeviceRunner._run_on_device(DeviceType.TT, workload) + + @staticmethod + def run_on_cpu(workload: Workload) -> Tensor: + """Runs `workload` on CPU.""" + return DeviceRunner._run_on_device(DeviceType.CPU, workload) + + @staticmethod + def run_on_gpu(workload: Workload) -> Tensor: + """Runs `workload` on GPU.""" + raise NotImplementedError("Support for GPUs not implemented") + + @staticmethod + def put_on_tt_device(workload: Workload) -> Workload: + """Puts `workload` on TT device.""" + return DeviceRunner._put_on_device(DeviceType.TT, workload) + + @staticmethod + def put_on_cpu(workload: Workload) -> Workload: + """Puts `workload` on CPU.""" + return DeviceRunner._put_on_device(DeviceType.CPU, workload) + + @staticmethod + def put_on_gpu(workload: Workload) -> Workload: + """Puts `workload` on GPU.""" + raise NotImplementedError("Support for GPUs not implemented") + + @staticmethod + def put_tensors_on_tt_device(*tensors: Tensor) -> Sequence[Tensor]: + """Puts `tensors` on TT device.""" + return DeviceRunner._put_tensors_on_device(DeviceType.TT, tensors) + + @staticmethod + def put_tensors_on_cpu(*tensors: Tensor) -> Sequence[Tensor]: + """Puts `tensors` on CPU.""" + return DeviceRunner._put_tensors_on_device(DeviceType.CPU, tensors) + + @staticmethod + def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: + """Puts `tensors` on GPU.""" + raise NotImplementedError("Support for GPUs not implemented") + + @staticmethod + def _run_on_device(device_type: DeviceType, workload: Workload) -> Tensor: + """Runs `workload` on device identified by `device_type`.""" + device_workload = DeviceRunner._put_on_device(device_type, workload) + + connector = DeviceConnector().get_instance() + device = connector.connect_device(device_type) + + with jax.default_device(device): + return device_workload.execute() + + @staticmethod + def _put_on_device(device_type: DeviceType, workload: Workload) -> Workload: + """Puts `workload` on device and returns it.""" + connector = DeviceConnector().get_instance() + device = connector.connect_device(device_type) + return DeviceRunner._safely_put_workload_on_device(workload, device) + + @staticmethod + def _put_tensors_on_device( + device_type: DeviceType, tensors: Sequence[Tensor] + ) -> Sequence[Tensor]: + """Puts `tensors` on device identified by `device_type`.""" + connector = DeviceConnector().get_instance() + device = connector.connect_device(device_type) + return [jax.device_put(t, device) for t in tensors] + + @staticmethod + def _safely_put_workload_on_device( + workload: Workload, device: jax.Device + ) -> Workload: + """ + Puts workload's args and kwargs on device only if `jax.device_put` supports it + and returns new workload which is "on device". + + `jax.device_put` by docs accepts + ``An array, scalar, or (nested) standard Python container thereof`` + which is too vague and not easy to check. In best case, has to be done + recursively. + + To avoid that, we try to `jax.device_put` arg or kwarg, and if it doesn't + succeed, we leave it as is. + """ + args_on_device = [] + + for arg in workload.args: + try: + arg_on_device = jax.device_put(arg, device) + except: + arg_on_device = arg + + args_on_device.append(arg_on_device) + + kwargs_on_device = {} + + for key, value in workload.kwargs.items(): + try: + value_on_device = jax.device_put(value, device) + except: + value_on_device = value + + kwargs_on_device[key] = value_on_device + + return Workload(workload.executable, args_on_device, kwargs_on_device) + + +# --------------- Convenience decorators --------------- + + +def run_on_tt_device(f: Callable): + """Runs any decorated function `f` on TT device.""" + + def wrapper(*args, **kwargs): + workload = Workload(f, args, kwargs) + return DeviceRunner.run_on_tt_device(workload) + + return wrapper + + +def run_on_cpu(f: Callable): + """Runs any decorated function `f` on CPU.""" + + def wrapper(*args, **kwargs): + workload = Workload(f, args, kwargs) + return DeviceRunner.run_on_cpu(workload) + + return wrapper + + +def run_on_gpu(f: Callable): + """Runs any decorated function `f` on GPU.""" + + def wrapper(*args, **kwargs): + workload = Workload(f, args, kwargs) + return DeviceRunner.run_on_gpu(workload) + + return wrapper diff --git a/tests/infra/graph_tester.py b/tests/infra/graph_tester.py new file mode 100644 index 0000000..4f63700 --- /dev/null +++ b/tests/infra/graph_tester.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Callable, Sequence + +from .comparison import ComparisonConfig +from .op_tester import OpTester +from .utils import Tensor, Workload + + +class GraphTester(OpTester): + """ + Specific tester for graphs. + + Currently same as OpTester. + """ + + pass + + +def run_graph_test( + graph: Callable, + inputs: Sequence[Tensor], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `graph` with `inputs` by running it on TT device and CPU and comparing the + results based on `comparison_config`. + """ + tester = GraphTester(comparison_config) + workload = Workload(graph, inputs) + tester.test(workload) + + +def run_graph_test_with_random_inputs( + graph: Callable, + input_shapes: Sequence[tuple], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `graph` with random inputs by running it on TT device and CPU and comparing + the results based on `comparison_config`. + """ + tester = GraphTester(comparison_config) + tester.test_with_random_inputs(graph, input_shapes) diff --git a/tests/infra/model_tester.py b/tests/infra/model_tester.py new file mode 100644 index 0000000..7faf601 --- /dev/null +++ b/tests/infra/model_tester.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Callable, Mapping, Sequence + +from flax import linen, nnx +from transformers.modeling_flax_utils import FlaxPreTrainedModel + +from .base_tester import BaseTester +from .comparison import ComparisonConfig +from .device_runner import DeviceRunner +from .utils import Model, Workload + + +class RunMode(Enum): + INFERENCE = "inference" + TRAINING = "training" + + +class ModelTester(BaseTester, ABC): + """ + Abstract base class all model testers must inherit. + + Derived classes must provide implementations of: + ``` + _get_model() -> Model + _get_input_activations() -> Sequence[Any] + _get_forward_method_name() -> str # Optional, has default behaviour. + # One of or both: + _get_forward_method_args(self) -> Sequence[Any] # Optional, has default behaviour. + _get_forward_method_kwargs(self) -> Mapping[str, Any] # Optional, has default behaviour. + ``` + """ + + def __init__( + self, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + super().__init__(comparison_config) + + self._run_mode = run_mode + + self._init_model_hooks() + + def _init_model_hooks(self) -> None: + """ + Extracted init method which handles validation of provided interface methods + subclasses must implement and storing of some useful return values. + """ + # Store model instance. + self._model = self._get_model() + + args = self._get_forward_method_args() + kwargs = self._get_forward_method_kwargs() + + if len(args) == 0 and len(kwargs) == 0: + raise ValueError(f"Forward method args or kwargs or both must be provided") + + forward_method_name = self._get_forward_method_name() + + if not hasattr(self._model, forward_method_name): + raise ValueError( + f"Model does not have {forward_method_name} method provided." + ) + + forward_pass_method = getattr(self._model, forward_method_name) + + # Store model's forward pass method and its arguments as a workload. + self._workload = Workload(forward_pass_method, args, kwargs) + + @staticmethod + @abstractmethod + def _get_model() -> Model: + """Returns model instance.""" + raise NotImplementedError("Subclasses must implement this method.") + + @staticmethod + @abstractmethod + def _get_input_activations() -> Sequence[Any]: + """Returns input activations.""" + raise NotImplementedError("Subclasses must implement this method.") + + @staticmethod + def _get_forward_method_name() -> str: + """ + Returns string name of model's forward pass method. + + Returns "__call__" by default which is the most common one. "forward" and + "apply" are also common. + """ + return "__call__" + + def _get_forward_method_args(self) -> Sequence[Any]: + """ + Returns positional arguments for model's forward pass. + + By default returns empty list. + + `self` is provided for convenience, for example if some model attribute needs + to be fetched. + """ + return [] + + def _get_forward_method_kwargs(self) -> Mapping[str, Any]: + """ + Returns keyword arguments for model's forward pass. + + By default returns empty dict. + + `self` is provided for convenience, for example if some model attribute needs + to be fetched. + """ + return {} + + def test(self) -> None: + """Tests the model depending on test type with which tester was configured.""" + if self._run_mode == RunMode.INFERENCE: + self._test_inference() + else: + self._test_training() + + def _test_inference(self) -> None: + """ + Tests the model by running inference on TT device and on CPU and comparing the + results. + """ + ModelTester._configure_model_for_inference(self._model) + + compiled_forward_method = self._compile_model() + + compiled_workload = Workload( + compiled_forward_method, self._workload.args, self._workload.kwargs + ) + + tt_res = DeviceRunner.run_on_tt_device(compiled_workload) + cpu_res = DeviceRunner.run_on_cpu(compiled_workload) + + self._compare(tt_res, cpu_res) + + def _test_training(self): + """TODO""" + # self._configure_model_for_training(model) + raise NotImplementedError("Support for training not implemented") + + @staticmethod + def _configure_model_for_inference(model: Model) -> None: + """Configures model for inference.""" + if isinstance(model, nnx.Module): + model.eval() + elif isinstance(model, (linen.Module, FlaxPreTrainedModel)): + # TODO find another way to do this since model.eval() does not exist, maybe + # by passing train param as kwarg to __call__. + pass + else: + raise TypeError(f"Uknown model type: {type(model)}") + + @staticmethod + def _configure_model_for_training(model: Model) -> None: + """Configures model for training.""" + if isinstance(model, nnx.Module): + model.train() + elif isinstance(model, (linen.Module, FlaxPreTrainedModel)): + # TODO find another way to do this since model.train() does not exist, maybe + # by passing train param as kwarg to __call__. + pass + else: + raise TypeError(f"Uknown model type: {type(model)}") + + def _compile_model(self) -> Callable: + """JIT-compiles model's forward pass into optimized kernels.""" + return super()._compile(self._workload.executable) diff --git a/tests/infra/op_tester.py b/tests/infra/op_tester.py new file mode 100644 index 0000000..7122d17 --- /dev/null +++ b/tests/infra/op_tester.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Callable, Sequence + +from .base_tester import BaseTester +from .comparison import ComparisonConfig +from .device_runner import DeviceRunner +from .utils import Tensor, Workload, random_tensor + + +class OpTester(BaseTester): + """Specific tester for ops.""" + + def __init__( + self, comparison_config: ComparisonConfig = ComparisonConfig() + ) -> None: + super().__init__(comparison_config) + + def test(self, workload: Workload) -> None: + """ + Runs test by running `workload` on TT device and CPU and comparing the results. + """ + compiled_executable = self._compile(workload.executable) + + compiled_workload = Workload( + compiled_executable, workload.args, workload.kwargs + ) + tt_res = DeviceRunner.run_on_tt_device(compiled_workload) + cpu_res = DeviceRunner.run_on_cpu(compiled_workload) + + self._compare(tt_res, cpu_res) + + def test_with_random_inputs( + self, f: Callable, input_shapes: Sequence[tuple] + ) -> None: + """ + Tests `f` by running it with random inputs on TT device and CPU and comparing + the results. + """ + workload = Workload(f, [random_tensor(shape) for shape in input_shapes]) + self.test(workload) + + +def run_op_test( + op: Callable, + inputs: Sequence[Tensor], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `op` with `inputs` by running it on TT device and CPU and comparing the + results based on `comparison_config`. + """ + tester = OpTester(comparison_config) + workload = Workload(op, inputs) + tester.test(workload) + + +def run_op_test_with_random_inputs( + op: Callable, + input_shapes: Sequence[tuple], + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `op` with random inputs by running it on TT device and CPU and comparing the + results based on `comparison_config`. + """ + tester = OpTester(comparison_config) + tester.test_with_random_inputs(op, input_shapes) diff --git a/tests/infra/utils.py b/tests/infra/utils.py new file mode 100644 index 0000000..0cfb44e --- /dev/null +++ b/tests/infra/utils.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Mapping, Optional, Sequence, Union + +import jax +import jax.numpy as jnp +from flax import linen, nnx +from jax import export + + +@dataclass +class Workload: + executable: Callable + args: Sequence[Any] + kwargs: Optional[Mapping[str, Any]] = None + + def __post_init__(self): + # If kwargs is None, initialize it to an empty dictionary. + if self.kwargs is None: + self.kwargs = {} + + def execute(self) -> Any: + return self.executable(*self.args, **self.kwargs) + + +class Framework(Enum): + JAX = "jax" + TORCH = "torch" + NUMPY = "numpy" + + +# Convenience alias. Could be used to represent jax.Array, torch.Tensor, np.ndarray, etc. +Tensor = Union[jax.Array] + +# Convenience alias. Could be used to represent nnx.Module, torch.nn.Module, etc. +# NOTE nnx.Module is the newest API, linen.Module is legacy but it is used in some +# huggingface models. +Model = Union[nnx.Module, linen.Module] + + +def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX): + """Convert a string dtype to the corresponding framework-specific dtype.""" + if framework == Framework.JAX: + return jnp.dtype(dtype_str) + else: + raise ValueError(f"Unsupported framework: {framework.value}.") + + +def random_tensor( + shape: tuple, + dtype: str = "float32", + random_seed: int = 0, + minval: float = 0.0, + maxval: float = 1.0, + framework: Framework = Framework.JAX, +) -> Tensor: + """ + Generates a random tensor of `shape`, `dtype`, and `random_seed` in range + [`minval`, `maxval`) for the desired `framework`. + """ + # Convert dtype string to actual dtype for the selected framework. + dtype_converted = _str_to_dtype(dtype, framework) + + # Generate random tensor based on framework type + if framework == Framework.JAX: + prng_key = jax.random.PRNGKey(random_seed) + + return jax.random.uniform( + key=prng_key, + shape=shape, + dtype=dtype_converted, + minval=minval, + maxval=maxval, + ) + else: + raise ValueError(f"Unsupported framework: {framework.value}.") + + +def workload_as_mlir_module( + workload: Workload, framework: Framework = Framework.JAX +) -> str: + """ + Returns workload as mlir module string. + + Note that in case of jax, workload.executable must be the result of jit, otherwise + empty string will be returned. + """ + + if framework == Framework.JAX: + try: + s = export.export(workload.executable)( + *workload.args, **workload.kwargs + ).mlir_module() + + # Remove all lines that start with "#loc" for cleaner output. + return "\n".join( + line for line in s.splitlines() if not line.startswith("#loc") + ) + + except ValueError: + return "" + else: + raise ValueError(f"Unsupported framework: {framework.value}.") diff --git a/tests/jax/graphs/test_example_graph.py b/tests/jax/graphs/test_example_graph.py new file mode 100644 index 0000000..b76d897 --- /dev/null +++ b/tests/jax/graphs/test_example_graph.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import jax +import pytest +from infra import run_graph_test_with_random_inputs +from jax import numpy as jnp + + +def example_graph(x: jax.Array, y: jax.Array) -> jax.Array: + a = jnp.abs(x) + b = jnp.add(a, y) + c = jnp.divide(a, b) + return jnp.exp(c) + + +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(3, 3), (3, 3)], + [(32, 32), (32, 32)], + [(64, 64), (64, 64)], + ], +) +def test_example_graph(x_shape: tuple, y_shape: tuple): + run_graph_test_with_random_inputs(example_graph, [x_shape, y_shape]) diff --git a/tests/jax/models/example_model/__init__.py b/tests/jax/models/example_model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/jax/models/example_model/mixed_args_and_kwargs/__init__.py b/tests/jax/models/example_model/mixed_args_and_kwargs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py new file mode 100644 index 0000000..3f4c0fb --- /dev/null +++ b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +import pytest +from flax import nnx +from infra import ComparisonConfig, ModelTester, RunMode + +from ..model import ExampleModel + +# ----- Tester ----- + + +class ExampleModelMixedArgsAndKwargsTester(ModelTester): + """ + Example tester showcasing how to use both positional and keyword arguments for + model's forward method. + + This is a completely artificial example. In most cases only one of + {`_get_forward_method_args`, `_get_forward_method_kwargs`} will suffice. + """ + + # @override + @staticmethod + def _get_model() -> nnx.Module: + return ExampleModel() + + # @override + @staticmethod + def _get_input_activations() -> Sequence[jax.Array]: + act_shape = (32, 784) + act = jax.numpy.ones(act_shape) + return [act] + + # @override + @staticmethod + def _get_forward_method_name() -> str: + return "__call__" + + # @override + def _get_forward_method_args(self) -> Sequence[jax.Array]: + """Returns just input activations as positional arg.""" + input_activations = self._get_input_activations() + assert len(input_activations) == 1 + input_activation = input_activations[0] + return [input_activation] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + """Returns weights and biases as keyword args.""" + assert hasattr(self._model, "w0") + assert hasattr(self._model, "w1") + assert hasattr(self._model, "b0") + assert hasattr(self._model, "b1") + + w0 = self._model.w0 + w1 = self._model.w1 + b0 = self._model.b0 + b1 = self._model.b1 + + # Order does not matter. + return {"b1": b1, "w1": w1, "w0": w0, "b0": b0} + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> ExampleModelMixedArgsAndKwargsTester: + return ExampleModelMixedArgsAndKwargsTester() + + +@pytest.fixture +def training_tester() -> ExampleModelMixedArgsAndKwargsTester: + return ExampleModelMixedArgsAndKwargsTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +def test_example_model_inference( + inference_tester: ExampleModelMixedArgsAndKwargsTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_example_model_training(training_tester: ExampleModelMixedArgsAndKwargsTester): + training_tester.test() diff --git a/tests/jax/models/example_model/model.py b/tests/jax/models/example_model/model.py new file mode 100644 index 0000000..f3a4f33 --- /dev/null +++ b/tests/jax/models/example_model/model.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import jax +import jax.numpy as jnp +from flax import nnx +from infra import random_tensor + + +class ExampleModel(nnx.Module): + def __init__(self) -> None: + w0_shape, w1_shape, b0_shape, b1_shape = ( + (784, 128), + (128, 128), + (1, 128), + (1, 128), + ) + + self.w0 = random_tensor(w0_shape, minval=-0.01, maxval=0.01) + self.w1 = random_tensor(w1_shape, minval=-0.01, maxval=0.01) + self.b0 = random_tensor(b0_shape, minval=-0.01, maxval=0.01) + self.b1 = random_tensor(b1_shape, minval=-0.01, maxval=0.01) + + def __call__( + self, act: jax.Array, w0: jax.Array, b0: jax.Array, w1: jax.Array, b1: jax.Array + ) -> jax.Array: + # Note how activations, weights and biases are directly passed to the forward + # method as inputs, `self` is not accessed. Otherwise they would be embedded + # into jitted graph as constants. + x = jnp.matmul(act, w0) + b0 + x = jnp.matmul(x, w1) + b1 + return x diff --git a/tests/jax/models/example_model/only_args/__init__.py b/tests/jax/models/example_model/only_args/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/jax/models/example_model/only_args/test_example_model_only_args.py b/tests/jax/models/example_model/only_args/test_example_model_only_args.py new file mode 100644 index 0000000..4beed15 --- /dev/null +++ b/tests/jax/models/example_model/only_args/test_example_model_only_args.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Sequence + +import jax +import pytest +from flax import nnx +from infra import ComparisonConfig, ModelTester, RunMode + +from ..model import ExampleModel + +# ----- Tester ----- + + +class ExampleModelOnlyArgsTester(ModelTester): + """ + Example tester showcasing how to use only positional arguments for model's forward + method. + """ + + # @override + @staticmethod + def _get_model() -> nnx.Module: + return ExampleModel() + + # @override + @staticmethod + def _get_input_activations() -> Sequence[jax.Array]: + act_shape = (32, 784) + act = jax.numpy.ones(act_shape) + return [act] + + # @override + @staticmethod + def _get_forward_method_name() -> str: + return "__call__" + + # @override + def _get_forward_method_args(self) -> Sequence[jax.Array]: + # Use stored `self._model` to fetch model attributes. + # Asserts are just sanity checks, no need to use them every time. + assert hasattr(self._model, "w0") + assert hasattr(self._model, "w1") + assert hasattr(self._model, "b0") + assert hasattr(self._model, "b1") + + w0 = self._model.w0 + w1 = self._model.w1 + b0 = self._model.b0 + b1 = self._model.b1 + + # Fetch activations. + input_activations = self._get_input_activations() + assert len(input_activations) == 1 + input_activation = input_activations[0] + + # Mix activations, weights and biases to match forward method signature. + return [input_activation, w0, b0, w1, b1] + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> ExampleModelOnlyArgsTester: + return ExampleModelOnlyArgsTester() + + +@pytest.fixture +def training_tester() -> ExampleModelOnlyArgsTester: + return ExampleModelOnlyArgsTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +def test_example_model_inference(inference_tester: ExampleModelOnlyArgsTester): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_example_model_training(training_tester: ExampleModelOnlyArgsTester): + training_tester.test() diff --git a/tests/jax/models/example_model/only_kwargs/__init__.py b/tests/jax/models/example_model/only_kwargs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py new file mode 100644 index 0000000..99f2916 --- /dev/null +++ b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +import pytest +from flax import nnx +from infra import ComparisonConfig, ModelTester, RunMode + +from ..model import ExampleModel + +# ----- Tester ----- + + +class ExampleModelOnlyKwargsTester(ModelTester): + """ + Example tester showcasing how to use only keyword arguments for model's forward + method. + """ + + # @override + @staticmethod + def _get_model() -> nnx.Module: + return ExampleModel() + + # @override + @staticmethod + def _get_input_activations() -> Sequence[jax.Array]: + act_shape = (32, 784) + act = jax.numpy.ones(act_shape) + return [act] + + # @override + @staticmethod + def _get_forward_method_name() -> str: + return "__call__" + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + # Use stored `self._model` to fetch model attributes. + # Asserts are just sanity checks, no need to use them every time. + assert hasattr(self._model, "w0") + assert hasattr(self._model, "w1") + assert hasattr(self._model, "b0") + assert hasattr(self._model, "b1") + + w0 = self._model.w0 + w1 = self._model.w1 + b0 = self._model.b0 + b1 = self._model.b1 + + # Fetch activations. + input_activations = self._get_input_activations() + assert len(input_activations) == 1 + input_activation = input_activations[0] + + # Mix activations, weights and biases to match forward method signature. + return {"act": input_activation, "w0": w0, "b0": b0, "w1": w1, "b1": b1} + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> ExampleModelOnlyKwargsTester: + return ExampleModelOnlyKwargsTester() + + +@pytest.fixture +def training_tester() -> ExampleModelOnlyKwargsTester: + return ExampleModelOnlyKwargsTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +def test_example_model_inference(inference_tester: ExampleModelOnlyKwargsTester): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_example_model_training(training_tester: ExampleModelOnlyKwargsTester): + training_tester.test() diff --git a/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py b/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py new file mode 100644 index 0000000..3b5ad6b --- /dev/null +++ b/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +import pytest +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from transformers import AutoTokenizer, FlaxDistilBertForMaskedLM + +MODEL = "distilbert/distilbert-base-uncased" + +# ----- Tester ----- + + +class FlaxDistilBertForMaskedLMTester(ModelTester): + """Tester for DistilBert model with a `language modeling` head on top.""" + + # @override + @staticmethod + def _get_model() -> nn.Module: + return FlaxDistilBertForMaskedLM.from_pretrained(MODEL) + + # @override + @staticmethod + def _get_forward_method_name() -> str: + return "__call__" + + # @override + @staticmethod + def _get_input_activations() -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(MODEL) + inputs = tokenizer("Hello [MASK].", return_tensors="np") + return [inputs["input_ids"]] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + input_activations = self._get_input_activations() + + assert len(input_activations) == 1 + assert hasattr(self._model, "params") + + return {"input_ids": input_activations[0], "params": self._model.params} + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxDistilBertForMaskedLMTester: + return FlaxDistilBertForMaskedLMTester() + + +@pytest.fixture +def training_tester() -> FlaxDistilBertForMaskedLMTester: + return FlaxDistilBertForMaskedLMTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +def test_flax_distil_bert_for_masked_lm_inference( + inference_tester: FlaxDistilBertForMaskedLMTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_distil_bert_for_masked_lm_training( + training_tester: FlaxDistilBertForMaskedLMTester, +): + training_tester.test() diff --git a/tests/jax/ops/test_add.py b/tests/jax/ops/test_add.py new file mode 100644 index 0000000..0e41353 --- /dev/null +++ b/tests/jax/ops/test_add.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import jax +import pytest +from infra import run_op_test_with_random_inputs + + +def add(x: jax.Array, y: jax.Array) -> jax.Array: + return jax.numpy.add(x, y) + + +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(32, 32), (32, 32)], + [(64, 64), (64, 64)], + ], +) +def test_add(x_shape: tuple, y_shape: tuple): + run_op_test_with_random_inputs(add, [x_shape, y_shape]) diff --git a/venv/activate b/venv/activate index 81e6e1b..3641d1a 100644 --- a/venv/activate +++ b/venv/activate @@ -18,17 +18,14 @@ else pip install --upgrade pip pip install -r requirements.txt fi + export TTXLA_ENV_ACTIVATED=1 export TTMLIR_ENV_ACTIVATED=1 - export PATH=$TTMLIR_TOOLCHAIN_DIR/bin:$PATH - if [ -n "$PROJECT_ROOT" ]; then - export TT_METAL_HOME="$PROJECT_ROOT/third_party/tt-mlir/src/tt-mlir/third_party/tt-metal/src/tt-metal" - else - export TT_METAL_HOME="$(pwd)/third_party/tt-mlir/src/tt-mlir/third_party/tt-metal/src/tt-metal" - fi - export TT_MLIR_HOME="$(pwd)" - export PYTHONPATH="$(pwd)/build/python_packages:$(pwd)/.local/toolchain/python_packages/mlir_core:${TT_METAL_HOME}:${TT_METAL_HOME}/tt_eager:${TT_METAL_BUILD_HOME}/tools/profiler/bin" - export ARCH_NAME="${ARCH_NAME:-wormhole_b0}" export TT_METAL_LOGGER_LEVEL="ERROR" + export ARCH_NAME="${ARCH_NAME:-wormhole_b0}" + export PATH=$TTMLIR_TOOLCHAIN_DIR/bin:$PATH + export PYTHONPATH="$(pwd):$(pwd)/tests" + export TT_MLIR_HOME="$(pwd)/third_party/tt-mlir/src/tt-mlir/" + export TT_METAL_HOME="$TT_MLIR_HOME/third_party/tt-metal/src/tt-metal" fi