Skip to content

Commit

Permalink
Fix legalization of stablehlo.dot_general to TTIR matmul (#1311)
Browse files Browse the repository at this point in the history
Add batch matmul silicon test
  • Loading branch information
LPanosTT authored Nov 20, 2024
1 parent 8434117 commit 8ba7d3d
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 14 deletions.
75 changes: 63 additions & 12 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,30 +279,81 @@ class StableHLOToTTIRDotGeneralOpConversionPattern
::mlir::stablehlo::DotDimensionNumbersAttr dimensions =
adaptor.getDotDimensionNumbers();

if (dimensions.getLhsContractingDimensions().empty() ||
dimensions.getRhsContractingDimensions().empty()) {
return rewriter.notifyMatchFailure(srcOp,
"Contracting dimension is missing.");
if (dimensions.getLhsContractingDimensions().size() != 1 ||
dimensions.getRhsContractingDimensions().size() != 1) {
return rewriter.notifyMatchFailure(
srcOp,
"LHS and RHS must have exactly 1 contracting dimension each. "
"Received LHS contracting dims: " +
std::to_string(dimensions.getLhsContractingDimensions().size()) +
", RHS contracting dims: " +
std::to_string(dimensions.getRhsContractingDimensions().size()));
}

// Use negative indexing to determine if this is a valid matmul since math
// is done over the final two dimensions.
int64_t lhsContractingDim = dimensions.getLhsContractingDimensions()[0] -
srcOp.getLhs().getType().getRank();
int64_t rhsContractingDim = dimensions.getRhsContractingDimensions()[0] -
srcOp.getRhs().getType().getRank();

if (lhsContractingDim != -1) {
return rewriter.notifyMatchFailure(
srcOp, "Only support contracting dimensions that correspond to valid "
"matmuls. LHS contracting dimension must be " +
std::to_string(srcOp.getLhs().getType().getRank() - 1) +
". Got " + std::to_string(lhsContractingDim));
}

if (dimensions.getLhsContractingDimensions()[0] != 1) {
if (rhsContractingDim != -2) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
srcOp, "Only support contracting dimensions that correspond to valid "
"matmuls. RHS contracting dimension must be " +
std::to_string(srcOp.getRhs().getType().getRank() - 2) +
". Got " + std::to_string(rhsContractingDim));
}

if (dimensions.getRhsContractingDimensions()[0] != 0) {
if (dimensions.getLhsBatchingDimensions() !=
dimensions.getRhsBatchingDimensions()) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
srcOp, "LHS and RHS must have same batching dimensions.");
}

if (!dimensions.getLhsBatchingDimensions().empty()) {
// For the RHS, all dimensions which are not the row and column dimensions
// must be 1 OR they must be equal to the corresponding dimension in the
// LHS. If the RHS has less dimensions than the LHS we will assume that the
// missing dimensions are 1.

auto lhsShape = srcOp.getLhs().getType().getShape().vec();
auto rhsShape = srcOp.getRhs().getType().getShape().vec();

if (rhsShape.size() > lhsShape.size()) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
srcOp, "RHS must not be a higher rank than LHS.");
}

while (rhsShape.size() < lhsShape.size()) {
rhsShape.insert(rhsShape.begin(), 1);
}

// Need only to check dims to the left of dim -2 on the RHS
bool allOnes = true;
bool mismatchedDims = false;
for (int32_t i = rhsShape.size() - 3; i >= 0; i--) {
if (rhsShape[i] != 1) {
allOnes = false;
}

if (rhsShape[i] != lhsShape[i]) {
mismatchedDims = true;
}
}

if (!dimensions.getRhsBatchingDimensions().empty()) {
if (mismatchedDims && !allOnes) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
srcOp, "All dimensions in the RHS that are not the row and column "
"dimensions must be 1 OR they must all be equal to the "
"corresponding dimensions in the LHS.");
}

return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module {
func.func @main(%arg0: tensor<8x1x920xbf16>, %arg1: tensor<8x100x32xbf16>, %arg2: tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16> {
%0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<8x32x920xbf16>) -> tensor<8x32x920xbf16>
// CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]]
%1 = stablehlo.dot_general %arg1, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x100x32xbf16>, tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16>
return %1 : tensor<8x100x920xbf16>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_dot_general attributes {} {
func.func public @test_dot_general(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> {
module @jit_dot_general_2d attributes {} {
func.func public @test_dot_general_2d(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> {
// CHECK-LABEL: func.func public @test_dot_general
// CHECK: ttnn.empty
// CHECK: ttnn.matmul
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_dot_general_4d attributes {} {
func.func public @test_dot_general_4d(%arg0 : tensor<1x128x16x32xf32>, %arg1 : tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> {
// CHECK-LABEL: func.func public @test_dot_general
// CHECK: ttnn.empty
// CHECK: ttnn.matmul
// CHECK-SAME: tensor<1x128x16x32xf32,
// CHECK-SAME: tensor<1x128x32x8xf32,
// CHECK-SAME: tensor<1x128x16x8xf32,
// CHECK-SAME: -> tensor<1x128x16x8xf32
%0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<1x128x16x32xf32>, tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32>
return %0 : tensor<1x128x16x8xf32>
}
}

0 comments on commit 8ba7d3d

Please sign in to comment.