Skip to content

Commit

Permalink
Code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Dec 19, 2024
1 parent 4d23305 commit bfae725
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 108 deletions.
3 changes: 2 additions & 1 deletion src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ class DeviceDescription {

private:
int client_id_;
// TODO We should understand better how these are used.
// See https://github.com/tenstorrent/tt-xla/issues/125
std::string kind_string_ = "TTDevice";
// TODO should not be hardcoded
std::string arch_string_ = "Wormhole";
std::string user_string_ = "";
};
Expand Down
6 changes: 2 additions & 4 deletions tests/TTIR/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import jax
import jax.numpy as jnp

from infrastructure import random_input_tensor


Expand All @@ -16,9 +14,9 @@ def test_num_devices():

def test_to_device():
cpu_array = random_input_tensor((32, 32))
device = jax.devices()[0]
device = jax.devices("tt")[0]
tt_array = jax.device_put(cpu_array, device)
assert tt_array.device.device_kind == "wormhole"
assert tt_array.device.device_kind == "TTDevice"


def test_input_on_device():
Expand Down
1 change: 1 addition & 0 deletions tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .graph_tester import run_graph_test, run_graph_test_with_random_inputs
from .model_tester import ModelTester, RunMode
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor
5 changes: 5 additions & 0 deletions tests/infra/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax
import jax.numpy as jnp

from .device_runner import run_on_cpu
from .utils import Tensor


Expand Down Expand Up @@ -68,6 +69,7 @@ def disable_all(self) -> None:
# frameworks in the future.


@run_on_cpu
def compare_equal(device_output: Tensor, golden_output: Tensor) -> None:
assert isinstance(device_output, jax.Array) and isinstance(
golden_output, jax.Array
Expand All @@ -78,6 +80,7 @@ def compare_equal(device_output: Tensor, golden_output: Tensor) -> None:
assert eq, f"Equal comparison failed"


@run_on_cpu
def compare_atol(
device_output: Tensor, golden_output: Tensor, atol_config: AtolConfig
) -> None:
Expand All @@ -92,6 +95,7 @@ def compare_atol(
), f"Atol comparison failed. Calculated atol={atol}"


@run_on_cpu
def compare_pcc(
device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig
) -> None:
Expand All @@ -113,6 +117,7 @@ def compare_pcc(
), f"PCC comparison failed. Calculated pcc={pcc}"


@run_on_cpu
def compare_allclose(
device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig
) -> None:
Expand Down
21 changes: 20 additions & 1 deletion tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Sequence
from typing import Callable, Sequence

import jax

Expand Down Expand Up @@ -124,3 +124,22 @@ def _safely_put_workload_on_device(
kwargs_on_device[key] = value_on_device

return Workload(workload.executable, args_on_device, kwargs_on_device)


# --------------- Convenience decorators ---------------


def run_on_cpu(f: Callable):
def wrapper(*args, **kwargs):
workload = Workload(f, args, kwargs)
return DeviceRunner.run_on_cpu(workload)

return wrapper


def run_on_tt_device(f: Callable):
def wrapper(*args, **kwargs):
workload = Workload(f, args, kwargs)
return DeviceRunner.run_on_tt_device(workload)

return wrapper
2 changes: 1 addition & 1 deletion tests/infra/model_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ModelTester(BaseTester, ABC):

def __init__(
self,
comparison_config: ComparisonConfig,
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
super().__init__(comparison_config)
Expand Down
59 changes: 39 additions & 20 deletions tests/infra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,6 @@ def __post_init__(self):
def execute(self) -> Any:
return self.executable(*self.args, **self.kwargs)

def as_mlir_module(self) -> str:
"""
Returns workload as mlir module string.
Note that workload.executable must be the result of jit, otherwise empty string
will be returned.
"""
try:
s = export.export(self.executable)(*self.args, **self.kwargs).mlir_module()
# Remove all lines that start with "#loc" for cleaner output.
return "\n".join(
line for line in s.splitlines() if not line.startswith("#loc")
)

except ValueError:
return ""


class Framework(Enum):
JAX = "jax"
Expand Down Expand Up @@ -71,18 +54,54 @@ def random_tensor(
shape: tuple,
dtype: str = "float32",
random_seed: int = 0,
minval: float = 0.0,
maxval: float = 1.0,
framework: Framework = Framework.JAX,
) -> Tensor:
"""
Generates a random tensor of `shape`, `dtype`, and `random_seed` for the desired
`framework`.
Generates a random tensor of `shape`, `dtype`, and `random_seed` in range
[`minval`, `maxval`) for the desired `framework`.
"""
# Convert dtype string to actual dtype for the selected framework.
dtype_converted = _str_to_dtype(dtype, framework)

# Generate random tensor based on framework type
if framework == Framework.JAX:
prng_key = jax.random.PRNGKey(random_seed)
return jax.random.uniform(key=prng_key, shape=shape, dtype=dtype_converted)

return jax.random.uniform(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=minval,
maxval=maxval,
)
else:
raise ValueError(f"Unsupported framework: {framework.value}.")


def workload_as_mlir_module(
workload: Workload, framework: Framework = Framework.JAX
) -> str:
"""
Returns workload as mlir module string.
Note that in case of jax, workload.executable must be the result of jit, otherwise
empty string will be returned.
"""

if framework == Framework.JAX:
try:
s = export.export(workload.executable)(
*workload.args, **workload.kwargs
).mlir_module()

# Remove all lines that start with "#loc" for cleaner output.
return "\n".join(
line for line in s.splitlines() if not line.startswith("#loc")
)

except ValueError:
return ""
else:
raise ValueError(f"Unsupported framework: {framework.value}.")
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jax import numpy as jnp


def arbitrary_op_chain(x: jax.Array, y: jax.Array) -> jax.Array:
def example_graph(x: jax.Array, y: jax.Array) -> jax.Array:
a = jnp.abs(x)
b = jnp.add(a, y)
c = jnp.divide(a, b)
Expand All @@ -23,5 +23,5 @@ def arbitrary_op_chain(x: jax.Array, y: jax.Array) -> jax.Array:
[(64, 64), (64, 64)],
],
)
def test_arbitrary_op_chain(x_shape: tuple, y_shape: tuple):
run_graph_test_with_random_inputs(arbitrary_op_chain, [x_shape, y_shape])
def test_example_graph(x_shape: tuple, y_shape: tuple):
run_graph_test_with_random_inputs(example_graph, [x_shape, y_shape])
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ----- Tester -----


class ExampleModelTester(ModelTester):
class ExampleModelMixedArgsAndKwargsTester(ModelTester):
"""
Example tester showcasing how to use both positional and keyword arguments for
model's forward method.
Expand Down Expand Up @@ -43,10 +43,10 @@ def _get_forward_method_name() -> str:
# @override
def _get_forward_method_args(self) -> Sequence[jax.Array]:
"""Returns just input activations as positional arg."""
acts = self._get_input_activations()
assert len(acts) == 1
act = acts[0]
return [act]
input_activations = self._get_input_activations()
assert len(input_activations) == 1
input_activation = input_activations[0]
return [input_activation]

# @override
def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]:
Expand All @@ -69,29 +69,24 @@ def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]:


@pytest.fixture
def comparison_config() -> ComparisonConfig:
config = ComparisonConfig()
config.atol.disable()
return config


@pytest.fixture
def inference_tester(comparison_config: ComparisonConfig) -> ExampleModelTester:
return ExampleModelTester(comparison_config)
def inference_tester() -> ExampleModelMixedArgsAndKwargsTester:
return ExampleModelMixedArgsAndKwargsTester()


@pytest.fixture
def training_tester(comparison_config: ComparisonConfig) -> ExampleModelTester:
return ExampleModelTester(comparison_config, RunMode.TRAINING)
def training_tester() -> ExampleModelMixedArgsAndKwargsTester:
return ExampleModelMixedArgsAndKwargsTester(RunMode.TRAINING)


# ----- Tests -----


def test_example_model_inference(inference_tester: ExampleModelTester):
def test_example_model_inference(
inference_tester: ExampleModelMixedArgsAndKwargsTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_example_model_training(training_tester: ExampleModelTester):
def test_example_model_training(training_tester: ExampleModelMixedArgsAndKwargsTester):
training_tester.test()
12 changes: 7 additions & 5 deletions tests/jax/models/example_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
from flax import nnx
from infra import random_tensor


class ExampleModel(nnx.Module):
Expand All @@ -16,16 +17,17 @@ def __init__(self) -> None:
(1, 128),
)

self.w0 = jax.numpy.ones(w0_shape)
self.w1 = jax.numpy.ones(w1_shape)
self.b0 = jax.numpy.ones(b0_shape)
self.b1 = jax.numpy.zeros(b1_shape)
self.w0 = random_tensor(w0_shape, minval=-0.01, maxval=0.01)
self.w1 = random_tensor(w1_shape, minval=-0.01, maxval=0.01)
self.b0 = random_tensor(b0_shape, minval=-0.01, maxval=0.01)
self.b1 = random_tensor(b1_shape, minval=-0.01, maxval=0.01)

def __call__(
self, act: jax.Array, w0: jax.Array, b0: jax.Array, w1: jax.Array, b1: jax.Array
) -> jax.Array:
# Note how activations, weights and biases are directly passed to the forward
# method. `self` is not accessed.
# method as inputs, `self` is not accessed. Otherwise they would be embedded
# into jitted graph as constants.
x = jnp.matmul(act, w0) + b0
x = jnp.matmul(x, w1) + b1
return x
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ----- Tester -----


class ExampleModelTester(ModelTester):
class ExampleModelOnlyArgsTester(ModelTester):
"""
Example tester showcasing how to use only positional arguments for model's forward
method.
Expand Down Expand Up @@ -52,41 +52,34 @@ def _get_forward_method_args(self) -> Sequence[jax.Array]:
b1 = self._model.b1

# Fetch activations.
acts = self._get_input_activations()
assert len(acts) == 1
act = acts[0]
input_activations = self._get_input_activations()
assert len(input_activations) == 1
input_activation = input_activations[0]

# Mix activations, weights and biases to match forward method signature.
return [act, w0, b0, w1, b1]
return [input_activation, w0, b0, w1, b1]


# ----- Fixtures -----


@pytest.fixture
def comparison_config() -> ComparisonConfig:
config = ComparisonConfig()
config.atol.disable()
return config
def inference_tester() -> ExampleModelOnlyArgsTester:
return ExampleModelOnlyArgsTester()


@pytest.fixture
def inference_tester(comparison_config: ComparisonConfig) -> ExampleModelTester:
return ExampleModelTester(comparison_config)


@pytest.fixture
def training_tester(comparison_config: ComparisonConfig) -> ExampleModelTester:
return ExampleModelTester(comparison_config, RunMode.TRAINING)
def training_tester() -> ExampleModelOnlyArgsTester:
return ExampleModelOnlyArgsTester(RunMode.TRAINING)


# ----- Tests -----


def test_example_model_inference(inference_tester: ExampleModelTester):
def test_example_model_inference(inference_tester: ExampleModelOnlyArgsTester):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_example_model_training(training_tester: ExampleModelTester):
def test_example_model_training(training_tester: ExampleModelOnlyArgsTester):
training_tester.test()
Loading

0 comments on commit bfae725

Please sign in to comment.