-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bug] Solves issue when certain inputs/constants aren't properly decl…
…ared during MLIR emit Previously, MLIR emit was hiting edge cases when declaring constant inputs. More precisely, they were mostly skipped. This fix redefines how inputs are recognized (using kInput node type), and properly distinguish regular and constant inputs vs model parameters. Issue uncovered during #112 op bringup (reciprocal). At the same time, PR related to #112 is testing this case. Additionally, inference and training MNIST are also covering this feature for functionality. Fixes #201
- Loading branch information
1 parent
831b1b5
commit af4a5b7
Showing
3 changed files
with
64 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import pytest | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
import pybuda | ||
from pybuda.op.eval.common import compare_with_golden_pcc | ||
|
||
def test_multiple_inputs(): | ||
class MultipleInputs(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b, c): | ||
return a + b + c | ||
|
||
inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32), torch.rand(1, 32, 32)] | ||
|
||
framework_model = MultipleInputs() | ||
fw_out = framework_model(*inputs) | ||
|
||
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) | ||
co_out = compiled_model(*inputs) | ||
|
||
co_out = [co.to("cpu") for co in co_out] | ||
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters