-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix issue when certain inputs/constants aren't properly declared duri…
…ng MLIR emit Previously, MLIR emit was hiting edge cases when declaring constant inputs. More precisely, they were mostly skipped. This fix redefines how inputs are recognized (using kInput node type), and properly distinguish regular and constant inputs vs model parameters. Issue uncovered during #112 op bringup (reciprocal). At the same time, PR related to #112 is testing this case. Additionally, inference and training MNIST are also covering this feature for functionality. Additionally, this change includes: - Shape recalculation before lowering to MLIR; just to be certain that all shapes are correctly matched - Additional logs through MLIR emit logic - Uplifted MLIR version to the latest Fixes #201
- Loading branch information
1 parent
fc211a7
commit 446f03e
Showing
8 changed files
with
117 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import pytest | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
import forge | ||
from forge.op.eval.common import compare_with_golden_pcc | ||
|
||
def test_multiple_inputs(): | ||
class MultipleInputs(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b, c): | ||
return a + b + c | ||
|
||
inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32), torch.rand(1, 32, 32)] | ||
|
||
framework_model = MultipleInputs() | ||
fw_out = framework_model(*inputs) | ||
|
||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
co_out = compiled_model(*inputs) | ||
|
||
co_out = [co.to("cpu") for co in co_out] | ||
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] | ||
|
||
|
||
@pytest.mark.parametrize("a_shape, b_shape, c_shape", [ | ||
((1, 1, 32, 64), (1, 1, 64, 128), (1, 1, 128, 32)), | ||
]) | ||
def test_input_order(a_shape, b_shape, c_shape): | ||
class InputOrderWithConstants(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.const1 = torch.rand(1, 1, 32, 32) | ||
self.const2 = torch.rand(1, 1, 32, 32) | ||
|
||
def forward(self, a, b, c): | ||
x = torch.matmul(a, b) | ||
x = torch.matmul(x, c) | ||
x = x + self.const1 | ||
x = x * self.const2 | ||
return x | ||
|
||
a = torch.rand(*a_shape) | ||
b = torch.rand(*b_shape) | ||
c = torch.rand(*c_shape) | ||
|
||
framework_model = InputOrderWithConstants() | ||
fw_out = framework_model(a, b, c) | ||
|
||
compiled_model = forge.compile(framework_model, sample_inputs=[a, b, c]) | ||
co_out = compiled_model(a, b, c) | ||
|
||
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0][0], pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters