Skip to content

Commit

Permalink
Add support for logical xor op. (#1141)
Browse files Browse the repository at this point in the history
* Add end-to-end implementation of the logical xor op.
* Add stablehlo to ttir conversion
  • Loading branch information
mmanzoorTT authored Nov 13, 2024
1 parent 9d08234 commit 9cc313f
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 29 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,13 @@ def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or"> {
}];
}

def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> {
let summary = "Eltwise logical xor.";
let description = [{
Eltwise logical xor operation.
}];
}

def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> {
let summary = "Eltwise maximum OP.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ def TTNN_LogicalOrOp : TTNN_ElementwiseBinaryOp<"logical_or"> {
}];
}

def TTNN_LogicalXorOp : TTNN_ElementwiseBinaryOp<"logical_xor"> {
let summary = "Eltwise logical xor.";
let description = [{
Eltwise logical xor operation.
}];
}

def TTNN_MaximumOp : TTNN_ElementwiseBinaryOp<"maximum"> {
let summary = "Eltwise maximum OP.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ enum EltwiseOpType: uint32 {
Floor = 34,
Where = 35,
Gelu = 36,
LogicalXor = 37,
}

union EltwiseOpParams {
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,9 @@ void addLogicalOpConversionPattern(MLIRContext *ctx,
ctx);
patterns.add<StableHLOToTTIROpLogicalOpConversionPattern<
mlir::stablehlo::OrOp, mlir::tt::ttir::LogicalOrOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpLogicalOpConversionPattern<
mlir::stablehlo::XorOp, mlir::tt::ttir::LogicalXorOp>>(typeConverter,
ctx);
}

void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::LogicalAndOp, ttnn::LogicalAndOp>,
ElementwiseOpConversionPattern<ttir::LogicalOrOp, ttnn::LogicalOrOp>,
ElementwiseOpConversionPattern<ttir::LogicalNotOp, ttnn::LogicalNotOp>,
ElementwiseOpConversionPattern<ttir::LogicalXorOp, ttnn::LogicalXorOp>,
ElementwiseOpConversionPattern<ttir::MultiplyOp, ttnn::MultiplyOp>,
ElementwiseOpConversionPattern<ttir::EqualOp, ttnn::EqualOp>,
ElementwiseOpConversionPattern<ttir::NotEqualOp, ttnn::NotEqualOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<EltwiseBinaryOpConversionPattern<ttnn::AddOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalAndOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalOrOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalXorOp>,
EltwiseBinaryOpConversionPattern<ttnn::SubtractOp>,
EltwiseBinaryOpConversionPattern<ttnn::MultiplyOp>,
DefaultOpConversionPattern<ttnn::EqualOp>,
Expand Down
5 changes: 5 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::LogicalNot;
} else if constexpr (std::is_same_v<EltwiseOp, LogicalOrOp>) {
type = ::tt::target::ttnn::EltwiseOpType::LogicalOr;
} else if constexpr (std::is_same_v<EltwiseOp, LogicalXorOp>) {
type = ::tt::target::ttnn::EltwiseOpType::LogicalXor;
} else if constexpr (std::is_same_v<EltwiseOp, MultiplyOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Multiply;
} else if constexpr (std::is_same_v<EltwiseOp, NegOp>) {
Expand Down Expand Up @@ -590,6 +592,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto orOp = dyn_cast<LogicalOrOp>(op); orOp) {
return createOperation(cache, createEltwiseOp(cache, orOp), debugString);
}
if (auto xorOp = dyn_cast<LogicalXorOp>(op); xorOp) {
return createOperation(cache, createEltwiseOp(cache, xorOp), debugString);
}
if (auto multiplyOp = dyn_cast<MultiplyOp>(op); multiplyOp) {
return createOperation(cache, createEltwiseOp(cache, multiplyOp),
debugString);
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_or);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LogicalXor: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_xor);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Multiply: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::multiply);
break;
Expand Down
49 changes: 20 additions & 29 deletions test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir
Original file line number Diff line number Diff line change
@@ -1,39 +1,30 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_eltwise_compare attributes {} {
func.func public @logical_and(%arg0: tensor<13x31xi1>, %arg1: tensor<13x31xi1>) -> tensor<13x31xi1> {
%0 = stablehlo.and %arg0, %arg1 : tensor<13x31xi1>
// CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16>
module @jit_eltwise_logical attributes {} {
func.func public @logical_and(%arg0: tensor<32x32xi1>, %arg1: tensor<32x32xi1>) -> tensor<32x32xi1> {
// CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]]
// CHECK: = "ttir.logical_and"(%arg0, %arg1, %[[E]])
// CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16>
return %0 : tensor<13x31xi1>
// CHECK: return %1 : tensor<13x31xbf16>
// CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%0 = stablehlo.and %arg0, %arg1 : tensor<32x32xi1>
// CHECK: return %1 : [[TENSOR]]
return %0 : tensor<32x32xi1>
}

func.func public @logical_or(%arg0: tensor<13x31xi1>, %arg1: tensor<13x31xi1>) -> tensor<13x31xi1> {
%0 = stablehlo.or %arg0, %arg1 : tensor<13x31xi1>
// CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16>
func.func public @logical_or(%arg0: tensor<32x32xi1>, %arg1: tensor<32x32xi1>) -> tensor<32x32xi1> {
// CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]]
// CHECK: = "ttir.logical_or"(%arg0, %arg1, %[[E]])
// CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16>
return %0 : tensor<13x31xi1>
// CHECK: return %1 : tensor<13x31xbf16>
// CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%0 = stablehlo.or %arg0, %arg1 : tensor<32x32xi1>
// CHECK: return %1 : [[TENSOR]]
return %0 : tensor<32x32xi1>
}

func.func public @logical_not(%arg0: tensor<13x31xi1>) -> tensor<13x31xi1> {
%0 = stablehlo.not %arg0 : tensor<13x31xi1>
// CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16>
// CHECK: = "ttir.logical_not"(%arg0, %[[E]])
// CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16>
return %0 : tensor<13x31xi1>
// CHECK: return %1 : tensor<13x31xbf16>
}

func.func public @logical_not_scalar(%arg0: tensor<i1>) -> tensor<i1> {
%0 = stablehlo.not %arg0 : tensor<i1>
// CHECK: %[[E:.*]] = tensor.empty() : tensor<1xbf16>
// CHECK: = "ttir.logical_not"(%arg0, %[[E]])
// CHECK-SAME: (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16>
return %0 : tensor<i1>
// CHECK: return %1 : tensor<1xbf16>
func.func public @logical_xor(%arg0: tensor<32x32xi1>, %arg1: tensor<32x32xi1>) -> tensor<32x32xi1> {
// CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]]
// CHECK: = "ttir.logical_xor"(%arg0, %arg1, %[[E]])
// CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%0 = stablehlo.xor %arg0, %arg1 : tensor<32x32xi1>
// CHECK: return %1 : [[TENSOR]]
return %0 : tensor<32x32xi1>
}
}
21 changes: 21 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/unary/logical_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_eltwise_logical attributes {} {
func.func public @logical_not(%arg0: tensor<32x32xi1>) -> tensor<32x32xi1> {
// CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]]
// CHECK: = "ttir.logical_not"(%arg0, %[[E]])
// CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%0 = stablehlo.not %arg0 : tensor<32x32xi1>
// CHECK: return %1 : [[TENSOR]]
return %0 : tensor<32x32xi1>
}

func.func public @logical_not_scalar(%arg0: tensor<i1>) -> tensor<i1> {
// CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<1xbf16>]]
// CHECK: = "ttir.logical_not"(%arg0, %[[E]])
// CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%0 = stablehlo.not %arg0 : tensor<i1>
// CHECK: return %1 : [[TENSOR]]
return %0 : tensor<i1>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s

#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> {
// CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]]
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: %{{[0-9]+}} = "ttnn.logical_xor"
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: -> [[TENSOR]]
%1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %1 : tensor<64x128xbf16>
}
}
18 changes: 18 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> {
// CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]]
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: %{{[0-9]+}} = "ttnn.logical_xor"
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: -> [[TENSOR]]
%1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %1 : tensor<64x128xbf16>
}
12 changes: 12 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_logical.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,16 @@ module attributes {} {
// CHECK-SAME: tensor<64x128xf32,
return %1 : tensor<64x128xf32>
}

func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> {
// CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]]
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: %{{[0-9]+}} = "ttnn.logical_xor"
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: -> [[TENSOR]]
%1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %1 : tensor<64x128xbf16>
}
}

0 comments on commit 9cc313f

Please sign in to comment.