Skip to content

Commit

Permalink
Add TOSA to TTIR conversions for some simple ops (#1445)
Browse files Browse the repository at this point in the history
* Add Tosa conversion for sin

* Add Tosa conversion for sigmoid

* Add Tosa conversion for reciprocal

* Add Tosa conversion for rsqrt

* Add Tosa conversion for where

* Add Tosa conversion for maximum

* Add Tosa conversion for minimum

* Add neg and sub tests

* Refactor tests

* Improve tests; fix whitespace

* Add return checks to tests
  • Loading branch information
sgligorijevicTT authored Dec 3, 2024
1 parent 0640c7c commit 37e10f3
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 0 deletions.
20 changes: 20 additions & 0 deletions lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::NegateOp,
mlir::tt::ttir::NegOp>>(
typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::SinOp,
mlir::tt::ttir::SinOp>>(
typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::SigmoidOp, mlir::tt::ttir::SigmoidOp>>(typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::ReciprocalOp, mlir::tt::ttir::ReciprocalOp>>(typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::RsqrtOp, mlir::tt::ttir::RsqrtOp>>(typeConverter, ctx);
}

void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
Expand All @@ -102,6 +111,10 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
patterns.add<TosaToTTIRMultiplyOpConversionPattern>(typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::SubOp, mlir::tt::ttir::SubtractOp>>(typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::MaximumOp, mlir::tt::ttir::MaximumOp>>(typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::MinimumOp, mlir::tt::ttir::MinimumOp>>(typeConverter, ctx);
}

void addCompareOpsConversionPatterns(MLIRContext *ctx,
Expand All @@ -112,6 +125,12 @@ void addCompareOpsConversionPatterns(MLIRContext *ctx,
ctx);
}

void addElementwiseTernaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::SelectOp, mlir::tt::ttir::WhereOp>>(typeConverter, ctx);
}
} // namespace

namespace mlir::tt {
Expand All @@ -120,6 +139,7 @@ void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter);
addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter);
addElementwiseTernaryOpsConversionPatterns(ctx, patterns, typeConverter);
addCompareOpsConversionPatterns(ctx, patterns, typeConverter);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_maximum(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.maximum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.maximum"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_minimum(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.minimum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.minimum"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_sub(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.sub %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.subtract"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_select(%arg0: tensor<32x128xi1>, %arg1: tensor<32x128xf32>, %arg2: tensor<32x128xf32>) -> tensor<32x128xf32> {
// CHECK: func.func {{.+}} [[SELECTOR:tensor<[0-9]+x[0-9]+xi1>]]
%0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x128xi1>, tensor<32x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.where"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[SELECTOR]], [[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<32x128xf32>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/negate_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.neg"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.reciprocal %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.reciprocal"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/rsqrt_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.rsqrt %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.rsqrt"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.sigmoid"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/sin_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.sin"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
// CHECK: return %[[VAL]] : [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
}
}

0 comments on commit 37e10f3

Please sign in to comment.