From 50fd849e3f4529d6e68cb57db78dce5bbb88faca Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Thu, 26 Sep 2024 19:51:14 +0000 Subject: [PATCH] Add conv2d, maxpool2d, and reshape tests. Uplift MLIR to latest main + stablehlo --> TTIR for conv2d, maxpool2d, and reshape Skip xfailing tests because runtime failures causing segfault on device closuer --- requirements.txt | 1 + tests/TTIR/test_basic_ops.py | 8 ++-- tests/TTIR/test_conv2d.py | 67 ++++++++++++++++++++++++++++ tests/TTIR/test_maxpool2d.py | 64 ++++++++++++++++++++++++++ tests/TTIR/test_mnist.py | 4 +- tests/TTIR/test_reshape.py | 22 +++++++++ tests/TTIR/test_simple_regression.py | 4 +- tests/infrastructure.py | 8 ++-- third_party/CMakeLists.txt | 2 +- 9 files changed, 167 insertions(+), 13 deletions(-) create mode 100644 tests/TTIR/test_conv2d.py create mode 100644 tests/TTIR/test_maxpool2d.py create mode 100644 tests/TTIR/test_reshape.py diff --git a/requirements.txt b/requirements.txt index 97ceaca..1ad3779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html jaxlib==0.4.31 jax +flax cmake ninja clang-format diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index cf259b7..5c36604 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -19,7 +19,7 @@ def module_abs(a): #Broadcasted values are incorrect -@pytest.mark.xfail +@pytest.mark.skip("Broadcasted values are incorrect") def test_broadcast_op(): def module_broadcast(a): return jnp.broadcast_to(a, (2, 4)) @@ -28,7 +28,7 @@ def module_broadcast(a): #error: 'ttir.constant' op failed to verify that all of {value, result} have same shape -@pytest.mark.xfail +@pytest.mark.skip("Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615") def test_constant_op(): def module_constant_zeros(a): zeros = jnp.zeros(a.shape) @@ -105,7 +105,7 @@ def module_negate(a): #Reduce is failing due to error in constant. -@pytest.mark.xfail +@pytest.mark.skip("keepdim=False is not supported") def test_reduce_op(): def module_reduce_max(a): return jnp.max(a) @@ -152,7 +152,7 @@ def module_transpose(a): # Transpose op failing for higher ranks/dimensions. -@pytest.mark.xfail +@pytest.mark.skip("Transpose op failing for higher ranks/dimensions.") def test_transpose_op_3d(): def module_transpose(a): return jnp.transpose(a) diff --git a/tests/TTIR/test_conv2d.py b/tests/TTIR/test_conv2d.py new file mode 100644 index 0000000..9bc3c0d --- /dev/null +++ b/tests/TTIR/test_conv2d.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp + +from infrastructure import verify_module + + +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, padding", + ( + # RESNET + (1, 64, 3, 224, 224, 7, 7, 2, 2, 3), + (1, 256, 64, 56, 56, 1, 1, 1, 1, 0), + (1, 64, 64, 56, 56, 1, 1, 1, 1, 0), + (1, 64, 64, 56, 56, 3, 3, 1, 1, 1), + (1, 64, 256, 56, 56, 1, 1, 1, 1, 0), + (1, 512, 256, 56, 56, 1, 1, 2, 2, 0), + (1, 128, 256, 56, 56, 1, 1, 2, 2, 0), + (1, 128, 128, 28, 28, 3, 3, 1, 1, 1), + (1, 512, 128, 28, 28, 1, 1, 1, 1, 0), + (1, 128, 512, 28, 28, 1, 1, 1, 1, 0), + # (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), Requires block sharding + (1, 256, 512, 28, 28, 1, 1, 2, 2, 0), + (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), + (1, 1024, 256, 14, 14, 1, 1, 1, 1, 0), + (1, 256, 1024, 14, 14, 1, 1, 1, 1, 0), + # (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding + # (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding + # (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), Requires block sharding + # (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), Requires block sharding + # (1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), Requires block sharding + + # MISCELLANEOUS + (1, 64, 16, 115, 115, 4, 4, 1, 1, 0), + (1, 64, 64, 8, 8, 3, 3, 1, 1, 1), + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1), + (1, 256, 256, 7, 7, 3, 3, 1, 1, 1), + (1, 256, 64, 56, 56, 1, 1, 2, 2, 0), + ), +) +def test_conv2d( + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + padding +): + def module_conv(img, weights): + return jax.lax.conv_general_dilated(img, weights, [stride_h, stride_w], [[padding]*2]*2, dimension_numbers=('NHWC', 'OIHW', 'NHWC')) + + + img_shape = (batch_size, input_height, input_width, input_channels) + weights_shape = (output_channels, input_channels, filter_height, filter_width) + + # Some resnet convolutions seem to require bfloat16, ttnn throws in runtime otherwise. + # On another note, MaxPool2d is also only supported for bfloat16 in ttnn, so we have + # to run resnet in bfloat16 for the time being. + verify_module(module_conv, [img_shape, weights_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) diff --git a/tests/TTIR/test_maxpool2d.py b/tests/TTIR/test_maxpool2d.py new file mode 100644 index 0000000..79dfd2e --- /dev/null +++ b/tests/TTIR/test_maxpool2d.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp +import flax + +from infrastructure import verify_module + + +@pytest.mark.parametrize( + "act_shape", ## NHWC + [ + (1, 32, 32, 32), + (1, 32, 32, 64), + (1, 32, 32, 128), + (1, 32, 64, 32), + (1, 32, 64, 64), + (1, 32, 64, 128), + (1, 32, 128, 32), + (1, 32, 128, 64), + (1, 32, 128, 128), + (1, 64, 32, 32), + (1, 64, 32, 64), + (1, 64, 32, 128), + (1, 64, 64, 32), + (1, 64, 64, 64), + (1, 64, 64, 128), + (1, 64, 128, 32), + (1, 64, 128, 64), + (1, 64, 128, 128), + (1, 128, 32, 32), + (1, 128, 32, 64), + (1, 128, 32, 128), + (1, 128, 64, 32), + (1, 128, 64, 64), + (1, 128, 64, 128), + (1, 128, 128, 32), + (1, 128, 128, 64), + (1, 128, 128, 128), + ], +) +def test_maxpool2d( + act_shape, +): + def module_maxpool(img): + return flax.linen.max_pool(img, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) + + verify_module(module_maxpool, [act_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) + +def test_resnet_maxpool2d(): + # This maxpool doesnt work on its own because of the reshape that is inserted on its input + # Issue: https://github.com/tenstorrent/tt-metal/issues/12866 + # It works with the conv on top since the output is already flattened. + # In resnet, this is essentially the sequence that occurs. The only difference is that + # there are a few eltwise ops in between. + def module_resnet_maxpool(act, weights): + x = jax.lax.conv_general_dilated(act, weights, [2, 2], ((3, 3), (3, 3)), dimension_numbers=('NHWC', 'OIHW', 'NHWC')) + x = flax.linen.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) + return x + + verify_module(module_resnet_maxpool, [(1, 224, 224, 3), (64, 3, 7, 7)], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) diff --git a/tests/TTIR/test_mnist.py b/tests/TTIR/test_mnist.py index 94c076a..031b029 100644 --- a/tests/TTIR/test_mnist.py +++ b/tests/TTIR/test_mnist.py @@ -34,14 +34,14 @@ def module_relu(a): verify_module(module_relu, [(32, 32)]) -@pytest.mark.xfail +@pytest.mark.skip("keepdims=False in runtime") def test_softmax(): def module_softmax(a): return jax.nn.softmax(a) verify_module(module_softmax, [(32, 32)]) -@pytest.mark.xfail +@pytest.mark.skip("Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615") def test_mnist(): def module_mnist(act, w0, b0, w1, b1, w2, b2): x = jnp.matmul(act, w0) + b0 diff --git a/tests/TTIR/test_reshape.py b/tests/TTIR/test_reshape.py new file mode 100644 index 0000000..cbf14e5 --- /dev/null +++ b/tests/TTIR/test_reshape.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp +import flax + +from infrastructure import verify_module +@pytest.mark.parametrize("source_and_target_shape", + [((8, 32, 256), (2, 4, 32, 256)), + ((8, 32, 32), (1, 2, 4, 32, 32)), + ((8192, 128), (1, 256, 32, 128)) + ], + ids=["1", "2", "3"]) +def test_reshape(source_and_target_shape): + act_shape, target_shape = source_and_target_shape + def module_reshape(act): + return jnp.reshape(act, target_shape) + + verify_module(module_reshape, [act_shape]) diff --git a/tests/TTIR/test_simple_regression.py b/tests/TTIR/test_simple_regression.py index 12049b3..854b19a 100644 --- a/tests/TTIR/test_simple_regression.py +++ b/tests/TTIR/test_simple_regression.py @@ -9,7 +9,7 @@ from infrastructure import verify_module -@pytest.mark.xfail +@pytest.mark.skip("Module contains function used inside the main function. Cannot compile Flatbuffer.") def test_gradient(): def simple_gradient(a): def gradient(a): @@ -20,7 +20,7 @@ def gradient(a): verify_module(simple_gradient, [(2, 2)]) -@pytest.mark.xfail +@pytest.mark.skip("TT_METAL_HOME is not set.") def test_simple_regression(): def simple_regression(weights, bias, X, y): def loss(weights, bias, X, y): diff --git a/tests/infrastructure.py b/tests/infrastructure.py index 63dc691..c5a456f 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -5,9 +5,9 @@ import jax import jax.numpy as jnp -def random_input_tensor(shape, key=42, on_device=False): +def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32): def random_input(shape, key): - return jax.random.uniform(jax.random.PRNGKey(key), shape=shape) + return jax.random.uniform(jax.random.PRNGKey(key), shape=shape, dtype=dtype) jitted_tensor_creator = jax.jit(random_input, static_argnums=[0,1], backend='cpu') tensor = jitted_tensor_creator(shape, key) @@ -37,9 +37,9 @@ def compare_tensor_to_golden(tensor, golden, required_pcc=0.99, required_atol=1e return ret -def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2): +def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2, dtype=jnp.float32): tt_device = jax.devices()[0] - cpu_inputs = [random_input_tensor(input_shapes[i], key + i) for i in range(len(input_shapes))] + cpu_inputs = [random_input_tensor(input_shapes[i], key + i, dtype=dtype) for i in range(len(input_shapes))] tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs] graph = jax.jit(module) res = graph(*tt_inputs) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 7a0d7f4..6c2a314 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "20a7ccd485a198fea14861e5a765dd51972e85f3") +set(TT_MLIR_VERSION "8c6494cb1f4fed060073f735b1f88c5da4d187f6") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON")