Skip to content

Commit

Permalink
Add lowering of stablehlo reshape to ttir.unsqueeze to handle shape w…
Browse files Browse the repository at this point in the history
…ith broadcast.
  • Loading branch information
uazizTT committed Dec 21, 2024
1 parent 78c1612 commit 054b9e0
Showing 1 changed file with 51 additions and 7 deletions.
58 changes: 51 additions & 7 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,17 +750,61 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern
return legalityResult;
}

auto inputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getOperand().getType()));

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg =
rewriter.getI64ArrayAttr(adaptor.getBroadcastDimensions());
if (inputType.getRank() == outputType.getRank()) {
// unsqueeze is not needed, proceed to converting to broadcast

rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor), dimArg);
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg =
rewriter.getI64ArrayAttr(adaptor.getBroadcastDimensions());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor), dimArg);
} else {

SmallVector<int64_t, 4> UnsqueezeShape;
UnsqueezeShape.push_back(1);
for (unsigned int i = 0; i < inputType.getRank(); i++) {
UnsqueezeShape.push_back(inputType.getDimSize(i));
}

RankedTensorType unsqueezeOutputType =
RankedTensorType::get(UnsqueezeShape, outputType.getElementType());

tensor::EmptyOp reshapeOutputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), unsqueezeOutputType.getShape(),
unsqueezeOutputType.getElementType());

std::vector<int32_t> new_shape_i32;
for (int64_t dim : outputType.getShape()) {
new_shape_i32.push_back(static_cast<int32_t>(dim));
}

mlir::tt::ttir::UnsqueezeOp reshape =
rewriter.create<mlir::tt::ttir::UnsqueezeOp>(
srcOp.getLoc(),
getTypeConverter()->convertType(unsqueezeOutputType),
adaptor.getOperand(), reshapeOutputTensor, 1);

tensor::EmptyOp broadcastOutputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg =
rewriter.getI64ArrayAttr(adaptor.getBroadcastDimensions());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
srcOp,
getTypeConverter()->convertType(broadcastOutputTensor.getType()),
Value(reshape.getResult()), Value(broadcastOutputTensor), dimArg);
}

return success();
}
Expand Down

0 comments on commit 054b9e0

Please sign in to comment.