Skip to content

Commit

Permalink
Move BroadcastOp folding to a seperate TTIR pass. (#1353)
Browse files Browse the repository at this point in the history
* Move BroadcastOp folding to a seperate TTIR pass. Add more tests.
  • Loading branch information
uazizTT authored Dec 2, 2024
1 parent cfbc6a1 commit e2c982c
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 55 deletions.
13 changes: 13 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,17 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
];
}

def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> {
let summary = "Broadcast operation is folded to all the consumers.";
let description = [{
This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops.
Example:
%1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32>
%2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>

This above broadcast is folded as:
%1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32>
}];
}

#endif
41 changes: 0 additions & 41 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,46 +812,6 @@ class TypecastOpConversionPattern
}
};

class BroadcastOpConversionPattern
: public OpConversionPattern<ttir::BroadcastOp> {
using OpConversionPattern<ttir::BroadcastOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(ttir::BroadcastOp srcOp, ttir::BroadcastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Fold this operation into all consumer ops. It will only work with TTNN
// ops that support implicit broadcasting. We expect each Op's verify
// function to assert their arguments to verify that they can broadcast.

if (srcOp->getUsers().empty()) {
// This broadcast chain has already been replaced.
rewriter.eraseOp(srcOp);
return success();
}

mlir::Value input = srcOp.getOperand(0);

mlir::Operation *nextOp = srcOp;
while (isa<ttir::BroadcastOp>(*nextOp->getUsers().begin())) {
assert(nextOp->hasOneUse() &&
"Broadcast with multiple uses are not supported");
nextOp = *nextOp->getUsers().begin();
if (nextOp->getUsers().empty()) {
// This broadcast chain has already been replaced.
rewriter.eraseOp(srcOp);
return success();
}
}

rewriter.replaceAllOpUsesWith(nextOp, input);
rewriter.eraseOp(srcOp);

return success();
}
};

class SubtractOpConversionPattern
: public OpConversionPattern<ttir::SubtractOp> {
using OpConversionPattern<ttir::SubtractOp>::OpConversionPattern;
Expand Down Expand Up @@ -1019,7 +979,6 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
Expand Down
68 changes: 68 additions & 0 deletions lib/Dialect/TTIR/Transforms/Broadcast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRBROADCASTFOLD
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Broadcast Folding pass
// Our backend supports implicit broadcast of operands, so explicit broadcast
// instructions are folded.
//
// For Example:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) ->
// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1,
// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) ->
// tensor<512xf32>
//
// After folding:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>,
// tensor<512xf32>) -> tensor<512xf32>
//===----------------------------------------------------------------------===//

class TTIRBroadcastFoldRewriter : public OpRewritePattern<BroadcastOp> {
public:
using OpRewritePattern<BroadcastOp>::OpRewritePattern;

LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const final {

rewriter.replaceOp(op, op->getOperand(0));
return success();
}
};

class TTIRBroadcastFold
: public impl::TTIRBroadcastFoldBase<TTIRBroadcastFold> {
public:
using impl::TTIRBroadcastFoldBase<TTIRBroadcastFold>::TTIRBroadcastFoldBase;

void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRBroadcastFoldRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
}

void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::tt::ttir::TTIRDialect>();
registry.insert<mlir::tt::TTDialect>();
}
};

} // namespace mlir::tt::ttir
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTTIRTransforms
Allocate.cpp
Broadcast.cpp
Constant.cpp
Generic.cpp
Layout.cpp
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,22 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTNNPipelineTTIRBroadcastFoldPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold());
}

void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
createTTNNPipelineLayoutDecompositionPass(pm, options);
Expand Down
51 changes: 51 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,54 @@ module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replic
return %1 : tensor<512x512xf32>
}
}

module {
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2, 3] : (tensor<1x23x40x1xf32>) -> tensor<1x23x40x128xf32>
%1 = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<128xf32>) -> tensor<1x23x40x128xf32>
// CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
%2 = stablehlo.divide %0, %1 : tensor<1x23x40x128xf32>
return %2 : tensor<1x23x40x128xf32>
}
}

module {
func.func @main(%arg0: tensor<32xi64>, %arg1: tensor<32x1xi64>) -> tensor<32x32xi1> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<32xi64>) -> tensor<32x32xi64>
%1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<32x1xi64>) -> tensor<32x32xi64>
%2 = stablehlo.compare GT, %0, %1, SIGNED : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1>
// CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
return %2 : tensor<32x32xi1>
}
}

module {
func.func @main(%arg0: tensor<16x1xf32>, %arg1: tensor<1x1x32xi64>) -> tensor<1x16x32xf32> {
%0 = stablehlo.convert %arg1 : (tensor<1x1x32xi64>) -> tensor<1x1x32xf32>
%1 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<16x1xf32>) -> tensor<1x16x32xf32>
%2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2] : (tensor<1x1x32xf32>) -> tensor<1x16x32xf32>
%3 = stablehlo.multiply %1, %2 : tensor<1x16x32xf32>
// CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
return %3 : tensor<1x16x32xf32>
}
}

module {
func.func @main(%arg0: tensor<1x10xi64>, %arg1: tensor<10x1xi64>) -> tensor<10x10xi64> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x10xi64>) -> tensor<10x10xi64>
%1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<10x1xi64>) -> tensor<10x10xi64>
%2 = stablehlo.subtract %0, %1 : tensor<10x10xi64>
// CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
return %2 : tensor<10x10xi64>
}
}

module {
func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<1xf32>) -> tensor<8xf32> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<8xf32>) -> tensor<8xf32>
%1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<1xf32>) -> tensor<8xf32>
%2 = stablehlo.add %0, %1 : tensor<8xf32>
// CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
return %2 : tensor<8xf32>
}
}
2 changes: 2 additions & 0 deletions test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
// XFAIL: *
// https://github.com/tenstorrent/tt-mlir/issues/1448
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// UNSUPPORTED: true
// https://github.com/tenstorrent/tt-mlir/issues/1448
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> {
Expand Down
14 changes: 0 additions & 14 deletions test/ttmlir/Silicon/TTNN/simple_broadcast.mlir

This file was deleted.

0 comments on commit e2c982c

Please sign in to comment.