Skip to content

Commit

Permalink
Refactored code and fixed other issues
Browse files Browse the repository at this point in the history
  • Loading branch information
umalesTT committed Jan 3, 2025
1 parent 2aa4a76 commit e651d14
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 343 deletions.
8 changes: 4 additions & 4 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ def TTIR_DotGeneralOp : TTIR_Op<"dot_general"> {

let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
DenseI64ArrayAttr:$batchdims_a,
DenseI64ArrayAttr:$contractdims_a,
DenseI64ArrayAttr:$batchdims_b,
DenseI64ArrayAttr:$contractdims_b);
DenseI64ArrayAttr:$batch_dims_a,
DenseI64ArrayAttr:$contract_dims_a,
DenseI64ArrayAttr:$batch_dims_b,
DenseI64ArrayAttr:$contract_dims_b);

let results = (outs AnyRankedTensor:$result);
let hasVerifier = 1;
Expand Down
123 changes: 1 addition & 122 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ class StableHLOToTTIRDotGeneralOpConversionPattern
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::DotGeneralOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
adaptor.getLhs(), adaptor.getRhs(),
srcOp, outputTensor.getType(), adaptor.getLhs(), adaptor.getRhs(),
adaptor.getDotDimensionNumbers().getLhsBatchingDimensions(),
adaptor.getDotDimensionNumbers().getLhsContractingDimensions(),
adaptor.getDotDimensionNumbers().getRhsBatchingDimensions(),
Expand Down Expand Up @@ -201,126 +200,6 @@ class StableHLOToTTIRReshapeOpConversionPattern
}
};

/*class StableHLOToTTIRDotGeneralOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::DotGeneralOp> {
using OpConversionPattern<mlir::stablehlo::DotGeneralOp>::OpConversionPattern;
public:
LogicalResult
matchAndRewrite(mlir::stablehlo::DotGeneralOp srcOp,
mlir::stablehlo::DotGeneralOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// This is a basic version that can only work for cases that can be directly
// converted to matmul. The op should be extended as other ops such as
// ttir.permute and ttir.broadcast_in_dim become available.
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<mlir::tt::ttir::MatmulOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
adaptor.getLhs(), adaptor.getRhs(), Value(outputTensor));
return success();
}
private:
LogicalResult
checkBasicLegality(mlir::stablehlo::DotGeneralOp &srcOp,
mlir::stablehlo::DotGeneralOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter) const {
::mlir::stablehlo::DotDimensionNumbersAttr dimensions =
adaptor.getDotDimensionNumbers();
if (dimensions.getLhsContractingDimensions().size() != 1 ||
dimensions.getRhsContractingDimensions().size() != 1) {
return rewriter.notifyMatchFailure(
srcOp,
"LHS and RHS must have exactly 1 contracting dimension each. "
"Received LHS contracting dims: " +
std::to_string(dimensions.getLhsContractingDimensions().size()) +
", RHS contracting dims: " +
std::to_string(dimensions.getRhsContractingDimensions().size()));
}
// Use negative indexing to determine if this is a valid matmul since math
// is done over the final two dimensions.
int64_t lhsContractingDim = dimensions.getLhsContractingDimensions()[0] -
srcOp.getLhs().getType().getRank();
int64_t rhsContractingDim = dimensions.getRhsContractingDimensions()[0] -
srcOp.getRhs().getType().getRank();
if (lhsContractingDim != -1) {
return rewriter.notifyMatchFailure(
srcOp, "Only support contracting dimensions that correspond to valid "
"matmuls. LHS contracting dimension must be " +
std::to_string(srcOp.getLhs().getType().getRank() - 1) +
". Got " + std::to_string(lhsContractingDim));
}
if (rhsContractingDim != -2) {
return rewriter.notifyMatchFailure(
srcOp, "Only support contracting dimensions that correspond to valid "
"matmuls. RHS contracting dimension must be " +
std::to_string(srcOp.getRhs().getType().getRank() - 2) +
". Got " + std::to_string(rhsContractingDim));
}
if (dimensions.getLhsBatchingDimensions() !=
dimensions.getRhsBatchingDimensions()) {
return rewriter.notifyMatchFailure(
srcOp, "LHS and RHS must have same batching dimensions.");
}
// For the RHS, all dimensions which are not the row and column dimensions
// must be 1 OR they must be equal to the corresponding dimension in the
// LHS. If the RHS has less dimensions than the LHS we will assume that the
// missing dimensions are 1.
auto lhsShape = srcOp.getLhs().getType().getShape().vec();
auto rhsShape = srcOp.getRhs().getType().getShape().vec();
if (rhsShape.size() > lhsShape.size()) {
return rewriter.notifyMatchFailure(
srcOp, "RHS must not be a higher rank than LHS.");
}
while (rhsShape.size() < lhsShape.size()) {
rhsShape.insert(rhsShape.begin(), 1);
}
// Need only to check dims to the left of dim -2 on the RHS
bool allOnes = true;
bool mismatchedDims = false;
for (int32_t i = rhsShape.size() - 3; i >= 0; i--) {
if (rhsShape[i] != 1) {
allOnes = false;
}
if (rhsShape[i] != lhsShape[i]) {
mismatchedDims = true;
}
}
if (mismatchedDims && !allOnes) {
return rewriter.notifyMatchFailure(
srcOp, "All dimensions in the RHS that are not the row and column "
"dimensions must be 1 OR they must all be equal to the "
"corresponding dimensions in the LHS.");
}
return success();
}
};
*/

class StableHLOToTTIRGetDimensionSizeOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::GetDimensionSizeOp> {

Expand Down
Loading

0 comments on commit e651d14

Please sign in to comment.