Skip to content

Commit

Permalink
TOSA to TTIR refactor: split pass and patterns into separate files (#…
Browse files Browse the repository at this point in the history
…1418)

* tosa to ttir refactor: split pass and patterns into separate files

* removed unused #include directives

* fixed the order of include directives, changed the name of the default dsp pattern

* added virtual checkConversionLegality and separated multiply op conversion pattern

* refactor for cleaner code: using the constructor of the base class for mulop pattern

* fixed formatting with pre-commit run
  • Loading branch information
sdjukicTT authored Nov 29, 2024
1 parent 5272015 commit 3eca3d8
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 124 deletions.
6 changes: 5 additions & 1 deletion include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::tt {

void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter);

std::unique_ptr<OperationPass<ModuleOp>> createConvertTosaToTTIRPass();

} // namespace mlir::tt

#endif
#endif // TTMLIR_CONVERSION_TOSATOTTIR_TOSATOTTIR_H
3 changes: 2 additions & 1 deletion lib/Conversion/TosaToTTIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_library(TTMLIRTosaToTTIR
TosaToTTIR.cpp
TosaToTTIRPass.cpp
TosaToTTIRPatterns.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TosaToTTIR
Expand Down
122 changes: 0 additions & 122 deletions lib/Conversion/TosaToTTIR/TosaToTTIR.cpp

This file was deleted.

74 changes: 74 additions & 0 deletions lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"

#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"

using namespace mlir;
using namespace mlir::tt;

namespace mlir::tt::ttir {

#define GEN_PASS_DEF_CONVERTTOSATOTTIR
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt::ttir

namespace {

struct ConvertTosaToTTIRPass
: public ttir::impl::ConvertTosaToTTIRBase<ConvertTosaToTTIRPass> {
void runOnOperation() override {
mlir::ConversionTarget target(getContext());

target.addIllegalDialect<tosa::TosaDialect>();

target.addLegalDialect<ttir::TTIRDialect>();
target.addLegalOp<mlir::tensor::EmptyOp>();
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<mlir::func::FuncOp>();
target.addLegalOp<mlir::func::ReturnOp>();

// For now keep the same type assuming tosa ops operate on builtin tensor.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) {
assert(isa<RankedTensorType>(type) &&
"only ranked tensor type supported");
return type;
});
RewritePatternSet patterns(&getContext());

// Add conversion patterns.
populateTosaToTTIRPatterns(&getContext(), patterns, typeConverter);

// Apply conversion.
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace mlir::tt {

std::unique_ptr<OperationPass<ModuleOp>> createConvertTosaToTTIRPass() {
return std::make_unique<ConvertTosaToTTIRPass>();
}

} // namespace mlir::tt
126 changes: 126 additions & 0 deletions lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"

#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

using namespace mlir;
using namespace mlir::tt;

namespace {

// TODO(sdjukic): extract this pattern into separate file and use it for both
// TOSA and StableHLO

template <typename SrcOp, typename DestOp,
typename Adaptor = typename SrcOp::Adaptor>
class TosaToTTIRDefaultDPSOpConversionPattern
: public OpConversionPattern<SrcOp> {
using OpConversionPattern<SrcOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

LogicalResult legalityResult =
checkConversionLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

RankedTensorType outputType =
mlir::cast<RankedTensorType>(srcOp.getResult().getType());
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<DestOp>(
srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(),
ValueRange(outputTensor),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}

private:
virtual LogicalResult
checkConversionLegality(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return success();
}
};

class TosaToTTIRMultiplyOpConversionPattern
: public TosaToTTIRDefaultDPSOpConversionPattern<
tosa::MulOp, mlir::tt::ttir::MultiplyOp> {
using TosaToTTIRDefaultDPSOpConversionPattern<
tosa::MulOp,
mlir::tt::ttir::MultiplyOp>::TosaToTTIRDefaultDPSOpConversionPattern;

private:
LogicalResult
checkConversionLegality(tosa::MulOp srcOp, tosa::MulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (srcOp.getShift() != 0) {
return rewriter.notifyMatchFailure(
srcOp, "TTIR MultiplyOp doesn't support shifted multiply.");
}
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {

patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::AbsOp,
mlir::tt::ttir::AbsOp>>(
typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::NegateOp,
mlir::tt::ttir::NegOp>>(
typeConverter, ctx);
}

void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::AddOp,
mlir::tt::ttir::AddOp>>(
typeConverter, ctx);
patterns.add<TosaToTTIRMultiplyOpConversionPattern>(typeConverter, ctx);
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::SubOp, mlir::tt::ttir::SubtractOp>>(typeConverter, ctx);
}

void addCompareOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<
tosa::GreaterEqualOp, mlir::tt::ttir::GreaterEqualOp>>(typeConverter,
ctx);
}

} // namespace

namespace mlir::tt {

void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter);
addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter);
addCompareOpsConversionPatterns(ctx, patterns, typeConverter);
}

} // namespace mlir::tt

0 comments on commit 3eca3d8

Please sign in to comment.