Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Reduce ops workaround for keepDim=false #1625

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
return {builder.getAffineMapArrayAttr(indexingMaps),
builder.getArrayAttr(iteratorTypes)};}
}];

let hasVerifier = 1;
}

def TTIR_SumOp : TTIR_ReductionOp<"sum"> {
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ class TTNN_ReductionOp<string mnemonic, list<Trait> traits = []> : TTNN_Op<mnemo
OptionalAttr<I32ArrayAttr>:$dim_arg);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayoutAttr memLayoutAttr);
TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
TTNNLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector<int64_t> shardShape);
TTNNLayoutAttr withTensorShape(::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape);

bool isSystemBufferType() const { return ::mlir::tt::ttnn::isSystemBufferType(getBufferType()); }
bool isDeviceBufferType() const { return ::mlir::tt::ttnn::isDeviceBufferType(getBufferType()); }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H
#define TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir::tt::ttnn::workarounds::decomposition {

// Extracts reduce dimensions' values from the dimArg attribute. In case when
// dimArg is not specified, returns empty vector.
llvm::SmallVector<int64_t>
getReduceDims(const std::optional<mlir::ArrayAttr> &dimArg);

// Calculates the shape of the new Reduce op created in the workaround, based
// on the input shape and reducing dimensions.
llvm::SmallVector<int64_t>
calculateNewReduceShape(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg);

// This workaround addresses the next Metal issue:
// https://github.com/tenstorrent/tt-metal/issues/13361
//
// TODO(mrakita): Remove this workaround once these Metal issues are fixed
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624).
//
template <typename ReduceOp>
class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> {
public:
using OpRewritePattern<ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReduceOp srcOp,
PatternRewriter &rewriter) const override {
if (srcOp.getKeepDim()) {
mrakitaTT marked this conversation as resolved.
Show resolved Hide resolved
return failure();
}

RankedTensorType inputType = srcOp.getInput().getType();
RankedTensorType outputType = srcOp.getResult().getType();

ReduceOp newReduceOp =
createReduceOpWithKeepDim(srcOp, rewriter, inputType, outputType);

// Metal TTNN implementation of Reduce ops doesn't yet support
// keepDim=false. As a workaround, we convert Reduce ops to combination of
// Reduce op with keepDim=true + Reshape op to remove the reduce dims so
// that the rest of the graph is not affected. In case when this is not
// needed (for example because type converters already promoted rank of the
// op result) then we avoid adding unnecessary Reshape op.
if (outputType.getShape().size() < inputType.getShape().size()) {
replaceOpWithReshapeOp(srcOp, newReduceOp, rewriter, outputType);
} else {
rewriter.replaceOp(srcOp, newReduceOp);
}

return success();
}

private:
ReduceOp createReduceOpWithKeepDim(ReduceOp srcOp, PatternRewriter &rewriter,
RankedTensorType inputType,
RankedTensorType outputType) const {
llvm::SmallVector<int64_t> outputShapeVec =
calculateNewReduceShape(inputType, srcOp.getDimArg());

TTNNLayoutAttr newOutputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(outputType.getEncoding())
.withTensorShape(rewriter.getContext(), outputShapeVec);

RankedTensorType newOutputType = RankedTensorType::get(
outputShapeVec, outputType.getElementType(), newOutputLayoutAttr);

return rewriter.create<ReduceOp>(srcOp.getLoc(), newOutputType,
srcOp.getInput(), true /*keep_dim*/,
srcOp.getDimArg().value_or(nullptr));
}

void replaceOpWithReshapeOp(ReduceOp srcOp, ReduceOp newReduceOp,
PatternRewriter &rewriter,
RankedTensorType outputType) const {
mlir::ArrayAttr shapeAttr = rewriter.getI32ArrayAttr(
llvm::SmallVector<int32_t>(outputType.getShape()));

rewriter.replaceOpWithNewOp<mlir::tt::ttnn::ReshapeOp>(
srcOp, outputType, newReduceOp, shapeAttr);
}
};

// This workaround addresses the next Metal issue:
// https://github.com/tenstorrent/tt-metal/issues/16118
//
// TODO(mrakita): Remove this workaround once these Metal issues are fixed
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624).
//
template <typename ReduceOp>
class ReduceOpsAllDimsRewritePattern : public OpRewritePattern<ReduceOp> {
public:
using OpRewritePattern<ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReduceOp srcOp,
PatternRewriter &rewriter) const override {
if (!srcOp.getDimArg() || srcOp.getDimArg()->empty()) {
return failure();
}

llvm::SmallVector<int64_t> reduceDims = getReduceDims(srcOp.getDimArg());
llvm::SmallSet<int64_t, 4> uniqueReduceDims(reduceDims.begin(),
reduceDims.end());

// Check if reduce is done over all dimensions of the input tensor.
if (uniqueReduceDims.size() !=
srcOp.getInput().getType().getShape().size()) {
return failure();
}

// In case when reduce is done over all dimensions of the input we need to
// unset the dimensions attribute, because Metal supports reduce over all
// dimensions for any tensor rank when reduce dimensions are not specified,
// but it doesn't support reduce for tensors with rank larger than 2 when
// reduce dimensions are specified.
rewriter.replaceOpWithNewOp<ReduceOp>(srcOp, srcOp.getResult().getType(),
srcOp.getInput(), srcOp.getKeepDim(),
nullptr);

return success();
}
};

} // namespace mlir::tt::ttnn::workarounds::decomposition

#endif // TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H
7 changes: 3 additions & 4 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ class StableHLOToTTIRReduceOpConversionPattern
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector<Attribute>(
1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr().size() > 0
? adaptor.getDimensionsAttr()[0]
: 1)));
// Can't reuse the original dimensions attribute because it uses i64 type.
mlir::ArrayAttr dimArg = rewriter.getI32ArrayAttr(
llvm::SmallVector<int32_t>(srcOp.getDimensions()));

rewriter.replaceOpWithNewOp<DestOp>(
srcOp, outputType, adaptor.getInputs().front(), outputTensor,
Expand Down
87 changes: 73 additions & 14 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"

Expand Down Expand Up @@ -1672,32 +1673,32 @@ static void buildGenericEltwiseUnaryRegion(::mlir::Location loc,
opBuilder.create<mlir::tt::ttir::YieldOp>(loc, mlir::ValueRange({result}));
}

// AddOp generic region builder
// AddOp generic region builder.
void mlir::tt::ttir::AddOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::AddFOp>(getLoc(), opBuilder, block);
}

// MultiplyOp generic region builder
// MultiplyOp generic region builder.
void mlir::tt::ttir::MultiplyOp::buildGenericRegion(
::mlir::OpBuilder &opBuilder, ::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::MulFOp>(getLoc(), opBuilder, block);
}

// ExpOp generic region builder
// ExpOp generic region builder.
void mlir::tt::ttir::ExpOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseUnaryRegion<math::ExpOp>(getLoc(), opBuilder, block);
}

// DivOp generic region builder
// DivOp generic region builder.
void mlir::tt::ttir::DivOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
return buildGenericEltwiseBinaryRegion<arith::DivFOp>(getLoc(), opBuilder,
block);
}

// MaximumOp generic region builder
// MaximumOp generic region builder.
void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::MaximumFOp>(getLoc(), opBuilder,
Expand All @@ -1708,7 +1709,7 @@ void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
// KernelOp
//===----------------------------------------------------------------------===//

// KernelOp builders
// KernelOp builders.
static mlir::tt::ttir::KernelOp
buildKernelOp(::mlir::OpBuilder &opBuilder, ::mlir::Location loc,
::mlir::StringRef kernelName, ::mlir::StringRef kernelKind,
Expand All @@ -1717,31 +1718,89 @@ buildKernelOp(::mlir::OpBuilder &opBuilder, ::mlir::Location loc,
loc, outputs.getTypes(), kernelName, kernelKind, inputs, outputs);
}

// Reduce op kernel builder
// Reduce op kernel builder.
static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block,
mlir::Location loc, ::mlir::StringRef kernelKind) {
auto kernelOp = buildKernelOp(opBuilder, loc, "reduce", kernelKind,
block->getArgument(0), block->getArgument(1));
opBuilder.create<mlir::tt::ttir::YieldOp>(loc, kernelOp->getResults());
}

// Sum op kernel builder
void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
// Common verifier for all Reduce ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
if (!reduceDims) {
return mlir::success();
}

int64_t inputTensorRank = inputType.getRank();

llvm::SmallSet<int64_t, 4> uniqueReduceDims;
for (mlir::Attribute reduceDim : *reduceDims) {
int64_t reduceDimInt = mlir::cast<mlir::IntegerAttr>(reduceDim).getInt();
if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) {
return reduceOp->emitOpError("Reduce dimensions are out of range");
}
uniqueReduceDims.insert(reduceDimInt);
}

if (uniqueReduceDims.size() != reduceDims->size()) {
return reduceOp->emitOpError("Reduce dimensions are not unique");
}

// TODO(mrakita): Add a check that depending on inputShape, reduceDims and
// keepDim computes the expected output shape and checks if it matches the
// actual output shape. Tracked by:
// https://github.com/tenstorrent/tt-mlir/issues/1639

return mlir::success();
mrakitaTT marked this conversation as resolved.
Show resolved Hide resolved
}

//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//

// MaxOp kernel builder.
void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "sum");
createReduceOp(opBuilder, block, getLoc(), "max");
}

// MaxOp verification.
::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

// Mean op kernel builder
//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//

// MeanOp kernel builder.
void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "mean");
}

// Max op kernel builder
void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
// MeanOp verification.
::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

// SumOp kernel builder.
void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "max");
createReduceOp(opBuilder, block, getLoc(), "sum");
}

// SumOp verification.
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
48 changes: 48 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,4 +1310,52 @@ ::mlir::LogicalResult mlir::tt::ttnn::PermuteOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// Reduction ops
//===----------------------------------------------------------------------===//

// Common verifier for all Reduction ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
int64_t inputTensorRank = inputType.getRank();

// TODO(mrakita): Only last two dimensions can be reduced, check for that
// too.
if (reduceDims && reduceDims->size() > 2 &&
static_cast<int64_t>(reduceDims->size()) != inputTensorRank) {
return reduceOp->emitOpError("Reduce on more than two dimensions is not "
"currently supported by TTNN");
}

return mlir::success();
}

//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//

// MaxOp verification.
::mlir::LogicalResult MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//

// MeanOp verification.
::mlir::LogicalResult MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

// SumOp verification.
::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

} // namespace mlir::tt::ttnn
Loading
Loading