From a3c044149caf1c87e0155097bf2ee54b9c3fd337 Mon Sep 17 00:00:00 2001 From: Nikola Vukobrat <124874832+nvukobratTT@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:15:01 +0300 Subject: [PATCH] Fix issue when certain inputs/constants aren't properly declared during MLIR emit (#203) Previously, MLIR emit was hiting edge cases when declaring constant inputs. More precisely, they were mostly skipped. This fix redefines how inputs are recognized (using kInput node type), and properly distinguish regular and constant inputs vs model parameters. Issue uncovered during #112 op bringup (reciprocal). At the same time, PR related to #112 is testing this case. Additionally, inference and training MNIST are also covering this feature for functionality. Additionally, this change includes: - Shape recalculation before lowering to MLIR; just to be certain that all shapes are correctly matched - Additional logs through MLIR emit logic - Uplifted MLIR version to the latest Fixes #201 --- forge/csrc/forge_passes.cpp | 3 + forge/csrc/passes/lower_to_mlir.cpp | 34 ++++++++++- forge/forge/compiled_graph_state.py | 5 +- forge/forge/op/eval/forge/eltwise_binary.py | 4 +- forge/forge/op/eval/forge/tm.py | 22 ++++---- forge/test/mlir/test_features.py | 62 +++++++++++++++++++++ pytest.ini | 3 + third_party/tt-mlir | 2 +- 8 files changed, 117 insertions(+), 18 deletions(-) create mode 100644 forge/test/mlir/test_features.py diff --git a/forge/csrc/forge_passes.cpp b/forge/csrc/forge_passes.cpp index 99d6bc073..188db7c0f 100644 --- a/forge/csrc/forge_passes.cpp +++ b/forge/csrc/forge_passes.cpp @@ -230,6 +230,9 @@ graphlib::Graph* run_pre_lowering_passes( // Apply user overrides passes::configure_output_data_formats(graph, default_df_override); + // Recalculate shapes before lowering to MLIR + recalculate_shapes(graph); + return graph; } diff --git a/forge/csrc/passes/lower_to_mlir.cpp b/forge/csrc/passes/lower_to_mlir.cpp index 6d39f8537..2048ef25c 100644 --- a/forge/csrc/passes/lower_to_mlir.cpp +++ b/forge/csrc/passes/lower_to_mlir.cpp @@ -132,6 +132,8 @@ class MLIRGenerator throw std::runtime_error("Variable " + node->name() + " already declared in the current scope."); } + log_trace(LogMLIRCompiler, "Declaring {} in the current scope.", node->name()); + symbolTable_[node->name()] = {value, node}; } @@ -173,10 +175,21 @@ class MLIRGenerator // Add the graph inputs to the argument list for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput)) { + log_trace(LogMLIRCompiler, "Adding input {} to the argument list.", input->name()); + argument_nodes.push_back(input); argument_types.push_back(get_node_type(input)); } + // Add the graph constants to the argument list + for (auto *constant : graph->get_constant_nodes()) + { + log_trace(LogMLIRCompiler, "Adding constant {} to the argument list.", constant->name()); + + argument_nodes.push_back(constant); + argument_types.push_back(get_node_type(constant)); + } + // Add the graph parameters to the argument list for(auto *parameter: graph->get_parameter_nodes()) { @@ -185,8 +198,10 @@ class MLIRGenerator // for forward and backward subgraphs (via GraphTraversalContext). if (graph->data_users(parameter).empty()) { + log_trace(LogMLIRCompiler, "Skipping parameter {} as it is not used in the current graph context.", parameter->name()); continue; } + log_trace(LogMLIRCompiler, "Adding parameter {} to the argument list.", parameter->name()); argument_nodes.push_back(parameter); argument_types.push_back(get_node_type(parameter)); @@ -201,6 +216,7 @@ class MLIRGenerator for (auto *output : output_nodes) { + log_trace(LogMLIRCompiler, "Adding output {} to the return list.", output->name()); returns.push_back(get_node_type(output)); } @@ -215,6 +231,7 @@ class MLIRGenerator llvm::SmallVector named_attributes; named_attributes.push_back(builder_.getNamedAttr("ttir.name", builder_.getStringAttr(argument_node->name()))); func.setArgAttrs(i, named_attributes); + log_trace(LogMLIRCompiler, "Set argument name {} for function argument {}.", argument_node->name(), i); } // Start the body of the function by creating an entry block. @@ -241,9 +258,9 @@ class MLIRGenerator // Skip if the node isn't TTForge operation if (node->node_type() != tt::graphlib::NodeType::kPyOp) { + log_trace(LogMLIRCompiler, "Skipping node {} as it is not a TTForge operation.", node->name()); continue; } - log_trace(LogMLIRCompiler, "Emitting MLIR for node {}", node->name()); tt::graphlib::OpNode *op_node = node->as(); @@ -353,9 +370,18 @@ class MLIRGenerator { llvm::SmallVector operands; +#ifdef DEBUG + // Log all values from symbolTable_ + log_trace(LogMLIRCompiler, "Logging all keys from symbolTable_"); + for (const auto& entry : symbolTable_) + { + log_trace(LogMLIRCompiler, "Key: {}", entry.first); + } +#endif + for (auto operand : graph->data_operands(op_node)) { - TT_ASSERT(symbolTable_.find(operand->name()) != symbolTable_.end(), "Operand " + operand->name() + "not found in symbol table."); + TT_ASSERT(symbolTable_.find(operand->name()) != symbolTable_.end(), "Operand " + operand->name() + " not found in symbol table."); operands.push_back(symbolTable_.at(operand->name()).first); } @@ -504,11 +530,13 @@ class MLIRGenerator lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["sqrt"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["squeeze"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["subtract"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op; - lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op; } }; } diff --git a/forge/forge/compiled_graph_state.py b/forge/forge/compiled_graph_state.py index f457a9816..25a57d321 100644 --- a/forge/forge/compiled_graph_state.py +++ b/forge/forge/compiled_graph_state.py @@ -252,6 +252,9 @@ def get_tensor(self, name_to_tensor, name): def get_constant_tensor(self, name): return self.get_tensor(self.post_const_eval_constants, name) + + def get_ordered_constant_tensors(self): + return [self.get_constant_tensor(name) for name in self.ordered_constant_node_names] def get_parameter_tensor(self, name): return self.get_tensor(self.post_const_eval_parameters, name) @@ -318,7 +321,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]: self.inputs = [*inputs] logger.info(f"Running model {self.compiled_graph_state.graph.get_name()} on device...") - inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()] + inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_constant_tensors(), *self.compiled_graph_state.get_ordered_parameter_tensors()] outputs = run_binary(self.compiled_binary, int(ProgramId.FORWARD), inputs_and_parameters) if self.compiled_graph_state.graph.training(): diff --git a/forge/forge/op/eval/forge/eltwise_binary.py b/forge/forge/op/eval/forge/eltwise_binary.py index 6c464b2b3..8eb122457 100644 --- a/forge/forge/op/eval/forge/eltwise_binary.py +++ b/forge/forge/op/eval/forge/eltwise_binary.py @@ -421,9 +421,9 @@ def decompose_post_autograd(op_type, attr, dc, inputs): max_operand_nd = max(len(op0_shape), len(op1_shape), 3) while len(operand0.shape) < max_operand_nd: - operand0 = dc.op("unsqueeze", [operand0], (0, len(operand0.shape))) + operand0 = dc.op_with_named_attrs("unsqueeze", [operand0], {"dim": 0} (0, len(operand0.shape))) while len(operand1.shape) < max_operand_nd: - operand1 = dc.op("unsqueeze", [operand1], (0, len(operand1.shape))) + operand1 = dc.op_with_named_attrs("unsqueeze", [operand1], {"dim": 0} (0, len(operand1.shape))) if (slice_factor != None): concat_z = dc.op("interleave", [operand0, operand1], (-3, 1)) diff --git a/forge/forge/op/eval/forge/tm.py b/forge/forge/op/eval/forge/tm.py index 09d445d93..852ead129 100644 --- a/forge/forge/op/eval/forge/tm.py +++ b/forge/forge/op/eval/forge/tm.py @@ -1047,7 +1047,7 @@ def unsqueeze_input_for_reshape_decomp(dc, inp): current_shape = inp.shape.as_list() while len(current_shape) < 4: current_shape.insert(0, 1) - inp = dc.op("unsqueeze", (inp,), (0, len(inp.shape.as_list()))) + inp = dc.op_with_named_attrs("unsqueeze", (inp,), {"dim": 0}, (0, len(inp.shape.as_list()))) return inp @@ -1057,7 +1057,7 @@ def squeeze_output_for_reshape_decomp(dc, output, orig_out_shape): while current_shape_len > len(orig_out_shape): current_shape_len -= 1 - output = dc.op("squeeze", (output,), (0,)) + result = dc.op_with_named_attrs("squeeze", [output], {"dim": 0}, (0,)) return output @@ -1315,7 +1315,6 @@ def decompose(type, attr, dc, inputs): result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result]) else: - # import pdb; pdb.set_trace() orig_shape = result.shape if len(orig_shape) == 2: result = dc.op("reshape", [result], (1, orig_shape[-2]*orig_shape[-1])) @@ -1403,10 +1402,10 @@ def decompose(type, attr, dc, inputs): if is_rank_only_reshape and rank != 0: result = inputs[0] while rank < 0: - result = dc.op("squeeze", [result], (0,)) + result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,)) rank += 1 while rank > 0: - result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list()))) + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list()))) rank -= 1 dc.fuse(result) return @@ -1542,10 +1541,10 @@ def decompose_xy_flatten_reshape(inputs, dc, orig_shape, attr): result = dc.op(TransposeTM.create(-2, -1), [result]) while len(result.shape) > len(attr): - result = dc.op("squeeze", [result], (0,)) + result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,)) while len(result.shape) < len(attr): - result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list()))) + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list()))) if orig_shape[-3] > 1: s = create_flattened_padding_removal_sparse_picker_matrix(result.shape[-2], 0, 1, TILE_DIM) @@ -1652,7 +1651,7 @@ def decompose_xy_unflatten(inputs, dc, orig_shape, attr): if orig_shape[-2] > 1: result = dc.op("vslice", [result], (orig_shape[-2], )) elif len(result.shape) == 2: - result = dc.op("unsqueeze", [result], (0, 2,)) + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, 2,)) _orig_shape = result.shape slice_factor = attr[-2] if attr[-1] < TILE_DIM else (math.ceil(attr[-2] / TILE_DIM) * TILE_DIM) result = dc.op(TransposeTM.create(-2, -1), [result]) @@ -1776,7 +1775,7 @@ def decompose_post_optimize(type, attr, dc, inputs): ) while len(result.shape) < 3: - result = dc.op("unsqueeze", [result,], (0, len(result.shape.as_list()))) + result = dc.op_with_named_attrs("unsqueeze", [result,], {"dim": 0}, (0, len(result.shape.as_list()))) spm = torch.stack([spm]*result.shape[-3], -3).unsqueeze(0) result = dc.op(TransposeTM.create(-2, -1), [result,]) @@ -1917,10 +1916,11 @@ def decompose_post_autograd(type, attr, dc, inputs): if is_rank_only_reshape and rank != 0: result = inputs[0] while rank < 0: - result = dc.op("squeeze", [result], (0,)) + result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,)) rank += 1 while rank > 0: - result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list()))) + import pdb; pdb.set_trace + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list()))) rank -= 1 dc.fuse(result) return diff --git a/forge/test/mlir/test_features.py b/forge/test/mlir/test_features.py new file mode 100644 index 000000000..b27d7aa22 --- /dev/null +++ b/forge/test/mlir/test_features.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest + +import pytest +import torch +from torch import nn + +import forge +from forge.op.eval.common import compare_with_golden_pcc + +def test_multiple_inputs(): + class MultipleInputs(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + return a + b + c + + inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32), torch.rand(1, 32, 32)] + + framework_model = MultipleInputs() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] + + +@pytest.mark.parametrize("a_shape, b_shape, c_shape", [ + ((1, 1, 32, 64), (1, 1, 64, 128), (1, 1, 128, 32)), +]) +def test_input_order(a_shape, b_shape, c_shape): + class InputOrderWithConstants(nn.Module): + def __init__(self): + super().__init__() + self.const1 = torch.rand(1, 1, 32, 32) + self.const2 = torch.rand(1, 1, 32, 32) + + def forward(self, a, b, c): + x = torch.matmul(a, b) + x = torch.matmul(x, c) + x = x + self.const1 + x = x * self.const2 + return x + + a = torch.rand(*a_shape) + b = torch.rand(*b_shape) + c = torch.rand(*c_shape) + + framework_model = InputOrderWithConstants() + fw_out = framework_model(a, b, c) + + compiled_model = forge.compile(framework_model, sample_inputs=[a, b, c]) + co_out = compiled_model(a, b, c) + + assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0][0], pcc=0.99) diff --git a/pytest.ini b/pytest.ini index 694708482..70e928e91 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,6 +9,9 @@ testpaths = # Ops forge/test/mlir/test_ops.py + # Features + forge/test/mlir/test_features.py + # API forge/test/test_api.py diff --git a/third_party/tt-mlir b/third_party/tt-mlir index 1f92bc6f7..a75fcf32a 160000 --- a/third_party/tt-mlir +++ b/third_party/tt-mlir @@ -1 +1 @@ -Subproject commit 1f92bc6f7de8a85dd7d2b4a9ff659ba645fe7956 +Subproject commit a75fcf32aa142d3a3dbf86d307027a0706e9c32b