Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic end to end tests #15

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
#

import pytest
import jax
import jax.numpy as jnp

from infrastructure import verify_module


def test_abs_op():
def module_abs(a):
return jnp.abs(a)

verify_module(module_abs, [(3, 3)])
verify_module(module_abs, [(3, 3, 3)])


#Broadcasted values are incorrect
@pytest.mark.xfail
def test_broadcast_op():
def module_broadcast(a):
return jnp.broadcast_to(a, (2, 4))

verify_module(module_broadcast, [(2, 1)])


#error: 'ttir.constant' op failed to verify that all of {value, result} have same shape
@pytest.mark.xfail
def test_constant_op():
def module_constant_zeros(a):
zeros = jnp.zeros(a.shape)
return zeros

def module_constant_ones(a):
ones = jnp.ones(a.shape)
return ones

def module_constant_multi(a):
multi = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
return multi

verify_module(module_constant_zeros, [(3, 3)])
verify_module(module_constant_ones, [(3, 3)])
verify_module(module_constant_multi, [(3, 3)])


def test_convert_op():
def module_convert(a):
return jax.lax.convert_element_type(a, jnp.bfloat16)

verify_module(module_convert, [(2, 2)])
verify_module(module_convert, [(4, 4, 4)])


def test_div_op():
def module_div(a, b):
return a / b

verify_module(module_div, [(3, 3), (3, 3)])
verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2)


def test_dot_general_op():
def module_dot_general(a, b):
return jnp.dot(a, b)

verify_module(module_dot_general, [(2, 1), (1, 2)])
verify_module(module_dot_general, [(1, 2), (2, 1)])


# Exponential generate slightly different values, so using higher ATOL value.
def test_exp_op():
def module_exp(a):
return jnp.exp(a)

verify_module(module_exp, [(3, 3)], required_atol=20e-2)
verify_module(module_exp, [(3, 3, 3)], required_atol=25e-2)


def test_maximum_op():
def module_maximum(a, b):
return jnp.maximum(a, b)

verify_module(module_maximum, [(3, 3), (3, 3)])
verify_module(module_maximum, [(3, 3, 3), (3, 3, 3)])


def test_multiply_op():
def module_multiply(a, b):
return a * b

verify_module(module_multiply, [(3, 3), (3, 3)])
verify_module(module_multiply, [(3, 3, 3), (3, 3, 3)])


def test_negate_op():
def module_negate(a):
return -a

verify_module(module_negate, [(3, 3)])
verify_module(module_negate, [(3, 3, 3)])


#Reduce is failing due to error in constant.
@pytest.mark.xfail
def test_reduce_op():
def module_reduce_max(a):
return jnp.max(a)

def module_reduce_sum(a):
return jnp.sum(a)

verify_module(module_reduce_max, [(3, 3)])
verify_module(module_reduce_max, [(3, 3, 3)])

verify_module(module_reduce_sum, [(3, 3)])
verify_module(module_reduce_sum, [(3, 3, 3)])


def test_rsqrt_op():
def module_rsqrt(a):
return jax.lax.rsqrt(a)

verify_module(module_rsqrt, [(3, 3)])
verify_module(module_rsqrt, [(3, 3, 3)])


def test_sqrt_op():
def module_sqrt(a):
return jnp.sqrt(a)

verify_module(module_sqrt, [(3, 3)])
verify_module(module_sqrt, [(3, 3, 3)])


def test_sub_op():
def module_sub(a, b):
return a - b

verify_module(module_sub, [(3, 3), (3, 3)])
verify_module(module_sub, [(3, 3, 3), (3, 3, 3)])


def test_transpose_op_2d():
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3)])


# Transpose op failing for higher ranks/dimensions.
@pytest.mark.xfail
def test_transpose_op_3d():
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3, 3)])

10 changes: 0 additions & 10 deletions tests/TTIR/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,22 @@
# Currently, tt::runtime only support float32, bfloat16, uint16, and uint32

def test_data_types(capfd):
print("Starting test")
a = jnp.array([[1., 2.], [3., 4.]], dtype=jnp.float32)
b = jnp.array([[5., 6.], [7., 8.]], dtype=jnp.bfloat16)
c = jnp.array([[1, 2], [3, 4]], dtype=jnp.uint32)
d = jnp.array([[5, 6], [7, 8]], dtype=jnp.uint16)
print(a)
out, _ = capfd.readouterr()
assert "[[1. 2.]\n [3. 4.]]" in out
# CHECK: [[1. 2.]
# CHECK: [3. 4.]]


print(b)
out, _ = capfd.readouterr()
assert "[[5 6]\n [7 8]]" in out
# CHECK: [[5 6]
# CHECK: [7 8]]

print(c)
out, _ = capfd.readouterr()
assert "[[1 2]\n [3 4]]" in out
# CHECK: [[1 2]
# CHECK: [3 4]]

print(d)
out, _ = capfd.readouterr()
assert "[[5 6]\n [7 8]]" in out
# CHECK: [[5 6]
# CHECK: [7 8]]
4 changes: 2 additions & 2 deletions tests/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def compare_tensor_to_golden(tensor, golden, required_pcc=0.99, required_atol=1e
ret = ret and atol <= required_atol
if assert_on_error:
assert ret, f"ATOL is {atol} which is greater than {required_atol}"

return ret

def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2):
tt_device = jax.devices()[0]
cpu_inputs = [random_input_tensor(shape, key) for shape in input_shapes]
cpu_inputs = [random_input_tensor(input_shapes[i], key + i) for i in range(len(input_shapes))]
tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs]
graph = jax.jit(module)
res = graph(*tt_inputs)
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

set(TT_MLIR_VERSION "5d92bf937bc76b521d39e0fa320b20773905bfc1")
set(TT_MLIR_VERSION "20a7ccd485a198fea14861e5a765dd51972e85f3")
set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1")

if (TOOLCHAIN STREQUAL "ON")
Expand Down