From 8c17b7416de756dff7ab5c9cccca7c125cb29f79 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Thu, 26 Sep 2024 19:51:14 +0000 Subject: [PATCH] Add conv2d and maxpool2d tests. Uplift MLIR to latest main + stablehlo --> TTIR for conv2d and maxpool2d --- requirements.txt | 1 + tests/TTIR/test_conv2d.py | 70 ++++++++++++++++++++++++++++++++++++ tests/TTIR/test_maxpool2d.py | 64 +++++++++++++++++++++++++++++++++ tests/infrastructure.py | 8 ++--- third_party/CMakeLists.txt | 4 +-- 5 files changed, 141 insertions(+), 6 deletions(-) create mode 100644 tests/TTIR/test_conv2d.py create mode 100644 tests/TTIR/test_maxpool2d.py diff --git a/requirements.txt b/requirements.txt index 97ceaca..1ad3779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html jaxlib==0.4.31 jax +flax cmake ninja clang-format diff --git a/tests/TTIR/test_conv2d.py b/tests/TTIR/test_conv2d.py new file mode 100644 index 0000000..461009a --- /dev/null +++ b/tests/TTIR/test_conv2d.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp + +from infrastructure import verify_module + + +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, padding", + ( + # RESNET + (1, 64, 3, 224, 224, 7, 7, 2, 2, 3), + (1, 256, 64, 56, 56, 1, 1, 1, 1, 0), + (1, 64, 64, 56, 56, 1, 1, 1, 1, 0), + (1, 64, 64, 56, 56, 3, 3, 1, 1, 1), + (1, 64, 256, 56, 56, 1, 1, 1, 1, 0), + (1, 512, 256, 56, 56, 1, 1, 2, 2, 0), + (1, 128, 256, 56, 56, 1, 1, 2, 2, 0), + (1, 128, 128, 28, 28, 3, 3, 1, 1, 1), + (1, 512, 128, 28, 28, 1, 1, 1, 1, 0), + (1, 128, 512, 28, 28, 1, 1, 1, 1, 0), + # (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), Requires block sharding + (1, 256, 512, 28, 28, 1, 1, 2, 2, 0), + (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), + (1, 1024, 256, 14, 14, 1, 1, 1, 1, 0), + (1, 256, 1024, 14, 14, 1, 1, 1, 1, 0), + # (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding + # (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding + # (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), Requires block sharding + # (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), Requires block sharding + # (1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), Requires block sharding + + # MISCELLANEOUS + (1, 64, 16, 115, 115, 4, 4, 1, 1, 0), + (1, 64, 64, 56, 56, 3, 3, 1, 1, 1), + (1, 128, 128, 28, 28, 3, 3, 1, 1, 1), + (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), + (1, 64, 64, 8, 8, 3, 3, 1, 1, 1), + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1), + (1, 256, 256, 7, 7, 3, 3, 1, 1, 1), + (1, 256, 64, 56, 56, 1, 1, 2, 2, 0), + ), +) +def test_conv2d( + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + padding +): + def module_conv(img, weights): + return jax.lax.conv_general_dilated(img, weights, [stride_h, stride_w], [[padding]*2]*2, dimension_numbers=('NHWC', 'OIHW', 'NHWC')) + + + img_shape = (batch_size, input_height, input_width, input_channels) + weights_shape = (output_channels, input_channels, filter_height, filter_width) + + # Some resnet convolutions seem to require bfloat16, ttnn throws in runtime otherwise. + # On another note, MaxPool2d is also only supported for bfloat16 in ttnn, so we have + # to run resnet in bfloat16 for the time being. + verify_module(module_conv, [img_shape, weights_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) diff --git a/tests/TTIR/test_maxpool2d.py b/tests/TTIR/test_maxpool2d.py new file mode 100644 index 0000000..3f6499b --- /dev/null +++ b/tests/TTIR/test_maxpool2d.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp +import flax + +from infrastructure import verify_module + + +@pytest.mark.parametrize( + "act_shape", ## NHWC + [ + (1, 32, 32, 32), + (1, 32, 32, 64), + (1, 32, 32, 128), + (1, 32, 64, 32), + (1, 32, 64, 64), + (1, 32, 64, 128), + (1, 32, 128, 32), + (1, 32, 128, 64), + (1, 32, 128, 128), + (1, 64, 32, 32), + (1, 64, 32, 64), + (1, 64, 32, 128), + (1, 64, 64, 32), + (1, 64, 64, 64), + (1, 64, 64, 128), + (1, 64, 128, 32), + (1, 64, 128, 64), + (1, 64, 128, 128), + (1, 128, 32, 32), + (1, 128, 32, 64), + (1, 128, 32, 128), + (1, 128, 64, 32), + (1, 128, 64, 64), + (1, 128, 64, 128), + (1, 128, 128, 32), + (1, 128, 128, 64), + (1, 128, 128, 128), + ], +) +def test_maxpool2d( + act_shape, +): + def module_maxpool(img): + return flax.linen.max_pool(img, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) + + verify_module(module_maxpool, [act_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) + +def test_resnet_maxpool2d(): + # This maxpool doesnt work on its own because of the reshape that is inserted on its input + # Issue: https://github.com/tenstorrent/tt-metal/issues/12866 + # It works with the conv on top since the output is already flattened. + # In resnet, this is essentially the sequence that occurs. The only difference is that + # there are a few eltwise ops in between. + def module_resnet_maxpool(act, weights): + x = jax.lax.conv_general_dilated(act, weights, [2, 2], ((3, 3), (3, 3)), dimension_numbers=('NHWC', 'OIHW', 'NHWC')) + x = flax.linen.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) + return x + + verify_module(module_resnet_maxpool, [(1, 224, 224, 3), (64, 3, 7, 7)], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) \ No newline at end of file diff --git a/tests/infrastructure.py b/tests/infrastructure.py index 9ae16a8..9666c1c 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -5,9 +5,9 @@ import jax import jax.numpy as jnp -def random_input_tensor(shape, key=42, on_device=False): +def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32): def random_input(shape, key): - return jax.random.uniform(jax.random.PRNGKey(key), shape=shape) + return jax.random.uniform(jax.random.PRNGKey(key), shape=shape, dtype=dtype) jitted_tensor_creator = jax.jit(random_input, static_argnums=[0,1], backend='cpu') tensor = jitted_tensor_creator(shape, key) @@ -37,9 +37,9 @@ def compare_tensor_to_golden(tensor, golden, required_pcc=0.99, required_atol=1e return ret -def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2): +def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2, dtype=jnp.float32): tt_device = jax.devices()[0] - cpu_inputs = [random_input_tensor(shape, key) for shape in input_shapes] + cpu_inputs = [random_input_tensor(shape, key, dtype=dtype) for shape in input_shapes] tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs] graph = jax.jit(module) res = graph(*tt_inputs) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index e3ed07b..7fbfa0d 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "5d92bf937bc76b521d39e0fa320b20773905bfc1") +set(TT_MLIR_VERSION "6dc351c2dc01dbb683617ff53d3cd623f5e87acc") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") - +set(CMAKE_BUILD_TYPE Debug) if (TOOLCHAIN STREQUAL "ON") cmake_minimum_required(VERSION 3.20) project(ttmlir-toolchain LANGUAGES CXX C)