Skip to content

Commit

Permalink
Added conv1 test
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Dec 10, 2024
1 parent 32418e6 commit f27867c
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/TTIR/test_conv2d.py → tests/TTIR/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,41 @@ def module_conv(img, weights):
required_atol=float("inf"),
dtype=jnp.bfloat16,
)


@pytest.mark.parametrize(
[
"img_shape",
"weights_shape",
],
[
((1, 256, 512), (1024, 256, 1)),
((1, 256, 256), (512, 256, 1)),
((1, 512, 256), (512, 512, 1)),
((1, 512, 512), (1024, 512, 1)),
]
)
def test_conv1d(
img_shape,
weights_shape
):
def module_conv(img, weights):
return jax.lax.conv_general_dilated(
lhs=img,
rhs=weights,
window_strides=(1,),
padding=[(0, 0)],
lhs_dilation=None,
rhs_dilation=(1,),
dimension_numbers=("NCW", "OIW", "NCW"),
feature_group_count=1,
batch_group_count=1
)

verify_module(
module_conv,
[img_shape, weights_shape],
required_pcc=0.95,
required_atol=float("inf"),
dtype=jnp.bfloat16,
)

0 comments on commit f27867c

Please sign in to comment.