Skip to content

Commit

Permalink
Fix issue when certain inputs/constants aren't properly declared duri…
Browse files Browse the repository at this point in the history
…ng MLIR emit

Previously, MLIR emit was hiting edge cases when declaring constant inputs. More precisely,
they were mostly skipped. This fix redefines how inputs are recognized (using kInput node type),
and properly distinguish regular and constant inputs vs model parameters.

Issue uncovered during #112 op bringup (reciprocal). At the same time, PR related to #112 is
testing this case. Additionally, inference and training MNIST are also covering this feature
for functionality.

Additionally, this change includes:
- Shape recalculation before lowering to MLIR; just to be certain that all shapes are correctly matched
- Additional logs through MLIR emit logic
- Uplifted MLIR version to the latest

Fixes #201
  • Loading branch information
nvukobratTT committed Sep 4, 2024
1 parent b81766f commit 54f1c16
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 18 deletions.
3 changes: 3 additions & 0 deletions forge/csrc/buda_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ graphlib::Graph* run_pre_lowering_passes(
// Apply user overrides
passes::configure_output_data_formats(graph, default_df_override);

// Recalculate shapes before lowering to MLIR
recalculate_shapes(graph);

return graph;
}

Expand Down
34 changes: 31 additions & 3 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class MLIRGenerator
throw std::runtime_error("Variable " + node->name() + " already declared in the current scope.");
}

log_trace(LogMLIRCompiler, "Declaring {} in the current scope.", node->name());

symbolTable_[node->name()] = {value, node};
}

Expand Down Expand Up @@ -173,10 +175,21 @@ class MLIRGenerator
// Add the graph inputs to the argument list
for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput))
{
log_trace(LogMLIRCompiler, "Adding input {} to the argument list.", input->name());

argument_nodes.push_back(input);
argument_types.push_back(get_node_type(input));
}

// Add the graph constants to the argument list
for (auto *constant : graph->get_constant_nodes())
{
log_trace(LogMLIRCompiler, "Adding constant {} to the argument list.", constant->name());

argument_nodes.push_back(constant);
argument_types.push_back(get_node_type(constant));
}

// Add the graph parameters to the argument list
for(auto *parameter: graph->get_parameter_nodes())
{
Expand All @@ -185,8 +198,10 @@ class MLIRGenerator
// for forward and backward subgraphs (via GraphTraversalContext).
if (graph->data_users(parameter).empty())
{
log_trace(LogMLIRCompiler, "Skipping parameter {} as it is not used in the current graph context.", parameter->name());
continue;
}
log_trace(LogMLIRCompiler, "Adding parameter {} to the argument list.", parameter->name());

argument_nodes.push_back(parameter);
argument_types.push_back(get_node_type(parameter));
Expand All @@ -201,6 +216,7 @@ class MLIRGenerator

for (auto *output : output_nodes)
{
log_trace(LogMLIRCompiler, "Adding output {} to the return list.", output->name());
returns.push_back(get_node_type(output));
}

Expand All @@ -215,6 +231,7 @@ class MLIRGenerator
llvm::SmallVector<mlir::NamedAttribute, 1> named_attributes;
named_attributes.push_back(builder_.getNamedAttr("ttir.name", builder_.getStringAttr(argument_node->name())));
func.setArgAttrs(i, named_attributes);
log_trace(LogMLIRCompiler, "Set argument name {} for function argument {}.", argument_node->name(), i);
}

// Start the body of the function by creating an entry block.
Expand All @@ -241,9 +258,9 @@ class MLIRGenerator
// Skip if the node isn't TTForge operation
if (node->node_type() != tt::graphlib::NodeType::kPyOp)
{
log_trace(LogMLIRCompiler, "Skipping node {} as it is not a TTForge operation.", node->name());
continue;
}

log_trace(LogMLIRCompiler, "Emitting MLIR for node {}", node->name());

tt::graphlib::OpNode *op_node = node->as<tt::graphlib::OpNode>();
Expand Down Expand Up @@ -353,9 +370,18 @@ class MLIRGenerator
{
llvm::SmallVector<mlir::Value> operands;

#ifdef DEBUG
// Log all values from symbolTable_
log_trace(LogMLIRCompiler, "Logging all keys from symbolTable_");
for (const auto& entry : symbolTable_)
{
log_trace(LogMLIRCompiler, "Key: {}", entry.first);
}
#endif

for (auto operand : graph->data_operands(op_node))
{
TT_ASSERT(symbolTable_.find(operand->name()) != symbolTable_.end(), "Operand " + operand->name() + "not found in symbol table.");
TT_ASSERT(symbolTable_.find(operand->name()) != symbolTable_.end(), "Operand " + operand->name() + " not found in symbol table.");
operands.push_back(symbolTable_.at(operand->name()).first);
}

Expand Down Expand Up @@ -504,11 +530,13 @@ class MLIRGenerator
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReluOp>;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SoftmaxOp>;
lowering_handler_map["sqrt"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SqrtOp>;
lowering_handler_map["squeeze"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SqueezeOp>;
lowering_handler_map["subtract"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SubtractOp>;
lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TransposeOp>;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::UnsqueezeOp>;
}
};
}
Expand Down
5 changes: 4 additions & 1 deletion forge/forge/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def get_tensor(self, name_to_tensor, name):

def get_constant_tensor(self, name):
return self.get_tensor(self.post_const_eval_constants, name)

def get_ordered_constant_tensors(self):
return [self.get_constant_tensor(name) for name in self.ordered_constant_node_names]

def get_parameter_tensor(self, name):
return self.get_tensor(self.post_const_eval_parameters, name)
Expand Down Expand Up @@ -318,7 +321,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]:
self.inputs = [*inputs]

logger.info(f"Running model {self.compiled_graph_state.graph.get_name()} on device...")
inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()]
inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_constant_tensors(), *self.compiled_graph_state.get_ordered_parameter_tensors()]
outputs = run_binary(self.compiled_binary, int(ProgramId.FORWARD), inputs_and_parameters)

if self.compiled_graph_state.graph.training():
Expand Down
4 changes: 2 additions & 2 deletions forge/forge/op/eval/forge/eltwise_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ def decompose_post_autograd(op_type, attr, dc, inputs):

max_operand_nd = max(len(op0_shape), len(op1_shape), 3)
while len(operand0.shape) < max_operand_nd:
operand0 = dc.op("unsqueeze", [operand0], (0, len(operand0.shape)))
operand0 = dc.op_with_named_attrs("unsqueeze", [operand0], {"dim": 0} (0, len(operand0.shape)))
while len(operand1.shape) < max_operand_nd:
operand1 = dc.op("unsqueeze", [operand1], (0, len(operand1.shape)))
operand1 = dc.op_with_named_attrs("unsqueeze", [operand1], {"dim": 0} (0, len(operand1.shape)))

if (slice_factor != None):
concat_z = dc.op("interleave", [operand0, operand1], (-3, 1))
Expand Down
22 changes: 11 additions & 11 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def unsqueeze_input_for_reshape_decomp(dc, inp):
current_shape = inp.shape.as_list()
while len(current_shape) < 4:
current_shape.insert(0, 1)
inp = dc.op("unsqueeze", (inp,), (0, len(inp.shape.as_list())))
inp = dc.op_with_named_attrs("unsqueeze", (inp,), {"dim": 0}, (0, len(inp.shape.as_list())))

return inp

Expand All @@ -1057,7 +1057,7 @@ def squeeze_output_for_reshape_decomp(dc, output, orig_out_shape):

while current_shape_len > len(orig_out_shape):
current_shape_len -= 1
output = dc.op("squeeze", (output,), (0,))
result = dc.op_with_named_attrs("squeeze", [output], {"dim": 0}, (0,))

return output

Expand Down Expand Up @@ -1315,7 +1315,6 @@ def decompose(type, attr, dc, inputs):

result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result])
else:
# import pdb; pdb.set_trace()
orig_shape = result.shape
if len(orig_shape) == 2:
result = dc.op("reshape", [result], (1, orig_shape[-2]*orig_shape[-1]))
Expand Down Expand Up @@ -1403,10 +1402,10 @@ def decompose(type, attr, dc, inputs):
if is_rank_only_reshape and rank != 0:
result = inputs[0]
while rank < 0:
result = dc.op("squeeze", [result], (0,))
result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,))
rank += 1
while rank > 0:
result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list())))
result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list())))
rank -= 1
dc.fuse(result)
return
Expand Down Expand Up @@ -1542,10 +1541,10 @@ def decompose_xy_flatten_reshape(inputs, dc, orig_shape, attr):
result = dc.op(TransposeTM.create(-2, -1), [result])

while len(result.shape) > len(attr):
result = dc.op("squeeze", [result], (0,))
result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,))

while len(result.shape) < len(attr):
result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list())))
result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list())))

if orig_shape[-3] > 1:
s = create_flattened_padding_removal_sparse_picker_matrix(result.shape[-2], 0, 1, TILE_DIM)
Expand Down Expand Up @@ -1652,7 +1651,7 @@ def decompose_xy_unflatten(inputs, dc, orig_shape, attr):
if orig_shape[-2] > 1:
result = dc.op("vslice", [result], (orig_shape[-2], ))
elif len(result.shape) == 2:
result = dc.op("unsqueeze", [result], (0, 2,))
result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, 2,))
_orig_shape = result.shape
slice_factor = attr[-2] if attr[-1] < TILE_DIM else (math.ceil(attr[-2] / TILE_DIM) * TILE_DIM)
result = dc.op(TransposeTM.create(-2, -1), [result])
Expand Down Expand Up @@ -1776,7 +1775,7 @@ def decompose_post_optimize(type, attr, dc, inputs):
)

while len(result.shape) < 3:
result = dc.op("unsqueeze", [result,], (0, len(result.shape.as_list())))
result = dc.op_with_named_attrs("unsqueeze", [result,], {"dim": 0}, (0, len(result.shape.as_list())))

spm = torch.stack([spm]*result.shape[-3], -3).unsqueeze(0)
result = dc.op(TransposeTM.create(-2, -1), [result,])
Expand Down Expand Up @@ -1917,10 +1916,11 @@ def decompose_post_autograd(type, attr, dc, inputs):
if is_rank_only_reshape and rank != 0:
result = inputs[0]
while rank < 0:
result = dc.op("squeeze", [result], (0,))
result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,))
rank += 1
while rank > 0:
result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list())))
import pdb; pdb.set_trace
result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list())))
rank -= 1
dc.fuse(result)
return
Expand Down
92 changes: 92 additions & 0 deletions forge/test/mlir/test_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import os
import pytest

import pytest
import torch
from torch import nn

import forge
from forge.op.eval.common import compare_with_golden_pcc

def test_multiple_inputs():
class MultipleInputs(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c):
return a + b + c

inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32), torch.rand(1, 32, 32)]

framework_model = MultipleInputs()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]


@pytest.mark.parametrize("a_shape, b_shape, c_shape", [
((1, 1, 32, 64), (1, 1, 64, 128), (1, 1, 128, 32)),
((1, 1, 64, 32), (1, 1, 32, 128), (1, 1, 128, 64)),
((1, 1, 128, 64), (1, 1, 64, 256), (1, 1, 256, 128)),
((1, 1, 256, 128), (1, 1, 128, 512), (1, 1, 512, 256))
])
def test_input_order(a_shape, b_shape, c_shape):
class InputOrder(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c):
x = torch.matmul(a, b)
x = torch.matmul(x, c)

return x

a = torch.rand(*a_shape)
b = torch.rand(*b_shape)
c = torch.rand(*c_shape)

framework_model = InputOrder()
fw_out = framework_model(a, b, c)

compiled_model = forge.compile(framework_model, sample_inputs=[a, b, c])
co_out = compiled_model(a, b, c)

assert compare_with_golden_pcc(golden=fw_out, calculated=co_out, pcc=0.99)


@pytest.mark.parametrize("a_shape, b_shape, c_shape", [
((1, 1, 32, 64), (1, 1, 64, 128), (1, 1, 128, 32)),
])
def test_input_order_with_constants(a_shape, b_shape, c_shape):
class InputOrderWithConstants(nn.Module):
def __init__(self):
super().__init__()
self.const1 = torch.rand(1, 1, 32, 32)
self.const2 = torch.rand(1, 1, 32, 32)

def forward(self, a, b, c):
x = torch.matmul(a, b)
x = torch.matmul(x, c)
x = x + self.const1
x = x * self.const2
return x

a = torch.rand(*a_shape)
b = torch.rand(*b_shape)
c = torch.rand(*c_shape)

framework_model = InputOrderWithConstants()
fw_out = framework_model(a, b, c)

compiled_model = forge.compile(framework_model, sample_inputs=[a, b, c])
co_out = compiled_model(a, b, c)

assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0][0], pcc=0.99)
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ testpaths =
# Ops
forge/test/mlir/test_ops.py

# Features
pybuda/test/mlir/test_features.py

# API
forge/test/test_api.py

Expand Down
2 changes: 1 addition & 1 deletion third_party/tt-mlir
Submodule tt-mlir updated 109 files

0 comments on commit 54f1c16

Please sign in to comment.