Skip to content

Commit

Permalink
add tests for sin, cosine, tanh, leakyrelu, layernorm, gelu, clip, cu…
Browse files Browse the repository at this point in the history
…msum, batchnorm2d, convtranspose2d, where ops (#613)
  • Loading branch information
kamalrajkannan78 authored Nov 5, 2024
1 parent 3a4dc32 commit 25da394
Showing 1 changed file with 250 additions and 2 deletions.
252 changes: 250 additions & 2 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,254 @@
from forge.tensor import to_forge_tensors, to_pt_tensors


@pytest.mark.parametrize(
"shape",
[
(1, 7, 256),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_sin(shape):
class sin(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

inputs = [torch.rand(shape)]

framework_model = sin()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape",
[
(1, 7, 256),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_cosine(shape):
class cosine(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.cos(x)

inputs = [torch.rand(shape)]

framework_model = cosine()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape",
[
(1, 768),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_tanh(shape):
class tanh(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tanh(x)

inputs = [torch.rand(shape)]

framework_model = tanh()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape",
[
(1, 32, 512, 512),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_leakyrelu(shape):

inputs = [torch.rand(shape)]

framework_model = nn.LeakyReLU(negative_slope=0.1)
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"batch_size, num_channels, height, width",
[
(1, 32, 56, 56),
],
)
@pytest.mark.xfail(reason="shape mismatch: expected [1], got []")
def test_layernorm(batch_size, num_channels, height, width):

# framework_model = nn.LayerNorm((num_channels, height, width)) # Support only normalization over last one dimension
framework_model = nn.LayerNorm((width))

inputs = [torch.rand(batch_size, num_channels, height, width)]
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out

assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape",
[
(1, 128, 4096),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_gelu(shape):

inputs = [torch.rand(shape)]

framework_model = nn.GELU()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape, min_val, max_val",
[
((1, 1, 256, 256), 0, 1),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_clip(shape, min_val, max_val):
class Clip(nn.Module):
def __init__(self, min_val, max_val):
super().__init__()
self.min_val = min_val
self.max_val = max_val

def forward(self, x):
return torch.clamp(x, self.min_val, self.max_val)

framework_model = Clip(min_val, max_val)
inputs = [torch.rand(shape)]

fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape, dim",
[
((1, 128), 1),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_cumsum(shape, dim):
class cumsum(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
return torch.cumsum(x, dim=self.dim)

framework_model = cumsum(dim)
inputs = [torch.rand(shape)]

fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"condition, input, other",
[
(
[[1, 0], [0, 1]],
[[1, 2], [3, 4]],
[[10, 20], [30, 40]],
),
],
)
@pytest.mark.xfail(reason="Unsupported data format during lowering from TTForge to TTIR: Bfp2_b")
def test_where(condition, input, other):
class Where(nn.Module):
def __init__(self):
super().__init__()

def forward(self, condition, input1, input2):
return torch.where(condition, input1, input2)

condition = torch.tensor(condition, dtype=torch.bool)
input = torch.tensor(input)
other = torch.tensor(other)

framework_model = Where()

inputs = [condition, input, other]

fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize(
"shape", [(1, 1, 256, 256), (1, 1, 1, 128), (1, 1, 1, 384), (1, 1, 32, 32), (1, 1, 6, 6), (1, 1, 29, 29)]
)
Expand Down Expand Up @@ -228,7 +476,7 @@ def forward(self, x, y):
(32, 256, 28, 28), # pcc = 0.39200606381500713
],
)
def test_multidim_unsqueeze(batch_size, num_channels, height, width):
def test_batchnorm2d(batch_size, num_channels, height, width):

framework_model = nn.BatchNorm2d(num_features=num_channels)

Expand Down Expand Up @@ -975,7 +1223,7 @@ def forward(self, a):
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)


@pytest.mark.xfail(reason="Shape Mismatch")
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
@pytest.mark.parametrize(
"in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode",
[
Expand Down

0 comments on commit 25da394

Please sign in to comment.