-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TOSA to TTIR refactor: split pass and patterns into separate files (#…
…1418) * tosa to ttir refactor: split pass and patterns into separate files * removed unused #include directives * fixed the order of include directives, changed the name of the default dsp pattern * added virtual checkConversionLegality and separated multiply op conversion pattern * refactor for cleaner code: using the constructor of the base class for mulop pattern * fixed formatting with pre-commit run
- Loading branch information
Showing
5 changed files
with
207 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Dialect.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" | ||
#include "ttmlir/Dialect/TTIR/IR/TTIR.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::tt; | ||
|
||
namespace mlir::tt::ttir { | ||
|
||
#define GEN_PASS_DEF_CONVERTTOSATOTTIR | ||
#include "ttmlir/Conversion/Passes.h.inc" | ||
|
||
} // namespace mlir::tt::ttir | ||
|
||
namespace { | ||
|
||
struct ConvertTosaToTTIRPass | ||
: public ttir::impl::ConvertTosaToTTIRBase<ConvertTosaToTTIRPass> { | ||
void runOnOperation() override { | ||
mlir::ConversionTarget target(getContext()); | ||
|
||
target.addIllegalDialect<tosa::TosaDialect>(); | ||
|
||
target.addLegalDialect<ttir::TTIRDialect>(); | ||
target.addLegalOp<mlir::tensor::EmptyOp>(); | ||
target.addLegalOp<mlir::ModuleOp>(); | ||
target.addLegalOp<mlir::func::FuncOp>(); | ||
target.addLegalOp<mlir::func::ReturnOp>(); | ||
|
||
// For now keep the same type assuming tosa ops operate on builtin tensor. | ||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { | ||
assert(isa<RankedTensorType>(type) && | ||
"only ranked tensor type supported"); | ||
return type; | ||
}); | ||
RewritePatternSet patterns(&getContext()); | ||
|
||
// Add conversion patterns. | ||
populateTosaToTTIRPatterns(&getContext(), patterns, typeConverter); | ||
|
||
// Apply conversion. | ||
if (failed( | ||
applyFullConversion(getOperation(), target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
return; | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace mlir::tt { | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTosaToTTIRPass() { | ||
return std::make_unique<ConvertTosaToTTIRPass>(); | ||
} | ||
|
||
} // namespace mlir::tt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/ValueRange.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" | ||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::tt; | ||
|
||
namespace { | ||
|
||
// TODO(sdjukic): extract this pattern into separate file and use it for both | ||
// TOSA and StableHLO | ||
|
||
template <typename SrcOp, typename DestOp, | ||
typename Adaptor = typename SrcOp::Adaptor> | ||
class TosaToTTIRDefaultDPSOpConversionPattern | ||
: public OpConversionPattern<SrcOp> { | ||
using OpConversionPattern<SrcOp>::OpConversionPattern; | ||
|
||
public: | ||
LogicalResult | ||
matchAndRewrite(SrcOp srcOp, Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
LogicalResult legalityResult = | ||
checkConversionLegality(srcOp, adaptor, rewriter); | ||
if (!legalityResult.succeeded()) { | ||
return legalityResult; | ||
} | ||
|
||
RankedTensorType outputType = | ||
mlir::cast<RankedTensorType>(srcOp.getResult().getType()); | ||
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>( | ||
srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); | ||
rewriter.replaceOpWithNewOp<DestOp>( | ||
srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), | ||
ValueRange(outputTensor), | ||
rewriter.getArrayAttr( | ||
SmallVector<Attribute>(adaptor.getOperands().size() + 1, | ||
rewriter.getAttr<OperandConstraintAttr>( | ||
OperandConstraint::AnyDeviceTile)))); | ||
return success(); | ||
} | ||
|
||
private: | ||
virtual LogicalResult | ||
checkConversionLegality(SrcOp srcOp, Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
return success(); | ||
} | ||
}; | ||
|
||
class TosaToTTIRMultiplyOpConversionPattern | ||
: public TosaToTTIRDefaultDPSOpConversionPattern< | ||
tosa::MulOp, mlir::tt::ttir::MultiplyOp> { | ||
using TosaToTTIRDefaultDPSOpConversionPattern< | ||
tosa::MulOp, | ||
mlir::tt::ttir::MultiplyOp>::TosaToTTIRDefaultDPSOpConversionPattern; | ||
|
||
private: | ||
LogicalResult | ||
checkConversionLegality(tosa::MulOp srcOp, tosa::MulOp::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
if (srcOp.getShift() != 0) { | ||
return rewriter.notifyMatchFailure( | ||
srcOp, "TTIR MultiplyOp doesn't support shifted multiply."); | ||
} | ||
return success(); | ||
} | ||
}; | ||
|
||
void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, | ||
RewritePatternSet &patterns, | ||
TypeConverter &typeConverter) { | ||
|
||
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::AbsOp, | ||
mlir::tt::ttir::AbsOp>>( | ||
typeConverter, ctx); | ||
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::NegateOp, | ||
mlir::tt::ttir::NegOp>>( | ||
typeConverter, ctx); | ||
} | ||
|
||
void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, | ||
RewritePatternSet &patterns, | ||
TypeConverter &typeConverter) { | ||
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern<tosa::AddOp, | ||
mlir::tt::ttir::AddOp>>( | ||
typeConverter, ctx); | ||
patterns.add<TosaToTTIRMultiplyOpConversionPattern>(typeConverter, ctx); | ||
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern< | ||
tosa::SubOp, mlir::tt::ttir::SubtractOp>>(typeConverter, ctx); | ||
} | ||
|
||
void addCompareOpsConversionPatterns(MLIRContext *ctx, | ||
RewritePatternSet &patterns, | ||
TypeConverter &typeConverter) { | ||
patterns.add<TosaToTTIRDefaultDPSOpConversionPattern< | ||
tosa::GreaterEqualOp, mlir::tt::ttir::GreaterEqualOp>>(typeConverter, | ||
ctx); | ||
} | ||
|
||
} // namespace | ||
|
||
namespace mlir::tt { | ||
|
||
void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, | ||
TypeConverter &typeConverter) { | ||
addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter); | ||
addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter); | ||
addCompareOpsConversionPatterns(ctx, patterns, typeConverter); | ||
} | ||
|
||
} // namespace mlir::tt |