diff --git a/requirements.txt b/requirements.txt index 97ceaca..1ad3779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html jaxlib==0.4.31 jax +flax cmake ninja clang-format diff --git a/tests/TTIR/test_conv2d.py b/tests/TTIR/test_conv2d.py new file mode 100644 index 0000000..461009a --- /dev/null +++ b/tests/TTIR/test_conv2d.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp + +from infrastructure import verify_module + + +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, padding", + ( + # RESNET + (1, 64, 3, 224, 224, 7, 7, 2, 2, 3), + (1, 256, 64, 56, 56, 1, 1, 1, 1, 0), + (1, 64, 64, 56, 56, 1, 1, 1, 1, 0), + (1, 64, 64, 56, 56, 3, 3, 1, 1, 1), + (1, 64, 256, 56, 56, 1, 1, 1, 1, 0), + (1, 512, 256, 56, 56, 1, 1, 2, 2, 0), + (1, 128, 256, 56, 56, 1, 1, 2, 2, 0), + (1, 128, 128, 28, 28, 3, 3, 1, 1, 1), + (1, 512, 128, 28, 28, 1, 1, 1, 1, 0), + (1, 128, 512, 28, 28, 1, 1, 1, 1, 0), + # (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), Requires block sharding + (1, 256, 512, 28, 28, 1, 1, 2, 2, 0), + (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), + (1, 1024, 256, 14, 14, 1, 1, 1, 1, 0), + (1, 256, 1024, 14, 14, 1, 1, 1, 1, 0), + # (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding + # (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding + # (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), Requires block sharding + # (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), Requires block sharding + # (1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), Requires block sharding + + # MISCELLANEOUS + (1, 64, 16, 115, 115, 4, 4, 1, 1, 0), + (1, 64, 64, 56, 56, 3, 3, 1, 1, 1), + (1, 128, 128, 28, 28, 3, 3, 1, 1, 1), + (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), + (1, 64, 64, 8, 8, 3, 3, 1, 1, 1), + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1), + (1, 256, 256, 7, 7, 3, 3, 1, 1, 1), + (1, 256, 64, 56, 56, 1, 1, 2, 2, 0), + ), +) +def test_conv2d( + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + padding +): + def module_conv(img, weights): + return jax.lax.conv_general_dilated(img, weights, [stride_h, stride_w], [[padding]*2]*2, dimension_numbers=('NHWC', 'OIHW', 'NHWC')) + + + img_shape = (batch_size, input_height, input_width, input_channels) + weights_shape = (output_channels, input_channels, filter_height, filter_width) + + # Some resnet convolutions seem to require bfloat16, ttnn throws in runtime otherwise. + # On another note, MaxPool2d is also only supported for bfloat16 in ttnn, so we have + # to run resnet in bfloat16 for the time being. + verify_module(module_conv, [img_shape, weights_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) diff --git a/tests/TTIR/test_maxpool2d.py b/tests/TTIR/test_maxpool2d.py new file mode 100644 index 0000000..3f6499b --- /dev/null +++ b/tests/TTIR/test_maxpool2d.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp +import flax + +from infrastructure import verify_module + + +@pytest.mark.parametrize( + "act_shape", ## NHWC + [ + (1, 32, 32, 32), + (1, 32, 32, 64), + (1, 32, 32, 128), + (1, 32, 64, 32), + (1, 32, 64, 64), + (1, 32, 64, 128), + (1, 32, 128, 32), + (1, 32, 128, 64), + (1, 32, 128, 128), + (1, 64, 32, 32), + (1, 64, 32, 64), + (1, 64, 32, 128), + (1, 64, 64, 32), + (1, 64, 64, 64), + (1, 64, 64, 128), + (1, 64, 128, 32), + (1, 64, 128, 64), + (1, 64, 128, 128), + (1, 128, 32, 32), + (1, 128, 32, 64), + (1, 128, 32, 128), + (1, 128, 64, 32), + (1, 128, 64, 64), + (1, 128, 64, 128), + (1, 128, 128, 32), + (1, 128, 128, 64), + (1, 128, 128, 128), + ], +) +def test_maxpool2d( + act_shape, +): + def module_maxpool(img): + return flax.linen.max_pool(img, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) + + verify_module(module_maxpool, [act_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) + +def test_resnet_maxpool2d(): + # This maxpool doesnt work on its own because of the reshape that is inserted on its input + # Issue: https://github.com/tenstorrent/tt-metal/issues/12866 + # It works with the conv on top since the output is already flattened. + # In resnet, this is essentially the sequence that occurs. The only difference is that + # there are a few eltwise ops in between. + def module_resnet_maxpool(act, weights): + x = jax.lax.conv_general_dilated(act, weights, [2, 2], ((3, 3), (3, 3)), dimension_numbers=('NHWC', 'OIHW', 'NHWC')) + x = flax.linen.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) + return x + + verify_module(module_resnet_maxpool, [(1, 224, 224, 3), (64, 3, 7, 7)], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) \ No newline at end of file diff --git a/tests/TTIR/test_reshape.py b/tests/TTIR/test_reshape.py new file mode 100644 index 0000000..1fe45cc --- /dev/null +++ b/tests/TTIR/test_reshape.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import jax +import jax.numpy as jnp +import flax + +from infrastructure import verify_module +@pytest.mark.parametrize("source_and_target_shape", + [((8, 32, 256), (2, 4, 32, 256)), + ((8, 32, 32), (1, 2, 4, 32, 32)), + ((8192, 128), (1, 256, 32, 128)) + ], + ids=["1", "2", "3"]) +def test_reshape(source_and_target_shape): + act_shape, target_shape = source_and_target_shape + def module_reshape(act): + return jnp.reshape(act, target_shape) + + verify_module(module_reshape, [act_shape]) \ No newline at end of file diff --git a/tests/infrastructure.py b/tests/infrastructure.py index 63dc691..c5a456f 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -5,9 +5,9 @@ import jax import jax.numpy as jnp -def random_input_tensor(shape, key=42, on_device=False): +def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32): def random_input(shape, key): - return jax.random.uniform(jax.random.PRNGKey(key), shape=shape) + return jax.random.uniform(jax.random.PRNGKey(key), shape=shape, dtype=dtype) jitted_tensor_creator = jax.jit(random_input, static_argnums=[0,1], backend='cpu') tensor = jitted_tensor_creator(shape, key) @@ -37,9 +37,9 @@ def compare_tensor_to_golden(tensor, golden, required_pcc=0.99, required_atol=1e return ret -def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2): +def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2, dtype=jnp.float32): tt_device = jax.devices()[0] - cpu_inputs = [random_input_tensor(input_shapes[i], key + i) for i in range(len(input_shapes))] + cpu_inputs = [random_input_tensor(input_shapes[i], key + i, dtype=dtype) for i in range(len(input_shapes))] tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs] graph = jax.jit(module) res = graph(*tt_inputs) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 95c1870..bea7372 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "20a7ccd485a198fea14861e5a765dd51972e85f3") +set(TT_MLIR_VERSION "3477f54b07fccd7f4c4f5828b97cd2c4cf581d07") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON")