Skip to content

Commit

Permalink
Add basic end to end tests
Browse files Browse the repository at this point in the history
* Uplift tt-mlir to support conversion of additional stable HLO ops
* Add basic end to end tests for the ops
  • Loading branch information
mmanzoorTT committed Sep 30, 2024
1 parent 4c1eda3 commit 16638b0
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 13 deletions.
156 changes: 156 additions & 0 deletions tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import pytest
import jax
import jax.numpy as jnp

from infrastructure import verify_module


def test_abs_op():
def module_abs(a):
return jnp.abs(a)

verify_module(module_abs, [(3, 3)])
verify_module(module_abs, [(3, 3, 3)])


#Broadcasted values are incorrect
@pytest.mark.xfail
def test_broadcast_op():
def module_broadcast(a):
return jnp.broadcast_to(a, (2, 4))

verify_module(module_broadcast, [(2, 1)])


#error: 'ttir.constant' op failed to verify that all of {value, result} have same shape
@pytest.mark.xfail
def test_constant_op():
def module_constant_zeros(a):
zeros = jnp.zeros(a.shape)
return zeros

def module_constant_ones(a):
ones = jnp.ones(a.shape)
return ones

def module_constant_multi(a):
multi = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
return multi

verify_module(module_constant_zeros, [(3, 3)])
verify_module(module_constant_ones, [(3, 3)])
verify_module(module_constant_multi, [(3, 3)])


def test_convert_op():
def module_convert(a):
return jax.lax.convert_element_type(a, jnp.bfloat16)

verify_module(module_convert, [(2, 2)])
verify_module(module_convert, [(4, 4, 4)])


def test_div_op():
def module_div(a, b):
return a / b

verify_module(module_div, [(3, 3), (3, 3)])
verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2)


def test_dot_general_op():
def module_dot_general(a, b):
return jnp.dot(a, b)

verify_module(module_dot_general, [(2, 1), (1, 2)])
verify_module(module_dot_general, [(1, 2), (2, 1)])


# Exponential generate slightly different values, so using higher ATOL value.
def test_exp_op():
def module_exp(a):
return jnp.exp(a)

verify_module(module_exp, [(3, 3)], required_atol=20e-2)
verify_module(module_exp, [(3, 3, 3)], required_atol=25e-2)


def test_maximum_op():
def module_maximum(a, b):
return jnp.maximum(a, b)

verify_module(module_maximum, [(3, 3), (3, 3)])
verify_module(module_maximum, [(3, 3, 3), (3, 3, 3)])


def test_multiply_op():
def module_multiply(a, b):
return a * b

verify_module(module_multiply, [(3, 3), (3, 3)])
verify_module(module_multiply, [(3, 3, 3), (3, 3, 3)])


def test_negate_op():
def module_negate(a):
return -a

verify_module(module_negate, [(3, 3)])
verify_module(module_negate, [(3, 3, 3)])


#Reduce is failing due to error in constant.
@pytest.mark.xfail
def test_reduce_op():
def module_reduce_max(a):
return jnp.max(a)

def module_reduce_sum(a):
return jnp.sum(a)

verify_module(module_reduce_max, [(3, 3)])
verify_module(module_reduce_max, [(3, 3, 3)])

verify_module(module_reduce_sum, [(3, 3)])
verify_module(module_reduce_sum, [(3, 3, 3)])


def test_rsqrt_op():
def module_rsqrt(a):
return jax.lax.rsqrt(a)

verify_module(module_rsqrt, [(3, 3)])
verify_module(module_rsqrt, [(3, 3, 3)])


def test_sqrt_op():
def module_sqrt(a):
return jnp.sqrt(a)

verify_module(module_sqrt, [(3, 3)])
verify_module(module_sqrt, [(3, 3, 3)])


def test_sub_op():
def module_sub(a, b):
return a - b

verify_module(module_sub, [(3, 3), (3, 3)])
verify_module(module_sub, [(3, 3, 3), (3, 3, 3)])


def test_transpose_op_2d():
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3)])


# Transpose op failing for higher ranks/dimensions.
@pytest.mark.xfail
def test_transpose_op_3d():
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3, 3)])

10 changes: 0 additions & 10 deletions tests/TTIR/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,22 @@
# Currently, tt::runtime only support float32, bfloat16, uint16, and uint32

def test_data_types(capfd):
print("Starting test")
a = jnp.array([[1., 2.], [3., 4.]], dtype=jnp.float32)
b = jnp.array([[5., 6.], [7., 8.]], dtype=jnp.bfloat16)
c = jnp.array([[1, 2], [3, 4]], dtype=jnp.uint32)
d = jnp.array([[5, 6], [7, 8]], dtype=jnp.uint16)
print(a)
out, _ = capfd.readouterr()
assert "[[1. 2.]\n [3. 4.]]" in out
# CHECK: [[1. 2.]
# CHECK: [3. 4.]]


print(b)
out, _ = capfd.readouterr()
assert "[[5 6]\n [7 8]]" in out
# CHECK: [[5 6]
# CHECK: [7 8]]

print(c)
out, _ = capfd.readouterr()
assert "[[1 2]\n [3 4]]" in out
# CHECK: [[1 2]
# CHECK: [3 4]]

print(d)
out, _ = capfd.readouterr()
assert "[[5 6]\n [7 8]]" in out
# CHECK: [[5 6]
# CHECK: [7 8]]
4 changes: 2 additions & 2 deletions tests/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def compare_tensor_to_golden(tensor, golden, required_pcc=0.99, required_atol=1e
ret = ret and atol <= required_atol
if assert_on_error:
assert ret, f"ATOL is {atol} which is greater than {required_atol}"

return ret

def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2):
tt_device = jax.devices()[0]
cpu_inputs = [random_input_tensor(shape, key) for shape in input_shapes]
cpu_inputs = [random_input_tensor(input_shapes[i], key + i) for i in range(len(input_shapes))]
tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs]
graph = jax.jit(module)
res = graph(*tt_inputs)
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

set(TT_MLIR_VERSION "5d92bf937bc76b521d39e0fa320b20773905bfc1")
set(TT_MLIR_VERSION "20a7ccd485a198fea14861e5a765dd51972e85f3")
set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1")

if (TOOLCHAIN STREQUAL "ON")
Expand Down

0 comments on commit 16638b0

Please sign in to comment.