Skip to content

Commit

Permalink
Rerun tests to update their status. Updated every one to use parametr…
Browse files Browse the repository at this point in the history
…ize.
  • Loading branch information
kmitrovicTT committed Nov 28, 2024
1 parent 39812db commit 6f1e3e5
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 188 deletions.
38 changes: 0 additions & 38 deletions tests/TTIR/test_array_add.py

This file was deleted.

198 changes: 116 additions & 82 deletions tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
#

import pytest
import jax
Expand All @@ -11,31 +10,69 @@
from infrastructure import verify_module


def test_abs_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_abs_op(input_shapes):
def module_abs(a):
return jnp.abs(a)

verify_module(module_abs, [(3, 3)])
verify_module(module_abs, [(3, 3, 3)])
verify_module(module_abs, input_shapes)


@pytest.mark.parametrize("input_shapes", [[(2, 2), (2, 2)], [(3, 2), (3, 2)]])
def test_array_add(input_shapes):
def module_add(a, b):
return a + b

verify_module(module_add, input_shapes)


@pytest.mark.parametrize("rank", [1, 2, 3, 4, 5, 6])
def test_module_add(rank):
def module_add(a, b):
c = a + a
d = b + b
return c + d

input_shape = []
for i in range(rank):
input_shape.insert(0, 32 if i < 2 else 1)

input_shape = tuple(input_shape)
verify_module(module_add, [input_shape, input_shape])


# TODO might be outdated.
def test_scalar_add():
def module_add(a, b):
return a + b

# Broadcasted values are incorrect
@pytest.mark.skip("Broadcasted values are incorrect")
def test_broadcast_op():
a = jnp.float32(5.0)
b = jnp.array(6.0)
tt_graph = jax.jit(module_add)
res = tt_graph(a, b)
cpu_graph = jax.jit(module_add, backend="cpu")
res_cpu = cpu_graph(a, b)
assert jnp.allclose(res, res_cpu)


@pytest.mark.parametrize("input_shapes", [[(2, 1)]])
@pytest.mark.skip(
"Broadcasted values are incorrect. "
"Fails with: AssertionError: PCC is 0.37796446681022644 which is less than 0.99"
)
def test_broadcast_op(input_shapes):
def module_broadcast(a):
return jnp.broadcast_to(a, (2, 4))

verify_module(module_broadcast, [(2, 1)])
verify_module(module_broadcast, input_shapes)


def test_cbrt_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_cbrt_op(input_shapes):
def module_cbrt(a):
return jax.lax.cbrt(a)

verify_module(
module_cbrt, [(3, 3)], required_atol=2e-2
) # ATOL is 0.010040640830993652
verify_module(module_cbrt, [(3, 3, 3)], required_atol=2e-2)
verify_module(module_cbrt, input_shapes, required_atol=2e-2)


def test_concat_op():
Expand Down Expand Up @@ -70,11 +107,9 @@ def module_concat_dim_3(x, y):
) # output shape: (32, 32, 32, 96)


# error: 'ttir.constant' op failed to verify that all of {value, result} have same shape
@pytest.mark.skip(
"Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615"
)
def test_constant_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)]])
@pytest.mark.skip("AssertionError: ATOL is 21574.4375 which is greater than 0.01")
def test_constant_op(input_shapes):
def module_constant_zeros(a):
zeros = jnp.zeros(a.shape)
return zeros
Expand All @@ -83,165 +118,165 @@ def module_constant_ones(a):
ones = jnp.ones(a.shape)
return ones

verify_module(module_constant_zeros, input_shapes)
verify_module(module_constant_ones, input_shapes)


@pytest.mark.parametrize("input_shapes", [[(3, 3)]])
@pytest.mark.skip("Fails due to: error: failed to legalize operation 'ttir.constant'")
def test_constant_op_multi_dim(input_shapes):
def module_constant_multi(a):
multi = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
return multi

verify_module(module_constant_zeros, [(3, 3)])
verify_module(module_constant_ones, [(3, 3)])
verify_module(module_constant_multi, [(3, 3)])
verify_module(module_constant_multi, input_shapes)


def test_convert_op():
@pytest.mark.parametrize("input_shapes", [[(2, 2)], [(4, 4, 4)]])
def test_convert_op(input_shapes):
def module_convert(a):
return jax.lax.convert_element_type(a, jnp.bfloat16)

verify_module(module_convert, [(2, 2)])
verify_module(module_convert, [(4, 4, 4)])
verify_module(module_convert, input_shapes)


def test_div_op():
@pytest.mark.parametrize(
["input_shapes", "required_atol"],
[([(3, 3), (3, 3)], 0.01), ([(3, 3, 3), (3, 3, 3)], 35e-2)],
)
def test_div_op(input_shapes, required_atol):
def module_div(a, b):
return a / b

verify_module(module_div, [(3, 3), (3, 3)])
verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2)
verify_module(module_div, input_shapes, required_atol=required_atol)


def test_dot_general_op():
@pytest.mark.parametrize(
"input_shapes",
[[(2, 1), (1, 2)], [(1, 2), (2, 1)]],
)
def test_dot_general_op(input_shapes):
def module_dot_general(a, b):
return jnp.dot(a, b)

verify_module(module_dot_general, [(2, 1), (1, 2)])
verify_module(module_dot_general, [(1, 2), (2, 1)])
verify_module(module_dot_general, input_shapes)


# Exponential generate slightly different values, so using higher ATOL value.
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_exp_op():
@pytest.mark.parametrize(
["input_shapes", "required_atol"], [([(3, 3)], 20e-2), ([(3, 3, 3)], 25e-2)]
)
def test_exp_op(input_shapes, required_atol):
def module_exp(a):
return jnp.exp(a)

verify_module(module_exp, [(3, 3)], required_atol=20e-2)
verify_module(module_exp, [(3, 3, 3)], required_atol=25e-2)
verify_module(module_exp, input_shapes, required_atol=required_atol)


def test_maximum_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]])
def test_maximum_op(input_shapes):
def module_maximum(a, b):
return jnp.maximum(a, b)

verify_module(module_maximum, [(3, 3), (3, 3)])
verify_module(module_maximum, [(3, 3, 3), (3, 3, 3)])
verify_module(module_maximum, input_shapes)


def test_multiply_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]])
def test_multiply_op(input_shapes):
def module_multiply(a, b):
return a * b

verify_module(module_multiply, [(3, 3), (3, 3)])
verify_module(module_multiply, [(3, 3, 3), (3, 3, 3)])
verify_module(module_multiply, input_shapes)


def test_negate_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_negate_op(input_shapes):
def module_negate(a):
return -a

verify_module(module_negate, [(3, 3)])
verify_module(module_negate, [(3, 3, 3)])
verify_module(module_negate, input_shapes)


# Reduce is failing due to error in constant.
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
@pytest.mark.skip("keepdim=False is not supported")
def test_reduce_op():
def test_reduce_op(input_shapes):
def module_reduce_max(a):
return jnp.max(a)

def module_reduce_sum(a):
return jnp.sum(a)

verify_module(module_reduce_max, [(3, 3)])
verify_module(module_reduce_max, [(3, 3, 3)])
verify_module(module_reduce_max, input_shapes)
verify_module(module_reduce_sum, input_shapes)

verify_module(module_reduce_sum, [(3, 3)])
verify_module(module_reduce_sum, [(3, 3, 3)])


def test_rsqrt_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_rsqrt_op(input_shapes):
def module_rsqrt(a):
return jax.lax.rsqrt(a)

verify_module(module_rsqrt, [(3, 3)])
verify_module(module_rsqrt, [(3, 3, 3)])
verify_module(module_rsqrt, input_shapes)


# Needs to have a bigger atol due to inaccuracies in the exp op on tt-metal
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_expm1_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_expm1_op(input_shapes):
def module_expm1(a):
return jax.lax.expm1(a)

verify_module(module_expm1, [(3, 3)], required_atol=20e-2)
verify_module(module_expm1, [(3, 3, 3)], required_atol=20e-2)
verify_module(module_expm1, input_shapes, required_atol=20e-2)


def test_log1p_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_log1p_op(input_shapes):
def module_log1p(a):
return jax.lax.log1p(a)

verify_module(module_log1p, [(3, 3)], required_atol=2e-2)
verify_module(module_log1p, [(3, 3, 3)], required_atol=2e-2)
verify_module(module_log1p, input_shapes, required_atol=2e-2)


def test_sign_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_sign_op(input_shapes):
def module_sign(a):
return jax.lax.sign(a)

verify_module(module_sign, [(3, 3)])
verify_module(module_sign, [(3, 3, 3)])
verify_module(module_sign, input_shapes)


def test_sqrt_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_sqrt_op(input_shapes):
def module_sqrt(a):
return jnp.sqrt(a)

verify_module(module_sqrt, [(3, 3)])
verify_module(module_sqrt, [(3, 3, 3)])
verify_module(module_sqrt, input_shapes)


def test_sub_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]])
def test_sub_op(input_shapes):
def module_sub(a, b):
return a - b

verify_module(module_sub, [(3, 3), (3, 3)])
verify_module(module_sub, [(3, 3, 3), (3, 3, 3)])
verify_module(module_sub, input_shapes)


def test_transpose_op_2d():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_transpose_op(input_shapes):
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3)])
verify_module(module_transpose, input_shapes)


@pytest.mark.skip(
"Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73"
)
def test_scalar_type():
def module_scalar_type(a):
return a.shape[0]

verify_module(module_scalar_type, [(3, 3)])


# Transpose op failing for higher ranks/dimensions.
@pytest.mark.skip("Transpose op failing for higher ranks/dimensions.")
def test_transpose_op_3d():
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3, 3)])


dim0_cases = []
for begin in numpy.arange(10).tolist():
for end in numpy.arange(90, 100).tolist():
Expand All @@ -264,9 +299,8 @@ def module_transpose(a):


@pytest.mark.parametrize(
"begin, end, dim", [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases]
["begin", "end", "dim"], [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases]
)
@pytest.mark.skip("Requires tt-metal uplift.")
def test_slice(begin, end, dim):
def module_slice(a):
if dim == 0:
Expand Down
Loading

0 comments on commit 6f1e3e5

Please sign in to comment.