From 2b4f3e6e3c8d5e0686f54f5d7dddd1c042aa66a8 Mon Sep 17 00:00:00 2001 From: jserbedzija Date: Tue, 19 Nov 2024 17:09:21 +0000 Subject: [PATCH] Add support for conv_transpose2d operation --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 53 +++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 51 +++ include/ttmlir/Target/TTNN/program.fbs | 20 + include/ttmlir/Utils.h | 21 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 74 ++++ lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 1 + lib/Dialect/TTIR/IR/TTIROps.cpp | 160 +++++++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 162 +++++++ lib/Dialect/TTNN/Transforms/TTNNLayout.cpp | 4 + lib/Target/TTNN/TTNNToFlatbuffer.cpp | 34 ++ runtime/lib/ttnn/operations/CMakeLists.txt | 1 + .../ttnn/operations/conv/conv_transpose2d.cpp | 60 +++ .../ttnn/operations/conv/conv_transpose2d.h | 16 + runtime/lib/ttnn/program.cpp | 4 + .../conv_transpose2d_tests_negative.mlir | 405 ++++++++++++++++++ .../conv_transpose2d_tests_positive.mlir | 109 +++++ .../perf_unit/test_perf_conv_transpose2d.mlir | 21 + .../Silicon/TTNN/simple_conv_transpose2d.mlir | 21 + 18 files changed, 1217 insertions(+) create mode 100644 runtime/lib/ttnn/operations/conv/conv_transpose2d.cpp create mode 100644 runtime/lib/ttnn/operations/conv/conv_transpose2d.h create mode 100644 test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_negative.mlir create mode 100644 test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_positive.mlir create mode 100644 test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv_transpose2d.mlir create mode 100644 test/ttmlir/Silicon/TTNN/simple_conv_transpose2d.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 69510f93a4..394402c5a9 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -796,6 +796,59 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { let hasVerifier = 1; } +def TTIR_ConvTranspose2dOp : TTIR_DPSOp<"conv_transpose2d"> { + let summary = "ConvTranspose2d operation."; + let description = [{ + Applies a 2D transposed convolution operator over an input image composed of several input planes. + + Inputs: + - `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels) + - `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width) + - `bias` Optional: (1 x 1 x 1 x output_channels) + - `output` AnyRankedTensor: NHWC format (batch_size x height x width x channels) + + Attributes: + - `stride` (i32 | array): Controls the stride for the cross-correlation. + - `padding` (i32 | array): Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points. + - `output_padding` (i32 | array): Controls the additional size added to one side of the output shape. + - `dilation` (i32 | array): Controls the spacing between the kernel points + - `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels. + + Example: + %input = tensor.empty() : () -> tensor<256x256x3x3xbf16> + %weight = tensor.empty() : () -> tensor<256x256x3x3xbf16> + %bias = tensor.empty() : () -> tensor<1x1x1x256xbf16> + %output = tensor.empty() : () -> tensor<1x10x10x256xbf16> + %0 = "ttir.conv_transpose2d"(%input, %weight, %bias, %output) + <{ + stride = = array, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32 + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$weight, + Optional:$bias, + AnyRankedTensor:$output, + AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$stride, + AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$padding, + AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$output_padding, + AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$dilation, + I32Attr:$groups, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> { let summary = "Generalized convolution op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index ed914cb555..3723a05c2c 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -766,6 +766,57 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> { let hasVerifier = 1; } +def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> { + let summary = "ConvTranspose2d operation."; + let description = [{ + Applies a 2D transposed convolution operator over an input image composed of several input planes. + + Inputs: + - `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels) + - `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width) + - `bias` Optional: (1 x 1 x 1 x output_channels) + - `output` AnyRankedTensor: (1 x 1 x (batch_size * height * width) x channels) + + Attributes: + - `in_channels` i32: The number of input channels. + - `out_channels` i32: The number of output channels. + - `batch_size` i32: The batch size. + - `input_height` i32: The input height. + - `input_width` i32: The input width. + - `kernel_size` array: The kernel size. + - `stride` array: Controls the stride for the cross-correlation. + - `padding` array: Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points. + - `output_padding` array: Controls the additional size added to one side of the output shape. + - `dilation` array: Controls the spacing between the kernel points + - `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$weight, + Optional:$bias, + AnyRankedTensor:$output, + TT_Device:$device, + I32Attr:$in_channels, + I32Attr:$out_channels, + I32Attr:$batch_size, + I32Attr:$input_height, + I32Attr:$input_width, + DenseI32ArrayAttr:$kernel_size, + DenseI32ArrayAttr:$stride, + DenseI32ArrayAttr:$padding, + DenseI32ArrayAttr:$output_padding, + DenseI32ArrayAttr:$dilation, + I32Attr:$groups); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> { let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 5644c970d8..2112b915fa 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -241,6 +241,25 @@ table Conv2dOp { groups: uint32; } +table ConvTranspose2dOp { + input: tt.target.TensorRef; + weight: tt.target.TensorRef; + bias: tt.target.TensorRef; + out: tt.target.TensorRef; + device: tt.target.DeviceRef; + in_channels: uint32; + out_channels: uint32; + batch_size: uint32; + input_height: uint32; + input_width: uint32; + kernel_size: [int32]; + stride: [int32]; + padding: [int32]; + output_padding: [int32]; + dilation: [int32]; + groups: uint32; +} + table MaxPool2dOp { in: tt.target.TensorRef; out: tt.target.TensorRef; @@ -289,6 +308,7 @@ union OpType { SoftmaxOp, TransposeOp, Conv2dOp, + ConvTranspose2dOp, ConcatOp, ReshapeOp, SliceOp, diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index 49dad79e5e..77a1a28a40 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -132,6 +132,27 @@ inline bool isRankedTensor(mlir::Value v) { return mlir::isa(v.getType()); } +// Parses an attribute into a two-element vector, commonly used for attributes +// representing spatial configurations like padding, strides, or dilation +// where a single integer can apply to all dimensions or a specific 2D +// configuration can be provided +inline llvm::SmallVector +parseAttrToTwoElementVector(mlir::Attribute baseAttr) { + llvm::SmallVector result; + + if (const auto attr = + mlir::dyn_cast_if_present(baseAttr)) { + result.assign(2, attr.getInt()); + } + + if (const auto attr = mlir::dyn_cast(baseAttr); + attr && attr.size() == 2) { + result.append({attr[0], attr[1]}); + } + + return result; +} + } // namespace ttmlir::utils #endif diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index d77d095acc..a54259ca27 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -11,6 +11,7 @@ #include "ttmlir/Dialect/TTNN/Types/Types.h" #include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include "ttmlir/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" @@ -28,6 +29,8 @@ using namespace mlir; using namespace mlir::tt; +#include + namespace { class TensorEmptyConversionPattern @@ -761,6 +764,76 @@ class Conv2dOpConversionPattern : public OpConversionPattern { } }; +class ConvTranspose2dOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ConvTranspose2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); + + auto inputTy = mlir::cast(adaptor.getInput().getType()); + auto kernelTy = mlir::cast(adaptor.getWeight().getType()); + auto outputTy = mlir::cast(adaptor.getOutput().getType()); + + llvm::ArrayRef output_shape = outputTy.getShape(); + + auto getLastDim = [](const RankedTensorType &ty, int offset = 1) { + return ty.getShape()[ty.getRank() - offset]; + }; + + auto inChannels = rewriter.getI32IntegerAttr(getLastDim(inputTy)); + auto outChannels = rewriter.getI32IntegerAttr(getLastDim(outputTy)); + auto batchSize = rewriter.getI32IntegerAttr(getLastDim(inputTy, 4)); + auto inputHeight = rewriter.getI32IntegerAttr(getLastDim(inputTy, 3)); + auto inputWidth = rewriter.getI32IntegerAttr(getLastDim(inputTy, 2)); + + auto kernelSize = rewriter.getDenseI32ArrayAttr( + {static_cast(getLastDim(kernelTy, 2)), + static_cast(getLastDim(kernelTy, 1))}); + auto stride = rewriter.getDenseI32ArrayAttr( + ttmlir::utils::parseAttrToTwoElementVector(adaptor.getStride())); + auto padding = rewriter.getDenseI32ArrayAttr( + ttmlir::utils::parseAttrToTwoElementVector(adaptor.getPaddingAttr())); + auto outputPadding = rewriter.getDenseI32ArrayAttr( + ttmlir::utils::parseAttrToTwoElementVector( + adaptor.getOutputPaddingAttr())); + auto dilation = rewriter.getDenseI32ArrayAttr( + ttmlir::utils::parseAttrToTwoElementVector(adaptor.getDilationAttr())); + auto groups = rewriter.getI32IntegerAttr(adaptor.getGroups()); + + std::vector flattenedOutputShape = { + 1, 1, output_shape[0] * output_shape[1] * output_shape[2], + output_shape[3]}; + + outputTy = mlir::cast(getTypeConverter()->convertType( + outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType()))); + + // Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the + // attribute determination + auto convDPSOutput = rewriter.replaceOpWithNewOp( + adaptor.getOutput().getDefiningOp(), flattenedOutputShape, + outputTy.getElementType()); + + // Must set the type to the output type to maintain the layout attributes + convDPSOutput.getResult().setType(outputTy); + + ttnn::ConvTranspose2dOp new_conv = rewriter.create( + op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(), + adaptor.getBias(), convDPSOutput, device, inChannels, outChannels, + batchSize, inputHeight, inputWidth, kernelSize, stride, padding, + outputPadding, dilation, groups); + + Value output = generateReshape(new_conv, output_shape, rewriter); + + rewriter.replaceOp(op, output); + return success(); + } +}; + class MaxPool2dOpConversionPattern : public OpConversionPattern { public: @@ -1033,6 +1106,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, LinearOpConversionPattern, MatmulOpConversionPattern, Conv2dOpConversionPattern, + ConvTranspose2dOpConversionPattern, MaxPool2dOpConversionPattern, SubtractOpConversionPattern, AllGatherOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 3986438e64..0b0ca1190e 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -738,6 +738,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Conv ops // patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); patterns.add>(typeConverter, ctx); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 44af2f2c4b..850d94e13a 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -151,6 +151,166 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ConvTranspose2dOp +//===----------------------------------------------------------------------===// + +// ConvTranspose2dOp verification +mlir::LogicalResult mlir::tt::ttir::ConvTranspose2dOp::verify() { + mlir::RankedTensorType inputType = getInput().getType(); + mlir::RankedTensorType weightType = getWeight().getType(); + mlir::RankedTensorType outputType = getOutput().getType(); + std::optional bias = + getBias().getImpl() ? std::make_optional(getBias().getType()) + : std::nullopt; + + if (inputType.getRank() != 4) { + return emitOpError("Input must be a 4D tensor"); + } + + if (outputType.getRank() != 4) { + return emitOpError("Output must be a 4D tensor"); + } + + if (weightType.getRank() != 4) { + return emitOpError("Weight must be a 4D tensor"); + } + + if (bias.has_value()) { + if (bias->getRank() != 4) { + return emitOpError("Bias must be a 4D tensor"); + } + } + + if (inputType.getShape()[0] != outputType.getShape()[0]) { + return emitOpError("Batch size of input and output tensors must match"); + } + + auto checkBiggerThan = [&](llvm::SmallVector &values, + const char *name, + int32_t minValue) -> mlir::LogicalResult { + for (int32_t value : values) { + if (value < minValue) { + return emitOpError() << "Attribute '" << name + << "' contains a value less than: " << minValue; + } + } + return mlir::success(); + }; + + llvm::SmallVector stride = + ttmlir::utils::parseAttrToTwoElementVector(getStrideAttr()); + if (stride.empty()) { + return emitOpError("Expected int or pair of ints for stride"); + } + if (failed(checkBiggerThan(stride, "stride", 1))) { + return mlir::failure(); + } + + llvm::SmallVector padding = + ttmlir::utils::parseAttrToTwoElementVector(getPaddingAttr()); + if (padding.empty()) { + return emitOpError("Expected int or pair of ints for padding"); + } + if (failed(checkBiggerThan(padding, "padding", 0))) { + return mlir::failure(); + } + + llvm::SmallVector outputPadding = + ttmlir::utils::parseAttrToTwoElementVector(getOutputPaddingAttr()); + if (outputPadding.empty()) { + return emitOpError("Expected int or pair of ints for output padding"); + } + if (failed(checkBiggerThan(outputPadding, "output padding", 0))) { + return mlir::failure(); + } + + llvm::SmallVector dilation = + ttmlir::utils::parseAttrToTwoElementVector(getDilationAttr()); + if (dilation.empty()) { + return emitOpError("Expected int or pair of ints for dilation"); + } + if (failed(checkBiggerThan(dilation, "dilation", 1))) { + return mlir::failure(); + } + + llvm::ArrayRef kernelShape = weightType.getShape(); + + int32_t inputChannels = inputType.getDimSize(inputType.getRank() - 1); + int32_t outputChannels = outputType.getDimSize(outputType.getRank() - 1); + uint32_t groups = getGroups(); + + if (inputChannels % groups != 0) { + return emitOpError() << "Number of input channels from input tensor must " + "be divisible by the number of groups. " + << "Got " << inputChannels << " input channels and " + << groups << " groups."; + } + + if (outputChannels % groups != 0) { + return emitOpError() << "Number of output channels from output tensor must " + "be divisible by the number of groups. " + << "Got " << outputChannels << " output channels and " + << groups << " groups."; + } + + if (inputChannels != kernelShape[0]) { + return emitOpError() << "Number of input channels from input tensor must " + "match the first dimension of the weight tensor. " + << "Got " << inputChannels << " input channels and " + << kernelShape[0] << " in the weight tensor."; + } + + if (outputChannels / groups != kernelShape[1]) { + return emitOpError() << "Number of output channels per group must match " + "the second dimension of the weight tensor. " + << "Got " << (outputChannels / groups) + << " output channels per group and " << kernelShape[1] + << " in the weight tensor."; + } + + if (bias) { + if (bias->getDimSize(bias->getRank() - 1) != outputChannels) { + return emitOpError() << "Mismatch in bias tensor dimensions. " + << "Bias tensor has " + << bias->getDimSize(bias->getRank() - 1) + << " channels, " + << "but the output tensor has " << outputChannels + << " channels."; + } + } + + int32_t kernelHeight = kernelShape[2]; + int32_t kernelWidth = kernelShape[3]; + + int32_t Hin = inputType.getDimSize(inputType.getRank() - 3); + int32_t Win = inputType.getDimSize(inputType.getRank() - 2); + + int32_t expectedHOut = (Hin - 1) * stride[0] - 2 * padding[0] + + dilation[0] * (kernelHeight - 1) + outputPadding[0] + 1; + int32_t expectedWOut = (Win - 1) * stride[1] - 2 * padding[1] + + dilation[1] * (kernelWidth - 1) + outputPadding[1] + 1; + if (expectedHOut < 0 || expectedWOut < 0) { + return emitOpError() << "Given input size per channel: (" << Hin << " x " + << Win << "). " + << "Calculated output size per channel: (" << expectedHOut + << " x " << expectedWOut << "). " + << "Output size is too small"; + } + + int32_t HOut = outputType.getDimSize(outputType.getRank() - 3); + int32_t WOut = outputType.getDimSize(outputType.getRank() - 2); + if (HOut != expectedHOut || WOut != expectedWOut) { + return emitOpError() << "Mismatch between expected output size per channel " + "and got output tensor dimensions. " + << "Expected: (" << expectedHOut << " x " << expectedWOut << "), " + << "got: (" << HOut << " x " << WOut + << ")."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConvolutionOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index cca75a7b26..372dc3064e 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -9,7 +9,9 @@ #include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "ttmlir/Utils.h" +#include #include +#include #include "mlir/Dialect/Traits.h" #include "mlir/IR/BuiltinTypes.h" @@ -79,6 +81,166 @@ ::mlir::LogicalResult mlir::tt::ttnn::Conv2dOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ConvTranspose2dOp +//===----------------------------------------------------------------------===// + +// ConvTranspose2dOp verification +::mlir::LogicalResult mlir::tt::ttnn::ConvTranspose2dOp::verify() { + mlir::RankedTensorType inputType = getInput().getType(); + mlir::RankedTensorType weightType = getWeight().getType(); + mlir::RankedTensorType outputType = getOutput().getType(); + std::optional bias = + getBias().getImpl() ? std::make_optional(getBias().getType()) + : std::nullopt; + + if (inputType.getRank() != 4) { + return emitOpError("Input must be a 4D tensor"); + } + + if (outputType.getRank() != 4) { + return emitOpError("Output must be a 4D tensor"); + } + + if (weightType.getRank() != 4) { + return emitOpError("Weight must be a 4D tensor"); + } + + if (bias.has_value()) { + if (bias->getRank() != 4) { + return emitOpError("Bias must be a 4D tensor"); + } + } + + auto checkBiggerThan = [&](llvm::ArrayRef &values, + const char *name, + int32_t minValue) -> mlir::LogicalResult { + for (int32_t value : values) { + if (value < minValue) { + return emitOpError() << "Attribute '" << name + << "' contains a value less than: " << minValue; + } + } + return mlir::success(); + }; + + uint32_t inChannels = getInChannels(); + if (inChannels != inputType.getDimSize(inputType.getRank() - 1)) { + return emitOpError("Input channels attribute must match " + "the last dimension of the input tensor"); + } + + uint32_t outChannels = getOutChannels(); + if (outChannels != outputType.getDimSize(outputType.getRank() - 1)) { + return emitOpError("Output channels attribute match " + "the last dimension of the output tensor"); + } + + uint32_t batchSize = getBatchSize(); + if (batchSize != inputType.getDimSize(0)) { + return emitOpError("Batch size attribute must match the first " + "dimension of the input tensor"); + } + + uint32_t inputHeight = getInputHeight(); + if (inputHeight != inputType.getDimSize(inputType.getRank() - 3)) { + return emitOpError("Input height attribute must match the third " + "dimension of the input tensor"); + } + + uint32_t inputWidth = getInputWidth(); + if (inputWidth != inputType.getDimSize(inputType.getRank() - 2)) { + return emitOpError("Input width attribute must match the second " + "dimension of the input tensor"); + } + + llvm::ArrayRef stride = getStride(); + if (failed(checkBiggerThan(stride, "stride", 1))) { + return mlir::failure(); + } + + llvm::ArrayRef padding = getPadding(); + if (failed(checkBiggerThan(padding, "padding", 0))) { + return mlir::failure(); + } + + llvm::ArrayRef outputPadding = getOutputPadding(); + if (failed(checkBiggerThan(outputPadding, "output padding", 0))) { + return mlir::failure(); + } + + llvm::ArrayRef dilation = getDilation(); + if (failed(checkBiggerThan(dilation, "dilation", 1))) { + return mlir::failure(); + } + + llvm::ArrayRef kernelShape = weightType.getShape(); + + int32_t inputChannels = inputType.getDimSize(inputType.getRank() - 1); + int32_t outputChannels = outputType.getDimSize(outputType.getRank() - 1); + uint32_t groups = getGroups(); + + if (inputChannels % groups != 0) { + return emitOpError() << "Number of input channels from input tensor must " + "be divisible by the number of groups. " + << "Got " << inputChannels << " input channels and " + << groups << " groups."; + } + + if (outputChannels % groups != 0) { + return emitOpError() << "Number of output channels from output tensor must " + "be divisible by the number of groups. " + << "Got " << outputChannels << " output channels and " + << groups << " groups."; + } + + if (inputChannels != kernelShape[0]) { + return emitOpError() << "Number of input channels from input tensor must " + "match the first dimension of the weight tensor. " + << "Got " << inputChannels << " input channels and " + << kernelShape[0] << " in the weight tensor."; + } + + if (outputChannels / groups != kernelShape[1]) { + return emitOpError() << "Number of output channels per group must match " + "the second dimension of the weight tensor. " + << "Got " << (outputChannels / groups) + << " output channels per group and " << kernelShape[1] + << " in the weight tensor."; + } + + if (bias) { + if (bias->getDimSize(bias->getRank() - 1) != outputChannels) { + return emitOpError() << "Mismatch in bias tensor dimensions. " + << "Bias tensor has " + << bias->getDimSize(bias->getRank() - 1) + << " channels, " + << "but the output tensor has " << outputChannels + << " channels."; + } + } + + int32_t kernelHeight = kernelShape[2]; + int32_t kernelWidth = kernelShape[3]; + + int32_t Hin = inputType.getDimSize(inputType.getRank() - 3); + int32_t Win = inputType.getDimSize(inputType.getRank() - 2); + + int32_t expectedHOut = (Hin - 1) * stride[0] - 2 * padding[0] + + dilation[0] * (kernelHeight - 1) + outputPadding[0] + 1; + int32_t expectedWOut = (Win - 1) * stride[1] - 2 * padding[1] + + dilation[1] * (kernelWidth - 1) + outputPadding[1] + 1; + if (expectedHOut < 0 || expectedWOut < 0) { + return emitOpError() << "Given input size per channel: (" << Hin << " x " + << Win << "). " + << "Calculated output size per channel: (" << expectedHOut + << " x " << expectedWOut << "). " + << "Output size is too small"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // MaxPool2dOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index 9036346a43..21ab001eeb 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Dialect/TT/Utils/OperandConstraints.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTNN/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -306,6 +307,9 @@ class TTNNLayoutDPSOperandsRewriter if (mlir::isa(op.getOperation()) && !isResult) { continue; } + else if (mlir::isa(op.getOperation()) && !isResult) { + continue; + } // If the operand is a BroadcastOp or a ToLayout op do not put a // ToLayoutOp on its output diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index d0d65ad874..781ab84ad3 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -414,6 +414,35 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) { op.getGroups()); } +::flatbuffers::Offset<::tt::target::ttnn::ConvTranspose2dOp> +createOp(FlatbufferObjectCache &cache, ConvTranspose2dOp op) { + auto in0 = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto in1 = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getWeight())); + auto in2 = op.getODSOperands(2).empty() + ? flatbuffers::Offset<::tt::target::TensorRef>() + : cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getBias())); + auto output = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + + auto device = getOperandThroughDPSOps(op.getDevice()); + + auto kernelSize = toFlatbuffer(cache, op.getKernelSize()); + auto stride = toFlatbuffer(cache, op.getStride()); + auto padding = toFlatbuffer(cache, op.getPadding()); + auto outputPadding = toFlatbuffer(cache, op.getOutputPadding()); + auto dilation = toFlatbuffer(cache, op.getDilation()); + + return ::tt::target::ttnn::CreateConvTranspose2dOp( + *cache.fbb, in0, in1, in2, output, + cache.at<::tt::target::DeviceRef>(device), op.getInChannels(), + op.getOutChannels(), op.getBatchSize(), op.getInputHeight(), + op.getInputWidth(), kernelSize, stride, padding, outputPadding, dilation, + op.getGroups()); +} + ::flatbuffers::Offset<::tt::target::ttnn::AllGatherOp> createOp(FlatbufferObjectCache &cache, AllGatherOp op) { auto input = @@ -942,6 +971,11 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createOp(cache, conv2dOp), debugString, locInfo); } + if (auto conv_transpose2dOp = dyn_cast(op); + conv_transpose2dOp) { + return createOperation(cache, createOp(cache, conv_transpose2dOp), + debugString, locInfo); + } if (auto allGatherOp = dyn_cast(op); allGatherOp) { return createOperation(cache, createOp(cache, allGatherOp), debugString, locInfo); diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index d7d9357b5f..8d0d3f68e0 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -6,6 +6,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv_transpose2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/concat.cpp diff --git a/runtime/lib/ttnn/operations/conv/conv_transpose2d.cpp b/runtime/lib/ttnn/operations/conv/conv_transpose2d.cpp new file mode 100644 index 0000000000..e3341534d8 --- /dev/null +++ b/runtime/lib/ttnn/operations/conv/conv_transpose2d.cpp @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "conv_transpose2d.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" +#include "ttmlir/Target/TTNN/program_generated.h" +#include "ttnn/types.hpp" +#include + +namespace tt::runtime::ttnn::operations::conv { +void run(const ::tt::target::ttnn::ConvTranspose2dOp *op, + ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); + const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id()); + DEBUG_ASSERT(input.is_allocated()); + DEBUG_ASSERT(weight.is_allocated()); + + std::optional<::ttnn::Tensor> bias = + op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) + : std::nullopt; + + auto copyArray = [](auto source, auto &destination) { + std::copy(source->begin(), source->end(), destination.begin()); + }; + + std::array kernelSize, stride, padding, outputPadding, dilation; + copyArray(op->kernel_size(), kernelSize); + copyArray(op->stride(), stride); + copyArray(op->padding(), padding); + copyArray(op->output_padding(), outputPadding); + copyArray(op->dilation(), dilation); + + auto config = ::ttnn::operations::conv::Conv2dConfig(); + config.dtype = utils::getDataType(op->input()); + config.weights_dtype = utils::getDataType(op->weight()); + config.shard_layout = ::ttnn::TensorMemoryLayout::WIDTH_SHARDED; + ::ttnn::MemoryConfig outMemConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); + + DeviceVariant targetDevice = + context.getTargetDevice(op->device()->global_id()); + ::ttnn::Tensor out = std::visit( + [&](auto &&targetDevice) -> ::ttnn::Tensor { + return std::get<0>(::ttnn::conv_transpose2d( + ::ttnn::DefaultQueueId, input, weight, &(targetDevice.get()), + op->in_channels(), op->out_channels(), op->batch_size(), + op->input_height(), op->input_width(), kernelSize, stride, padding, + outputPadding, dilation, op->groups(), bias, config)); + }, + targetDevice); + + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + +} // namespace tt::runtime::ttnn::operations::conv diff --git a/runtime/lib/ttnn/operations/conv/conv_transpose2d.h b/runtime/lib/ttnn/operations/conv/conv_transpose2d.h new file mode 100644 index 0000000000..c98f6b0142 --- /dev/null +++ b/runtime/lib/ttnn/operations/conv/conv_transpose2d.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CONV_CONVTRANSPOSE2D_H +#define RUNTIME_LIB_TTNN_OPERATIONS_CONV_CONVTRANSPOSE2D_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::conv { +void run(const ::tt::target::ttnn::ConvTranspose2dOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::conv + +#endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index f38bfe83ce..d8de861a93 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -4,6 +4,7 @@ #include "operations/ccl/all_gather.h" #include "operations/context/get_device.h" #include "operations/conv/conv2d.h" +#include "operations/conv/conv_transpose2d.h" #include "operations/creation/arange.h" #include "operations/creation/empty.h" #include "operations/creation/full.h" @@ -202,6 +203,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::Conv2dOp: { return operations::conv::run(op->type_as_Conv2dOp(), context); } + case ::tt::target::ttnn::OpType::ConvTranspose2dOp: { + return operations::conv::run(op->type_as_ConvTranspose2dOp(), context); + } case ::tt::target::ttnn::OpType::DeallocateOp: { return operations::deletion::run(op->type_as_DeallocateOp(), context); } diff --git a/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_negative.mlir new file mode 100644 index 0000000000..0ddfad2a94 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_negative.mlir @@ -0,0 +1,405 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for conv_transpose2d operation + +// Verify that the parsing fails if tensors don't have four dimensions +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_input_shape(%arg0: tensor<8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Input must be a 4D tensor + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_weight_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x8x8x256xbf16> { + %0 = tensor.empty() : tensor<1x8x8x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Weight must be a 4D tensor + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x8x8x256xbf16>) -> tensor<1x8x8x256xbf16> + return %1 : tensor<1x8x8x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_bias_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<256xbf16>) -> tensor<1x8x8x256xbf16> { + %0 = tensor.empty() : tensor<1x8x8x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Bias must be a 4D tensor + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<256xbf16>, tensor<1x8x8x256xbf16>) -> tensor<1x8x8x256xbf16> + return %1 : tensor<1x8x8x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_output_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<10x10x256xbf16> { + %0 = tensor.empty() : tensor<10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Output must be a 4D tensor + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<10x10x256xbf16>) -> tensor<10x10x256xbf16> + return %1 : tensor<10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_output_shape(%arg0: tensor<4x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<2x10x10x256xbf16> { + %0 = tensor.empty() : tensor<2x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Batch size of input and output tensors must match + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<4x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<2x10x10x256xbf16>) -> tensor<2x10x10x256xbf16> + return %1 : tensor<2x10x10x256xbf16> + } +} + +// Verify that the parsing fails if attributes are not integers or pair of integers +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_stride_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Expected int or pair of ints for stride + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_padding_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Expected int or pair of ints for padding + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_output_padding_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Expected int or pair of ints for output padding + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = array, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_dilation_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Expected int or pair of ints for dilation + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = array, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// Verify that the parsing fails if attributes have invalid values +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_stride_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Attribute 'stride' contains a value less than: 1 + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_padding_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Attribute 'padding' contains a value less than: 0 + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_output_padding_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Attribute 'output padding' contains a value less than: 0 + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = -6: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_invalid_dilation_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Attribute 'dilation' contains a value less than: 1 + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = array, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// Verify the parsing fails if number of channels are incorrect +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_input_channels_not_divisible_by_groups(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Number of input channels from input tensor must be divisible by the number of groups. Got 256 input channels and 3 groups + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 3: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_output_channels_not_divisible_by_groups(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x350x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x350xbf16> { + %0 = tensor.empty() : tensor<1x10x10x350xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Number of output channels from output tensor must be divisible by the number of groups. Got 350 output channels and 4 groups. + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 4: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x350x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x350xbf16>) -> tensor<1x10x10x350xbf16> + return %1 : tensor<1x10x10x350xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_input_channels_missmatch_with_weight(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<128x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Number of input channels from input tensor must match the first dimension of the weight tensor. Got 256 input channels and 128 in the weight tensor. + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<128x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_output_channels_missmatch_with_weight(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Number of output channels per group must match the second dimension of the weight tensor. Got 64 output channels per group and 256 in the weight tensor. + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 4: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_output_channels_missmatch_with_bias(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x128xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Mismatch in bias tensor dimensions. Bias tensor has 128 channels, but the output tensor has 256 channels. + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x128xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// Verify the parsing fails if calculated output size per channel is below zero or different from the output tensor +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_output_channels_missmatch_with_bias(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x128xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Mismatch in bias tensor dimensions. Bias tensor has 128 channels, but the output tensor has 256 channels. + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x128xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_calculated_output_size_per_channel_below_zero(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> { + %0 = tensor.empty() : tensor<1x10x10x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Given input size per channel: (8 x 8). Calculated output size per channel: (-2 x -4). Output size is too small + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16> + return %1 : tensor<1x10x10x256xbf16> + } +} + +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @conv_transpose2d_calculated_output_size_per_channel_missmatch_with_output_tensor(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x2x2x256xbf16> { + %0 = tensor.empty() : tensor<1x2x2x256xbf16> + // CHECK: error: 'ttir.conv_transpose2d' op Mismatch between expected output size per channel and got output tensor dimensions. Expected: (10 x 10), got: (2 x 2). + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x2x2x256xbf16>) -> tensor<1x2x2x256xbf16> + return %1 : tensor<1x2x2x256xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_positive.mlir new file mode 100644 index 0000000000..3736518ebe --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_positive.mlir @@ -0,0 +1,109 @@ +// RUN: ttmlir-opt %s | FileCheck %s +#any_device = #tt.operand_constraint + +module attributes {} { + func.func @conv_transpose2d_simple(%arg0: tensor<4x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<4x10x10x256xbf16> { + %0 = tensor.empty() : tensor<4x10x10x256xbf16> + // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<4x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<4x10x10x256xbf16>) -> tensor<4x10x10x256xbf16> + return %1 : tensor<4x10x10x256xbf16> + } + + func.func @conv_transpose2d_stride(%arg0: tensor<1x16x32x256xbf16>, %arg1: tensor<256x256x8x8xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x38x132x256xbf16> { + %0 = tensor.empty() : tensor<1x38x132x256xbf16> + // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x16x32x256xbf16>, tensor<256x256x8x8xbf16>, tensor<1x1x1x256xbf16>, tensor<1x38x132x256xbf16>) -> tensor<1x38x132x256xbf16> + return %1 : tensor<1x38x132x256xbf16> + } + + func.func @conv_transpose2d_padding(%arg0: tensor<1x64x64x256xbf16>, %arg1: tensor<256x256x16x16xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x73x67x256xbf16> { + %0 = tensor.empty() : tensor<1x73x67x256xbf16> + // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = array, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x64x64x256xbf16>, tensor<256x256x16x16xbf16>, tensor<1x1x1x256xbf16>, tensor<1x73x67x256xbf16>) -> tensor<1x73x67x256xbf16> + return %1 : tensor<1x73x67x256xbf16> + } + + func.func @conv_transpose2d_output_padding(%arg0: tensor<1x32x32x128xbf16>, %arg1: tensor<128x256x8x8xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x45x47x256xbf16> { + %0 = tensor.empty() : tensor<1x45x47x256xbf16> + // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = array, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x32x32x128xbf16>, tensor<128x256x8x8xbf16>, tensor<1x1x1x256xbf16>, tensor<1x45x47x256xbf16>) -> tensor<1x45x47x256xbf16> + return %1 : tensor<1x45x47x256xbf16> + } + + func.func @conv_transpose2d_dilation(%arg0: tensor<1x32x32x128xbf16>, %arg1: tensor<128x256x16x32xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x77x94x256xbf16> { + %0 = tensor.empty() : tensor<1x77x94x256xbf16> + // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = array, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x32x32x128xbf16>, tensor<128x256x16x32xbf16>, tensor<1x1x1x256xbf16>, tensor<1x77x94x256xbf16>) -> tensor<1x77x94x256xbf16> + return %1 : tensor<1x77x94x256xbf16> + } + + func.func @conv_transpose2d_groups(%arg0: tensor<1x16x32x192xbf16>, %arg1: tensor<192x126x8x8xbf16>, %arg2: tensor<1x1x1x252xbf16>) -> tensor<1x23x39x252xbf16> { + %0 = tensor.empty() : tensor<1x23x39x252xbf16> + // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 2: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x16x32x192xbf16>, tensor<192x126x8x8xbf16>, tensor<1x1x1x252xbf16>, tensor<1x23x39x252xbf16>) -> tensor<1x23x39x252xbf16> + return %1 : tensor<1x23x39x252xbf16> + } + + func.func @conv_transpose2d(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x64x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x21x38x256xbf16> { + %0 = tensor.empty() : tensor<1x21x38x256xbf16> + // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = array, + padding = array, + output_padding = array, + dilation = array, + groups = 4: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<1x8x8x256xbf16>, tensor<256x64x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x21x38x256xbf16>) -> tensor<1x21x38x256xbf16> + return %1 : tensor<1x21x38x256xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv_transpose2d.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv_transpose2d.mlir new file mode 100644 index 0000000000..5df3c6b631 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv_transpose2d.mlir @@ -0,0 +1,21 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint + +module attributes {} { + func.func @forward(%arg0: tensor<3x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<3x10x10x256xbf16> { + %0 = tensor.empty() : tensor<3x10x10x256xbf16> + // CHECK: %[[C:.*]] = "ttnn.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<3x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<3x10x10x256xbf16>) -> tensor<3x10x10x256xbf16> + return %1 : tensor<3x10x10x256xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_conv_transpose2d.mlir b/test/ttmlir/Silicon/TTNN/simple_conv_transpose2d.mlir new file mode 100644 index 0000000000..5df3c6b631 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/simple_conv_transpose2d.mlir @@ -0,0 +1,21 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint + +module attributes {} { + func.func @forward(%arg0: tensor<3x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<3x10x10x256xbf16> { + %0 = tensor.empty() : tensor<3x10x10x256xbf16> + // CHECK: %[[C:.*]] = "ttnn.conv_transpose2d"[[C:.*]] + %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0) + <{ + stride = 1: i32, + padding = 0: i32, + output_padding = 0: i32, + dilation = 1: i32, + groups = 1: i32, + operand_constraints = [#any_device, #any_device, #any_device, #any_device]} + > : (tensor<3x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<3x10x10x256xbf16>) -> tensor<3x10x10x256xbf16> + return %1 : tensor<3x10x10x256xbf16> + } +}