Skip to content

Commit

Permalink
Added TTIR Reverse op and stablehlo -> ttir conversion (#1346)
Browse files Browse the repository at this point in the history
Fixes #1330. 
A part of the solution for for
#1142 (need to add tt-torch
and tt-xla tests in separate PRs) .

Not implemented end to end since this OP should not exists e2e but
rather fitted inside transposed conv op.
  • Loading branch information
kmitrovicTT authored Dec 19, 2024
1 parent 305ea47 commit 8f326f4
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 0 deletions.
34 changes: 34 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,40 @@ def TTIR_OnesOp : TTIR_Op<"ones"> {
let results = (outs AnyRankedTensor:$result);
}

def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>]> {
let summary = "Reverse operation.";

let description = [{
Reverses the order of elements in the `operand` along the specified
`dimensions` and produces a `result` tensor.

Examples:
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "ttir.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "ttir.reverse"(%operand) {
dimensions = array<i64: 1, 0>
} : (tensor<3x2xi64>) -> tensor<3x2xi64>
// %result: [[6, 5], [4, 3], [2, 1]]
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
DenseI64ArrayAttr:$dimensions);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
34 changes: 34 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,33 @@ class StableHLOToTTIRReturnOpConversionPattern
}
};

class StableHLOToTTIROpReverseOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ReverseOp> {

using OpConversionPattern<mlir::stablehlo::ReverseOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ReverseOp srcOp,
mlir::stablehlo::ReverseOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ReverseOp>(
srcOp,
outputType, // result type
adaptor.getOperand(), // input
outputTensor, // output
adaptor.getDimensionsAttr() // dimensions
);
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1910,6 +1937,12 @@ void addReturnOpConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIRReturnOpConversionPattern>(typeConverter, ctx);
}

void addReverseOpConversionPattern(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpReverseOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand Down Expand Up @@ -1938,6 +1971,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addIotaOpConversionPattern(ctx, patterns, typeConverter);
addScatterOpConversionPatterns(ctx, patterns, typeConverter);
addReturnOpConversionPatterns(ctx, patterns, typeConverter);
addReverseOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
34 changes: 34 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,40 @@ ::mlir::LogicalResult mlir::tt::ttir::FillCacheOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ReverseOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::tt::ttir::ReverseOp::verify() {
llvm::ArrayRef<int64_t> dimensions = getDimensions();

// Check that all given dimensions are unique/not repeating.
llvm::SmallDenseSet<int64_t> uniqueDims(dimensions.begin(), dimensions.end());

if (uniqueDims.size() != dimensions.size()) {
return emitOpError("dimensions should be unique. Got: ") << dimensions;
}

::mlir::RankedTensorType operandTy = getInput().getType();

// Check that each dimension is positive and within valid interval [0,
// operandRank).
for (int64_t dim : dimensions) {
if (dim < 0) {
return emitOpError(
"all dimensions should be non-negative. Got dimension: ")
<< dim;
}

if (dim >= operandTy.getRank()) {
return emitOpError("all dimensions should be in interval [0, ")
<< operandTy.getRank() << "). Got dimension: " << dim;
}
}

return success();
}

//===----------------------------------------------------------------------===//
// GenericOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/reverse_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s

module @jit_eltwise_reverse attributes {} {
func.func @reverse_op(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
%0 = "stablehlo.reverse"(%arg0) {dimensions = array<i64: 1, 0>} : (tensor<32x64xf32>) -> tensor<32x64xf32>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : tensor<32x64xf32>
// CHECK: %[[REV:[0-9]+]] = "ttir.reverse"(%arg0, %0) <{dimensions = array<i64: 1, 0>}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %0 : tensor<32x64xf32>
// CHECK: return %[[REV]] : tensor<32x64xf32>
}
}
34 changes: 34 additions & 0 deletions test/ttmlir/Dialect/TTIR/reverse/reverse_tests_negative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for reverse operation

// Verify that parsing fails if dimensions are not unique.
module attributes {} {
func.func @reverse_non_unique_dims(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
// CHECK: error: 'ttir.reverse' op dimensions should be unique. Got: 0, 0
%0 = tensor.empty() : tensor<32x64xf32>
%1 = "ttir.reverse"(%arg0, %0) <{dimensions = array<i64: 0, 0>}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %1 : tensor<32x64xf32>
}
}

// Verify that parsing fails if any dimension is negative.
// -----
module attributes {} {
func.func @reverse_negative_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
// CHECK: error: 'ttir.reverse' op all dimensions should be non-negative. Got dimension: -1
%0 = tensor.empty() : tensor<32x64xf32>
%1 = "ttir.reverse"(%arg0, %0) <{dimensions = array<i64: 0, -1>}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %1 : tensor<32x64xf32>
}
}

// Verify that parsing fails if any dimension is out of range [0, operandRank).
// -----
module attributes {} {
func.func @reverse_out_of_bounds_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
// CHECK: error: 'ttir.reverse' op all dimensions should be in interval [0, 2). Got dimension: 2
%0 = tensor.empty() : tensor<32x64xf32>
%1 = "ttir.reverse"(%arg0, %0) <{dimensions = array<i64: 2>}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %1 : tensor<32x64xf32>
}
}
24 changes: 24 additions & 0 deletions test/ttmlir/Dialect/TTIR/reverse/reverse_tests_positive.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: ttmlir-opt %s | FileCheck %s

module attributes {} {
func.func @reverse_first_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
%0 = tensor.empty() : tensor<32x64xf32>
// CHECK: %[[C:.*]] = "ttir.reverse"[[C:.*]]
%1 = "ttir.reverse"(%arg0, %0) <{dimensions = array<i64: 0>}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %1 : tensor<32x64xf32>
}

func.func @reverse_second_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
%0 = tensor.empty() : tensor<32x64xf32>
// CHECK: %[[C:.*]] = "ttir.reverse"[[C:.*]]
%1 = "ttir.reverse"(%arg0, %0) <{dimensions = array<i64: 1>}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %1 : tensor<32x64xf32>
}

func.func @reverse_both_dims(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
%0 = tensor.empty() : tensor<32x64xf32>
// CHECK: %[[C:.*]] = "ttir.reverse"[[C:.*]]
%1 = "ttir.reverse"(%arg0, %0) <{dimensions = array<i64: 0, 1>}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %1 : tensor<32x64xf32>
}
}

0 comments on commit 8f326f4

Please sign in to comment.