-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for conv_transpose2d operation #1540
base: main
Are you sure you want to change the base?
Conversation
2b4f3e6
to
b000588
Compare
Adding @LPanosTT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
Hey thanks for adding this. I have something to say about this op though. It seems as though some frontends reverse the order of the data in the kernel window for this op, and some do not. I.e PyTorch does (and thus TTNN does) and JAX does not. You will see that There is an issue to add |
Also if you could add a pattern to lower |
b000588
to
7b36217
Compare
683fb3b
to
4ddde58
Compare
I will merge this pr as is but I have opened an issue to track the things you mentioned: #1662 |
62b9199
to
2837812
Compare
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great change Joco, thanks, couple of comments inline.
@@ -32,6 +33,8 @@ | |||
using namespace mlir; | |||
using namespace mlir::tt; | |||
|
|||
#include <iostream> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably not needed?
|
||
llvm::ArrayRef<std::int64_t> output_shape = outputTy.getShape(); | ||
|
||
auto getLastDim = [](const RankedTensorType &ty, int offset = 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please substitue for a real type here?
auto inChannels = rewriter.getI32IntegerAttr(getLastDim(inputTy)); | ||
auto outChannels = rewriter.getI32IntegerAttr(getLastDim(outputTy)); | ||
auto batchSize = rewriter.getI32IntegerAttr(getLastDim(inputTy, 4)); | ||
auto inputHeight = rewriter.getI32IntegerAttr(getLastDim(inputTy, 3)); | ||
auto inputWidth = rewriter.getI32IntegerAttr(getLastDim(inputTy, 2)); | ||
|
||
auto kernelSize = rewriter.getDenseI32ArrayAttr( | ||
{static_cast<int32_t>(getLastDim(kernelTy, 2)), | ||
static_cast<int32_t>(getLastDim(kernelTy, 1))}); | ||
auto stride = rewriter.getDenseI32ArrayAttr( | ||
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getStride())); | ||
auto padding = rewriter.getDenseI32ArrayAttr( | ||
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getPaddingAttr())); | ||
auto outputPadding = rewriter.getDenseI32ArrayAttr( | ||
ttmlir::utils::parseAttrToTwoElementVector( | ||
adaptor.getOutputPaddingAttr())); | ||
auto dilation = rewriter.getDenseI32ArrayAttr( | ||
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getDilationAttr())); | ||
auto groups = rewriter.getI32IntegerAttr(adaptor.getGroups()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please substitute auto for a real type here as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, should we refrain from using auto completely? In this case I thought it made sense because on the RHS we already know which type we will get.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you are right, I just saw a lot of autos hence my head got it wrong :D
Instead of switching auto here in this example, can you rename variables, for example, inChannelsAttr, outChannelsAttr, etc... Also kernelSizeArrayAttr, strideArrayAttr, etc...
// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the | ||
// attribute determination | ||
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>( | ||
adaptor.getOutput().getDefiningOp(), flattenedOutputShape, | ||
outputTy.getElementType()); | ||
|
||
// Must set the type to the output type to maintain the layout attributes | ||
convDPSOutput.getResult().setType(outputTy); | ||
|
||
ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>( | ||
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(), | ||
adaptor.getBias(), convDPSOutput, device, inChannels, outChannels, | ||
batchSize, inputHeight, inputWidth, kernelSize, stride, padding, | ||
outputPadding, dilation, groups); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please sync with @azecevicTT, he had in mind an API for creating a DPS op, not sure if applicable here. :)
@@ -871,6 +874,77 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> { | |||
} | |||
}; | |||
|
|||
class ConvTranspose2dOpConversionPattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a brief descrpipiton of conversion as it isn't 1-1 mapping to ttnn conv2d.
return emitOpError("Batch size of input and output tensors must match"); | ||
} | ||
|
||
auto checkBiggerThan = [&](llvm::SmallVector<int32_t, 2> &values, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Substitute auto here as well.
} | ||
} | ||
|
||
auto checkBiggerThan = [&](llvm::ArrayRef<int32_t> &values, const char *name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Substitute auto here as well.
@@ -287,7 +287,9 @@ class TTNNLayoutDPSOperandsRewriter | |||
|
|||
// TTNN Conv2d moves input, weight, and bias from host to device | |||
// itself. Inserting the ToLayoutOp on these operands is thus problematic. | |||
if (mlir::isa<ttir::Conv2dOp>(op.getOperation()) && !isResult) { | |||
if (!isResult && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you plese rename isResult in isDPSResult? :)
auto kernelSize = toFlatbuffer(cache, op.getKernelSize()); | ||
auto stride = toFlatbuffer(cache, op.getStride()); | ||
auto padding = toFlatbuffer(cache, op.getPadding()); | ||
auto outputPadding = toFlatbuffer(cache, op.getOutputPadding()); | ||
auto dilation = toFlatbuffer(cache, op.getDilation()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least substitute this autos for real types.
2837812
to
b219b1f
Compare
closes (#1084)