Skip to content

Commit

Permalink
Fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mrakitaTT committed Dec 19, 2024
1 parent 2dc61ef commit d46471f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallVector.h"

#include <vector>

namespace mlir::tt::ttnn::workarounds::decomposition {

// Extracts reduce dimensions' values from the dimArg attribute. In case when
Expand All @@ -24,8 +22,8 @@ getReduceDims(const std::optional<mlir::ArrayAttr> &dimArg);

// Calculates the shape of the new Reduce op created in the workaround, based
// on the input shape and reducing dimensions.
std::vector<int64_t>
calculateNewReduceShape(const RankedTensorType &inputType,
llvm::SmallVector<int64_t>
calculateNewReduceShape(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg);

// Creates the dimArg attribute of the new Reduce op created in the workaround.
Expand All @@ -34,7 +32,7 @@ calculateNewReduceShape(const RankedTensorType &inputType,
// rank when reduce dimensions are not specified, but it doesn't support reduce
// for tensors with rank larger than 2 when reduce dimensions are specified.
mlir::ArrayAttr
createNewReduceDimArg(const RankedTensorType &inputType,
createNewReduceDimArg(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg);

// This workaround addresses next two Metal issues:
Expand Down Expand Up @@ -77,10 +75,10 @@ class ReduceOpsRewritePattern : public OpRewritePattern<ReduceOp> {
}

private:
ReduceOp createReduceOpWithKeepDim(ReduceOp &srcOp, PatternRewriter &rewriter,
const RankedTensorType &inputType,
const RankedTensorType &outputType) const {
std::vector<int64_t> outputShapeVec =
ReduceOp createReduceOpWithKeepDim(ReduceOp srcOp, PatternRewriter &rewriter,
RankedTensorType inputType,
RankedTensorType outputType) const {
llvm::SmallVector<int64_t> outputShapeVec =
calculateNewReduceShape(inputType, srcOp.getDimArg());

RankedTensorType newOutputType = RankedTensorType::get(
Expand All @@ -91,9 +89,9 @@ class ReduceOpsRewritePattern : public OpRewritePattern<ReduceOp> {
createNewReduceDimArg(inputType, srcOp.getDimArg()));
}

void replaceOpWithReshapeOp(ReduceOp &srcOp, ReduceOp &newReduceOp,
void replaceOpWithReshapeOp(ReduceOp srcOp, ReduceOp newReduceOp,
PatternRewriter &rewriter,
RankedTensorType &outputType) const {
RankedTensorType outputType) const {
mlir::ArrayAttr shapeAttr = rewriter.getI32ArrayAttr(
llvm::SmallVector<int32_t>(outputType.getShape()));

Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,8 @@ class ReductionOpConversionPattern : public OpConversionPattern<TTIROpTy> {

// TODO(mrakita): Only last two dimensions can be reduced, check for that
// too. Should this check be in verifier?
if (adaptor.getDimArg().has_value() &&
adaptor.getDimArg().value().size() > 2 &&
static_cast<int64_t>(adaptor.getDimArg().value().size()) !=
inputTensorRank) {
if (adaptor.getDimArg() && adaptor.getDimArg()->size() > 2 &&
static_cast<int64_t>(adaptor.getDimArg()->size()) != inputTensorRank) {
return rewriter.notifyMatchFailure(op,
"Reduce on more than two dimensions "
"is not currently supported by TTNN");
Expand Down
13 changes: 5 additions & 8 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

#include <cstdint>
#include <string>
#include <unordered_set>

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc"
Expand Down Expand Up @@ -1658,15 +1657,13 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block,

// Common verifier for all Reduce ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp,
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
if (!reduceDims) {
return mlir::success();
}

int64_t inputTensorRank =
mlir::cast<mlir::RankedTensorType>(*reduceOp->getOperandTypes().begin())
.getRank();
int64_t inputTensorRank = inputType.getRank();

llvm::SmallSet<int64_t, 4> uniqueReduceDims;
for (mlir::Attribute reduceDim : *reduceDims) {
Expand Down Expand Up @@ -1702,7 +1699,7 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// MaxOp verification
::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() {
return verifyReduceOp(getOperation(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1718,7 +1715,7 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// MeanOp verification
::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() {
return verifyReduceOp(getOperation(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1734,5 +1731,5 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// SumOp verification
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,30 @@

#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h"

#include "llvm/ADT/SmallSet.h"

#include <algorithm>
#include <unordered_set>

namespace mlir::tt::ttnn::workarounds::decomposition {

llvm::SmallVector<int64_t>
getReduceDims(const std::optional<mlir::ArrayAttr> &dimArg) {
llvm::SmallVector<int64_t> reduceDims;
if (!dimArg.has_value()) {
if (!dimArg) {
return reduceDims;
}

for (const mlir::Attribute &reduceDim : dimArg.value()) {
for (const mlir::Attribute &reduceDim : *dimArg) {
reduceDims.push_back(mlir::cast<mlir::IntegerAttr>(reduceDim).getInt());
}

return reduceDims;
}

std::vector<int64_t>
calculateNewReduceShape(const RankedTensorType &inputType,
llvm::SmallVector<int64_t>
calculateNewReduceShape(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg) {
std::vector<int64_t> outputShapeVec = inputType.getShape().vec();
llvm::SmallVector<int64_t> outputShapeVec(inputType.getShape());
llvm::SmallVector<int64_t> reduceDims = getReduceDims(dimArg);

if (reduceDims.empty()) {
Expand All @@ -49,15 +50,15 @@ calculateNewReduceShape(const RankedTensorType &inputType,
}

mlir::ArrayAttr
createNewReduceDimArg(const RankedTensorType &inputType,
createNewReduceDimArg(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg) {
llvm::SmallVector<int64_t> reduceDims = getReduceDims(dimArg);
if (reduceDims.empty()) {
return nullptr;
}

std::unordered_set<int64_t> uniqueReduceDims(reduceDims.begin(),
reduceDims.end());
llvm::SmallSet<int64_t, 4> uniqueReduceDims(reduceDims.begin(),
reduceDims.end());
if (uniqueReduceDims.size() == inputType.getShape().size()) {
// In case when reduce is done over all dimensions of the input nullptr is
// returned, because Metal supports reduce over all dimensions for any
Expand All @@ -67,7 +68,7 @@ createNewReduceDimArg(const RankedTensorType &inputType,
return nullptr;
}

return dimArg.value();
return *dimArg;
}

} // namespace mlir::tt::ttnn::workarounds::decomposition

0 comments on commit d46471f

Please sign in to comment.