From 14cd5d05bfaa55cae1e8cf68ddaf5e1fea35b304 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Tue, 12 Nov 2024 10:28:50 -0500 Subject: [PATCH] Create new general pooling op (#1154) Add decomposition pattern that converts it to maxpool2d. Use output memory config attribute in runtime, add silicon tests --- include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt | 6 + include/ttmlir/Dialect/TTIR/IR/TTIROps.h | 2 + include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 31 +- .../ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td | 4 + .../ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td | 21 + .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 7 - .../StableHLOToTTIR/StableHLOToTTIRPass.cpp | 1 - .../StableHLOToTTIRPatterns.cpp | 244 +++++--- .../TTIRToTTIRDecomposition.cpp | 563 ++++++++++++------ .../TTIRToTTIRDecompositionPass.cpp | 1 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 94 +-- lib/Dialect/TTIR/IR/TTIRDialect.cpp | 2 + lib/Dialect/TTIR/IR/TTIROps.cpp | 54 +- lib/Dialect/TTIR/Transforms/Transforms.cpp | 137 ----- lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 1 - .../lib/ttnn/operations/pool/maxpool2d.cpp | 5 +- .../StableHLOToTTIR/maxpool2d_op.mlir | 2 +- .../Dialect/TTNN/pooling/complex_pooling.mlir | 19 + .../TTNN/{ => pooling}/simple_maxpool2d.mlir | 0 .../Dialect/TTNN/pooling/simple_pooling.mlir | 15 + .../Silicon/TTNN/pooling/complex_pooling.mlir | 18 + .../Silicon/TTNN/pooling/simple_pooling.mlir | 17 + 22 files changed, 747 insertions(+), 497 deletions(-) create mode 100644 include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td create mode 100644 test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir rename test/ttmlir/Dialect/TTNN/{ => pooling}/simple_maxpool2d.mlir (100%) create mode 100644 test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir create mode 100644 test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir create mode 100644 test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt b/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt index e04bf3f3e..15bf10223 100644 --- a/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt +++ b/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt @@ -8,6 +8,12 @@ mlir_tablegen(TTIROpsAttrs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(TTIROpsAttrsIncGen) add_dependencies(mlir-headers TTIROpsAttrsIncGen) +set(LLVM_TARGET_DEFINITIONS TTIROpsEnums.td) +mlir_tablegen(TTIROpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(TTIROpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRTTIROpsEnumsIncGen) +add_dependencies(mlir-headers MLIRTTIROpsEnumsIncGen) + set(LLVM_TARGET_DEFINITIONS TTIROpsInterfaces.td) mlir_tablegen(TTIROpsInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(TTIROpsInterfaces.cpp.inc -gen-op-interface-defs) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index cd102dbb9..f23fd6d88 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -16,6 +16,8 @@ #include "TTIROpsInterfaces.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc" + #define GET_ATTRDEF_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.h.inc" diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index ce9b69c63..cd62b8289 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -727,6 +727,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> { }]; } +def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> { + let summary = "General pooling op"; + let description = [{ + General pooling op + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + TTIR_PoolingMethodAttr:$pooling_method, + DenseI64ArrayAttr:$window_dimensions, + + // Default stride of 1 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size(), 1)">:$window_strides, + // Default dilation of 1 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size(), 1)">:$base_dilations, + // Default dilation of 1 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size(), 1)">:$window_dilations, + // Default padding of 0 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size() * 2, 0)">:$padding, + TT_OperandConstraintArrayAttr:$operand_constraints + ); + + let results = (outs Variadic); + + let hasVerifier = 1; +} def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> { let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; @@ -747,9 +774,7 @@ def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> { SI32Attr:$padding_right, SI32Attr:$padding_top, SI32Attr:$padding_bottom, - TT_OperandConstraintArrayAttr:$operand_constraints, - OptionalAttr:$original_height, - OptionalAttr:$original_width); + TT_OperandConstraintArrayAttr:$operand_constraints); let results = (outs AnyRankedTensor:$result); diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td index 60943af26..0c0e306bc 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td @@ -6,8 +6,12 @@ #define TTMLIR_TTIR_ATTRS_TD include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" +include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td" include "ttmlir/Dialect/TTIR/IR/TTIRBase.td" +def TTIR_PoolingMethodAttr : EnumAttr; + def TTIR_ConvolutionLayoutAttr : AttrDef { let mnemonic = "convolution_layout"; let summary = "Structure of dimension information for convolution op"; diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td new file mode 100644 index 000000000..35cce50d4 --- /dev/null +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_TTIR_ENUMS_TD +#define TTMLIR_TTIR_ENUMS_TD + +include "mlir/IR/EnumAttr.td" + +def TTIR_PoolingMethodAverage : I32EnumAttrCase<"Average", 0>; +def TTIR_PoolingMethodMax : I32EnumAttrCase<"Max", 1>; + +def TTIR_PoolingMethod : I32EnumAttr<"PoolingMethod", "TTIR PoolingMethod", [ + TTIR_PoolingMethodAverage, + TTIR_PoolingMethodMax + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt::ttir"; +} + +#endif diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 1cee4cbb5..63ccb0d28 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -70,13 +70,6 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> { ]; } -def TTIRSlidingWindow2dFixShapes: Pass<"ttir-sliding-window-2d-fix-shapes", "::mlir::ModuleOp"> { - let summary = "Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)"; - let description = [{ - Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C) - }]; -} - def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> { let summary = "Split compound layouts."; let description = [{ diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp index b6c0b4988..587d63bc4 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp @@ -101,7 +101,6 @@ struct ConvertStableHLOToTTIRPass [&](func::CallOp op) { return typeConverter.isLegal(op); }); populateStableHLOToTTIRPatterns(&getContext(), patterns, typeConverter); - // Apply conversion. if (failed( applyFullConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index e6c448aaa..9d9fbee39 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 +#include +#include #include #include "mlir/Dialect/Traits.h" @@ -497,24 +499,43 @@ class StableHLOToTTIRReduceWindowOpConversionPattern rewriter.eraseOp(op); } - bool isMaxPool2d(mlir::stablehlo::ReduceWindowOp &srcOp) const { + bool isMaxPool(mlir::stablehlo::ReduceWindowOp &srcOp) const { if (srcOp.getBody().getBlocks().size() != 1) { return false; } + // Find constant input(s) + Operation *initValue; + for (uint64_t i = 0; i < srcOp.getInitValues().size(); i++) { + initValue = srcOp.getInitValues()[i].getDefiningOp(); + while (initValue->getOpOperands().size() == 1) { + initValue = initValue->getOpOperand(0).get().getDefiningOp(); + } + if (!isa(initValue)) { + return false; + } + + stablehlo::ConstantOp initValueOp = + mlir::cast(initValue); + + if (!checkInitValue(initValueOp, TypicalInitReductionValue::NEG_INF)) { + return false; + } + } + Block &block = *srcOp.getBody().getBlocks().begin(); - uint32_t op_idx = 0; + uint32_t opIdx = 0; for (Operation &op : block) { - if (op_idx == 0 && !isa(op)) { + if (opIdx == 0 && !isa(op)) { return false; } - if (op_idx == 1 && !isa(op)) { + if (opIdx == 1 && !isa(op)) { return false; } - if (op_idx >= 2) { + if (opIdx >= 2) { return false; // More than two ops in the block } - op_idx++; + opIdx++; } return true; @@ -525,105 +546,136 @@ class StableHLOToTTIRReduceWindowOpConversionPattern mlir::stablehlo::ReduceWindowOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (isMaxPool2d(srcOp)) { - - RankedTensorType outputType = mlir::cast( - getTypeConverter()->convertType(srcOp.getResult(0).getType())); + RankedTensorType outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult(0).getType())); + SmallVector outputsVec; + for (uint32_t i = 0; i < srcOp.getResults().size(); i++) { tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + outputsVec.push_back(outputTensor); + } + ValueRange outputs = outputsVec; + + auto windowDimensions = adaptor.getWindowDimensionsAttr(); + auto windowStrides = adaptor.getWindowStridesAttr(); + auto baseDilations = adaptor.getBaseDilationsAttr(); + auto window_dilations = adaptor.getWindowDilationsAttr(); + auto padding_ = adaptor.getPaddingAttr(); + + // Generate defaults if they dont exist + windowStrides = windowStrides + ? windowStrides + : rewriter.getDenseI64ArrayAttr( + SmallVector(windowDimensions.size(), 1)); + baseDilations = baseDilations + ? baseDilations + : rewriter.getDenseI64ArrayAttr( + SmallVector(windowDimensions.size(), 1)); + window_dilations = window_dilations + ? window_dilations + : rewriter.getDenseI64ArrayAttr(SmallVector( + windowDimensions.size(), 1)); + auto padding = + padding_ ? rewriter.getDenseI64ArrayAttr( + SmallVector(padding_.getValues())) + : rewriter.getDenseI64ArrayAttr( + SmallVector(windowDimensions.size() * 2, 1)); + + auto operandConstraints = rewriter.getArrayAttr(SmallVector( + adaptor.getOperands().size(), rewriter.getAttr( + OperandConstraint::AnyDeviceTile))); + + mlir::tt::ttir::PoolingMethod poolingMethod; + if (isMaxPool(srcOp)) { + poolingMethod = mlir::tt::ttir::PoolingMethod::Max; + } else { + return rewriter.notifyMatchFailure(srcOp, "Unsupported pooling method"); + } + + rewriter.replaceOpWithNewOp( + srcOp, outputType, adaptor.getInputs(), outputs, poolingMethod, + windowDimensions, windowStrides, baseDilations, window_dilations, + padding, operandConstraints); + + return success(); + } + +private: + // Just to make the code more readable + enum TypicalInitReductionValue { + NEG_INF, // used for max pooling + ZERO, // used for sum pooling + }; + + // Using the value enum rather than actual values because of different data + // types the init value could be + bool checkInitValue(stablehlo::ConstantOp initValueOp, + TypicalInitReductionValue desired) const { + if (initValueOp.getValueAttr().size() != 1) { + return false; + } - // The generalized ReduceWindow allows for kernel_size, strides, dilation, - // and padding to act on all 4 input dimensions. Since we only support - // channel-last pooling, we select the middle two values for H and W. - // And fail if the others are not 1 (or 0 in the case of padding). - std::vector window_dimensions = adaptor.getWindowDimensions(); - if (window_dimensions[0] != 1 || window_dimensions[3] != 1) { - return failure(); + float desiredF32; + double desiredF64; + uint16_t desiredBF16; + int32_t desiredI32; + int64_t desiredI64; + if (desired == TypicalInitReductionValue::NEG_INF) { + desiredF32 = -std::numeric_limits::infinity(); + desiredF64 = -std::numeric_limits::infinity(); + desiredBF16 = 0xff80; // This is -inf in bfloat16 raw bits + desiredI32 = std::numeric_limits::min(); + desiredI64 = std::numeric_limits::min(); + } else if (desired == TypicalInitReductionValue::ZERO) { + desiredF32 = 0.0; + desiredF64 = 0.0; + desiredBF16 = 0x0000; // This is 0 in bfloat16 raw bits + desiredI32 = 0; + desiredI64 = 0; + } else { + return false; + } + + // Constant operand must be -inf if this is to be a max pool + // since bfloat16 is not a type we acually have I must compare the raw + // bits + if (initValueOp.getResult().getType().getElementType().isBF16()) { + // Collect the values into a vector + std::vector values; + for (int64_t i = 0; i < initValueOp.getValueAttr().size(); ++i) { + values.push_back( + initValueOp.getValueAttr().getValues()[i]); + } + + auto denseValues = ::mlir::DenseElementsAttr::get( + initValueOp.getValueAttr().getShapedType(), values); + uint16_t bfloat_bits = + static_cast(*denseValues.getRawData().data()); + if (bfloat_bits != desiredBF16) { // This is -inf in bfloat16 + return false; } - IntegerAttr kernel_height_attr = rewriter.getSI32IntegerAttr( - static_cast(window_dimensions[1])); - IntegerAttr kernel_width_attr = rewriter.getSI32IntegerAttr( - static_cast(window_dimensions[2])); - - std::vector strides = - adaptor.getWindowStrides() - .value_or(ArrayRef({1, 1, 1, 1})) - .vec(); - - if (strides[0] != 1 || strides[3] != 1) { - return failure(); + } else if (initValueOp.getValue().getType().isF32()) { + if (*initValueOp.getValue().value_begin() != desiredF32) { + return false; } - IntegerAttr stride_height_attr = - rewriter.getSI32IntegerAttr(static_cast(strides[1])); - IntegerAttr stride_width_attr = - rewriter.getSI32IntegerAttr(static_cast(strides[2])); - - std::vector dilation = - adaptor.getBaseDilations() - .value_or(ArrayRef({1, 1, 1, 1})) - .vec(); - - if (dilation[0] != 1 || dilation[3] != 1) { - return failure(); + } else if (initValueOp.getValue().getType().isF64()) { + if (*initValueOp.getValue().value_begin() != desiredF64) { + return false; } - IntegerAttr dilation_height_attr = - rewriter.getSI32IntegerAttr(static_cast(dilation[1])); - IntegerAttr dilation_width_attr = - rewriter.getSI32IntegerAttr(static_cast(dilation[2])); - - // Padding here is in the form ((., .), (top, bottom), (left, right), (., - // .)) one for each of (N, H, W, C). Since we only support maxpool2d, the - // first and last padding tuples must be zero to be valid. This list is - // flattened so we can use a single iterator to get the values. - std::vector padding = {0, 0, 0, 0}; - if (adaptor.getPadding().has_value()) { - uint32_t pad_idx = 0; - for (auto iter = adaptor.getPadding()->value_begin(); - iter < adaptor.getPadding()->value_end(); iter++) { - - // TTIR requires left, right, top, bottom - if (pad_idx == 2) { - padding[2] = *iter; - } else if (pad_idx == 3) { - padding[3] = *iter; - } else if (pad_idx == 4) { - padding[0] = *iter; - } else if (pad_idx == 5) { - padding[1] = *iter; - } else if (*iter != 0) { - // Padding on the channel or batch is > 1. TTIR/TTNN does not - // support this. - return failure(); - } - pad_idx++; - } + } else if (initValueOp.getValue().getType().isInteger(32)) { + if (*initValueOp.getValue().value_begin() != desiredI32) { + return false; + } + } else if (initValueOp.getValue().getType().isInteger(64)) { + if (*initValueOp.getValue().value_begin() != desiredI64) { + return false; } - ::llvm::ArrayRef input_shape = - mlir::cast(adaptor.getInputs()[0].getType()) - .getShape(); - - // Dead ttir.constant sticks around and fails verification. Removing it - // like so since its behind another op - recursiveErase(rewriter, adaptor.getInitValues()[0].getDefiningOp()); - rewriter.replaceOpWithNewOp( - srcOp, outputType, srcOp.getInputs()[0], outputTensor, - kernel_height_attr, kernel_width_attr, stride_height_attr, - stride_width_attr, dilation_height_attr, dilation_width_attr, - rewriter.getBoolAttr(false), rewriter.getSI32IntegerAttr(padding[0]), - rewriter.getSI32IntegerAttr(padding[1]), - rewriter.getSI32IntegerAttr(padding[2]), - rewriter.getSI32IntegerAttr(padding[3]), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile))), - rewriter.getSI32IntegerAttr(input_shape[1]), - rewriter.getSI32IntegerAttr(input_shape[2])); - - return success(); + } else { + return false; } - return failure(); + + return true; } }; diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index f3ecd8902..909420702 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -4,26 +4,18 @@ #include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" -#include "ttmlir/Dialect/TTNN/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Traits.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" -#include -#include + +#include using namespace mlir; using namespace mlir::tt; @@ -115,160 +107,120 @@ enum ConvolutionKernelDimension { INVALID_KERNEL_DIM = -3 }; -static tensor::EmptyOp generateTransposeDPSOutput(Value input, int64_t dim0, - int64_t dim1, - PatternRewriter &rewriter) { - auto input_type = mlir::cast(input.getType()); - auto output_shape = input_type.getShape().vec(); - std::swap(output_shape[dim0], output_shape[dim1]); - - auto output_type = RankedTensorType::get( - output_shape, input_type.getElementType(), input_type.getEncoding()); - - return rewriter.create(input.getLoc(), output_shape, - output_type.getElementType()); -} - -static ttir::TransposeOp -generateTranspose(Value input, int64_t dim0, int64_t dim1, - PatternRewriter &rewriter, - ::mlir::ArrayAttr operandConstraints) { - auto input_type = mlir::cast(input.getType()); - auto output_shape = input_type.getShape().vec(); - std::swap(output_shape[dim0], output_shape[dim1]); - - auto dim0_attr = rewriter.getSI32IntegerAttr(dim0); - auto dim1_attr = rewriter.getSI32IntegerAttr(dim1); - - auto dps_output = generateTransposeDPSOutput(input, dim0, dim1, rewriter); - return rewriter.create( - input.getLoc(), dps_output.getType(), input, dps_output, dim0_attr, - dim1_attr, operandConstraints); -} - -static std::vector generateKernelTransposeIndices( - ttir::ConvolutionOp op, - const std::vector ttnn_convolution_kernel_layout) { - std::vector transpose_indices; - - std::vector kernel_layout( - ttnn_convolution_kernel_layout.size(), - ConvolutionKernelDimension::INVALID_KERNEL_DIM); - kernel_layout[op.getConvolutionLayout().getKernelOutputFeatureDimension()] = - ConvolutionKernelDimension::OUTPUT_FEATURES; - kernel_layout[op.getConvolutionLayout().getKernelInputFeatureDimension()] = - ConvolutionKernelDimension::INPUT_FEATURES; - - int64_t spatial_count = 0; - for (int64_t spatial_dim : - op.getConvolutionLayout().getKernelSpatialDimensions()) { - kernel_layout[spatial_dim] = spatial_count; - spatial_count++; - } - - const std::vector desired_kernel_layout = - ttnn_convolution_kernel_layout; - for (int64_t i = 0; i < static_cast(kernel_layout.size()); i++) { - if (kernel_layout[i] != desired_kernel_layout[i]) { +/* + * Generates a sequence of dims in which to transpose to make currentLayout + * match desiredLayout + * + * Ex: if currentLayout = [0, 1, 2, 3] and desiredLayout = [0, 2, 3, 1] + * then the function will return [(1, 2), (2, 3)] because when we swap + * currentLayout[1] with currentLayout[2] we get [0, 2, 1, 3], and then when + * we swap currentLayout[2] with currentLayout[3] we get [0, 2, 3, 1], which + * is the desired layout + */ +static std::vector +generateTransposeIndices(std::vector currentLayout, + const std::vector desiredLayout) { + std::vector transposeIndices; + for (int64_t i = 0; i < static_cast(currentLayout.size()); i++) { + if (currentLayout[i] != desiredLayout[i]) { int64_t dim0 = i; - int64_t dim1 = std::find(kernel_layout.begin(), kernel_layout.end(), - desired_kernel_layout[i]) - - kernel_layout.begin(); - transpose_indices.push_back(std::make_tuple(dim0, dim1)); - std::swap(kernel_layout[dim0], kernel_layout[dim1]); + int64_t dim1 = std::find(currentLayout.begin(), currentLayout.end(), + desiredLayout[i]) - + currentLayout.begin(); + transposeIndices.push_back(std::make_tuple(dim0, dim1)); + std::swap(currentLayout[dim0], currentLayout[dim1]); } } - return transpose_indices; + return transposeIndices; } -static std::vector generateInputTransposeIndices( - ttir::ConvolutionOp op, - const std::vector ttnn_convolution_layout) { - std::vector transpose_indices; - - std::vector input_layout(ttnn_convolution_layout.size(), - ConvolutionDimension::INVALID_DIM); - input_layout[op.getConvolutionLayout().getInputBatchDimension()] = - ConvolutionDimension::BATCH; - input_layout[op.getConvolutionLayout().getInputFeatureDimension()] = - ConvolutionDimension::FEATURE; - - int64_t spatial_count = 0; - for (int64_t spatial_dim : - op.getConvolutionLayout().getInputSpatialDimensions()) { - input_layout[spatial_dim] = spatial_count; - spatial_count++; - } - - const std::vector desired_input_layout = ttnn_convolution_layout; - for (int64_t i = 0; i < static_cast(input_layout.size()); i++) { - if (input_layout[i] != desired_input_layout[i]) { - int64_t dim0 = i; - int64_t dim1 = std::find(input_layout.begin(), input_layout.end(), - desired_input_layout[i]) - - input_layout.begin(); - transpose_indices.push_back(std::make_tuple(dim0, dim1)); - std::swap(input_layout[dim0], input_layout[dim1]); - } +/* + * This function will use a sequence of transpose indices to + * generate the actual transpose operations descrbibed by them. + * + * It takes an input to apply these transposes to and returns the + * result at the end of the sequence + */ +static Value generateTransposeOps(Value input, PatternRewriter &rewriter, + std::vector transposeIndices, + ::mlir::ArrayAttr operandConstraints) { + for (auto [dim0, dim1] : transposeIndices) { + + auto inputType = mlir::cast(input.getType()); + auto outputShape = inputType.getShape().vec(); + std::swap(outputShape[dim0], outputShape[dim1]); + + auto dim0Attr = rewriter.getSI32IntegerAttr(dim0); + auto dim1Attr = rewriter.getSI32IntegerAttr(dim1); + + auto outputType = RankedTensorType::get( + outputShape, inputType.getElementType(), inputType.getEncoding()); + + auto dpsOutput = rewriter.create( + input.getLoc(), outputShape, outputType.getElementType()); + input = rewriter + .create(input.getLoc(), outputType, input, + dpsOutput, dim0Attr, dim1Attr, + operandConstraints) + .getResult(); } - return transpose_indices; + return input; } -/** - * Although this function is mostly a clone of generateInputTransposeIndices, - * its slightly different in that if the original Convolution op had the same - * input and output layout, this function will generate the same transposes, - * that were applied to the input but in reverse order. This makes optimizing - * away the inserted transposes easier. +/* + * This function will generate the transpose indices needed to convert a + * convolution input to a desired layout. The reason for the separate + * function is to encapsulate the logic for constructuring the inputLayout */ -static std::vector generateOutputTransposeIndices( - ttir::ConvolutionOp op, - const std::vector ttnn_convolution_layout) { - std::vector transpose_indices; +static std::vector +generateConvTransposeIndices(ttir::ConvolutionOp op, + const std::vector ttnnConvolutionLayout) { - std::vector desired_output_layout(ttnn_convolution_layout.size(), - ConvolutionDimension::INVALID_DIM); - desired_output_layout[op.getConvolutionLayout().getOutputBatchDimension()] = + std::vector inputLayout(ttnnConvolutionLayout.size(), + ConvolutionDimension::INVALID_DIM); + inputLayout[op.getConvolutionLayout().getInputBatchDimension()] = ConvolutionDimension::BATCH; - desired_output_layout[op.getConvolutionLayout().getOutputFeatureDimension()] = + inputLayout[op.getConvolutionLayout().getInputFeatureDimension()] = ConvolutionDimension::FEATURE; - int64_t spatial_count = 0; - for (int64_t spatial_dim : - op.getConvolutionLayout().getOutputSpatialDimensions()) { - desired_output_layout[spatial_dim] = spatial_count; - spatial_count++; + int64_t spatialCount = 0; + for (int64_t spatialDim : + op.getConvolutionLayout().getInputSpatialDimensions()) { + inputLayout[spatialDim] = spatialCount; + spatialCount++; } - std::vector output_layout = ttnn_convolution_layout; + return generateTransposeIndices(inputLayout, ttnnConvolutionLayout); +} - for (int64_t i = static_cast(desired_output_layout.size()) - 1; - i >= 0; i--) { - if (desired_output_layout[i] != output_layout[i]) { - int64_t dim0 = i; - int64_t dim1 = std::find(output_layout.begin(), output_layout.end(), - desired_output_layout[i]) - - output_layout.begin(); - transpose_indices.push_back(std::make_tuple(dim0, dim1)); - std::swap(output_layout[dim0], output_layout[dim1]); - } - } +/* + * This function will generate the transpose indices needed to convert a + * convolution input to a desired layout. The reason for the separate + * function is to encapsulate the logic for constructuring the kernelLayout + */ +static std::vector generateConvKernelTransposeIndices( + ttir::ConvolutionOp op, + const std::vector ttnnConvolutionKernelLayout) { + std::vector transposeIndices; - return transpose_indices; -} + std::vector kernelLayout( + ttnnConvolutionKernelLayout.size(), + ConvolutionKernelDimension::INVALID_KERNEL_DIM); + kernelLayout[op.getConvolutionLayout().getKernelOutputFeatureDimension()] = + ConvolutionKernelDimension::OUTPUT_FEATURES; + kernelLayout[op.getConvolutionLayout().getKernelInputFeatureDimension()] = + ConvolutionKernelDimension::INPUT_FEATURES; -static Value -generateTransposeSequence(Value input, PatternRewriter &rewriter, - std::vector transpose_indices, - ::mlir::ArrayAttr operandConstraints) { - for (auto [dim0, dim1] : transpose_indices) { - input = generateTranspose(input, dim0, dim1, rewriter, operandConstraints) - .getResult(); + int64_t spatialCount = 0; + for (int64_t spatialDim : + op.getConvolutionLayout().getKernelSpatialDimensions()) { + kernelLayout[spatialDim] = spatialCount; + spatialCount++; } - return input; + return generateTransposeIndices(kernelLayout, ttnnConvolutionKernelLayout); } struct ConvolutionToConv2dPattern @@ -281,11 +233,11 @@ struct ConvolutionToConv2dPattern constexpr static uint32_t SPATIAL_DIM_WIDTH = 1; // NHWC - const std::vector conv2d_layout = { + const std::vector conv2dLayout = { ConvolutionDimension::BATCH, SPATIAL_DIM_HEIGHT, SPATIAL_DIM_WIDTH, ConvolutionDimension::FEATURE}; // OIHW - const std::vector conv2d_kernel_layout = { + const std::vector conv2dKernelLayout = { ConvolutionKernelDimension::OUTPUT_FEATURES, ConvolutionKernelDimension::INPUT_FEATURES, SPATIAL_DIM_HEIGHT, SPATIAL_DIM_WIDTH}; @@ -308,9 +260,9 @@ struct ConvolutionToConv2dPattern } // Not currently supporting window reversal - std::vector window_reversal(op.getWindowReversal().begin(), - op.getWindowReversal().end()); - for (bool reversed : window_reversal) { + std::vector windowReversal(op.getWindowReversal().begin(), + op.getWindowReversal().end()); + for (bool reversed : windowReversal) { if (reversed) { return failure(); } @@ -332,73 +284,71 @@ struct ConvolutionToConv2dPattern return failure(); } - auto stride_height_attr = rewriter.getSI32IntegerAttr( + auto strideHeightAttr = rewriter.getSI32IntegerAttr( adaptor.getWindowStrides()[SPATIAL_DIM_HEIGHT]); - auto stride_width_attr = rewriter.getSI32IntegerAttr( + auto strideWidthAttr = rewriter.getSI32IntegerAttr( adaptor.getWindowStrides()[SPATIAL_DIM_WIDTH]); - auto dilation_height_attr = rewriter.getSI32IntegerAttr( + auto dilationHeightAttr = rewriter.getSI32IntegerAttr( adaptor.getWeightDilation()[SPATIAL_DIM_HEIGHT]); - auto dilation_width_attr = rewriter.getSI32IntegerAttr( + auto dilationWidthAttr = rewriter.getSI32IntegerAttr( adaptor.getWeightDilation()[SPATIAL_DIM_WIDTH]); // Padding is a list of 2-tuples, the order of the 2-tuples is in // most-significant spatial dimension first order For Conv2d the most // significant spatial dimension is the height, followed by the width. - auto padding_matrix = - getPaddingMatrix(adaptor.getPadding()); - auto padding_top_attr = - rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_HEIGHT][0]); - auto padding_bottom_attr = - rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_HEIGHT][1]); - auto padding_left_attr = - rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_WIDTH][0]); - auto padding_right_attr = - rewriter.getSI32IntegerAttr(padding_matrix[SPATIAL_DIM_WIDTH][1]); - - auto groups_attr = + auto paddingMatrix = getPaddingMatrix(adaptor.getPadding()); + auto paddingTopAttr = + rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_HEIGHT][0]); + auto paddingBottomAttr = + rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_HEIGHT][1]); + auto paddingLeftAttr = + rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_WIDTH][0]); + auto paddingRightAttr = + rewriter.getSI32IntegerAttr(paddingMatrix[SPATIAL_DIM_WIDTH][1]); + + auto groupsAttr = rewriter.getSI32IntegerAttr(adaptor.getFeatureGroupCount()); - auto output_shape = op.getResult().getType().getShape().vec(); - std::vector new_output_shape = { - output_shape[adaptor.getConvolutionLayout().getOutputBatchDimension()], - output_shape[adaptor.getConvolutionLayout() - .getOutputSpatialDimensions()[SPATIAL_DIM_HEIGHT]], - output_shape[adaptor.getConvolutionLayout() - .getOutputSpatialDimensions()[SPATIAL_DIM_WIDTH]], - output_shape[adaptor.getConvolutionLayout() - .getOutputFeatureDimension()]}; + auto outputShape = op.getResult().getType().getShape().vec(); + std::vector newOutputShape = { + outputShape[adaptor.getConvolutionLayout().getOutputBatchDimension()], + outputShape[adaptor.getConvolutionLayout() + .getOutputSpatialDimensions()[SPATIAL_DIM_HEIGHT]], + outputShape[adaptor.getConvolutionLayout() + .getOutputSpatialDimensions()[SPATIAL_DIM_WIDTH]], + outputShape[adaptor.getConvolutionLayout() + .getOutputFeatureDimension()]}; auto inputType = mlir::cast(adaptor.getInput().getType()); auto outputType = - inputType.cloneWith(new_output_shape, inputType.getElementType()); + inputType.cloneWith(newOutputShape, inputType.getElementType()); auto convDPSOutput = rewriter.create( - adaptor.getInput().getLoc(), new_output_shape, + adaptor.getInput().getLoc(), newOutputShape, outputType.getElementType()); - auto input_transpose_indices = - generateInputTransposeIndices(op, conv2d_layout); - Value input = generateTransposeSequence(adaptor.getInput(), rewriter, - input_transpose_indices, - adaptor.getOperandConstraints()); - - auto kernel_transpose_indices = - generateKernelTransposeIndices(op, conv2d_kernel_layout); - Value weight = generateTransposeSequence(adaptor.getWeight(), rewriter, - kernel_transpose_indices, - adaptor.getOperandConstraints()); - ttir::Conv2dOp new_conv = rewriter.create( + auto transposeIndices = generateConvTransposeIndices(op, conv2dLayout); + Value input = + generateTransposeOps(adaptor.getInput(), rewriter, transposeIndices, + adaptor.getOperandConstraints()); + + auto kernelTransposeIndices = + generateConvKernelTransposeIndices(op, conv2dKernelLayout); + Value weight = generateTransposeOps(adaptor.getWeight(), rewriter, + kernelTransposeIndices, + adaptor.getOperandConstraints()); + ttir::Conv2dOp newConv = rewriter.create( op.getLoc(), outputType, input, weight, adaptor.getBias(), - convDPSOutput, stride_height_attr, stride_width_attr, - dilation_height_attr, dilation_width_attr, groups_attr, - padding_left_attr, padding_right_attr, padding_top_attr, - padding_bottom_attr, adaptor.getOperandConstraints()); + convDPSOutput, strideHeightAttr, strideWidthAttr, dilationHeightAttr, + dilationWidthAttr, groupsAttr, paddingLeftAttr, paddingRightAttr, + paddingTopAttr, paddingBottomAttr, adaptor.getOperandConstraints()); - auto output_transpose_indices = - generateOutputTransposeIndices(op, conv2d_layout); - Value output = generateTransposeSequence(new_conv.getResult(), rewriter, - output_transpose_indices, - adaptor.getOperandConstraints()); + // Applying the transposes in reverse order to the output will restore the + // tensor to the original layout + std::reverse(transposeIndices.begin(), transposeIndices.end()); + Value output = + generateTransposeOps(newConv.getResult(), rewriter, transposeIndices, + adaptor.getOperandConstraints()); rewriter.replaceOp(op, output); @@ -406,6 +356,224 @@ struct ConvolutionToConv2dPattern } }; +struct PoolingToPool2dPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + std::vector getIndicesOfSpatialDims(ttir::PoolingOp op) const { + std::vector spatialDims; + for (int64_t i = 0; + i < static_cast(op.getWindowDimensions().size()); i++) { + if (op.getWindowDimensions()[i] > 1) { + spatialDims.push_back(i); + } + } + return spatialDims; + } + + LogicalResult canDecompose2DPoolingOp(ttir::PoolingOp op) const { + + // Window dimensions must be 4 in length + if (op.getWindowDimensions().size() != 4) { + return failure(); + } + + // Window strides must be 4 in length + if (op.getWindowStrides().size() != 4) { + return failure(); + } + + // Operand rank(s) must be 4 + for (Value operand : op.getInputs()) { + auto operandType = mlir::cast(operand.getType()); + if (operandType.getRank() != 4) { + return failure(); + } + } + + // Exactly two of the window dimensions must be greater than 1 + std::vector trueWindowDimensionsIndices = + getIndicesOfSpatialDims(op); + + if (trueWindowDimensionsIndices.size() != 2) { + return failure(); + } + + // Exactly two of the window strides must be greater than 1 + std::vector trueWindowStrideIndices; + for (int64_t i = 0; i < static_cast(op.getWindowStrides().size()); + i++) { + if (op.getWindowStrides()[i] > 1) { + trueWindowStrideIndices.push_back(i); + } + } + + if (trueWindowStrideIndices.size() != 2) { + return failure(); + } + + // The indices of the true window dimensions and strides must be the same + if ((trueWindowDimensionsIndices[0] != trueWindowStrideIndices[0] || + trueWindowDimensionsIndices[1] != trueWindowStrideIndices[1]) && + (trueWindowDimensionsIndices[0] != trueWindowStrideIndices[1] || + trueWindowDimensionsIndices[1] != trueWindowStrideIndices[0])) { + return failure(); + } + + // Padding must be 8 in length + if (op.getPadding().size() != 8) { + return failure(); + } + + return success(); + } + + template + void rewritePool2d(ttir::PoolingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + const int64_t SPATIAL_H = -3; + const int64_t SPATIAL_W = -2; + const int64_t NON_SPATIAL = -1; + + auto inputType = + mlir::cast(adaptor.getInputs()[0].getType()); + assert(inputType.getRank() == 4 && "Input must be 4D tensor"); + std::vector desiredLayout(inputType.getRank(), NON_SPATIAL); + desiredLayout[inputType.getRank() - 3] = SPATIAL_H; + desiredLayout[inputType.getRank() - 2] = SPATIAL_W; + + int64_t nonSpatialCount = 0; + for (int64_t i = 0; i < static_cast(desiredLayout.size()); i++) { + if (desiredLayout[i] == NON_SPATIAL) { + desiredLayout[i] = nonSpatialCount; + nonSpatialCount++; + } + } + + std::vector spatialDims = getIndicesOfSpatialDims(op); + + std::vector currentLayout(inputType.getRank(), NON_SPATIAL); + currentLayout[spatialDims[0]] = SPATIAL_H; + currentLayout[spatialDims[1]] = SPATIAL_W; + + nonSpatialCount = 0; + for (int64_t i = 0; i < static_cast(currentLayout.size()); i++) { + if (currentLayout[i] == NON_SPATIAL) { + currentLayout[i] = nonSpatialCount; + nonSpatialCount++; + } + } + + auto transposeIndices = + generateTransposeIndices(currentLayout, desiredLayout); + + auto kernelHeightAttr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowDimensions()[spatialDims[0]])); + auto kernelWidthAttr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowDimensions()[spatialDims[1]])); + + auto strideHeightAttr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowStrides()[spatialDims[0]])); + + auto strideWidthAttr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowStrides()[spatialDims[1]])); + + auto dilationHeightAttr = rewriter.getSI32IntegerAttr( + adaptor.getWindowDilations()[spatialDims[0]]); + auto dilationWidthAttr = rewriter.getSI32IntegerAttr( + adaptor.getWindowDilations()[spatialDims[1]]); + auto ceilModeAttr = rewriter.getBoolAttr(false); + + auto paddingTopAttr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[0]]); + auto paddingBottomAttr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[0] + 1]); + auto paddingLeftAttr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1]]); + auto paddingRightAttr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1] + 1]); + auto operandConstraints = adaptor.getOperandConstraints(); + + std::vector outputs; + for (Value input : adaptor.getInputs()) { + input = generateTransposeOps(input, rewriter, transposeIndices, + operandConstraints); + + auto outputType = mlir::cast(op.getResult(0).getType()); + auto newOutputShape = outputType.getShape().vec(); + for (TransposeDims dims : transposeIndices) { + std::swap(newOutputShape[std::get<0>(dims)], + newOutputShape[std::get<1>(dims)]); + } + auto newOutputType = + outputType.cloneWith(newOutputShape, outputType.getElementType()); + auto outputTensor = rewriter.create( + op.getLoc(), newOutputType.getShape(), + newOutputType.getElementType()); + + auto newPool = rewriter.create( + op.getLoc(), newOutputType, input, outputTensor, kernelHeightAttr, + kernelWidthAttr, strideHeightAttr, strideWidthAttr, + dilationHeightAttr, dilationWidthAttr, ceilModeAttr, paddingTopAttr, + paddingBottomAttr, paddingLeftAttr, paddingRightAttr, + operandConstraints); + + // Applying the transposes in reverse order to the output will restore the + // tensor to the original layout + std::reverse(transposeIndices.begin(), transposeIndices.end()); + Value output = generateTransposeOps(newPool.getResult(), rewriter, + transposeIndices, operandConstraints); + + // Reverse back so the proper input transposes are generated for the next + // pool + std::reverse(transposeIndices.begin(), transposeIndices.end()); + outputs.push_back(output); + } + + rewriter.replaceOp(op, outputs); + } + + uint32_t getNumSpatialDims(ttir::PoolingOp op) const { + uint32_t numSpatialDims = 0; + for (int64_t dim : op.getWindowDimensions()) { + if (dim > 1) { + numSpatialDims++; + } + } + return numSpatialDims; + } + + LogicalResult + matchAndRewrite(ttir::PoolingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + uint32_t numSpatialDims = getNumSpatialDims(op); + if (numSpatialDims == 2) { + if (failed(canDecompose2DPoolingOp(op))) { + return rewriter.notifyMatchFailure( + op, "2D pooling op with the given attributes is not supported " + "currently"); + } + + switch (op.getPoolingMethod()) { + case ttir::PoolingMethod::Max: { + rewritePool2d(op, adaptor, rewriter); + return success(); + } + default: { + return rewriter.notifyMatchFailure( + op, "Failed to match pooling method: " + + stringifyPoolingMethod(op.getPoolingMethod())); + } + } + } + return rewriter.notifyMatchFailure( + op, "No decompositions for a pooling op with " + + std::to_string(numSpatialDims) + " spatial dimensions"); + } +}; + class GetDimensionSizeToConstantConversionPattern : public OpConversionPattern { public: @@ -437,6 +605,7 @@ class GetDimensionSizeToConstantConversionPattern void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index 1932a80ae..18b59ede9 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -49,6 +49,7 @@ struct TTIRToTTIRDecompositionPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); TypeConverter typeConverter; // All types map 1:1. diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 8b622736f..981045e90 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -8,21 +8,20 @@ #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" #include "ttmlir/Dialect/TTNN/Types/Types.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" using namespace mlir; using namespace mlir::tt; @@ -570,33 +569,31 @@ class MatmulOpConversionPattern : public OpConversionPattern { }; // ANCHOR_END: adding_an_op_matmul_op_rewriter -class Conv2dOpConversionPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +static ttnn::ReshapeOp generateReshape(Value input, ArrayRef newShape, + PatternRewriter &rewriter) { + auto inputType = mlir::cast(input.getType()); + auto outputType = inputType.cloneWith(newShape, inputType.getElementType()); - ttnn::ReshapeOp generateReshape(ttir::Conv2dOp op, Value input, - ArrayRef newShape, - PatternRewriter &rewriter) const { - auto inputType = mlir::cast(input.getType()); - auto outputType = inputType.cloneWith(newShape, inputType.getElementType()); + std::vector newShapeI32(newShape.begin(), newShape.end()); + return rewriter.create( + input.getLoc(), outputType, input, rewriter.getI32ArrayAttr(newShapeI32)); +} - std::vector newShapeI32(newShape.begin(), newShape.end()); - return rewriter.create( - input.getLoc(), outputType, input, - rewriter.getI32ArrayAttr(newShapeI32)); - } +static ttnn::ReshapeOp generateNHWFlatten(Value input, + PatternRewriter &rewriter) { + std::vector shape = + mlir::cast(input.getType()).getShape().vec(); - ttnn::ReshapeOp generateNHWFlatten(ttir::Conv2dOp op, Value input, - PatternRewriter &rewriter) const { - std::vector shape = - mlir::cast(input.getType()).getShape().vec(); + assert(shape.size() == 4 && "Must have 4-dim tensor as conv2d input"); - assert(shape.size() == 4 && "Must have 4-dim tensor as conv2d input"); + std::vector newShape = {1, 1, shape[0] * shape[1] * shape[2], + shape[3]}; + return generateReshape(input, newShape, rewriter); +} - std::vector newShape = {1, 1, shape[0] * shape[1] * shape[2], - shape[3]}; - return generateReshape(op, input, newShape, rewriter); - } +class Conv2dOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ttir::Conv2dOp op, OpAdaptor adaptor, @@ -652,7 +649,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { std::vector flattenedInputShape = { 1, 1, input_shape[0] * input_shape[1] * input_shape[2], input_shape[3]}; - Value flattenedInput = generateNHWFlatten(op, adaptor.getInput(), rewriter); + Value flattenedInput = generateNHWFlatten(adaptor.getInput(), rewriter); std::vector flattenedOutputShape = { 1, 1, output_shape[0] * output_shape[1] * output_shape[2], @@ -677,7 +674,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { stride_height, stride_width, padding_height, padding_width, dilation_height, dilation_width, groups); - Value output = generateReshape(op, new_conv, output_shape, rewriter); + Value output = generateReshape(new_conv, output_shape, rewriter); rewriter.replaceOp(op, output); return success(); @@ -709,22 +706,43 @@ class MaxPool2dOpConversionPattern auto channels = rewriter.getSI32IntegerAttr(input_shape[input_shape.size() - 1]); - assert(adaptor.getOriginalHeight().has_value() && - "ttir::MaxPool2dOp must have original_height set before translating " - "to TTNN dialect."); - assert(adaptor.getOriginalWidth().has_value() && - "ttir::MaxPool2dOp must have original_width set before translating " - "to TTNN dialect."); + Value flattenedInput = generateNHWFlatten(adaptor.getInput(), rewriter); - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getOutput(), device, batch_size, - adaptor.getOriginalHeightAttr(), adaptor.getOriginalWidthAttr(), + auto output_ty = + mlir::cast(adaptor.getOutput().getType()); + llvm::ArrayRef output_shape = output_ty.getShape(); + + std::vector flattenedOutputShape = { + 1, 1, output_shape[0] * output_shape[1] * output_shape[2], + output_shape[3]}; + + output_ty = mlir::cast(getTypeConverter()->convertType( + output_ty.cloneWith(flattenedOutputShape, output_ty.getElementType()))); + + // Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the + // attribute determination + auto poolDPSOutput = rewriter.replaceOpWithNewOp( + adaptor.getOutput().getDefiningOp(), flattenedOutputShape, + output_ty.getElementType()); + + // Must set the type to the output type to maintain the layout attributes + poolDPSOutput.getResult().setType(output_ty); + + auto new_pool = rewriter.create( + op.getLoc(), output_ty, flattenedInput, poolDPSOutput, device, + batch_size, + rewriter.getSI32IntegerAttr(input_shape[input_shape.size() - 3]), + rewriter.getSI32IntegerAttr(input_shape[input_shape.size() - 2]), channels, adaptor.getKernelHeightAttr(), adaptor.getKernelWidthAttr(), adaptor.getStrideHeightAttr(), adaptor.getStrideWidthAttr(), adaptor.getDilationHeightAttr(), adaptor.getDilationWidthAttr(), adaptor.getCeilModeAttr(), adaptor.getPaddingTopAttr(), adaptor.getPaddingRightAttr()); + + Value output = generateReshape(new_pool, output_shape, rewriter); + + rewriter.replaceOp(op, output); + return success(); } }; diff --git a/lib/Dialect/TTIR/IR/TTIRDialect.cpp b/lib/Dialect/TTIR/IR/TTIRDialect.cpp index 73d259ea3..b935b613b 100644 --- a/lib/Dialect/TTIR/IR/TTIRDialect.cpp +++ b/lib/Dialect/TTIR/IR/TTIRDialect.cpp @@ -11,6 +11,8 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.cpp.inc" diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index f0decc33c..f19d10550 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -150,6 +150,46 @@ ::mlir::LogicalResult mlir::tt::ttir::ConvolutionOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PoolingOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::PoolingOp::verify() { + + uint32_t inputRank = + mlir::cast(getInputs()[0].getType()).getRank(); + + for (auto input : getInputs()) { + auto inputType = mlir::cast(input.getType()); + if (inputType.getRank() != inputRank) { + return emitOpError("All input tensors must have the same rank"); + } + } + + if (getWindowStrides().size() != inputRank) { + return emitOpError("Window strides must have the same number of elements " + "as the rank of the input tensor"); + } + + if (getWindowDilations().size() != inputRank) { + return emitOpError("Window dilations must have the same number of elements " + "as the rank of the input tensor"); + } + + if (getWindowDimensions().size() != inputRank) { + return emitOpError( + "Window dimensions must have the same number of elements " + "as the rank of the input tensor"); + } + + if (getPadding().size() != 2 * inputRank) { + return emitOpError("Padding must have the same number of elements as twice " + "the rank of the input tensor"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // MaxPool2dOp //===----------------------------------------------------------------------===// @@ -165,20 +205,6 @@ ::mlir::LogicalResult mlir::tt::ttir::MaxPool2dOp::verify() { << inputType.getRank() << ". Shape: (" << inputShape << ")."; } - if (getOriginalHeight().has_value() != getOriginalWidth().has_value()) { - std::string with_value = - getOriginalHeight().has_value() ? "original_height" : "original_width"; - return emitOpError() - << "If providing the original height and width as attributes, both " - "original_height and original_width must be set. However, only " - << with_value << " was provided."; - } - - if (getOriginalHeight().has_value() && getOriginalWidth().has_value()) { - inputShape[1] = getOriginalHeight().value(); - inputShape[2] = getOriginalWidth().value(); - } - if (getKernelHeight() > inputShape[1]) { return emitOpError() << "Kernel height " << getKernelHeight() << " is greater than input height " << inputShape[1] diff --git a/lib/Dialect/TTIR/Transforms/Transforms.cpp b/lib/Dialect/TTIR/Transforms/Transforms.cpp index 084f1a90d..0a34de9c4 100644 --- a/lib/Dialect/TTIR/Transforms/Transforms.cpp +++ b/lib/Dialect/TTIR/Transforms/Transforms.cpp @@ -14,141 +14,4 @@ namespace mlir::tt::ttir { #define GEN_PASS_DEF_TTIRSLIDINGWINDOW2DFIXSHAPES #include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" -//===----------------------------------------------------------------------===// -// Helper methods -//===----------------------------------------------------------------------===// - -std::vector collapseNHW(std::vector shape) { - std::vector collapsed(shape.size(), 1); - - int64_t NHW = 1; - for (uint32_t i = 0; i < shape.size() - 1; i++) { - NHW *= shape[i]; - } - collapsed[collapsed.size() - 2] = NHW; - collapsed[collapsed.size() - 1] = shape[shape.size() - 1]; - return collapsed; -} - -//===----------------------------------------------------------------------===// -// Sliding window pass -//===----------------------------------------------------------------------===// - -template -class UncollapsedSlidingWindow2dPatternRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc, - Value input, ::llvm::ArrayRef shapei64, - ::mlir::ArrayAttr operandConstraints) const { - auto ty = mlir::cast(input.getType()); - auto output = - rewriter.create(loc, shapei64, ty.getElementType()); - - auto shape_attr = rewriter.getI32ArrayAttr( - {static_cast(shapei64[0]), static_cast(shapei64[1]), - static_cast(shapei64[2]), static_cast(shapei64[3])}); - return rewriter.create( - loc, output.getType(), input, output, shape_attr, operandConstraints); - } - - MaxPool2dOp createMaxPool2dOp(PatternRewriter &rewriter, MaxPool2dOp op, - Value input, int32_t input_height, - int32_t input_width, - RankedTensorType new_result_type) const { - auto output = rewriter.create( - op->getLoc(), new_result_type.getShape(), - new_result_type.getElementType()); - - auto input_height_attr = rewriter.getSI32IntegerAttr(input_height); - auto input_width_attr = rewriter.getSI32IntegerAttr(input_width); - - MaxPool2dOp new_maxpool = rewriter.create( - op.getLoc(), new_result_type, input, output, op.getKernelHeightAttr(), - op.getKernelWidthAttr(), op.getStrideHeightAttr(), - op.getStrideWidthAttr(), op.getDilationHeightAttr(), - op.getDilationWidthAttr(), op.getCeilModeAttr(), - op.getPaddingLeftAttr(), op.getPaddingRightAttr(), - op.getPaddingTopAttr(), op.getPaddingBottomAttr(), - op.getOperandConstraints(), input_height_attr, input_width_attr); - - return new_maxpool; - } - - LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { - ::llvm::ArrayRef input_shape = - mlir::cast(op.getInput().getType()).getShape(); - - if (input_shape.size() != 4) { - return failure(); - } - - if (input_shape[0] == 1 && input_shape[1] == 1) { - return failure(); - } - - if (!llvm::isa(op)) { - return failure(); - } - - // By this point we are certain that the input tensor is not in the form (1, - // 1, N*H*W, C) And so we must insert reshapes on the input/output - - std::vector new_input_shape = collapseNHW(input_shape); - ::llvm::ArrayRef new_input_shape_array(new_input_shape); - - ReshapeOp input_reshape = - createReshapeOp(rewriter, op.getLoc(), op.getInput(), - new_input_shape_array, op.getOperandConstraints()); - - std::vector new_result_shape = - collapseNHW(op.getResult().getType().getShape().vec()); - ::llvm::ArrayRef new_result_shape_array(new_result_shape); - - RankedTensorType new_result_type = RankedTensorType::get( - new_result_shape_array, op.getResult().getType().getElementType(), - op.getResult().getType().getEncoding()); - - Operation *new_op = createMaxPool2dOp( - rewriter, mlir::cast(op), input_reshape, - static_cast(input_shape[1]), - static_cast(input_shape[2]), new_result_type); - - ReshapeOp output_reshape = createReshapeOp( - rewriter, op.getLoc(), new_op->getResult(0), - op.getResult().getType().getShape().vec(), op.getOperandConstraints()); - - rewriter.replaceOp(op, output_reshape); - return success(); - } -}; - -class TTIRSlidingWindow2dFixShapes - : public impl::TTIRSlidingWindow2dFixShapesBase< - TTIRSlidingWindow2dFixShapes> { -public: - using impl::TTIRSlidingWindow2dFixShapesBase< - TTIRSlidingWindow2dFixShapes>::TTIRSlidingWindow2dFixShapesBase; - - void runOnOperation() final { - { - RewritePatternSet patterns(&getContext()); - patterns.add>( - &getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { - signalPassFailure(); - return; - } - } - } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - } -}; - } // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 772b51b04..a84259eba 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -27,7 +27,6 @@ void createTTNNPipelineTTIRPasses( // function. Removes all private functions. pm.addPass(mlir::createInlinerPass()); - pm.addPass(mlir::tt::ttir::createTTIRSlidingWindow2dFixShapes()); pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc(systemDescOptions)); ttir::TTIRImplicitDeviceOptions implicitDeviceOptions; diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index a81984307..a7bc95eee 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -8,6 +8,7 @@ #include "tt/runtime/detail/workarounds.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" +#include "ttnn/types.hpp" #include namespace tt::runtime::ttnn::operations::pool { @@ -57,13 +58,13 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { }, targetDevice); } - + ::ttnn::MemoryConfig outMemConfig = utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = operation.invoke( 0, input, op->batch_size(), op->input_height(), op->input_width(), op->channels(), {op->kernel_height(), op->kernel_width()}, {op->stride_height(), op->stride_width()}, {op->padding_height(), op->padding_width()}, - {op->dilation_height(), op->dilation_width()}, std::nullopt, + {op->dilation_height(), op->dilation_width()}, outMemConfig, std::nullopt); tensorPool.insert_or_assign(op->out()->global_id(), out); diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/maxpool2d_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/maxpool2d_op.mlir index 8d983326f..30e60ac49 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/maxpool2d_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/maxpool2d_op.mlir @@ -9,6 +9,6 @@ func.func public @test_maxpool2d(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x6 stablehlo.return %3 : tensor }) : (tensor<1x128x128x32xbf16>, tensor) -> tensor<1x64x64x32xbf16> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.max_pool2d"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.pooling"[[C:.*]] return %2 : tensor<1x64x64x32xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir b/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir new file mode 100644 index 000000000..8a7eb509d --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir @@ -0,0 +1,19 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xbf16>, %arg1: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { + %0 = tensor.empty() : tensor<1x32x64x64xbf16> + %1 = tensor.empty() : tensor<1x32x64x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + %2, %3 = "ttir.pooling"(%arg0, %arg1, %0, %1) <{ + operandSegmentSizes = array, + pooling_method = #ttir, + window_dimensions = array, + window_strides = array, + operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) + + %4 = tensor.empty() : tensor<1x32x64x64xbf16> + %6 = "ttir.add"(%2, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + return %6 : tensor<1x32x64x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir similarity index 100% rename from test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir rename to test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir diff --git a/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir new file mode 100644 index 000000000..a2d6141de --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { + %0 = tensor.empty() : tensor<1x32x64x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + %1 = "ttir.pooling"(%arg0, %0) <{ + operandSegmentSizes = array, + pooling_method = #ttir, + window_dimensions = array, + window_strides = array, + operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + return %1 : tensor<1x32x64x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir b/test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir new file mode 100644 index 000000000..0196dd106 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir @@ -0,0 +1,18 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xbf16>, %arg1: tensor<1x32x128x128xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) { + %0 = tensor.empty() : tensor<1x32x64x64xbf16> + %1 = tensor.empty() : tensor<1x32x64x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + %2, %3 = "ttir.pooling"(%arg0, %arg1, %0, %1) <{ + operandSegmentSizes = array, + pooling_method = #ttir, + window_dimensions = array, + window_strides = array, + operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) + return %2, %3 : tensor<1x32x64x64xbf16>,tensor<1x32x64x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir new file mode 100644 index 000000000..d61ee781a --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir @@ -0,0 +1,17 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { + %0 = tensor.empty() : tensor<1x32x64x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + %1 = "ttir.pooling"(%arg0, %0) <{ + operandSegmentSizes = array, + pooling_method = #ttir, + window_dimensions = array, + window_strides = array, + operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + return %1 : tensor<1x32x64x64xbf16> + } +}