diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py new file mode 100644 index 0000000..2dd03d4 --- /dev/null +++ b/tests/TTIR/test_basic_ops.py @@ -0,0 +1,156 @@ +import pytest +import jax +import jax.numpy as jnp + +from infrastructure import verify_module + + +def test_abs_op(): + def module_abs(a): + return jnp.abs(a) + + verify_module(module_abs, [(3, 3)]) + verify_module(module_abs, [(3, 3, 3)]) + + +#Broadcasted values are incorrect +@pytest.mark.xfail +def test_broadcast_op(): + def module_broadcast(a): + return jnp.broadcast_to(a, (2, 4)) + + verify_module(module_broadcast, [(2, 1)]) + + +#error: 'ttir.constant' op failed to verify that all of {value, result} have same shape +@pytest.mark.xfail +def test_constant_op(): + def module_constant_zeros(a): + zeros = jnp.zeros(a.shape) + return zeros + + def module_constant_ones(a): + ones = jnp.ones(a.shape) + return ones + + def module_constant_multi(a): + multi = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) + return multi + + verify_module(module_constant_zeros, [(3, 3)]) + verify_module(module_constant_ones, [(3, 3)]) + verify_module(module_constant_multi, [(3, 3)]) + + +def test_convert_op(): + def module_convert(a): + return jax.lax.convert_element_type(a, jnp.bfloat16) + + verify_module(module_convert, [(2, 2)]) + verify_module(module_convert, [(4, 4, 4)]) + + +def test_div_op(): + def module_div(a, b): + return a / b + + verify_module(module_div, [(3, 3), (3, 3)]) + verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2) + + +def test_dot_general_op(): + def module_dot_general(a, b): + return jnp.dot(a, b) + + verify_module(module_dot_general, [(2, 1), (1, 2)]) + verify_module(module_dot_general, [(1, 2), (2, 1)]) + + +# Exponential generate slightly different values, so using higher ATOL value. +def test_exp_op(): + def module_exp(a): + return jnp.exp(a) + + verify_module(module_exp, [(3, 3)], required_atol=20e-2) + verify_module(module_exp, [(3, 3, 3)], required_atol=25e-2) + + +def test_maximum_op(): + def module_maximum(a, b): + return jnp.maximum(a, b) + + verify_module(module_maximum, [(3, 3), (3, 3)]) + verify_module(module_maximum, [(3, 3, 3), (3, 3, 3)]) + + +def test_multiply_op(): + def module_multiply(a, b): + return a * b + + verify_module(module_multiply, [(3, 3), (3, 3)]) + verify_module(module_multiply, [(3, 3, 3), (3, 3, 3)]) + + +def test_negate_op(): + def module_negate(a): + return -a + + verify_module(module_negate, [(3, 3)]) + verify_module(module_negate, [(3, 3, 3)]) + + +#Reduce is failing due to error in constant. +@pytest.mark.xfail +def test_reduce_op(): + def module_reduce_max(a): + return jnp.max(a) + + def module_reduce_sum(a): + return jnp.sum(a) + + verify_module(module_reduce_max, [(3, 3)]) + verify_module(module_reduce_max, [(3, 3, 3)]) + + verify_module(module_reduce_sum, [(3, 3)]) + verify_module(module_reduce_sum, [(3, 3, 3)]) + + +def test_rsqrt_op(): + def module_rsqrt(a): + return jax.lax.rsqrt(a) + + verify_module(module_rsqrt, [(3, 3)]) + verify_module(module_rsqrt, [(3, 3, 3)]) + + +def test_sqrt_op(): + def module_sqrt(a): + return jnp.sqrt(a) + + verify_module(module_sqrt, [(3, 3)]) + verify_module(module_sqrt, [(3, 3, 3)]) + + +def test_sub_op(): + def module_sub(a, b): + return a - b + + verify_module(module_sub, [(3, 3), (3, 3)]) + verify_module(module_sub, [(3, 3, 3), (3, 3, 3)]) + + +def test_transpose_op_2d(): + def module_transpose(a): + return jnp.transpose(a) + + verify_module(module_transpose, [(3, 3)]) + + +# Transpose op failing for higher ranks/dimensions. +@pytest.mark.xfail +def test_transpose_op_3d(): + def module_transpose(a): + return jnp.transpose(a) + + verify_module(module_transpose, [(3, 3, 3)]) + diff --git a/tests/TTIR/test_data_types.py b/tests/TTIR/test_data_types.py index f49ac28..2531525 100644 --- a/tests/TTIR/test_data_types.py +++ b/tests/TTIR/test_data_types.py @@ -11,7 +11,6 @@ # Currently, tt::runtime only support float32, bfloat16, uint16, and uint32 def test_data_types(capfd): - print("Starting test") a = jnp.array([[1., 2.], [3., 4.]], dtype=jnp.float32) b = jnp.array([[5., 6.], [7., 8.]], dtype=jnp.bfloat16) c = jnp.array([[1, 2], [3, 4]], dtype=jnp.uint32) @@ -19,24 +18,15 @@ def test_data_types(capfd): print(a) out, _ = capfd.readouterr() assert "[[1. 2.]\n [3. 4.]]" in out - # CHECK: [[1. 2.] - # CHECK: [3. 4.]] - print(b) out, _ = capfd.readouterr() assert "[[5 6]\n [7 8]]" in out - # CHECK: [[5 6] - # CHECK: [7 8]] print(c) out, _ = capfd.readouterr() assert "[[1 2]\n [3 4]]" in out - # CHECK: [[1 2] - # CHECK: [3 4]] print(d) out, _ = capfd.readouterr() assert "[[5 6]\n [7 8]]" in out - # CHECK: [[5 6] - # CHECK: [7 8]] diff --git a/tests/infrastructure.py b/tests/infrastructure.py index 9ae16a8..63dc691 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -34,12 +34,12 @@ def compare_tensor_to_golden(tensor, golden, required_pcc=0.99, required_atol=1e ret = ret and atol <= required_atol if assert_on_error: assert ret, f"ATOL is {atol} which is greater than {required_atol}" - + return ret def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2): tt_device = jax.devices()[0] - cpu_inputs = [random_input_tensor(shape, key) for shape in input_shapes] + cpu_inputs = [random_input_tensor(input_shapes[i], key + i) for i in range(len(input_shapes))] tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs] graph = jax.jit(module) res = graph(*tt_inputs) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index e3ed07b..95c1870 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "5d92bf937bc76b521d39e0fa320b20773905bfc1") +set(TT_MLIR_VERSION "20a7ccd485a198fea14861e5a765dd51972e85f3") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON")