diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index 6ec79c7..a282308 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -233,11 +233,12 @@ def module_transpose(a): verify_module(module_transpose, input_shapes) -def test_scalar_type(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)]]) +def test_scalar_type(input_shapes): def module_scalar_type(a): return a.shape[0] - verify_module(module_scalar_type, [(3, 3)]) + verify_module(module_scalar_type, input_shapes) dim0_cases = []