Skip to content

Commit

Permalink
Add support for conv_transpose2d operation
Browse files Browse the repository at this point in the history
  • Loading branch information
jserbedzijaTT committed Dec 27, 2024
1 parent 6d04d25 commit 1a7caeb
Show file tree
Hide file tree
Showing 18 changed files with 1,177 additions and 2 deletions.
52 changes: 52 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,58 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTIR_ConvTranspose2dOp : TTIR_DPSOp<"conv_transpose2d"> {
let summary = "ConvTranspose2d operation.";
let description = [{
Applies a 2D transposed convolution operator over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
- `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
- `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
- `output` AnyRankedTensor: NHWC format (batch_size x height x width x channels)

Attributes:
- `stride` (i32 | array<i32>): Controls the stride for the cross-correlation.
- `padding` (i32 | array<i32>): Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
- `output_padding` (i32 | array<i32>): Controls the additional size added to one side of the output shape.
- `dilation` (i32 | array<i32>): Controls the spacing between the kernel points
- `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.

Example:
%input = tensor.empty() : () -> tensor<256x256x3x3xbf16>
%weight = tensor.empty() : () -> tensor<256x256x3x3xbf16>
%bias = tensor.empty() : () -> tensor<1x1x1x256xbf16>
%output = tensor.empty() : () -> tensor<1x10x10x256xbf16>
%0 = "ttir.conv_transpose2d"(%input, %weight, %bias, %output)
<{
stride = = array<i32: 1, 1>,
padding = 0: i32,
output_padding = 0: i32,
dilation = 1: i32,
groups = 1: i32
> : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$stride,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$padding,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$output_padding,
AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$dilation,
I32Attr:$groups);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
let summary = "Generalized convolution op.";
let description = [{
Expand Down
51 changes: 51 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,57 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {
let summary = "ConvTranspose2d operation.";
let description = [{
Applies a 2D transposed convolution operator over an input image composed of several input planes.

Inputs:
- `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
- `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
- `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
- `output` AnyRankedTensor: (1 x 1 x (batch_size * height * width) x channels)

Attributes:
- `in_channels` i32: The number of input channels.
- `out_channels` i32: The number of output channels.
- `batch_size` i32: The batch size.
- `input_height` i32: The input height.
- `input_width` i32: The input width.
- `kernel_size` array<i32>: The kernel size.
- `stride` array<i32>: Controls the stride for the cross-correlation.
- `padding` array<i32>: Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
- `output_padding` array<i32>: Controls the additional size added to one side of the output shape.
- `dilation` array<i32>: Controls the spacing between the kernel points
- `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_Device:$device,
I32Attr:$in_channels,
I32Attr:$out_channels,
I32Attr:$batch_size,
I32Attr:$input_height,
I32Attr:$input_width,
DenseI32ArrayAttr:$kernel_size,
DenseI32ArrayAttr:$stride,
DenseI32ArrayAttr:$padding,
DenseI32ArrayAttr:$output_padding,
DenseI32ArrayAttr:$dilation,
I32Attr:$groups);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Expand Down
20 changes: 20 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ table Conv2dOp {
groups: uint32;
}

table ConvTranspose2dOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
bias: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
in_channels: uint32;
out_channels: uint32;
batch_size: uint32;
input_height: uint32;
input_width: uint32;
kernel_size: [int32];
stride: [int32];
padding: [int32];
output_padding: [int32];
dilation: [int32];
groups: uint32;
}

table MaxPool2dOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
Expand Down Expand Up @@ -340,6 +359,7 @@ union OpType {
SoftmaxOp,
TransposeOp,
Conv2dOp,
ConvTranspose2dOp,
ConcatOp,
ReshapeOp,
SliceOp,
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ inversePermutation(llvm::ArrayRef<int64_t> permutation) {
return inversePermutation;
}

// Parses an attribute into a two-element vector, commonly used for attributes
// representing spatial configurations like padding, strides, or dilation
// where a single integer can apply to all dimensions or a specific 2D
// configuration can be provided
inline llvm::SmallVector<int32_t, 2>
parseAttrToTwoElementVector(mlir::Attribute baseAttr) {
llvm::SmallVector<int32_t, 2> result;

if (const auto attr =
mlir::dyn_cast_if_present<mlir::IntegerAttr>(baseAttr)) {
result.assign(2, attr.getInt());
}

if (const auto attr = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(baseAttr);
attr && attr.size() == 2) {
result.append({attr[0], attr[1]});
}

return result;
}

} // namespace ttmlir::utils

#endif
79 changes: 79 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ttmlir/Dialect/TTNN/Types/Types.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "ttmlir/Utils.h"

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -864,6 +865,83 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
}
};

// Since the transposed convolution in ttnn returns a tensor in a flattened
// shape (1 x 1 x N * H * W x C), we need to reshape it to restore the normal
// shape (N x H x W x C)
class ConvTranspose2dOpConversionPattern
: public OpConversionPattern<ttir::ConvTranspose2dOp> {
public:
using OpConversionPattern<ttir::ConvTranspose2dOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ConvTranspose2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);

auto inputTy = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
auto kernelTy = mlir::cast<RankedTensorType>(adaptor.getWeight().getType());
auto outputTy = mlir::cast<RankedTensorType>(adaptor.getOutput().getType());

std::function<int64_t(const RankedTensorType &, int)> getLastDim =
[](const RankedTensorType &ty, int offset = 1) {
return ty.getShape()[ty.getRank() - offset];
};

auto inChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 1));
auto outChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(outputTy, 1));
auto batchSizeAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 4));
auto inputHeightAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 3));
auto inputWidthAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 2));

auto kernelSizeAttr = rewriter.getDenseI32ArrayAttr(
{static_cast<int32_t>(getLastDim(kernelTy, 2)),
static_cast<int32_t>(getLastDim(kernelTy, 1))});
auto strideAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getStride()));
auto paddingAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getPaddingAttr()));
auto outputPaddingAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(
adaptor.getOutputPaddingAttr()));
auto dilationAttr = rewriter.getDenseI32ArrayAttr(
ttmlir::utils::parseAttrToTwoElementVector(adaptor.getDilationAttr()));
auto groupsAttr = rewriter.getI32IntegerAttr(adaptor.getGroups());

llvm::ArrayRef<std::int64_t> output_shape = outputTy.getShape();
llvm::SmallVector<std::int64_t, 4> flattenedOutputShape = {
1, 1, output_shape[0] * output_shape[1] * output_shape[2],
output_shape[3]};

// Account for the flattend shape that the ttnn conv transpose returns
outputTy = mlir::cast<RankedTensorType>(getTypeConverter()->convertType(
outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType())));

// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the
// attribute determination
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
adaptor.getOutput().getDefiningOp(), flattenedOutputShape,
outputTy.getElementType());

// Must set the type to the output type to maintain the layout attributes
convDPSOutput.getResult().setType(outputTy);

ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>(
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
adaptor.getBias(), convDPSOutput, device, inChannelsAttr,
outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr,
kernelSizeAttr, strideAttr, paddingAttr, outputPaddingAttr,
dilationAttr, groupsAttr);

// Restore the normal shape (N x H x W x C)
Value output =
ttir_to_ttnn::utils::generateReshape(new_conv, output_shape, rewriter);

rewriter.replaceOp(op, output);
return success();
}
};

class MaxPool2dOpConversionPattern
: public OpConversionPattern<ttir::MaxPool2dOp> {
public:
Expand Down Expand Up @@ -1203,6 +1281,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
LinearOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
ConvTranspose2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
MeshShardOpConversionPattern,
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Conv ops
//
patterns.add<DefaultOpConversionPattern<ttnn::Conv2dOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ConvTranspose2dOp>>(
typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MaxPool2dOp>>(typeConverter,
ctx);

Expand Down
Loading

0 comments on commit 1a7caeb

Please sign in to comment.