From 8ba7d3db2a87a8b49dc9e412180412939aa2e84f Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Wed, 20 Nov 2024 10:33:22 -0500 Subject: [PATCH] Fix legalization of stablehlo.dot_general to TTIR matmul (#1311) Add batch matmul silicon test --- .../StableHLOToTTIRPatterns.cpp | 75 ++++++++++++++++--- .../dot_general_2d.mlir} | 0 .../dot_general/dot_general_3d.mlir | 10 +++ .../dot_general_op_2d.mlir} | 4 +- .../dot_general_op_batch_matmul.mlir | 21 ++++++ 5 files changed, 96 insertions(+), 14 deletions(-) rename test/ttmlir/Conversion/StableHLOToTTIR/{dot_general_op.mlir => dot_general/dot_general_2d.mlir} (100%) create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir rename test/ttmlir/Silicon/StableHLO/{dot_general_op.mlir => dot_general/dot_general_op_2d.mlir} (82%) create mode 100644 test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 1ec8556cf..9120edc11 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -279,30 +279,81 @@ class StableHLOToTTIRDotGeneralOpConversionPattern ::mlir::stablehlo::DotDimensionNumbersAttr dimensions = adaptor.getDotDimensionNumbers(); - if (dimensions.getLhsContractingDimensions().empty() || - dimensions.getRhsContractingDimensions().empty()) { - return rewriter.notifyMatchFailure(srcOp, - "Contracting dimension is missing."); + if (dimensions.getLhsContractingDimensions().size() != 1 || + dimensions.getRhsContractingDimensions().size() != 1) { + return rewriter.notifyMatchFailure( + srcOp, + "LHS and RHS must have exactly 1 contracting dimension each. " + "Received LHS contracting dims: " + + std::to_string(dimensions.getLhsContractingDimensions().size()) + + ", RHS contracting dims: " + + std::to_string(dimensions.getRhsContractingDimensions().size())); + } + + // Use negative indexing to determine if this is a valid matmul since math + // is done over the final two dimensions. + int64_t lhsContractingDim = dimensions.getLhsContractingDimensions()[0] - + srcOp.getLhs().getType().getRank(); + int64_t rhsContractingDim = dimensions.getRhsContractingDimensions()[0] - + srcOp.getRhs().getType().getRank(); + + if (lhsContractingDim != -1) { + return rewriter.notifyMatchFailure( + srcOp, "Only support contracting dimensions that correspond to valid " + "matmuls. LHS contracting dimension must be " + + std::to_string(srcOp.getLhs().getType().getRank() - 1) + + ". Got " + std::to_string(lhsContractingDim)); } - if (dimensions.getLhsContractingDimensions()[0] != 1) { + if (rhsContractingDim != -2) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "Only support contracting dimensions that correspond to valid " + "matmuls. RHS contracting dimension must be " + + std::to_string(srcOp.getRhs().getType().getRank() - 2) + + ". Got " + std::to_string(rhsContractingDim)); } - if (dimensions.getRhsContractingDimensions()[0] != 0) { + if (dimensions.getLhsBatchingDimensions() != + dimensions.getRhsBatchingDimensions()) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "LHS and RHS must have same batching dimensions."); } - if (!dimensions.getLhsBatchingDimensions().empty()) { + // For the RHS, all dimensions which are not the row and column dimensions + // must be 1 OR they must be equal to the corresponding dimension in the + // LHS. If the RHS has less dimensions than the LHS we will assume that the + // missing dimensions are 1. + + auto lhsShape = srcOp.getLhs().getType().getShape().vec(); + auto rhsShape = srcOp.getRhs().getType().getShape().vec(); + + if (rhsShape.size() > lhsShape.size()) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "RHS must not be a higher rank than LHS."); + } + + while (rhsShape.size() < lhsShape.size()) { + rhsShape.insert(rhsShape.begin(), 1); + } + + // Need only to check dims to the left of dim -2 on the RHS + bool allOnes = true; + bool mismatchedDims = false; + for (int32_t i = rhsShape.size() - 3; i >= 0; i--) { + if (rhsShape[i] != 1) { + allOnes = false; + } + + if (rhsShape[i] != lhsShape[i]) { + mismatchedDims = true; + } } - if (!dimensions.getRhsBatchingDimensions().empty()) { + if (mismatchedDims && !allOnes) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "All dimensions in the RHS that are not the row and column " + "dimensions must be 1 OR they must all be equal to the " + "corresponding dimensions in the LHS."); } return success(); diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_2d.mlir similarity index 100% rename from test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir rename to test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_2d.mlir diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir new file mode 100644 index 000000000..52e2d8001 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module { + func.func @main(%arg0: tensor<8x1x920xbf16>, %arg1: tensor<8x100x32xbf16>, %arg2: tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16> { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<8x32x920xbf16>) -> tensor<8x32x920xbf16> + // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] + %1 = stablehlo.dot_general %arg1, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x100x32xbf16>, tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16> + return %1 : tensor<8x100x920xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/dot_general_op.mlir b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir similarity index 82% rename from test/ttmlir/Silicon/StableHLO/dot_general_op.mlir rename to test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir index 57a0bdcd8..179f112b4 100644 --- a/test/ttmlir/Silicon/StableHLO/dot_general_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir @@ -6,8 +6,8 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // RUN: FileCheck --input-file=%t.mlir %s -module @jit_dot_general attributes {} { - func.func public @test_dot_general(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { +module @jit_dot_general_2d attributes {} { + func.func public @test_dot_general_2d(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { // CHECK-LABEL: func.func public @test_dot_general // CHECK: ttnn.empty // CHECK: ttnn.matmul diff --git a/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir new file mode 100644 index 000000000..f23ece73f --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir @@ -0,0 +1,21 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_dot_general_4d attributes {} { + func.func public @test_dot_general_4d(%arg0 : tensor<1x128x16x32xf32>, %arg1 : tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> { + // CHECK-LABEL: func.func public @test_dot_general + // CHECK: ttnn.empty + // CHECK: ttnn.matmul + // CHECK-SAME: tensor<1x128x16x32xf32, + // CHECK-SAME: tensor<1x128x32x8xf32, + // CHECK-SAME: tensor<1x128x16x8xf32, + // CHECK-SAME: -> tensor<1x128x16x8xf32 + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<1x128x16x32xf32>, tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> + return %0 : tensor<1x128x16x8xf32> + } +}