From bfae725dfd0253595e487dd914f0be80b0b11090 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Thu, 19 Dec 2024 13:33:24 +0000 Subject: [PATCH] Code review changes --- src/common/api_impl.h | 3 +- tests/TTIR/test_device.py | 6 +- tests/infra/__init__.py | 1 + tests/infra/comparison.py | 5 ++ tests/infra/device_runner.py | 21 ++++++- tests/infra/model_tester.py | 2 +- tests/infra/utils.py | 59 ++++++++++++------- ...rary_op_chain.py => test_example_graph.py} | 6 +- ...st_example_model_mixed_args_and_kwargs.py} | 31 ++++------ tests/jax/models/example_model/model.py | 12 ++-- .../only_args/test_example_model_only_args.py | 29 ++++----- .../test_example_model_only_kwargs.py | 29 ++++----- .../test_flax_distil_bert_for_masked_lm.py | 26 +++----- 13 files changed, 122 insertions(+), 108 deletions(-) rename tests/jax/graphs/{test_arbitrary_op_chain.py => test_example_graph.py} (68%) rename tests/jax/models/example_model/mixed_args_and_kwargs/{test_example_model_args_and_kwargs.py => test_example_model_mixed_args_and_kwargs.py} (71%) diff --git a/src/common/api_impl.h b/src/common/api_impl.h index 873744e..73de56b 100644 --- a/src/common/api_impl.h +++ b/src/common/api_impl.h @@ -161,8 +161,9 @@ class DeviceDescription { private: int client_id_; + // TODO We should understand better how these are used. + // See https://github.com/tenstorrent/tt-xla/issues/125 std::string kind_string_ = "TTDevice"; - // TODO should not be hardcoded std::string arch_string_ = "Wormhole"; std::string user_string_ = ""; }; diff --git a/tests/TTIR/test_device.py b/tests/TTIR/test_device.py index 5921d25..72da8a9 100644 --- a/tests/TTIR/test_device.py +++ b/tests/TTIR/test_device.py @@ -2,10 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -import pytest import jax import jax.numpy as jnp - from infrastructure import random_input_tensor @@ -16,9 +14,9 @@ def test_num_devices(): def test_to_device(): cpu_array = random_input_tensor((32, 32)) - device = jax.devices()[0] + device = jax.devices("tt")[0] tt_array = jax.device_put(cpu_array, device) - assert tt_array.device.device_kind == "wormhole" + assert tt_array.device.device_kind == "TTDevice" def test_input_on_device(): diff --git a/tests/infra/__init__.py b/tests/infra/__init__.py index e1cf44a..126f127 100644 --- a/tests/infra/__init__.py +++ b/tests/infra/__init__.py @@ -7,3 +7,4 @@ from .graph_tester import run_graph_test, run_graph_test_with_random_inputs from .model_tester import ModelTester, RunMode from .op_tester import run_op_test, run_op_test_with_random_inputs +from .utils import random_tensor diff --git a/tests/infra/comparison.py b/tests/infra/comparison.py index e412828..cb3c88f 100644 --- a/tests/infra/comparison.py +++ b/tests/infra/comparison.py @@ -9,6 +9,7 @@ import jax import jax.numpy as jnp +from .device_runner import run_on_cpu from .utils import Tensor @@ -68,6 +69,7 @@ def disable_all(self) -> None: # frameworks in the future. +@run_on_cpu def compare_equal(device_output: Tensor, golden_output: Tensor) -> None: assert isinstance(device_output, jax.Array) and isinstance( golden_output, jax.Array @@ -78,6 +80,7 @@ def compare_equal(device_output: Tensor, golden_output: Tensor) -> None: assert eq, f"Equal comparison failed" +@run_on_cpu def compare_atol( device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig ) -> None: @@ -92,6 +95,7 @@ def compare_atol( ), f"Atol comparison failed. Calculated atol={atol}" +@run_on_cpu def compare_pcc( device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig ) -> None: @@ -113,6 +117,7 @@ def compare_pcc( ), f"PCC comparison failed. Calculated pcc={pcc}" +@run_on_cpu def compare_allclose( device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig ) -> None: diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 5d77952..45328e3 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Sequence +from typing import Callable, Sequence import jax @@ -124,3 +124,22 @@ def _safely_put_workload_on_device( kwargs_on_device[key] = value_on_device return Workload(workload.executable, args_on_device, kwargs_on_device) + + +# --------------- Convenience decorators --------------- + + +def run_on_cpu(f: Callable): + def wrapper(*args, **kwargs): + workload = Workload(f, args, kwargs) + return DeviceRunner.run_on_cpu(workload) + + return wrapper + + +def run_on_tt_device(f: Callable): + def wrapper(*args, **kwargs): + workload = Workload(f, args, kwargs) + return DeviceRunner.run_on_tt_device(workload) + + return wrapper diff --git a/tests/infra/model_tester.py b/tests/infra/model_tester.py index 9f4c20e..7faf601 100644 --- a/tests/infra/model_tester.py +++ b/tests/infra/model_tester.py @@ -39,7 +39,7 @@ class ModelTester(BaseTester, ABC): def __init__( self, - comparison_config: ComparisonConfig, + comparison_config: ComparisonConfig = ComparisonConfig(), run_mode: RunMode = RunMode.INFERENCE, ) -> None: super().__init__(comparison_config) diff --git a/tests/infra/utils.py b/tests/infra/utils.py index 0f9b8c8..0cfb44e 100644 --- a/tests/infra/utils.py +++ b/tests/infra/utils.py @@ -26,23 +26,6 @@ def __post_init__(self): def execute(self) -> Any: return self.executable(*self.args, **self.kwargs) - def as_mlir_module(self) -> str: - """ - Returns workload as mlir module string. - - Note that workload.executable must be the result of jit, otherwise empty string - will be returned. - """ - try: - s = export.export(self.executable)(*self.args, **self.kwargs).mlir_module() - # Remove all lines that start with "#loc" for cleaner output. - return "\n".join( - line for line in s.splitlines() if not line.startswith("#loc") - ) - - except ValueError: - return "" - class Framework(Enum): JAX = "jax" @@ -71,11 +54,13 @@ def random_tensor( shape: tuple, dtype: str = "float32", random_seed: int = 0, + minval: float = 0.0, + maxval: float = 1.0, framework: Framework = Framework.JAX, ) -> Tensor: """ - Generates a random tensor of `shape`, `dtype`, and `random_seed` for the desired - `framework`. + Generates a random tensor of `shape`, `dtype`, and `random_seed` in range + [`minval`, `maxval`) for the desired `framework`. """ # Convert dtype string to actual dtype for the selected framework. dtype_converted = _str_to_dtype(dtype, framework) @@ -83,6 +68,40 @@ def random_tensor( # Generate random tensor based on framework type if framework == Framework.JAX: prng_key = jax.random.PRNGKey(random_seed) - return jax.random.uniform(key=prng_key, shape=shape, dtype=dtype_converted) + + return jax.random.uniform( + key=prng_key, + shape=shape, + dtype=dtype_converted, + minval=minval, + maxval=maxval, + ) + else: + raise ValueError(f"Unsupported framework: {framework.value}.") + + +def workload_as_mlir_module( + workload: Workload, framework: Framework = Framework.JAX +) -> str: + """ + Returns workload as mlir module string. + + Note that in case of jax, workload.executable must be the result of jit, otherwise + empty string will be returned. + """ + + if framework == Framework.JAX: + try: + s = export.export(workload.executable)( + *workload.args, **workload.kwargs + ).mlir_module() + + # Remove all lines that start with "#loc" for cleaner output. + return "\n".join( + line for line in s.splitlines() if not line.startswith("#loc") + ) + + except ValueError: + return "" else: raise ValueError(f"Unsupported framework: {framework.value}.") diff --git a/tests/jax/graphs/test_arbitrary_op_chain.py b/tests/jax/graphs/test_example_graph.py similarity index 68% rename from tests/jax/graphs/test_arbitrary_op_chain.py rename to tests/jax/graphs/test_example_graph.py index 317244e..b76d897 100644 --- a/tests/jax/graphs/test_arbitrary_op_chain.py +++ b/tests/jax/graphs/test_example_graph.py @@ -8,7 +8,7 @@ from jax import numpy as jnp -def arbitrary_op_chain(x: jax.Array, y: jax.Array) -> jax.Array: +def example_graph(x: jax.Array, y: jax.Array) -> jax.Array: a = jnp.abs(x) b = jnp.add(a, y) c = jnp.divide(a, b) @@ -23,5 +23,5 @@ def arbitrary_op_chain(x: jax.Array, y: jax.Array) -> jax.Array: [(64, 64), (64, 64)], ], ) -def test_arbitrary_op_chain(x_shape: tuple, y_shape: tuple): - run_graph_test_with_random_inputs(arbitrary_op_chain, [x_shape, y_shape]) +def test_example_graph(x_shape: tuple, y_shape: tuple): + run_graph_test_with_random_inputs(example_graph, [x_shape, y_shape]) diff --git a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_args_and_kwargs.py b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py similarity index 71% rename from tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_args_and_kwargs.py rename to tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py index 22d7893..3f4c0fb 100644 --- a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_args_and_kwargs.py +++ b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py @@ -14,7 +14,7 @@ # ----- Tester ----- -class ExampleModelTester(ModelTester): +class ExampleModelMixedArgsAndKwargsTester(ModelTester): """ Example tester showcasing how to use both positional and keyword arguments for model's forward method. @@ -43,10 +43,10 @@ def _get_forward_method_name() -> str: # @override def _get_forward_method_args(self) -> Sequence[jax.Array]: """Returns just input activations as positional arg.""" - acts = self._get_input_activations() - assert len(acts) == 1 - act = acts[0] - return [act] + input_activations = self._get_input_activations() + assert len(input_activations) == 1 + input_activation = input_activations[0] + return [input_activation] # @override def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: @@ -69,29 +69,24 @@ def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: @pytest.fixture -def comparison_config() -> ComparisonConfig: - config = ComparisonConfig() - config.atol.disable() - return config - - -@pytest.fixture -def inference_tester(comparison_config: ComparisonConfig) -> ExampleModelTester: - return ExampleModelTester(comparison_config) +def inference_tester() -> ExampleModelMixedArgsAndKwargsTester: + return ExampleModelMixedArgsAndKwargsTester() @pytest.fixture -def training_tester(comparison_config: ComparisonConfig) -> ExampleModelTester: - return ExampleModelTester(comparison_config, RunMode.TRAINING) +def training_tester() -> ExampleModelMixedArgsAndKwargsTester: + return ExampleModelMixedArgsAndKwargsTester(RunMode.TRAINING) # ----- Tests ----- -def test_example_model_inference(inference_tester: ExampleModelTester): +def test_example_model_inference( + inference_tester: ExampleModelMixedArgsAndKwargsTester, +): inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") -def test_example_model_training(training_tester: ExampleModelTester): +def test_example_model_training(training_tester: ExampleModelMixedArgsAndKwargsTester): training_tester.test() diff --git a/tests/jax/models/example_model/model.py b/tests/jax/models/example_model/model.py index ea5f4ea..f3a4f33 100644 --- a/tests/jax/models/example_model/model.py +++ b/tests/jax/models/example_model/model.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp from flax import nnx +from infra import random_tensor class ExampleModel(nnx.Module): @@ -16,16 +17,17 @@ def __init__(self) -> None: (1, 128), ) - self.w0 = jax.numpy.ones(w0_shape) - self.w1 = jax.numpy.ones(w1_shape) - self.b0 = jax.numpy.ones(b0_shape) - self.b1 = jax.numpy.zeros(b1_shape) + self.w0 = random_tensor(w0_shape, minval=-0.01, maxval=0.01) + self.w1 = random_tensor(w1_shape, minval=-0.01, maxval=0.01) + self.b0 = random_tensor(b0_shape, minval=-0.01, maxval=0.01) + self.b1 = random_tensor(b1_shape, minval=-0.01, maxval=0.01) def __call__( self, act: jax.Array, w0: jax.Array, b0: jax.Array, w1: jax.Array, b1: jax.Array ) -> jax.Array: # Note how activations, weights and biases are directly passed to the forward - # method. `self` is not accessed. + # method as inputs, `self` is not accessed. Otherwise they would be embedded + # into jitted graph as constants. x = jnp.matmul(act, w0) + b0 x = jnp.matmul(x, w1) + b1 return x diff --git a/tests/jax/models/example_model/only_args/test_example_model_only_args.py b/tests/jax/models/example_model/only_args/test_example_model_only_args.py index 44b492d..4beed15 100644 --- a/tests/jax/models/example_model/only_args/test_example_model_only_args.py +++ b/tests/jax/models/example_model/only_args/test_example_model_only_args.py @@ -14,7 +14,7 @@ # ----- Tester ----- -class ExampleModelTester(ModelTester): +class ExampleModelOnlyArgsTester(ModelTester): """ Example tester showcasing how to use only positional arguments for model's forward method. @@ -52,41 +52,34 @@ def _get_forward_method_args(self) -> Sequence[jax.Array]: b1 = self._model.b1 # Fetch activations. - acts = self._get_input_activations() - assert len(acts) == 1 - act = acts[0] + input_activations = self._get_input_activations() + assert len(input_activations) == 1 + input_activation = input_activations[0] # Mix activations, weights and biases to match forward method signature. - return [act, w0, b0, w1, b1] + return [input_activation, w0, b0, w1, b1] # ----- Fixtures ----- @pytest.fixture -def comparison_config() -> ComparisonConfig: - config = ComparisonConfig() - config.atol.disable() - return config +def inference_tester() -> ExampleModelOnlyArgsTester: + return ExampleModelOnlyArgsTester() @pytest.fixture -def inference_tester(comparison_config: ComparisonConfig) -> ExampleModelTester: - return ExampleModelTester(comparison_config) - - -@pytest.fixture -def training_tester(comparison_config: ComparisonConfig) -> ExampleModelTester: - return ExampleModelTester(comparison_config, RunMode.TRAINING) +def training_tester() -> ExampleModelOnlyArgsTester: + return ExampleModelOnlyArgsTester(RunMode.TRAINING) # ----- Tests ----- -def test_example_model_inference(inference_tester: ExampleModelTester): +def test_example_model_inference(inference_tester: ExampleModelOnlyArgsTester): inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") -def test_example_model_training(training_tester: ExampleModelTester): +def test_example_model_training(training_tester: ExampleModelOnlyArgsTester): training_tester.test() diff --git a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py index 65442e4..99f2916 100644 --- a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py +++ b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py @@ -14,7 +14,7 @@ # ----- Tester ----- -class ExampleModelTester(ModelTester): +class ExampleModelOnlyKwargsTester(ModelTester): """ Example tester showcasing how to use only keyword arguments for model's forward method. @@ -52,41 +52,34 @@ def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: b1 = self._model.b1 # Fetch activations. - acts = self._get_input_activations() - assert len(acts) == 1 - act = acts[0] + input_activations = self._get_input_activations() + assert len(input_activations) == 1 + input_activation = input_activations[0] # Mix activations, weights and biases to match forward method signature. - return {"act": act, "w0": w0, "b0": b0, "w1": w1, "b1": b1} + return {"act": input_activation, "w0": w0, "b0": b0, "w1": w1, "b1": b1} # ----- Fixtures ----- @pytest.fixture -def comparison_config() -> ComparisonConfig: - config = ComparisonConfig() - config.atol.disable() - return config +def inference_tester() -> ExampleModelOnlyKwargsTester: + return ExampleModelOnlyKwargsTester() @pytest.fixture -def inference_tester(comparison_config: ComparisonConfig) -> ExampleModelTester: - return ExampleModelTester(comparison_config) - - -@pytest.fixture -def training_tester(comparison_config: ComparisonConfig) -> ExampleModelTester: - return ExampleModelTester(comparison_config, RunMode.TRAINING) +def training_tester() -> ExampleModelOnlyKwargsTester: + return ExampleModelOnlyKwargsTester(RunMode.TRAINING) # ----- Tests ----- -def test_example_model_inference(inference_tester: ExampleModelTester): +def test_example_model_inference(inference_tester: ExampleModelOnlyKwargsTester): inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") -def test_example_model_training(training_tester: ExampleModelTester): +def test_example_model_training(training_tester: ExampleModelOnlyKwargsTester): training_tester.test() diff --git a/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py b/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py index 2854386..3b5ad6b 100644 --- a/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py +++ b/tests/jax/models/flax_distil_bert_for_masked_lm/test_flax_distil_bert_for_masked_lm.py @@ -37,37 +37,25 @@ def _get_input_activations() -> Sequence[jax.Array]: # @override def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: - activations = self._get_input_activations() + input_activations = self._get_input_activations() - assert len(activations) == 1 + assert len(input_activations) == 1 assert hasattr(self._model, "params") - return {"input_ids": activations[0], "params": self._model.params} + return {"input_ids": input_activations[0], "params": self._model.params} # ----- Fixtures ----- @pytest.fixture -def comparison_config() -> ComparisonConfig: - config = ComparisonConfig() - config.disable_all() - config.pcc.enable() - return config +def inference_tester() -> FlaxDistilBertForMaskedLMTester: + return FlaxDistilBertForMaskedLMTester() @pytest.fixture -def inference_tester( - comparison_config: ComparisonConfig, -) -> FlaxDistilBertForMaskedLMTester: - return FlaxDistilBertForMaskedLMTester(comparison_config) - - -@pytest.fixture -def training_tester( - comparison_config: ComparisonConfig, -) -> FlaxDistilBertForMaskedLMTester: - return FlaxDistilBertForMaskedLMTester(comparison_config, RunMode.TRAINING) +def training_tester() -> FlaxDistilBertForMaskedLMTester: + return FlaxDistilBertForMaskedLMTester(RunMode.TRAINING) # ----- Tests -----