Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conv_transpose2d operation #1540

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,58 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTIR_ConvTranspose2dOp : TTIR_DPSOp<"conv_transpose2d"> {
let summary = "ConvTranspose2d operation.";
let description = [{
Applies a 2D transposed convolution operator over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
- `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
- `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
- `output` AnyRankedTensor: NHWC format (batch_size x height x width x channels)

Attributes:
- `stride` (i32 | array<i32>): Controls the stride for the cross-correlation.
- `padding` (i32 | array<i32>): Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
- `output_padding` (i32 | array<i32>): Controls the additional size added to one side of the output shape.
- `dilation` (i32 | array<i32>): Controls the spacing between the kernel points
- `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.

Example:
%input = tensor.empty() : () -> tensor<256x256x3x3xbf16>
%weight = tensor.empty() : () -> tensor<256x256x3x3xbf16>
%bias = tensor.empty() : () -> tensor<1x1x1x256xbf16>
%output = tensor.empty() : () -> tensor<1x10x10x256xbf16>
%0 = "ttir.conv_transpose2d"(%input, %weight, %bias, %output)
<{
stride = = array<i32: 1, 1>,
padding = 0: i32,
output_padding = 0: i32,
dilation = 1: i32,
groups = 1: i32
> : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$stride,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$padding,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$output_padding,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$dilation,
I32Attr:$groups);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
let summary = "Generalized convolution op.";
let description = [{
Expand Down
51 changes: 51 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,57 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {
let summary = "ConvTranspose2d operation.";
let description = [{
Applies a 2D transposed convolution operator over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
- `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
- `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
- `output` AnyRankedTensor: (1 x 1 x (batch_size * height * width) x channels)

Attributes:
- `in_channels` i32: The number of input channels.
- `out_channels` i32: The number of output channels.
- `batch_size` i32: The batch size.
- `input_height` i32: The input height.
- `input_width` i32: The input width.
- `kernel_size` array<i32>: The kernel size.
- `stride` array<i32>: Controls the stride for the cross-correlation.
- `padding` array<i32>: Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
- `output_padding` array<i32>: Controls the additional size added to one side of the output shape.
- `dilation` array<i32>: Controls the spacing between the kernel points
- `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_Device:$device,
I32Attr:$in_channels,
I32Attr:$out_channels,
I32Attr:$batch_size,
I32Attr:$input_height,
I32Attr:$input_width,
DenseI32ArrayAttr:$kernel_size,
DenseI32ArrayAttr:$stride,
DenseI32ArrayAttr:$padding,
DenseI32ArrayAttr:$output_padding,
DenseI32ArrayAttr:$dilation,
I32Attr:$groups);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Expand Down
20 changes: 20 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ table Conv2dOp {
groups: uint32;
}

table ConvTranspose2dOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
bias: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
in_channels: uint32;
out_channels: uint32;
batch_size: uint32;
input_height: uint32;
input_width: uint32;
kernel_size: [int32];
stride: [int32];
padding: [int32];
output_padding: [int32];
dilation: [int32];
groups: uint32;
}

table MaxPool2dOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
Expand Down Expand Up @@ -340,6 +359,7 @@ union OpType {
SoftmaxOp,
TransposeOp,
Conv2dOp,
ConvTranspose2dOp,
ConcatOp,
ReshapeOp,
SliceOp,
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ inversePermutation(llvm::ArrayRef<int64_t> permutation) {
return inversePermutation;
}

// Parses an attribute into a two-element vector, commonly used for attributes
// representing spatial configurations like padding, strides, or dilation
// where a single integer can apply to all dimensions or a specific 2D
// configuration can be provided
inline llvm::SmallVector<int32_t, 2>
parseAttrToTwoElementVector(mlir::Attribute baseAttr) {
llvm::SmallVector<int32_t, 2> result;

if (const auto attr =
mlir::dyn_cast_if_present<mlir::IntegerAttr>(baseAttr)) {
result.assign(2, attr.getInt());
}

if (const auto attr = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(baseAttr);
attr && attr.size() == 2) {
result.append({attr[0], attr[1]});
}

return result;
}

} // namespace ttmlir::utils

#endif
79 changes: 79 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ttmlir/Dialect/TTNN/Types/Types.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "ttmlir/Utils.h"

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -864,6 +865,83 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
}
};

// Since the transposed convolution in ttnn returns a tensor in a flattened
// shape (1 x 1 x N * H * W x C), we need to reshape it to restore the normal
// shape (N x H x W x C)
class ConvTranspose2dOpConversionPattern
jserbedzijaTT marked this conversation as resolved.
Show resolved Hide resolved
: public OpConversionPattern<ttir::ConvTranspose2dOp> {
public:
using OpConversionPattern<ttir::ConvTranspose2dOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ConvTranspose2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);

auto inputTy = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
auto kernelTy = mlir::cast<RankedTensorType>(adaptor.getWeight().getType());
auto outputTy = mlir::cast<RankedTensorType>(adaptor.getOutput().getType());

std::function<int64_t(const RankedTensorType &, int)> getLastDim =
[](const RankedTensorType &ty, int offset = 1) {
return ty.getShape()[ty.getRank() - offset];
};

auto inChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 1));
auto outChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(outputTy, 1));
auto batchSizeAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 4));
auto inputHeightAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 3));
auto inputWidthAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 2));

auto kernelSizeAttr = rewriter.getDenseI32ArrayAttr(
{static_cast<int32_t>(getLastDim(kernelTy, 2)),
static_cast<int32_t>(getLastDim(kernelTy, 1))});
auto strideAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getStride()));
auto paddingAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getPaddingAttr()));
auto outputPaddingAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(
adaptor.getOutputPaddingAttr()));
auto dilationAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getDilationAttr()));
auto groupsAttr = rewriter.getI32IntegerAttr(adaptor.getGroups());

llvm::ArrayRef<std::int64_t> output_shape = outputTy.getShape();
llvm::SmallVector<std::int64_t, 4> flattenedOutputShape = {
1, 1, output_shape[0] * output_shape[1] * output_shape[2],
output_shape[3]};

// Account for the flattend shape that the ttnn conv transpose returns
outputTy = mlir::cast<RankedTensorType>(getTypeConverter()->convertType(
outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType())));

// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the
// attribute determination
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
adaptor.getOutput().getDefiningOp(), flattenedOutputShape,
outputTy.getElementType());

// Must set the type to the output type to maintain the layout attributes
convDPSOutput.getResult().setType(outputTy);

ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>(
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
adaptor.getBias(), convDPSOutput, device, inChannelsAttr,
outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr,
kernelSizeAttr, strideAttr, paddingAttr, outputPaddingAttr,
dilationAttr, groupsAttr);

// Restore the normal shape (N x H x W x C)
Value output =
ttir_to_ttnn::utils::generateReshape(new_conv, output_shape, rewriter);

rewriter.replaceOp(op, output);
return success();
}
};

class MaxPool2dOpConversionPattern
: public OpConversionPattern<ttir::MaxPool2dOp> {
public:
Expand Down Expand Up @@ -1203,6 +1281,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
LinearOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
ConvTranspose2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
MeshShardOpConversionPattern,
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Conv ops
//
patterns.add<DefaultOpConversionPattern<ttnn::Conv2dOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ConvTranspose2dOp>>(
typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MaxPool2dOp>>(typeConverter,
ctx);

Expand Down
Loading
Loading