diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h b/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h index 2c307044a8..e7a13c6575 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h @@ -13,8 +13,6 @@ #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SmallVector.h" -#include - namespace mlir::tt::ttnn::workarounds::decomposition { // Extracts reduce dimensions' values from the dimArg attribute. In case when @@ -24,8 +22,8 @@ getReduceDims(const std::optional &dimArg); // Calculates the shape of the new Reduce op created in the workaround, based // on the input shape and reducing dimensions. -std::vector -calculateNewReduceShape(const RankedTensorType &inputType, +llvm::SmallVector +calculateNewReduceShape(RankedTensorType inputType, const std::optional &dimArg); // Creates the dimArg attribute of the new Reduce op created in the workaround. @@ -34,7 +32,7 @@ calculateNewReduceShape(const RankedTensorType &inputType, // rank when reduce dimensions are not specified, but it doesn't support reduce // for tensors with rank larger than 2 when reduce dimensions are specified. mlir::ArrayAttr -createNewReduceDimArg(const RankedTensorType &inputType, +createNewReduceDimArg(RankedTensorType inputType, const std::optional &dimArg); // This workaround addresses next two Metal issues: @@ -77,10 +75,10 @@ class ReduceOpsRewritePattern : public OpRewritePattern { } private: - ReduceOp createReduceOpWithKeepDim(ReduceOp &srcOp, PatternRewriter &rewriter, - const RankedTensorType &inputType, - const RankedTensorType &outputType) const { - std::vector outputShapeVec = + ReduceOp createReduceOpWithKeepDim(ReduceOp srcOp, PatternRewriter &rewriter, + RankedTensorType inputType, + RankedTensorType outputType) const { + llvm::SmallVector outputShapeVec = calculateNewReduceShape(inputType, srcOp.getDimArg()); RankedTensorType newOutputType = RankedTensorType::get( @@ -91,9 +89,9 @@ class ReduceOpsRewritePattern : public OpRewritePattern { createNewReduceDimArg(inputType, srcOp.getDimArg())); } - void replaceOpWithReshapeOp(ReduceOp &srcOp, ReduceOp &newReduceOp, + void replaceOpWithReshapeOp(ReduceOp srcOp, ReduceOp newReduceOp, PatternRewriter &rewriter, - RankedTensorType &outputType) const { + RankedTensorType outputType) const { mlir::ArrayAttr shapeAttr = rewriter.getI32ArrayAttr( llvm::SmallVector(outputType.getShape())); diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index b06d58c360..160bdda690 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -340,10 +340,8 @@ class ReductionOpConversionPattern : public OpConversionPattern { // TODO(mrakita): Only last two dimensions can be reduced, check for that // too. Should this check be in verifier? - if (adaptor.getDimArg().has_value() && - adaptor.getDimArg().value().size() > 2 && - static_cast(adaptor.getDimArg().value().size()) != - inputTensorRank) { + if (adaptor.getDimArg() && adaptor.getDimArg()->size() > 2 && + static_cast(adaptor.getDimArg()->size()) != inputTensorRank) { return rewriter.notifyMatchFailure(op, "Reduce on more than two dimensions " "is not currently supported by TTNN"); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 532960a5bd..9a0045ef0d 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -24,7 +24,6 @@ #include #include -#include #define GET_OP_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc" @@ -1658,15 +1657,13 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block, // Common verifier for all Reduce ops. static mlir::LogicalResult -verifyReduceOp(mlir::Operation *reduceOp, +verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, const std::optional &reduceDims) { if (!reduceDims) { return mlir::success(); } - int64_t inputTensorRank = - mlir::cast(*reduceOp->getOperandTypes().begin()) - .getRank(); + int64_t inputTensorRank = inputType.getRank(); llvm::SmallSet uniqueReduceDims; for (mlir::Attribute reduceDim : *reduceDims) { @@ -1702,7 +1699,7 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MaxOp verification ::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() { - return verifyReduceOp(getOperation(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); } //===----------------------------------------------------------------------===// @@ -1718,7 +1715,7 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MeanOp verification ::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() { - return verifyReduceOp(getOperation(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); } //===----------------------------------------------------------------------===// @@ -1734,5 +1731,5 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // SumOp verification ::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() { - return verifyReduceOp(getOperation(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); } diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp index f02321b57e..740de3123f 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp @@ -4,29 +4,30 @@ #include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h" +#include "llvm/ADT/SmallSet.h" + #include -#include namespace mlir::tt::ttnn::workarounds::decomposition { llvm::SmallVector getReduceDims(const std::optional &dimArg) { llvm::SmallVector reduceDims; - if (!dimArg.has_value()) { + if (!dimArg) { return reduceDims; } - for (const mlir::Attribute &reduceDim : dimArg.value()) { + for (const mlir::Attribute &reduceDim : *dimArg) { reduceDims.push_back(mlir::cast(reduceDim).getInt()); } return reduceDims; } -std::vector -calculateNewReduceShape(const RankedTensorType &inputType, +llvm::SmallVector +calculateNewReduceShape(RankedTensorType inputType, const std::optional &dimArg) { - std::vector outputShapeVec = inputType.getShape().vec(); + llvm::SmallVector outputShapeVec(inputType.getShape()); llvm::SmallVector reduceDims = getReduceDims(dimArg); if (reduceDims.empty()) { @@ -49,15 +50,15 @@ calculateNewReduceShape(const RankedTensorType &inputType, } mlir::ArrayAttr -createNewReduceDimArg(const RankedTensorType &inputType, +createNewReduceDimArg(RankedTensorType inputType, const std::optional &dimArg) { llvm::SmallVector reduceDims = getReduceDims(dimArg); if (reduceDims.empty()) { return nullptr; } - std::unordered_set uniqueReduceDims(reduceDims.begin(), - reduceDims.end()); + llvm::SmallSet uniqueReduceDims(reduceDims.begin(), + reduceDims.end()); if (uniqueReduceDims.size() == inputType.getShape().size()) { // In case when reduce is done over all dimensions of the input nullptr is // returned, because Metal supports reduce over all dimensions for any @@ -67,7 +68,7 @@ createNewReduceDimArg(const RankedTensorType &inputType, return nullptr; } - return dimArg.value(); + return *dimArg; } } // namespace mlir::tt::ttnn::workarounds::decomposition