-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add conv2d, maxpool2d, and reshape tests. Uplift MLIR to latest main …
…+ stablehlo --> TTIR for conv2d, maxpool2d, and reshape Skip xfailing tests because runtime failures causing segfault on device closuer
- Loading branch information
Showing
9 changed files
with
167 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from infrastructure import verify_module | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, padding", | ||
( | ||
# RESNET | ||
(1, 64, 3, 224, 224, 7, 7, 2, 2, 3), | ||
(1, 256, 64, 56, 56, 1, 1, 1, 1, 0), | ||
(1, 64, 64, 56, 56, 1, 1, 1, 1, 0), | ||
(1, 64, 64, 56, 56, 3, 3, 1, 1, 1), | ||
(1, 64, 256, 56, 56, 1, 1, 1, 1, 0), | ||
(1, 512, 256, 56, 56, 1, 1, 2, 2, 0), | ||
(1, 128, 256, 56, 56, 1, 1, 2, 2, 0), | ||
(1, 128, 128, 28, 28, 3, 3, 1, 1, 1), | ||
(1, 512, 128, 28, 28, 1, 1, 1, 1, 0), | ||
(1, 128, 512, 28, 28, 1, 1, 1, 1, 0), | ||
# (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), Requires block sharding | ||
(1, 256, 512, 28, 28, 1, 1, 2, 2, 0), | ||
(1, 256, 256, 14, 14, 3, 3, 1, 1, 1), | ||
(1, 1024, 256, 14, 14, 1, 1, 1, 1, 0), | ||
(1, 256, 1024, 14, 14, 1, 1, 1, 1, 0), | ||
# (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding | ||
# (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding | ||
# (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), Requires block sharding | ||
# (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), Requires block sharding | ||
# (1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), Requires block sharding | ||
# MISCELLANEOUS | ||
(1, 64, 16, 115, 115, 4, 4, 1, 1, 0), | ||
(1, 64, 64, 8, 8, 3, 3, 1, 1, 1), | ||
(1, 64, 64, 16, 16, 3, 3, 1, 1, 1), | ||
(1, 256, 256, 7, 7, 3, 3, 1, 1, 1), | ||
(1, 256, 64, 56, 56, 1, 1, 2, 2, 0), | ||
), | ||
) | ||
def test_conv2d( | ||
batch_size, | ||
output_channels, | ||
input_channels, | ||
input_height, | ||
input_width, | ||
filter_height, | ||
filter_width, | ||
stride_h, | ||
stride_w, | ||
padding | ||
): | ||
def module_conv(img, weights): | ||
return jax.lax.conv_general_dilated(img, weights, [stride_h, stride_w], [[padding]*2]*2, dimension_numbers=('NHWC', 'OIHW', 'NHWC')) | ||
|
||
|
||
img_shape = (batch_size, input_height, input_width, input_channels) | ||
weights_shape = (output_channels, input_channels, filter_height, filter_width) | ||
|
||
# Some resnet convolutions seem to require bfloat16, ttnn throws in runtime otherwise. | ||
# On another note, MaxPool2d is also only supported for bfloat16 in ttnn, so we have | ||
# to run resnet in bfloat16 for the time being. | ||
verify_module(module_conv, [img_shape, weights_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import jax | ||
import jax.numpy as jnp | ||
import flax | ||
|
||
from infrastructure import verify_module | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"act_shape", ## NHWC | ||
[ | ||
(1, 32, 32, 32), | ||
(1, 32, 32, 64), | ||
(1, 32, 32, 128), | ||
(1, 32, 64, 32), | ||
(1, 32, 64, 64), | ||
(1, 32, 64, 128), | ||
(1, 32, 128, 32), | ||
(1, 32, 128, 64), | ||
(1, 32, 128, 128), | ||
(1, 64, 32, 32), | ||
(1, 64, 32, 64), | ||
(1, 64, 32, 128), | ||
(1, 64, 64, 32), | ||
(1, 64, 64, 64), | ||
(1, 64, 64, 128), | ||
(1, 64, 128, 32), | ||
(1, 64, 128, 64), | ||
(1, 64, 128, 128), | ||
(1, 128, 32, 32), | ||
(1, 128, 32, 64), | ||
(1, 128, 32, 128), | ||
(1, 128, 64, 32), | ||
(1, 128, 64, 64), | ||
(1, 128, 64, 128), | ||
(1, 128, 128, 32), | ||
(1, 128, 128, 64), | ||
(1, 128, 128, 128), | ||
], | ||
) | ||
def test_maxpool2d( | ||
act_shape, | ||
): | ||
def module_maxpool(img): | ||
return flax.linen.max_pool(img, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) | ||
|
||
verify_module(module_maxpool, [act_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) | ||
|
||
def test_resnet_maxpool2d(): | ||
# This maxpool doesnt work on its own because of the reshape that is inserted on its input | ||
# Issue: https://github.com/tenstorrent/tt-metal/issues/12866 | ||
# It works with the conv on top since the output is already flattened. | ||
# In resnet, this is essentially the sequence that occurs. The only difference is that | ||
# there are a few eltwise ops in between. | ||
def module_resnet_maxpool(act, weights): | ||
x = jax.lax.conv_general_dilated(act, weights, [2, 2], ((3, 3), (3, 3)), dimension_numbers=('NHWC', 'OIHW', 'NHWC')) | ||
x = flax.linen.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) | ||
return x | ||
|
||
verify_module(module_resnet_maxpool, [(1, 224, 224, 3), (64, 3, 7, 7)], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import jax | ||
import jax.numpy as jnp | ||
import flax | ||
|
||
from infrastructure import verify_module | ||
@pytest.mark.parametrize("source_and_target_shape", | ||
[((8, 32, 256), (2, 4, 32, 256)), | ||
((8, 32, 32), (1, 2, 4, 32, 32)), | ||
((8192, 128), (1, 256, 32, 128)) | ||
], | ||
ids=["1", "2", "3"]) | ||
def test_reshape(source_and_target_shape): | ||
act_shape, target_shape = source_and_target_shape | ||
def module_reshape(act): | ||
return jnp.reshape(act, target_shape) | ||
|
||
verify_module(module_reshape, [act_shape]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters