diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 63ccb0d28..b6269f715 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,4 +112,17 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } +def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { + let summary = "Broadcast operation is folded to all the consumers."; + let description = [{ + This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops. + Example: + %1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32> + %2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> + + This above broadcast is folded as: + %1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32> + }]; +} + #endif diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 18efb982e..789485eac 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -812,46 +812,6 @@ class TypecastOpConversionPattern } }; -class BroadcastOpConversionPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -public: - LogicalResult - matchAndRewrite(ttir::BroadcastOp srcOp, ttir::BroadcastOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // Fold this operation into all consumer ops. It will only work with TTNN - // ops that support implicit broadcasting. We expect each Op's verify - // function to assert their arguments to verify that they can broadcast. - - if (srcOp->getUsers().empty()) { - // This broadcast chain has already been replaced. - rewriter.eraseOp(srcOp); - return success(); - } - - mlir::Value input = srcOp.getOperand(0); - - mlir::Operation *nextOp = srcOp; - while (isa(*nextOp->getUsers().begin())) { - assert(nextOp->hasOneUse() && - "Broadcast with multiple uses are not supported"); - nextOp = *nextOp->getUsers().begin(); - if (nextOp->getUsers().empty()) { - // This broadcast chain has already been replaced. - rewriter.eraseOp(srcOp); - return success(); - } - } - - rewriter.replaceAllOpUsesWith(nextOp, input); - rewriter.eraseOp(srcOp); - - return success(); - } -}; - class SubtractOpConversionPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1019,7 +979,6 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ReductionOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, - BroadcastOpConversionPattern, EmbeddingOpConversionPattern, SoftmaxOpConversionPattern, TransposeOpConversionPattern, diff --git a/lib/Dialect/TTIR/Transforms/Broadcast.cpp b/lib/Dialect/TTIR/Transforms/Broadcast.cpp new file mode 100644 index 000000000..7823b021e --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/Broadcast.cpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRBROADCASTFOLD +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Broadcast Folding pass +// Our backend supports implicit broadcast of operands, so explicit broadcast +// instructions are folded. +// +// For Example: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) -> +// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1, +// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) -> +// tensor<512xf32> +// +// After folding: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>, +// tensor<512xf32>) -> tensor<512xf32> +//===----------------------------------------------------------------------===// + +class TTIRBroadcastFoldRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const final { + + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } +}; + +class TTIRBroadcastFold + : public impl::TTIRBroadcastFoldBase { +public: + using impl::TTIRBroadcastFoldBase::TTIRBroadcastFoldBase; + + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index f5fec45a8..597c55e3c 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms Allocate.cpp + Broadcast.cpp Constant.cpp Generic.cpp Layout.cpp diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 24980fb7c..3ade96bf8 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -107,9 +107,22 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } +void createTTNNPipelineTTIRBroadcastFoldPass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold()); +} + +void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm, + std::string options) { + auto optionsStruct = + TTIRToTTNNBackendPipelineOptions::createFromString(options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct); +} + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLayoutDecompositionPass(pm, options); diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir index fa6cbb423..42a26ad15 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir @@ -8,3 +8,54 @@ module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replic return %1 : tensor<512x512xf32> } } + +module { + func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2, 3] : (tensor<1x23x40x1xf32>) -> tensor<1x23x40x128xf32> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<128xf32>) -> tensor<1x23x40x128xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + %2 = stablehlo.divide %0, %1 : tensor<1x23x40x128xf32> + return %2 : tensor<1x23x40x128xf32> + } +} + +module { + func.func @main(%arg0: tensor<32xi64>, %arg1: tensor<32x1xi64>) -> tensor<32x32xi1> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<32xi64>) -> tensor<32x32xi64> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<32x1xi64>) -> tensor<32x32xi64> + %2 = stablehlo.compare GT, %0, %1, SIGNED : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<32x32xi1> + } +} + +module { + func.func @main(%arg0: tensor<16x1xf32>, %arg1: tensor<1x1x32xi64>) -> tensor<1x16x32xf32> { + %0 = stablehlo.convert %arg1 : (tensor<1x1x32xi64>) -> tensor<1x1x32xf32> + %1 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<16x1xf32>) -> tensor<1x16x32xf32> + %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2] : (tensor<1x1x32xf32>) -> tensor<1x16x32xf32> + %3 = stablehlo.multiply %1, %2 : tensor<1x16x32xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %3 : tensor<1x16x32xf32> + } +} + +module { + func.func @main(%arg0: tensor<1x10xi64>, %arg1: tensor<10x1xi64>) -> tensor<10x10xi64> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x10xi64>) -> tensor<10x10xi64> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<10x1xi64>) -> tensor<10x10xi64> + %2 = stablehlo.subtract %0, %1 : tensor<10x10xi64> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<10x10xi64> + } +} + +module { + func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<1xf32>) -> tensor<8xf32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<8xf32>) -> tensor<8xf32> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<1xf32>) -> tensor<8xf32> + %2 = stablehlo.add %0, %1 : tensor<8xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<8xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir index 4c04e138b..16c396c00 100644 --- a/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir @@ -1,4 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +// XFAIL: * +// https://github.com/tenstorrent/tt-mlir/issues/1448 #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir index ec509a1b6..f3affc69d 100644 --- a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir @@ -1,6 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// UNSUPPORTED: true +// https://github.com/tenstorrent/tt-mlir/issues/1448 #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { diff --git a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir b/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir deleted file mode 100644 index 1d88725d1..000000000 --- a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - -func.func public @broadcast() -> (tensor<32xf32>) { - %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> - %1 = tensor.empty() : tensor<32xf32> - %2 = "ttir.broadcast"(%0, %1) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1xf32>, tensor<32xf32>) -> tensor<32xf32> - %3 = tensor.empty() : tensor<32xf32> - %4 = "ttir.broadcast"(%2, %3) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - // CHECK-NOT: %[[C:.*]] = "ttir.broadcast"[[C:.*]] - return %4 : tensor<32xf32> -}