Skip to content

Commit

Permalink
Add support for clamp op. (#1093)
Browse files Browse the repository at this point in the history
* Add end-to-end implementation of the ops.
* Add stablehlo to ttir conversion for clamp op.
  • Loading branch information
mmanzoorTT authored Nov 14, 2024
1 parent 30d6fff commit 0432a34
Show file tree
Hide file tree
Showing 17 changed files with 362 additions and 2 deletions.
29 changes: 29 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,35 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> {
let hasVerifier = 1;
}

def TTIR_ClampOp : TTIR_DPSOp<"clamp"> {
let summary = "Clamp op.";
let description = [{
Clamp tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttir.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
F32Attr:$min,
F32Attr:$max,
TT_OperandConstraintArrayAttr:$operand_constraints);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
23 changes: 23 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,29 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let hasVerifier = 1;
}

def TTNN_ClampOp : TTNN_Op<"clamp"> {
let summary = "Clamp op.";
let description = [{
Clamp tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttnn.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
F32Attr:$min,
F32Attr:$max);

let results = (outs Variadic<AnyRankedTensor>:$result);

let hasVerifier = 1;
}

// Note: NoMemoryEffect is used to indicate that operation can be removed if it is not used.
// Removal of this operation is done by the dead code elimination pass (RemoveDeadValuesPass).
def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> {
Expand Down
8 changes: 7 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ include "Common/debug_info.fbs";

namespace tt.target.ttnn;

table ClampOpParams {
min: float;
max: float;
}

table GetDeviceOp {
mesh: Dim2d;
chip_ids: [uint32];
Expand Down Expand Up @@ -100,10 +105,11 @@ enum EltwiseOpType: uint32 {
Where = 35,
Gelu = 36,
LogicalXor = 37,
Clamp = 38,
}

union EltwiseOpParams {

ClampOpParams,
}

table EltwiseOp {
Expand Down
74 changes: 74 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,74 @@ class StableHLOToTTIRSliceOpConversionPattern
}
};

class StableHLOToTTIROpClampOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ClampOp> {

using OpConversionPattern<mlir::stablehlo::ClampOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ClampOp srcOp,
mlir::stablehlo::ClampOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
Value min = adaptor.getMin();
Value max = adaptor.getMax();
Operation *minDefiningOp = min.getDefiningOp();
Operation *maxDefiningOp = max.getDefiningOp();
if (minDefiningOp && maxDefiningOp &&
isa<mlir::tt::ttir::ConstantOp>(minDefiningOp) &&
isa<mlir::tt::ttir::ConstantOp>(maxDefiningOp)) {
mlir::ElementsAttr minValAttr =
mlir::cast<mlir::tt::ttir::ConstantOp>(minDefiningOp).getValueAttr();
mlir::ElementsAttr maxValAttr =
mlir::cast<mlir::tt::ttir::ConstantOp>(maxDefiningOp).getValueAttr();
if (minValAttr.isSplat() && maxValAttr.isSplat()) {
float minValue =
minValAttr.getElementType().isInteger()
? static_cast<float>(minValAttr.getSplatValue<int>())
: minValAttr.getSplatValue<float>();
float maxValue =
maxValAttr.getElementType().isInteger()
? static_cast<float>(maxValAttr.getSplatValue<int>())
: maxValAttr.getSplatValue<float>();
rewriter.replaceOpWithNewOp<mlir::tt::ttir::ClampOp>(
srcOp,
this->getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor),
rewriter.getF32FloatAttr(minValue),
rewriter.getF32FloatAttr(maxValue),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

return success();
}
}

ttir::MaximumOp maximumOp = rewriter.create<mlir::tt::ttir::MaximumOp>(
srcOp->getLoc(), min, adaptor.getOperand(), outputTensor,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

tensor::EmptyOp finalOutputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<mlir::tt::ttir::MinimumOp>(
srcOp, maximumOp->getResult(0), max, finalOutputTensor,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1124,6 +1192,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
patterns.add<StableHLOToTTIRSliceOpConversionPattern>(typeConverter, ctx);
}

void addClampOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpClampOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1146,6 +1219,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addLogicalOpConversionPattern(ctx, patterns, typeConverter);
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addClampOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
15 changes: 15 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,20 @@ class TransposeOpConversionPattern
}
};

class ClampOpConversionPattern : public OpConversionPattern<ttir::ClampOp> {
public:
using OpConversionPattern<ttir::ClampOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ClampOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::ClampOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getMin(), adaptor.getMax());
return success();
}
};

class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
public:
using OpConversionPattern<ttir::ConcatOp>::OpConversionPattern;
Expand Down Expand Up @@ -933,6 +947,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
TypecastOpConversionPattern,
ClampOpConversionPattern,
ConcatOpConversionPattern,
ReshapeOpConversionPattern,
SliceOpConversionPattern,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::AbsOp>,
DefaultOpConversionPattern<ttnn::CbrtOp>,
DefaultOpConversionPattern<ttnn::ClampOp>,
DefaultOpConversionPattern<ttnn::FloorOp>,
DefaultOpConversionPattern<ttnn::IsFiniteOp>,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@
#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc"

//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::tt::ttir::ClampOp::verify() {
const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(getInput().getType());

const RankedTensorType outputTensorType =
mlir::cast<RankedTensorType>(getResult().getType());

if (inputTensorType != outputTensorType) {
return emitOpError("input and output must have same shape.");
}

return success();
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,35 @@

namespace mlir::tt::ttnn {

//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::tt::ttnn::ClampOp::verify() {
::mlir::Operation::operand_range inputs = getInputs();
::mlir::Operation::result_range outputs = getResults();

if (inputs.size() != 1) {
return emitOpError("expects one tensor as input.");
}

if (outputs.size() != 1) {
return emitOpError("generates one tensor as output.");
}

const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(inputs.front().getType());

const RankedTensorType outputTensorType =
mlir::cast<RankedTensorType>(outputs.front().getType());

if (inputTensorType != outputTensorType) {
return emitOpError("input and output must have same shape.");
}

return success();
}

//===----------------------------------------------------------------------===//
// Conv2dOp
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 38 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,40 @@ createOp(FlatbufferObjectCache &cache, AllGatherOp op) {
op.getDim(), op.getNumLinks());
}

::flatbuffers::Offset<::tt::target::ttnn::ClampOpParams>
createEltwiseOpParams(FlatbufferObjectCache &cache, ClampOp op) {
auto min = op.getMin().convertToFloat();
auto max = op.getMax().convertToFloat();
return ::tt::target::ttnn::CreateClampOpParams(*cache.fbb, min, max);
}

template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createNonDPSEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
::tt::target::ttnn::EltwiseOpType type;
::tt::target::ttnn::EltwiseOpParams paramsType =
::tt::target::ttnn::EltwiseOpParams::NONE;
::flatbuffers::Offset<void> params = 0;
if constexpr (std::is_same_v<EltwiseOp, ClampOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Clamp;
paramsType = ::tt::target::ttnn::EltwiseOpParams::ClampOpParams;
params = createEltwiseOpParams(cache, op).Union();
} else {
llvm_unreachable("unhandled non-DPS EltwiseOp");
}

std::vector<::flatbuffers::Offset<::tt::target::TensorRef>> ins;
for (auto input : op.getInputs()) {
ins.push_back(
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(input)));
}
assert(op.getResults().size() == 1);
auto out = cache.getOrCreate(op.getResults().front(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
return ::tt::target::ttnn::CreateEltwiseOpDirect(*cache.fbb, type, &ins, out,
paramsType, params);
}

template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
Expand Down Expand Up @@ -696,6 +730,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createTransposeOp(cache, transposeOp),
debugString);
}
if (auto clampOp = dyn_cast<ClampOp>(op); clampOp) {
return createOperation(cache, createNonDPSEltwiseOp(cache, clampOp),
debugString);
}
if (auto conv2dOp = dyn_cast<Conv2dOp>(op); conv2dOp) {
return createOperation(cache, createOp(cache, conv2dOp), debugString);
}
Expand Down
21 changes: 21 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,34 @@ static void runEltwiseUnaryCompositeOp(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryCompositeClampOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float, float,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {
::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOpInputTensor(op, tensorPool, &in);

float min = op->params_as_ClampOpParams()->min();
float max = op->params_as_ClampOpParams()->max();
::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());
::ttnn::Tensor out = ttnnOp(*in, min, max, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
return;
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOp(op, tensorPool, ::ttnn::cbrt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Clamp: {
runEltwiseUnaryCompositeClampOP(op, tensorPool, ::ttnn::clamp);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Log1p: {
runEltwiseUnaryCompositeOp(op, tensorPool, ::ttnn::log1p);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace tt::runtime::ttnn::operations::unary::composite {
inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt:
return true;
case ::tt::target::ttnn::EltwiseOpType::Clamp:
case ::tt::target::ttnn::EltwiseOpType::Log1p:
return true;
default:
Expand Down
Loading

0 comments on commit 0432a34

Please sign in to comment.