Skip to content

Commit

Permalink
Add conv2d and maxpool2d tests. Uplift MLIR to latest main + stablehl…
Browse files Browse the repository at this point in the history
…o --> TTIR for conv2d and maxpool2d
  • Loading branch information
LPanosTT committed Oct 1, 2024
1 parent 4c1eda3 commit 8c17b74
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 6 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
jaxlib==0.4.31
jax
flax
cmake
ninja
clang-format
Expand Down
70 changes: 70 additions & 0 deletions tests/TTIR/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import jax
import jax.numpy as jnp

from infrastructure import verify_module


@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, padding",
(
# RESNET
(1, 64, 3, 224, 224, 7, 7, 2, 2, 3),
(1, 256, 64, 56, 56, 1, 1, 1, 1, 0),
(1, 64, 64, 56, 56, 1, 1, 1, 1, 0),
(1, 64, 64, 56, 56, 3, 3, 1, 1, 1),
(1, 64, 256, 56, 56, 1, 1, 1, 1, 0),
(1, 512, 256, 56, 56, 1, 1, 2, 2, 0),
(1, 128, 256, 56, 56, 1, 1, 2, 2, 0),
(1, 128, 128, 28, 28, 3, 3, 1, 1, 1),
(1, 512, 128, 28, 28, 1, 1, 1, 1, 0),
(1, 128, 512, 28, 28, 1, 1, 1, 1, 0),
# (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), Requires block sharding
(1, 256, 512, 28, 28, 1, 1, 2, 2, 0),
(1, 256, 256, 14, 14, 3, 3, 1, 1, 1),
(1, 1024, 256, 14, 14, 1, 1, 1, 1, 0),
(1, 256, 1024, 14, 14, 1, 1, 1, 1, 0),
# (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding
# (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding
# (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), Requires block sharding
# (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), Requires block sharding
# (1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), Requires block sharding
# MISCELLANEOUS
(1, 64, 16, 115, 115, 4, 4, 1, 1, 0),
(1, 64, 64, 56, 56, 3, 3, 1, 1, 1),
(1, 128, 128, 28, 28, 3, 3, 1, 1, 1),
(1, 256, 256, 14, 14, 3, 3, 1, 1, 1),
(1, 64, 64, 8, 8, 3, 3, 1, 1, 1),
(1, 64, 64, 16, 16, 3, 3, 1, 1, 1),
(1, 256, 256, 7, 7, 3, 3, 1, 1, 1),
(1, 256, 64, 56, 56, 1, 1, 2, 2, 0),
),
)
def test_conv2d(
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
padding
):
def module_conv(img, weights):
return jax.lax.conv_general_dilated(img, weights, [stride_h, stride_w], [[padding]*2]*2, dimension_numbers=('NHWC', 'OIHW', 'NHWC'))


img_shape = (batch_size, input_height, input_width, input_channels)
weights_shape = (output_channels, input_channels, filter_height, filter_width)

# Some resnet convolutions seem to require bfloat16, ttnn throws in runtime otherwise.
# On another note, MaxPool2d is also only supported for bfloat16 in ttnn, so we have
# to run resnet in bfloat16 for the time being.
verify_module(module_conv, [img_shape, weights_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16)
64 changes: 64 additions & 0 deletions tests/TTIR/test_maxpool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import jax
import jax.numpy as jnp
import flax

from infrastructure import verify_module


@pytest.mark.parametrize(
"act_shape", ## NHWC
[
(1, 32, 32, 32),
(1, 32, 32, 64),
(1, 32, 32, 128),
(1, 32, 64, 32),
(1, 32, 64, 64),
(1, 32, 64, 128),
(1, 32, 128, 32),
(1, 32, 128, 64),
(1, 32, 128, 128),
(1, 64, 32, 32),
(1, 64, 32, 64),
(1, 64, 32, 128),
(1, 64, 64, 32),
(1, 64, 64, 64),
(1, 64, 64, 128),
(1, 64, 128, 32),
(1, 64, 128, 64),
(1, 64, 128, 128),
(1, 128, 32, 32),
(1, 128, 32, 64),
(1, 128, 32, 128),
(1, 128, 64, 32),
(1, 128, 64, 64),
(1, 128, 64, 128),
(1, 128, 128, 32),
(1, 128, 128, 64),
(1, 128, 128, 128),
],
)
def test_maxpool2d(
act_shape,
):
def module_maxpool(img):
return flax.linen.max_pool(img, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0)))

verify_module(module_maxpool, [act_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16)

def test_resnet_maxpool2d():
# This maxpool doesnt work on its own because of the reshape that is inserted on its input
# Issue: https://github.com/tenstorrent/tt-metal/issues/12866
# It works with the conv on top since the output is already flattened.
# In resnet, this is essentially the sequence that occurs. The only difference is that
# there are a few eltwise ops in between.
def module_resnet_maxpool(act, weights):
x = jax.lax.conv_general_dilated(act, weights, [2, 2], ((3, 3), (3, 3)), dimension_numbers=('NHWC', 'OIHW', 'NHWC'))
x = flax.linen.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
return x

verify_module(module_resnet_maxpool, [(1, 224, 224, 3), (64, 3, 7, 7)], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16)
8 changes: 4 additions & 4 deletions tests/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import jax
import jax.numpy as jnp

def random_input_tensor(shape, key=42, on_device=False):
def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32):
def random_input(shape, key):
return jax.random.uniform(jax.random.PRNGKey(key), shape=shape)
return jax.random.uniform(jax.random.PRNGKey(key), shape=shape, dtype=dtype)

jitted_tensor_creator = jax.jit(random_input, static_argnums=[0,1], backend='cpu')
tensor = jitted_tensor_creator(shape, key)
Expand Down Expand Up @@ -37,9 +37,9 @@ def compare_tensor_to_golden(tensor, golden, required_pcc=0.99, required_atol=1e

return ret

def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2):
def verify_module(module, input_shapes, key=42, required_pcc=0.99, required_atol=1e-2, dtype=jnp.float32):
tt_device = jax.devices()[0]
cpu_inputs = [random_input_tensor(shape, key) for shape in input_shapes]
cpu_inputs = [random_input_tensor(shape, key, dtype=dtype) for shape in input_shapes]
tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs]
graph = jax.jit(module)
res = graph(*tt_inputs)
Expand Down
4 changes: 2 additions & 2 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
#

set(TT_MLIR_VERSION "5d92bf937bc76b521d39e0fa320b20773905bfc1")
set(TT_MLIR_VERSION "6dc351c2dc01dbb683617ff53d3cd623f5e87acc")
set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1")

set(CMAKE_BUILD_TYPE Debug)
if (TOOLCHAIN STREQUAL "ON")
cmake_minimum_required(VERSION 3.20)
project(ttmlir-toolchain LANGUAGES CXX C)
Expand Down

0 comments on commit 8c17b74

Please sign in to comment.