Skip to content

Commit

Permalink
Create new general pooling op (#1154)
Browse files Browse the repository at this point in the history
Add decomposition pattern that converts it to maxpool2d.

Use output memory config attribute in runtime, add silicon tests
  • Loading branch information
LPanosTT authored Nov 12, 2024
1 parent 385771d commit 14cd5d0
Show file tree
Hide file tree
Showing 22 changed files with 747 additions and 497 deletions.
6 changes: 6 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ mlir_tablegen(TTIROpsAttrs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TTIROpsAttrsIncGen)
add_dependencies(mlir-headers TTIROpsAttrsIncGen)

set(LLVM_TARGET_DEFINITIONS TTIROpsEnums.td)
mlir_tablegen(TTIROpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(TTIROpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTTIROpsEnumsIncGen)
add_dependencies(mlir-headers MLIRTTIROpsEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS TTIROpsInterfaces.td)
mlir_tablegen(TTIROpsInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TTIROpsInterfaces.cpp.inc -gen-op-interface-defs)
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "TTIROpsInterfaces.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.h.inc"

Expand Down
31 changes: 28 additions & 3 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
}];
}

def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> {
let summary = "General pooling op";
let description = [{
General pooling op
}];

let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TTIR_PoolingMethodAttr:$pooling_method,
DenseI64ArrayAttr:$window_dimensions,

// Default stride of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$window_strides,
// Default dilation of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$base_dilations,
// Default dilation of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$window_dilations,
// Default padding of 0 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size() * 2, 0)">:$padding,
TT_OperandConstraintArrayAttr:$operand_constraints
);

let results = (outs Variadic<AnyRankedTensor>);

let hasVerifier = 1;
}

def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
Expand All @@ -747,9 +774,7 @@ def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom,
TT_OperandConstraintArrayAttr:$operand_constraints,
OptionalAttr<SI32Attr>:$original_height,
OptionalAttr<SI32Attr>:$original_width);
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

Expand Down
4 changes: 4 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
#define TTMLIR_TTIR_ATTRS_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td"
include "ttmlir/Dialect/TTIR/IR/TTIRBase.td"

def TTIR_PoolingMethodAttr : EnumAttr<TTIR_Dialect, TTIR_PoolingMethod, "pooling_method">;

def TTIR_ConvolutionLayoutAttr : AttrDef<TTIR_Dialect, "ConvolutionLayout", [], "::mlir::Attribute"> {
let mnemonic = "convolution_layout";
let summary = "Structure of dimension information for convolution op";
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTIR_ENUMS_TD
#define TTMLIR_TTIR_ENUMS_TD

include "mlir/IR/EnumAttr.td"

def TTIR_PoolingMethodAverage : I32EnumAttrCase<"Average", 0>;
def TTIR_PoolingMethodMax : I32EnumAttrCase<"Max", 1>;

def TTIR_PoolingMethod : I32EnumAttr<"PoolingMethod", "TTIR PoolingMethod", [
TTIR_PoolingMethodAverage,
TTIR_PoolingMethodMax
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttir";
}

#endif
7 changes: 0 additions & 7 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
];
}

def TTIRSlidingWindow2dFixShapes: Pass<"ttir-sliding-window-2d-fix-shapes", "::mlir::ModuleOp"> {
let summary = "Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)";
let description = [{
Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)
}];
}

def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> {
let summary = "Split compound layouts.";
let description = [{
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ struct ConvertStableHLOToTTIRPass
[&](func::CallOp op) { return typeConverter.isLegal(op); });

populateStableHLOToTTIRPatterns(&getContext(), patterns, typeConverter);

// Apply conversion.
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
Expand Down
Loading

0 comments on commit 14cd5d0

Please sign in to comment.