Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conv_transpose2d operation #1540

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jserbedzijaTT
Copy link
Contributor

closes (#1084)

@nsmithtt
Copy link
Contributor

nsmithtt commented Dec 9, 2024

Adding @LPanosTT

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

lib/Dialect/TTNN/Transforms/TTNNLayout.cpp Outdated Show resolved Hide resolved
@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 9, 2024

Hey thanks for adding this. I have something to say about this op though. It seems as though some frontends reverse the order of the data in the kernel window for this op, and some do not. I.e PyTorch does (and thus TTNN does) and JAX does not. You will see that ttir.convolution has a window_reversal boolean attr as well. In order to model the cases in all frontends we need this attribute for conv_transpose2d in ttnn. Or for us to add ttir.reverse so we can consteval the window reversal away.

There is an issue to add window_reversal to ttnn: tenstorrent/tt-metal#15342

@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 9, 2024

Also if you could add a pattern to lower ttir.convolution to ttir.conv_transpose2d that would be great. Check out the stablehlo spec for convolution, which ttir.convolution is meant to mimic to see how you can tell if a given convolution is a transposed convolution or not.

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from b000588 to 7b36217 Compare December 20, 2024 12:58
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 2 times, most recently from 683fb3b to 4ddde58 Compare December 23, 2024 11:06
@jserbedzijaTT
Copy link
Contributor Author

jserbedzijaTT commented Dec 24, 2024

Also if you could add a pattern to lower ttir.convolution to ttir.conv_transpose2d that would be great. Check out the stablehlo spec for convolution, which ttir.convolution is meant to mimic to see how you can tell if a given convolution is a transposed convolution or not.

I will merge this pr as is but I have opened an issue to track the things you mentioned: #1662

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 2 times, most recently from 62b9199 to 2837812 Compare December 24, 2024 10:32
@mtopalovicTT
Copy link
Contributor

Copy link
Contributor

@sdjordjevicTT sdjordjevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great change Joco, thanks, couple of comments inline.

@@ -32,6 +33,8 @@
using namespace mlir;
using namespace mlir::tt;

#include <iostream>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not needed?


llvm::ArrayRef<std::int64_t> output_shape = outputTy.getShape();

auto getLastDim = [](const RankedTensorType &ty, int offset = 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please substitue for a real type here?

Comment on lines +898 to +909
auto inChannels = rewriter.getI32IntegerAttr(getLastDim(inputTy));
auto outChannels = rewriter.getI32IntegerAttr(getLastDim(outputTy));
auto batchSize = rewriter.getI32IntegerAttr(getLastDim(inputTy, 4));
auto inputHeight = rewriter.getI32IntegerAttr(getLastDim(inputTy, 3));
auto inputWidth = rewriter.getI32IntegerAttr(getLastDim(inputTy, 2));

auto kernelSize = rewriter.getDenseI32ArrayAttr(
{static_cast<int32_t>(getLastDim(kernelTy, 2)),
static_cast<int32_t>(getLastDim(kernelTy, 1))});
auto stride = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getStride()));
auto padding = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getPaddingAttr()));
auto outputPadding = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(
adaptor.getOutputPaddingAttr()));
auto dilation = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getDilationAttr()));
auto groups = rewriter.getI32IntegerAttr(adaptor.getGroups());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please substitute auto for a real type here as well...

Copy link
Contributor Author

@jserbedzijaTT jserbedzijaTT Dec 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, should we refrain from using auto completely? In this case I thought it made sense because on the RHS we already know which type we will get.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you are right, I just saw a lot of autos hence my head got it wrong :D

Instead of switching auto here in this example, can you rename variables, for example, inChannelsAttr, outChannelsAttr, etc... Also kernelSizeArrayAttr, strideArrayAttr, etc...

Comment on lines +925 to +931
// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the
// attribute determination
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
adaptor.getOutput().getDefiningOp(), flattenedOutputShape,
outputTy.getElementType());

// Must set the type to the output type to maintain the layout attributes
convDPSOutput.getResult().setType(outputTy);

ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>(
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
adaptor.getBias(), convDPSOutput, device, inChannels, outChannels,
batchSize, inputHeight, inputWidth, kernelSize, stride, padding,
outputPadding, dilation, groups);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sync with @azecevicTT, he had in mind an API for creating a DPS op, not sure if applicable here. :)

@@ -871,6 +874,77 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
}
};

class ConvTranspose2dOpConversionPattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a brief descrpipiton of conversion as it isn't 1-1 mapping to ttnn conv2d.

return emitOpError("Batch size of input and output tensors must match");
}

auto checkBiggerThan = [&](llvm::SmallVector<int32_t, 2> &values,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substitute auto here as well.

}
}

auto checkBiggerThan = [&](llvm::ArrayRef<int32_t> &values, const char *name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substitute auto here as well.

@@ -287,7 +287,9 @@ class TTNNLayoutDPSOperandsRewriter

// TTNN Conv2d moves input, weight, and bias from host to device
// itself. Inserting the ToLayoutOp on these operands is thus problematic.
if (mlir::isa<ttir::Conv2dOp>(op.getOperation()) && !isResult) {
if (!isResult &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you plese rename isResult in isDPSResult? :)

Comment on lines +494 to +489
auto kernelSize = toFlatbuffer(cache, op.getKernelSize());
auto stride = toFlatbuffer(cache, op.getStride());
auto padding = toFlatbuffer(cache, op.getPadding());
auto outputPadding = toFlatbuffer(cache, op.getOutputPadding());
auto dilation = toFlatbuffer(cache, op.getDilation());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least substitute this autos for real types.

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from 2837812 to b219b1f Compare December 27, 2024 10:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants