From 25da394e6643dd8aa5485c0863aac64867ede8fa Mon Sep 17 00:00:00 2001 From: Kamalraj Kannan <157608228+kamalrajkannan78@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:17:27 +0530 Subject: [PATCH] add tests for sin, cosine, tanh, leakyrelu, layernorm, gelu, clip, cumsum, batchnorm2d, convtranspose2d, where ops (#613) --- forge/test/mlir/test_ops.py | 252 +++++++++++++++++++++++++++++++++++- 1 file changed, 250 insertions(+), 2 deletions(-) diff --git a/forge/test/mlir/test_ops.py b/forge/test/mlir/test_ops.py index ba14960f5..e035fe5ce 100644 --- a/forge/test/mlir/test_ops.py +++ b/forge/test/mlir/test_ops.py @@ -14,6 +14,254 @@ from forge.tensor import to_forge_tensors, to_pt_tensors +@pytest.mark.parametrize( + "shape", + [ + (1, 7, 256), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +def test_sin(shape): + class sin(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + inputs = [torch.rand(shape)] + + framework_model = sin() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 7, 256), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +def test_cosine(shape): + class cosine(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cos(x) + + inputs = [torch.rand(shape)] + + framework_model = cosine() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 768), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +def test_tanh(shape): + class tanh(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tanh(x) + + inputs = [torch.rand(shape)] + + framework_model = tanh() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 32, 512, 512), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +def test_leakyrelu(shape): + + inputs = [torch.rand(shape)] + + framework_model = nn.LeakyReLU(negative_slope=0.1) + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "batch_size, num_channels, height, width", + [ + (1, 32, 56, 56), + ], +) +@pytest.mark.xfail(reason="shape mismatch: expected [1], got []") +def test_layernorm(batch_size, num_channels, height, width): + + # framework_model = nn.LayerNorm((num_channels, height, width)) # Support only normalization over last one dimension + framework_model = nn.LayerNorm((width)) + + inputs = [torch.rand(batch_size, num_channels, height, width)] + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 128, 4096), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +def test_gelu(shape): + + inputs = [torch.rand(shape)] + + framework_model = nn.GELU() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "shape, min_val, max_val", + [ + ((1, 1, 256, 256), 0, 1), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +def test_clip(shape, min_val, max_val): + class Clip(nn.Module): + def __init__(self, min_val, max_val): + super().__init__() + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return torch.clamp(x, self.min_val, self.max_val) + + framework_model = Clip(min_val, max_val) + inputs = [torch.rand(shape)] + + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "shape, dim", + [ + ((1, 128), 1), + ], +) +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") +def test_cumsum(shape, dim): + class cumsum(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.cumsum(x, dim=self.dim) + + framework_model = cumsum(dim) + inputs = [torch.rand(shape)] + + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.parametrize( + "condition, input, other", + [ + ( + [[1, 0], [0, 1]], + [[1, 2], [3, 4]], + [[10, 20], [30, 40]], + ), + ], +) +@pytest.mark.xfail(reason="Unsupported data format during lowering from TTForge to TTIR: Bfp2_b") +def test_where(condition, input, other): + class Where(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, condition, input1, input2): + return torch.where(condition, input1, input2) + + condition = torch.tensor(condition, dtype=torch.bool) + input = torch.tensor(input) + other = torch.tensor(other) + + framework_model = Where() + + inputs = [condition, input, other] + + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + @pytest.mark.parametrize( "shape", [(1, 1, 256, 256), (1, 1, 1, 128), (1, 1, 1, 384), (1, 1, 32, 32), (1, 1, 6, 6), (1, 1, 29, 29)] ) @@ -228,7 +476,7 @@ def forward(self, x, y): (32, 256, 28, 28), # pcc = 0.39200606381500713 ], ) -def test_multidim_unsqueeze(batch_size, num_channels, height, width): +def test_batchnorm2d(batch_size, num_channels, height, width): framework_model = nn.BatchNorm2d(num_features=num_channels) @@ -975,7 +1223,7 @@ def forward(self, a): assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99) -@pytest.mark.xfail(reason="Shape Mismatch") +@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") @pytest.mark.parametrize( "in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode", [