diff --git a/forge/csrc/buda_passes.cpp b/forge/csrc/buda_passes.cpp index 2ef56d169..205ccb362 100644 --- a/forge/csrc/buda_passes.cpp +++ b/forge/csrc/buda_passes.cpp @@ -230,6 +230,9 @@ graphlib::Graph* run_pre_lowering_passes( // Apply user overrides passes::configure_output_data_formats(graph, default_df_override); + // Recalculate shapes before lowering to MLIR + recalculate_shapes(graph); + return graph; } diff --git a/forge/csrc/passes/lower_to_mlir.cpp b/forge/csrc/passes/lower_to_mlir.cpp index cca63f864..9ce670eb4 100644 --- a/forge/csrc/passes/lower_to_mlir.cpp +++ b/forge/csrc/passes/lower_to_mlir.cpp @@ -132,6 +132,8 @@ class MLIRGenerator throw std::runtime_error("Variable " + node->name() + " already declared in the current scope."); } + log_trace(LogMLIRCompiler, "Declaring {} in the current scope.", node->name()); + symbolTable_[node->name()] = {value, node}; } @@ -173,10 +175,21 @@ class MLIRGenerator // Add the graph inputs to the argument list for (auto *input: graph->ordered_module_inputs()) //for (auto *input : graph->nodes_by_type(tt::graphlib::kInput)) { + log_trace(LogMLIRCompiler, "Adding input {} to the argument list.", input->name()); + argument_nodes.push_back(input); argument_types.push_back(get_node_type(input)); } + // Add the graph constants to the argument list + for (auto *constant : graph->get_constant_nodes()) + { + log_trace(LogMLIRCompiler, "Adding constant {} to the argument list.", constant->name()); + + argument_nodes.push_back(constant); + argument_types.push_back(get_node_type(constant)); + } + // Add the graph parameters to the argument list for(auto *parameter: graph->get_parameter_nodes()) { @@ -185,8 +198,10 @@ class MLIRGenerator // for forward and backward subgraphs (via GraphTraversalContext). if (graph->data_users(parameter).empty()) { + log_trace(LogMLIRCompiler, "Skipping parameter {} as it is not used in the current graph context.", parameter->name()); continue; } + log_trace(LogMLIRCompiler, "Adding parameter {} to the argument list.", parameter->name()); argument_nodes.push_back(parameter); argument_types.push_back(get_node_type(parameter)); @@ -201,6 +216,7 @@ class MLIRGenerator for (auto *output : output_nodes) { + log_trace(LogMLIRCompiler, "Adding output {} to the return list.", output->name()); returns.push_back(get_node_type(output)); } @@ -215,6 +231,7 @@ class MLIRGenerator llvm::SmallVector named_attributes; named_attributes.push_back(builder_.getNamedAttr("ttir.name", builder_.getStringAttr(argument_node->name()))); func.setArgAttrs(i, named_attributes); + log_trace(LogMLIRCompiler, "Set argument name {} for function argument {}.", argument_node->name(), i); } // Start the body of the function by creating an entry block. @@ -241,9 +258,9 @@ class MLIRGenerator // Skip if the node isn't TTForge operation if (node->node_type() != tt::graphlib::NodeType::kPyOp) { + log_trace(LogMLIRCompiler, "Skipping node {} as it is not a TTForge operation.", node->name()); continue; } - log_trace(LogMLIRCompiler, "Emitting MLIR for node {}", node->name()); tt::graphlib::OpNode *op_node = node->as(); @@ -353,9 +370,18 @@ class MLIRGenerator { llvm::SmallVector operands; +#ifdef DEBUG + // Log all values from symbolTable_ + log_trace(LogMLIRCompiler, "Logging all keys from symbolTable_"); + for (const auto& entry : symbolTable_) + { + log_trace(LogMLIRCompiler, "Key: {}", entry.first); + } +#endif + for (auto operand : graph->data_operands(op_node)) { - TT_ASSERT(symbolTable_.find(operand->name()) != symbolTable_.end(), "Operand " + operand->name() + "not found in symbol table."); + TT_ASSERT(symbolTable_.find(operand->name()) != symbolTable_.end(), "Operand " + operand->name() + " not found in symbol table."); operands.push_back(symbolTable_.at(operand->name()).first); } @@ -504,11 +530,13 @@ class MLIRGenerator lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["sqrt"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["squeeze"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["subtract"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op; - lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op; } }; } diff --git a/forge/forge/compiled_graph_state.py b/forge/forge/compiled_graph_state.py index 97bb82dc1..d7ab5c4e2 100644 --- a/forge/forge/compiled_graph_state.py +++ b/forge/forge/compiled_graph_state.py @@ -252,6 +252,9 @@ def get_tensor(self, name_to_tensor, name): def get_constant_tensor(self, name): return self.get_tensor(self.post_const_eval_constants, name) + + def get_ordered_constant_tensors(self): + return [self.get_constant_tensor(name) for name in self.ordered_constant_node_names] def get_parameter_tensor(self, name): return self.get_tensor(self.post_const_eval_parameters, name) @@ -318,7 +321,7 @@ def __call__(self, *inputs: torch.Tensor) -> List[torch.Tensor]: self.inputs = [*inputs] logger.info(f"Running model {self.compiled_graph_state.graph.get_name()} on device...") - inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_parameter_tensors()] + inputs_and_parameters = [*inputs, *self.compiled_graph_state.get_ordered_constant_tensors(), *self.compiled_graph_state.get_ordered_parameter_tensors()] outputs = run_binary(self.compiled_binary, int(ProgramId.FORWARD), inputs_and_parameters) if self.compiled_graph_state.graph.training(): diff --git a/forge/forge/op/eval/forge/eltwise_binary.py b/forge/forge/op/eval/forge/eltwise_binary.py index bdbec7f43..3e5097f26 100644 --- a/forge/forge/op/eval/forge/eltwise_binary.py +++ b/forge/forge/op/eval/forge/eltwise_binary.py @@ -421,9 +421,9 @@ def decompose_post_autograd(op_type, attr, dc, inputs): max_operand_nd = max(len(op0_shape), len(op1_shape), 3) while len(operand0.shape) < max_operand_nd: - operand0 = dc.op("unsqueeze", [operand0], (0, len(operand0.shape))) + operand0 = dc.op_with_named_attrs("unsqueeze", [operand0], {"dim": 0} (0, len(operand0.shape))) while len(operand1.shape) < max_operand_nd: - operand1 = dc.op("unsqueeze", [operand1], (0, len(operand1.shape))) + operand1 = dc.op_with_named_attrs("unsqueeze", [operand1], {"dim": 0} (0, len(operand1.shape))) if (slice_factor != None): concat_z = dc.op("interleave", [operand0, operand1], (-3, 1)) diff --git a/forge/forge/op/eval/forge/tm.py b/forge/forge/op/eval/forge/tm.py index e7c513cab..646d64dc5 100644 --- a/forge/forge/op/eval/forge/tm.py +++ b/forge/forge/op/eval/forge/tm.py @@ -1047,7 +1047,7 @@ def unsqueeze_input_for_reshape_decomp(dc, inp): current_shape = inp.shape.as_list() while len(current_shape) < 4: current_shape.insert(0, 1) - inp = dc.op("unsqueeze", (inp,), (0, len(inp.shape.as_list()))) + inp = dc.op_with_named_attrs("unsqueeze", (inp,), {"dim": 0}, (0, len(inp.shape.as_list()))) return inp @@ -1057,7 +1057,7 @@ def squeeze_output_for_reshape_decomp(dc, output, orig_out_shape): while current_shape_len > len(orig_out_shape): current_shape_len -= 1 - output = dc.op("squeeze", (output,), (0,)) + result = dc.op_with_named_attrs("squeeze", [output], {"dim": 0}, (0,)) return output @@ -1315,7 +1315,6 @@ def decompose(type, attr, dc, inputs): result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result]) else: - # import pdb; pdb.set_trace() orig_shape = result.shape if len(orig_shape) == 2: result = dc.op("reshape", [result], (1, orig_shape[-2]*orig_shape[-1])) @@ -1403,10 +1402,10 @@ def decompose(type, attr, dc, inputs): if is_rank_only_reshape and rank != 0: result = inputs[0] while rank < 0: - result = dc.op("squeeze", [result], (0,)) + result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,)) rank += 1 while rank > 0: - result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list()))) + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list()))) rank -= 1 dc.fuse(result) return @@ -1542,10 +1541,10 @@ def decompose_xy_flatten_reshape(inputs, dc, orig_shape, attr): result = dc.op(TransposeTM.create(-2, -1), [result]) while len(result.shape) > len(attr): - result = dc.op("squeeze", [result], (0,)) + result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,)) while len(result.shape) < len(attr): - result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list()))) + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list()))) if orig_shape[-3] > 1: s = create_flattened_padding_removal_sparse_picker_matrix(result.shape[-2], 0, 1, TILE_DIM) @@ -1652,7 +1651,7 @@ def decompose_xy_unflatten(inputs, dc, orig_shape, attr): if orig_shape[-2] > 1: result = dc.op("vslice", [result], (orig_shape[-2], )) elif len(result.shape) == 2: - result = dc.op("unsqueeze", [result], (0, 2,)) + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, 2,)) _orig_shape = result.shape slice_factor = attr[-2] if attr[-1] < TILE_DIM else (math.ceil(attr[-2] / TILE_DIM) * TILE_DIM) result = dc.op(TransposeTM.create(-2, -1), [result]) @@ -1776,7 +1775,7 @@ def decompose_post_optimize(type, attr, dc, inputs): ) while len(result.shape) < 3: - result = dc.op("unsqueeze", [result,], (0, len(result.shape.as_list()))) + result = dc.op_with_named_attrs("unsqueeze", [result,], {"dim": 0}, (0, len(result.shape.as_list()))) spm = torch.stack([spm]*result.shape[-3], -3).unsqueeze(0) result = dc.op(TransposeTM.create(-2, -1), [result,]) @@ -1917,10 +1916,11 @@ def decompose_post_autograd(type, attr, dc, inputs): if is_rank_only_reshape and rank != 0: result = inputs[0] while rank < 0: - result = dc.op("squeeze", [result], (0,)) + result = dc.op_with_named_attrs("squeeze", [result], {"dim": 0}, (0,)) rank += 1 while rank > 0: - result = dc.op("unsqueeze", [result], (0, len(result.shape.as_list()))) + import pdb; pdb.set_trace + result = dc.op_with_named_attrs("unsqueeze", [result], {"dim": 0}, (0, len(result.shape.as_list()))) rank -= 1 dc.fuse(result) return diff --git a/forge/test/mlir/test_features.py b/forge/test/mlir/test_features.py new file mode 100644 index 000000000..b27d7aa22 --- /dev/null +++ b/forge/test/mlir/test_features.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest + +import pytest +import torch +from torch import nn + +import forge +from forge.op.eval.common import compare_with_golden_pcc + +def test_multiple_inputs(): + class MultipleInputs(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + return a + b + c + + inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32), torch.rand(1, 32, 32)] + + framework_model = MultipleInputs() + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] + + +@pytest.mark.parametrize("a_shape, b_shape, c_shape", [ + ((1, 1, 32, 64), (1, 1, 64, 128), (1, 1, 128, 32)), +]) +def test_input_order(a_shape, b_shape, c_shape): + class InputOrderWithConstants(nn.Module): + def __init__(self): + super().__init__() + self.const1 = torch.rand(1, 1, 32, 32) + self.const2 = torch.rand(1, 1, 32, 32) + + def forward(self, a, b, c): + x = torch.matmul(a, b) + x = torch.matmul(x, c) + x = x + self.const1 + x = x * self.const2 + return x + + a = torch.rand(*a_shape) + b = torch.rand(*b_shape) + c = torch.rand(*c_shape) + + framework_model = InputOrderWithConstants() + fw_out = framework_model(a, b, c) + + compiled_model = forge.compile(framework_model, sample_inputs=[a, b, c]) + co_out = compiled_model(a, b, c) + + assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0][0], pcc=0.99) diff --git a/pytest.ini b/pytest.ini index 694708482..70e928e91 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,6 +9,9 @@ testpaths = # Ops forge/test/mlir/test_ops.py + # Features + forge/test/mlir/test_features.py + # API forge/test/test_api.py diff --git a/third_party/tt-mlir b/third_party/tt-mlir index 1f92bc6f7..a75fcf32a 160000 --- a/third_party/tt-mlir +++ b/third_party/tt-mlir @@ -1 +1 @@ -Subproject commit 1f92bc6f7de8a85dd7d2b4a9ff659ba645fe7956 +Subproject commit a75fcf32aa142d3a3dbf86d307027a0706e9c32b