Skip to content

Commit

Permalink
Fix PR comments and add negative tests for TTNN verifier
Browse files Browse the repository at this point in the history
  • Loading branch information
mrakitaTT committed Dec 20, 2024
1 parent d46471f commit 3d6c7db
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 51 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ class TTNN_ReductionOp<string mnemonic, list<Trait> traits = []> : TTNN_Op<mnemo
OptionalAttr<I32ArrayAttr>:$dim_arg);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayoutAttr memLayoutAttr);
TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
TTNNLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector<int64_t> shardShape);
TTNNLayoutAttr withTensorShape(::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape);

bool isSystemBufferType() const { return ::mlir::tt::ttnn::isSystemBufferType(getBufferType()); }
bool isDeviceBufferType() const { return ::mlir::tt::ttnn::isDeviceBufferType(getBufferType()); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ mlir::ArrayAttr
createNewReduceDimArg(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg);

// This workaround addresses next two Metal issues:
// - https://github.com/tenstorrent/tt-metal/issues/13361
// - https://github.com/tenstorrent/tt-metal/issues/16118
// This workaround addresses the next Metal issue:
// https://github.com/tenstorrent/tt-metal/issues/13361
//
// TODO(mrakita): Remove this workaround once these Metal issues are fixed
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624).
//
template <typename ReduceOp>
class ReduceOpsRewritePattern : public OpRewritePattern<ReduceOp> {
class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> {
public:
using OpRewritePattern<ReduceOp>::OpRewritePattern;

Expand Down Expand Up @@ -81,8 +80,12 @@ class ReduceOpsRewritePattern : public OpRewritePattern<ReduceOp> {
llvm::SmallVector<int64_t> outputShapeVec =
calculateNewReduceShape(inputType, srcOp.getDimArg());

TTNNLayoutAttr newOutputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(outputType.getEncoding())
.withTensorShape(rewriter.getContext(), outputShapeVec);

RankedTensorType newOutputType = RankedTensorType::get(
outputShapeVec, inputType.getElementType(), inputType.getEncoding());
outputShapeVec, outputType.getElementType(), newOutputLayoutAttr);

return rewriter.create<ReduceOp>(
srcOp.getLoc(), newOutputType, srcOp.getInput(), true /*keep_dim*/,
Expand All @@ -100,6 +103,32 @@ class ReduceOpsRewritePattern : public OpRewritePattern<ReduceOp> {
}
};

// This workaround addresses the next Metal issue:
// https://github.com/tenstorrent/tt-metal/issues/16118
//
// TODO(mrakita): Remove this workaround once these Metal issues are fixed
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624).
//
template <typename ReduceOp>
class ReduceOpsAllDimsRewritePattern : public OpRewritePattern<ReduceOp> {
public:
using OpRewritePattern<ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReduceOp srcOp,
PatternRewriter &rewriter) const override {
if (!srcOp.getDimArg() || srcOp.getDimArg()->empty()) {
return failure();
}

rewriter.replaceOpWithNewOp<ReduceOp>(
srcOp, srcOp.getResult().getType(), srcOp.getInput(),
srcOp.getKeepDim(),
createNewReduceDimArg(srcOp.getInput().getType(), srcOp.getDimArg()));

return success();
}
};

} // namespace mlir::tt::ttnn::workarounds::decomposition

#endif // TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H
13 changes: 0 additions & 13 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,19 +334,6 @@ class ReductionOpConversionPattern : public OpConversionPattern<TTIROpTy> {
LogicalResult
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int64_t inputTensorRank =
mlir::cast<::mlir::RankedTensorType>(adaptor.getInput().getType())
.getRank();

// TODO(mrakita): Only last two dimensions can be reduced, check for that
// too. Should this check be in verifier?
if (adaptor.getDimArg() && adaptor.getDimArg()->size() > 2 &&
static_cast<int64_t>(adaptor.getDimArg()->size()) != inputTensorRank) {
return rewriter.notifyMatchFailure(op,
"Reduce on more than two dimensions "
"is not currently supported by TTNN");
}

rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getKeepDim(),
Expand Down
26 changes: 13 additions & 13 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,32 +1602,32 @@ static void buildGenericEltwiseUnaryRegion(::mlir::Location loc,
opBuilder.create<mlir::tt::ttir::YieldOp>(loc, mlir::ValueRange({result}));
}

// AddOp generic region builder
// AddOp generic region builder.
void mlir::tt::ttir::AddOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::AddFOp>(getLoc(), opBuilder, block);
}

// MultiplyOp generic region builder
// MultiplyOp generic region builder.
void mlir::tt::ttir::MultiplyOp::buildGenericRegion(
::mlir::OpBuilder &opBuilder, ::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::MulFOp>(getLoc(), opBuilder, block);
}

// ExpOp generic region builder
// ExpOp generic region builder.
void mlir::tt::ttir::ExpOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseUnaryRegion<math::ExpOp>(getLoc(), opBuilder, block);
}

// DivOp generic region builder
// DivOp generic region builder.
void mlir::tt::ttir::DivOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
return buildGenericEltwiseBinaryRegion<arith::DivFOp>(getLoc(), opBuilder,
block);
}

// MaximumOp generic region builder
// MaximumOp generic region builder.
void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::MaximumFOp>(getLoc(), opBuilder,
Expand All @@ -1638,7 +1638,7 @@ void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
// KernelOp
//===----------------------------------------------------------------------===//

// KernelOp builders
// KernelOp builders.
static mlir::tt::ttir::KernelOp
buildKernelOp(::mlir::OpBuilder &opBuilder, ::mlir::Location loc,
::mlir::StringRef kernelName, ::mlir::StringRef kernelKind,
Expand All @@ -1647,7 +1647,7 @@ buildKernelOp(::mlir::OpBuilder &opBuilder, ::mlir::Location loc,
loc, outputs.getTypes(), kernelName, kernelKind, inputs, outputs);
}

// Reduce op kernel builder
// Reduce op kernel builder.
static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block,
mlir::Location loc, ::mlir::StringRef kernelKind) {
auto kernelOp = buildKernelOp(opBuilder, loc, "reduce", kernelKind,
Expand Down Expand Up @@ -1690,14 +1690,14 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
// MaxOp
//===----------------------------------------------------------------------===//

// MaxOp kernel builder
// MaxOp kernel builder.
void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "max");
}

// MaxOp verification
// MaxOp verification.
::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
Expand All @@ -1706,14 +1706,14 @@ ::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() {
// MeanOp
//===----------------------------------------------------------------------===//

// MeanOp kernel builder
// MeanOp kernel builder.
void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "mean");
}

// MeanOp verification
// MeanOp verification.
::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
Expand All @@ -1722,14 +1722,14 @@ ::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() {
// SumOp
//===----------------------------------------------------------------------===//

// SumOp kernel builder
// SumOp kernel builder.
void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "sum");
}

// SumOp verification
// SumOp verification.
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
48 changes: 48 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,4 +1271,52 @@ ::mlir::LogicalResult FillCacheOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// Reduction ops
//===----------------------------------------------------------------------===//

// Common verifier for all Reduction ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
int64_t inputTensorRank = inputType.getRank();

// TODO(mrakita): Only last two dimensions can be reduced, check for that
// too.
if (reduceDims && reduceDims->size() > 2 &&
static_cast<int64_t>(reduceDims->size()) != inputTensorRank) {
return reduceOp->emitOpError("Reduce on more than two dimensions is not "
"currently supported by TTNN");
}

return mlir::success();
}

//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//

// MaxOp verification.
::mlir::LogicalResult MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//

// MeanOp verification.
::mlir::LogicalResult MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

// SumOp verification.
::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

} // namespace mlir::tt::ttnn
18 changes: 18 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,24 @@ TTNNLayoutAttr::withShardShape(::mlir::MLIRContext *context,
getMemLayout());
}

// Construct a new TTNNLayoutAttr
//
// This function creates a deep copy of the current TTNNLayoutAttr and
// applies changes necessary to fit new tensor shape.
//
// param context The MLIR context.
// param tensorShape The new tensor shape.
// return The new TTNNLayoutAttr with the given tensor shape.
TTNNLayoutAttr TTNNLayoutAttr::withTensorShape(::mlir::MLIRContext *context,
ArrayRef<int64_t> tensorShape) {
// TODO(mrakita): This leaves default value of collapseIntervals parameter,
// which might be different than the original value used to create the layout
// attribute. This will work for now since we always use default value, but in
// the future we would need to take this into account.
return TTNNLayoutAttr::get(context, tensorShape, getElementType(),
getBufferType(), getGrid(), getMemLayout());
}

// Construct a new TTNNLayoutAttr
//
// This function constructs a new TTNNLayoutAttr with the given parameters.
Expand Down
50 changes: 30 additions & 20 deletions lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,40 +480,50 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
void runOnOperation() final {
if (decompositionWorkaroundsEnabled) {
RewritePatternSet patterns(&getContext());
patterns.add<
TTNNAllReduceWorkarounds,
workarounds::decomposition::ReduceOpsRewritePattern<ttnn::SumOp>,
workarounds::decomposition::ReduceOpsRewritePattern<ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsRewritePattern<ttnn::MeanOp>>(
&getContext());

runRewritePatterns(std::move(patterns));
patterns.add<TTNNAllReduceWorkarounds,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::SumOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MeanOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::SumOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MeanOp>>(&getContext());

runRewritePatterns(std::move(patterns),
GreedyRewriteConfig::kNoLimit /*maxIterations*/);
}
if (layouotWorkaroundsEnabled) {
RewritePatternSet patterns(&getContext());
patterns.add<TTNNOperandsWorkaroundsRewriter>(&getContext());

runRewritePatterns(std::move(patterns));
// All layout workarounds should be applied during the first iteration. If
// the workarounds are not applied in the first iteration, it indicates a
// bug in the workarounds implementation. Although the workarounds are
// applied in the first iteration, the rewriter must iterate through the
// IR once more to confirm that the fixpoint is reached. If the fixpoint
// is not reached in the second iteration, it indicates a bug in the
// workarounds implementation.
const int64_t maxIterations = 2;
runRewritePatterns(std::move(patterns), maxIterations);
}
}

private:
void runRewritePatterns(RewritePatternSet &&patterns) {
// Runs rewrite patterns with specified maximum number of iterations the
// rewriter will perform on the IR. The rewriter will iterate through the IR
// until a fixpoint is reached.
void runRewritePatterns(RewritePatternSet &&patterns, int64_t maxIterations) {
FrozenRewritePatternSet patternSet(std::move(patterns));
GreedyRewriteConfig config = GreedyRewriteConfig();
config.maxIterations = maxIterations;
// This configuration specifies that the rewriter should traverse the IR
// in a top-down order.
config.useTopDownTraversal = true;
// This configuration specifies the maximum number of iterations the
// rewriter will perform on the IR. The rewriter will iterate through the
// IR until a fixpoint is reached. All workarounds should be applied
// during the first iteration. If the workarounds are not applied in the
// first iteration, it indicates a bug in the workarounds implementation.
// Although the workarounds are applied in the first iteration, the
// rewriter must iterate through the IR once more to confirm that the
// fixpoint is reached. If the fixpoint is not reached in the second
// iteration, it indicates a bug in the workarounds implementation.
config.maxIterations = 2;
if (failed(
applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) {
signalPassFailure();
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/reduction/max_op_negative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not ttmlir-opt --split-input-file --ttir-to-ttnn-backend-pipeline="system-desc-path=ttrt-artifacts/system_desc.ttsys" %s 2>&1 | FileCheck %s
// Negative tests for Max op.
module {
func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> {
%0 = tensor.empty() : tensor<128x1x1x1xbf16>
// CHECK: error: 'ttnn.max' op Reduce on more than two dimensions is not currently supported by TTNN
%1 = "ttir.max"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
return %1 : tensor<128x1x1x1xbf16>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/reduction/mean_op_negative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not ttmlir-opt --split-input-file --ttir-to-ttnn-backend-pipeline="system-desc-path=ttrt-artifacts/system_desc.ttsys" %s 2>&1 | FileCheck %s
// Negative tests for Mean op.
module {
func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> {
%0 = tensor.empty() : tensor<128x1x1x1xbf16>
// CHECK: error: 'ttnn.mean' op Reduce on more than two dimensions is not currently supported by TTNN
%1 = "ttir.mean"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
return %1 : tensor<128x1x1x1xbf16>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/reduction/sum_op_negative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not ttmlir-opt --split-input-file --ttir-to-ttnn-backend-pipeline="system-desc-path=ttrt-artifacts/system_desc.ttsys" %s 2>&1 | FileCheck %s
// Negative tests for Sum op.
module {
func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> {
%0 = tensor.empty() : tensor<128x1x1x1xbf16>
// CHECK: error: 'ttnn.sum' op Reduce on more than two dimensions is not currently supported by TTNN
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
return %1 : tensor<128x1x1x1xbf16>
}
}

0 comments on commit 3d6c7db

Please sign in to comment.