Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rerun tests to update their status. Updated each one to use parametrize. #82

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/TTIR/test_array_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import jax
import jax.numpy as jnp

from infrastructure import verify_module

Expand Down
161 changes: 79 additions & 82 deletions tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
#

import pytest
import jax
Expand All @@ -11,31 +10,32 @@
from infrastructure import verify_module


def test_abs_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_abs_op(input_shapes):
def module_abs(a):
return jnp.abs(a)

verify_module(module_abs, [(3, 3)])
verify_module(module_abs, [(3, 3, 3)])
verify_module(module_abs, input_shapes)


# Broadcasted values are incorrect
@pytest.mark.skip("Broadcasted values are incorrect")
def test_broadcast_op():
@pytest.mark.parametrize("input_shapes", [[(2, 1)]])
@pytest.mark.skip(
"Broadcasted values are incorrect. "
"Fails with: AssertionError: PCC is 0.37796446681022644 which is less than 0.99"
)
def test_broadcast_op(input_shapes):
def module_broadcast(a):
return jnp.broadcast_to(a, (2, 4))

verify_module(module_broadcast, [(2, 1)])
verify_module(module_broadcast, input_shapes)


def test_cbrt_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_cbrt_op(input_shapes):
def module_cbrt(a):
return jax.lax.cbrt(a)

verify_module(
module_cbrt, [(3, 3)], required_atol=2e-2
) # ATOL is 0.010040640830993652
verify_module(module_cbrt, [(3, 3, 3)], required_atol=2e-2)
verify_module(module_cbrt, input_shapes, required_atol=2e-2)


def test_concat_op():
Expand Down Expand Up @@ -70,11 +70,9 @@ def module_concat_dim_3(x, y):
) # output shape: (32, 32, 32, 96)


# error: 'ttir.constant' op failed to verify that all of {value, result} have same shape
@pytest.mark.skip(
"Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615"
)
def test_constant_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)]])
@pytest.mark.skip("AssertionError: ATOL is 21574.4375 which is greater than 0.01")
def test_constant_op(input_shapes):
def module_constant_zeros(a):
zeros = jnp.zeros(a.shape)
return zeros
Expand All @@ -83,165 +81,165 @@ def module_constant_ones(a):
ones = jnp.ones(a.shape)
return ones

verify_module(module_constant_zeros, input_shapes)
verify_module(module_constant_ones, input_shapes)


@pytest.mark.parametrize("input_shapes", [[(3, 3)]])
@pytest.mark.skip("Fails due to: error: failed to legalize operation 'ttir.constant'")
def test_constant_op_multi_dim(input_shapes):
def module_constant_multi(a):
multi = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
return multi

verify_module(module_constant_zeros, [(3, 3)])
verify_module(module_constant_ones, [(3, 3)])
verify_module(module_constant_multi, [(3, 3)])
verify_module(module_constant_multi, input_shapes)


def test_convert_op():
@pytest.mark.parametrize("input_shapes", [[(2, 2)], [(4, 4, 4)]])
def test_convert_op(input_shapes):
def module_convert(a):
return jax.lax.convert_element_type(a, jnp.bfloat16)

verify_module(module_convert, [(2, 2)])
verify_module(module_convert, [(4, 4, 4)])
verify_module(module_convert, input_shapes)


def test_div_op():
@pytest.mark.parametrize(
["input_shapes", "required_atol"],
[([(3, 3), (3, 3)], 0.01), ([(3, 3, 3), (3, 3, 3)], 35e-2)],
)
def test_div_op(input_shapes, required_atol):
def module_div(a, b):
return a / b

verify_module(module_div, [(3, 3), (3, 3)])
verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2)
verify_module(module_div, input_shapes, required_atol=required_atol)


def test_dot_general_op():
@pytest.mark.parametrize(
"input_shapes",
[[(2, 1), (1, 2)], [(1, 2), (2, 1)]],
)
def test_dot_general_op(input_shapes):
def module_dot_general(a, b):
return jnp.dot(a, b)

verify_module(module_dot_general, [(2, 1), (1, 2)])
verify_module(module_dot_general, [(1, 2), (2, 1)])
verify_module(module_dot_general, input_shapes)


# Exponential generate slightly different values, so using higher ATOL value.
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_exp_op():
@pytest.mark.parametrize(
["input_shapes", "required_atol"], [([(3, 3)], 20e-2), ([(3, 3, 3)], 25e-2)]
)
def test_exp_op(input_shapes, required_atol):
def module_exp(a):
return jnp.exp(a)

verify_module(module_exp, [(3, 3)], required_atol=20e-2)
verify_module(module_exp, [(3, 3, 3)], required_atol=25e-2)
verify_module(module_exp, input_shapes, required_atol=required_atol)


def test_maximum_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]])
def test_maximum_op(input_shapes):
def module_maximum(a, b):
return jnp.maximum(a, b)

verify_module(module_maximum, [(3, 3), (3, 3)])
verify_module(module_maximum, [(3, 3, 3), (3, 3, 3)])
verify_module(module_maximum, input_shapes)


def test_multiply_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]])
def test_multiply_op(input_shapes):
def module_multiply(a, b):
return a * b

verify_module(module_multiply, [(3, 3), (3, 3)])
verify_module(module_multiply, [(3, 3, 3), (3, 3, 3)])
verify_module(module_multiply, input_shapes)


def test_negate_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_negate_op(input_shapes):
def module_negate(a):
return -a

verify_module(module_negate, [(3, 3)])
verify_module(module_negate, [(3, 3, 3)])
verify_module(module_negate, input_shapes)


# Reduce is failing due to error in constant.
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
@pytest.mark.skip("keepdim=False is not supported")
def test_reduce_op():
def test_reduce_op(input_shapes):
def module_reduce_max(a):
return jnp.max(a)

def module_reduce_sum(a):
return jnp.sum(a)

verify_module(module_reduce_max, [(3, 3)])
verify_module(module_reduce_max, [(3, 3, 3)])
verify_module(module_reduce_max, input_shapes)
verify_module(module_reduce_sum, input_shapes)

verify_module(module_reduce_sum, [(3, 3)])
verify_module(module_reduce_sum, [(3, 3, 3)])


def test_rsqrt_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_rsqrt_op(input_shapes):
def module_rsqrt(a):
return jax.lax.rsqrt(a)

verify_module(module_rsqrt, [(3, 3)])
verify_module(module_rsqrt, [(3, 3, 3)])
verify_module(module_rsqrt, input_shapes)


# Needs to have a bigger atol due to inaccuracies in the exp op on tt-metal
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_expm1_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_expm1_op(input_shapes):
def module_expm1(a):
return jax.lax.expm1(a)

verify_module(module_expm1, [(3, 3)], required_atol=20e-2)
verify_module(module_expm1, [(3, 3, 3)], required_atol=20e-2)
verify_module(module_expm1, input_shapes, required_atol=20e-2)


def test_log1p_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_log1p_op(input_shapes):
def module_log1p(a):
return jax.lax.log1p(a)

verify_module(module_log1p, [(3, 3)], required_atol=2e-2)
verify_module(module_log1p, [(3, 3, 3)], required_atol=2e-2)
verify_module(module_log1p, input_shapes, required_atol=2e-2)


def test_sign_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_sign_op(input_shapes):
def module_sign(a):
return jax.lax.sign(a)

verify_module(module_sign, [(3, 3)])
verify_module(module_sign, [(3, 3, 3)])
verify_module(module_sign, input_shapes)


def test_sqrt_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_sqrt_op(input_shapes):
def module_sqrt(a):
return jnp.sqrt(a)

verify_module(module_sqrt, [(3, 3)])
verify_module(module_sqrt, [(3, 3, 3)])
verify_module(module_sqrt, input_shapes)


def test_sub_op():
@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]])
def test_sub_op(input_shapes):
def module_sub(a, b):
return a - b

verify_module(module_sub, [(3, 3), (3, 3)])
verify_module(module_sub, [(3, 3, 3), (3, 3, 3)])
verify_module(module_sub, input_shapes)


def test_transpose_op_2d():
@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]])
def test_transpose_op(input_shapes):
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3)])
verify_module(module_transpose, input_shapes)


@pytest.mark.skip(
"Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73"
)
def test_scalar_type():
def module_scalar_type(a):
return a.shape[0]

verify_module(module_scalar_type, [(3, 3)])


# Transpose op failing for higher ranks/dimensions.
@pytest.mark.skip("Transpose op failing for higher ranks/dimensions.")
def test_transpose_op_3d():
def module_transpose(a):
return jnp.transpose(a)

verify_module(module_transpose, [(3, 3, 3)])


dim0_cases = []
for begin in numpy.arange(10).tolist():
for end in numpy.arange(90, 100).tolist():
Expand All @@ -264,9 +262,8 @@ def module_transpose(a):


@pytest.mark.parametrize(
"begin, end, dim", [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases]
["begin", "end", "dim"], [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases]
)
@pytest.mark.skip("Requires tt-metal uplift.")
def test_slice(begin, end, dim):
def module_slice(a):
if dim == 0:
Expand Down
32 changes: 23 additions & 9 deletions tests/TTIR/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,19 @@


@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, padding",
(
[
"batch_size",
"output_channels",
"input_channels",
"input_height",
"input_width",
"filter_height",
"filter_width",
"stride_h",
"stride_w",
"padding",
],
[
# RESNET
(1, 64, 3, 224, 224, 7, 7, 2, 2, 3),
(1, 256, 64, 56, 56, 1, 1, 1, 1, 0),
Expand All @@ -23,23 +34,26 @@
(1, 128, 128, 28, 28, 3, 3, 1, 1, 1),
(1, 512, 128, 28, 28, 1, 1, 1, 1, 0),
(1, 128, 512, 28, 28, 1, 1, 1, 1, 0),
# (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), Requires block sharding
(1, 1024, 512, 28, 28, 1, 1, 2, 2, 0),
(1, 256, 512, 28, 28, 1, 1, 2, 2, 0),
(1, 256, 256, 14, 14, 3, 3, 1, 1, 1),
(1, 1024, 256, 14, 14, 1, 1, 1, 1, 0),
(1, 256, 1024, 14, 14, 1, 1, 1, 1, 0),
# (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding
# (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding
# (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), Requires block sharding
# (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), Requires block sharding
# (1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), Requires block sharding
(1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0),
(1, 512, 1024, 14, 14, 1, 1, 2, 2, 0),
(1, 512, 512, 7, 7, 3, 3, 1, 1, 1),
(1, 2048, 512, 7, 7, 1, 1, 1, 1, 0),
pytest.param(
*(1, 512, 2048, 7, 7, 1, 1, 1, 1, 0),
marks=pytest.mark.skip(reason="PCC is 0.8828 which is less than 0.95"),
),
# MISCELLANEOUS
(1, 64, 16, 115, 115, 4, 4, 1, 1, 0),
(1, 64, 64, 8, 8, 3, 3, 1, 1, 1),
(1, 64, 64, 16, 16, 3, 3, 1, 1, 1),
(1, 256, 256, 7, 7, 3, 3, 1, 1, 1),
(1, 256, 64, 56, 56, 1, 1, 2, 2, 0),
),
],
)
def test_conv2d(
batch_size,
Expand Down
Loading
Loading