From 59ad175df7ac80009eae69cb6b0d3069ebec025d Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 16 Oct 2024 18:42:56 +0100 Subject: [PATCH 01/15] initial commit: broken --- xformer/Transforms/OptimizeTranspose.cpp | 347 +++++++++++++++++++---- xformer/Transforms/Options.h | 1 + xformer/XCoreOptMain.cpp | 5 + 3 files changed, 293 insertions(+), 60 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 8223cd3f8..7bda60e45 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -1,7 +1,6 @@ // Copyright 2023 XMOS LIMITED. This Software is subject to the terms of the // XMOS Public License: Version 1 -#include "IR/XCoreOps.h" #include "Transforms/Options.h" #include "mlir/Pass/Pass.h" @@ -25,105 +24,322 @@ struct OptimizeTranspose void runOnOperation() override; }; -struct HoistTransposeWCHAbovePadPattern +struct MoveTransposeForwardOverUnaryOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + // Ensure the TransposeOp has a single use + if (!transposeOp->hasOneUse()) + return failure(); + + Operation *userOp = *transposeOp->getUsers().begin(); + + // Check if the user operation is a unary op that can commute with transpose + if (!isa(userOp)) + return failure(); + + // Get the types of the input and output tensors + auto transposeInputType = + transposeOp.getInput().getType().dyn_cast(); + auto transposeOutputType = + transposeOp.getType().dyn_cast(); + if (!transposeInputType || !transposeOutputType) + return failure(); + + // Get the permutation used in the transpose + Value perm = transposeOp.getPerm(); + + Value newUnaryOpResult; + auto loc = userOp->getLoc(); + auto input = transposeOp.getInput(); + + // Retrieve the original unary operation's output type + auto originalUnaryOutputType = + userOp->getResult(0).getType().dyn_cast(); + if (!originalUnaryOutputType) + return failure(); + + // Create a new output type for the unary op with the same shape as 'input' + // and the same element type as the original output type + auto newUnaryOutputType = RankedTensorType::get( + transposeInputType.getShape(), originalUnaryOutputType.getElementType(), + originalUnaryOutputType.getEncoding()); + + if (auto quantizeOp = dyn_cast(userOp)) { + // For QuantizeOp, create new QuantizeOp with input as 'input' and output + // type adjusted + + // Create new QuantizeOp with adjusted output type + newUnaryOpResult = rewriter.create( + loc, newUnaryOutputType, input, quantizeOp.getQtypeAttr()); + + } else if (auto absOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto negOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto reluOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto relu6Op = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto leakyReluOp = dyn_cast(userOp)) { + newUnaryOpResult = rewriter.create( + loc, newUnaryOutputType, input, leakyReluOp.getAlphaAttr()); + } else if (auto tanhOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto logisticOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else { + // This should not happen as we checked the op type earlier + return failure(); + } + + // Create a new Transpose operation after the unary operation + auto newTransposeOp = rewriter.create( + transposeOp.getLoc(), transposeOutputType, newUnaryOpResult, perm); + + // Replace the original user operation's result with the new transpose + // result + rewriter.replaceOp(userOp, newTransposeOp.getResult()); + + // Remove the original TransposeOp + rewriter.eraseOp(transposeOp); + + return success(); + } +}; + +struct MoveTransposeForwardOverConcatOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::ConcatenationOp concatOp, + PatternRewriter &rewriter) const override { + // Get all input operands + auto inputs = concatOp.getValues(); + + // Check that all inputs are TransposeOps with the same permutation + SmallVector newInputs; + DenseIntElementsAttr commonPermAttr; + for (auto input : inputs) { + auto transposeOp = input.getDefiningOp(); + if (!transposeOp) + return failure(); + + // Get permutation attribute + DenseIntElementsAttr permAttr; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) + return failure(); + + // Check if the permutation is the same as others + if (commonPermAttr) { + if (permAttr != commonPermAttr) + return failure(); + } else { + commonPermAttr = permAttr; + } + + // Collect the inputs to the transpose ops + newInputs.push_back(transposeOp.getInput()); + } + + // Get the permutation vector + SmallVector permVec; + for (auto val : commonPermAttr.getValues()) { + permVec.push_back(val); + } + + // Compute the inverse permutation + SmallVector invPerm(permVec.size()); + for (size_t i = 0; i < permVec.size(); ++i) { + invPerm[permVec[i]] = i; + } + + // Adjust the axis according to the inverse permutation + int32_t oldAxis = concatOp.getAxis(); + int64_t rank = permVec.size(); + if (oldAxis < 0) { + oldAxis += rank; + } + if (oldAxis < 0 || oldAxis >= rank) { + return failure(); // Invalid axis + } + int32_t newAxis = invPerm[oldAxis]; + + // Collect input types and compute the new result type + SmallVector inputTypes; + for (auto input : newInputs) { + auto inputType = input.getType().dyn_cast(); + if (!inputType) { + return failure(); + } + inputTypes.push_back(inputType); + } + + // Ensure all input types have the same rank + for (auto type : inputTypes) { + if (type.getRank() != rank) { + return failure(); + } + } + + // Compute the shape of the concatenated tensor + SmallVector resultShape(inputTypes[0].getShape().begin(), + inputTypes[0].getShape().end()); + for (size_t i = 1; i < inputTypes.size(); ++i) { + auto shape = inputTypes[i].getShape(); + for (int64_t dim = 0; dim < rank; ++dim) { + if (dim == newAxis) { + resultShape[dim] += shape[dim]; + } else if (resultShape[dim] != shape[dim]) { + // Dimensions must be equal except for the concatenation axis + return failure(); + } + } + } + + // Create the new ConcatenationOp with the correct result type and axis + auto elementType = inputTypes[0].getElementType(); + auto newConcatType = RankedTensorType::get(resultShape, elementType); + auto newConcatOp = rewriter.create( + concatOp.getLoc(), newConcatType, newInputs, + rewriter.getI32IntegerAttr(newAxis), + concatOp.getFusedActivationFunctionAttr()); + + // Create the permutation constant with correct data types + auto permType = RankedTensorType::get( + {static_cast(permVec.size())}, rewriter.getIntegerType(32)); + auto permAttr = DenseIntElementsAttr::get(permType, permVec); + auto permConstOp = rewriter.create(concatOp.getLoc(), + permType, permAttr); + + // Create the new TransposeOp with the original output type + auto newTransposeOp = rewriter.create( + concatOp.getLoc(), concatOp.getType(), newConcatOp.getResult(), + permConstOp.getResult()); + + rewriter.replaceOp(concatOp, newTransposeOp.getResult()); + return success(); + } +}; + +struct HoistTransposeAbovePadPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::TransposeOp op, PatternRewriter &rewriter) const override { - - // Check for invalid types and return - // Defining op must be pad + // Check if the input to TransposeOp is a PadOp auto padOp = dyn_cast_or_null(op.getInput().getDefiningOp()); if (!padOp) { return failure(); } - // Get transpose permutation - DenseIntElementsAttr perm; - if (!matchPattern(op.getPerm(), m_Constant(&perm))) { + // Get the permutation attribute + DenseIntElementsAttr permAttr; + if (!matchPattern(op.getPerm(), m_Constant(&permAttr))) { return failure(); } + auto perm = permAttr.getValues(); - // Confirm transpose permutation is 0,2,3,1 i.e., NWCH - // Remnants of Pytorch to TFlite conversion - auto permVal = perm.getValues(); - if (perm.size() != 4 || permVal[0] != 0 || permVal[1] != 2 || - permVal[2] != 3 || permVal[3] != 1) { + // Get the padding attribute + DenseIntElementsAttr padAttr; + if (!matchPattern(padOp.getPadding(), m_Constant(&padAttr))) { return failure(); } + auto padValues = padAttr.getValues(); - // Get padding val - DenseIntElementsAttr pad; - if (!matchPattern(padOp.getPadding(), m_Constant(&pad))) { + // Get the rank of the tensor + auto padInputType = padOp.getInput().getType().dyn_cast(); + if (!padInputType) { return failure(); } + int64_t rank = padInputType.getRank(); - // Confirm padding is only in last two dimensions - auto padVal = pad.getValues(); - if (padVal[{0, 0}] != 0 || padVal[{0, 1}] != 0 || padVal[{1, 0}] != 0 || - padVal[{1, 1}] != 0 || padVal[{2, 0}] != 1 || padVal[{2, 1}] != 1 || - padVal[{3, 0}] != 1 || padVal[{3, 1}] != 1) { - return failure(); + // Reshape the padding values into a matrix of shape [rank, 2] + SmallVector paddingMatrix; + paddingMatrix.reserve(padValues.size()); + for (int64_t i = 0; i < padValues.size(); ++i) { + paddingMatrix.push_back(padValues[i]); + } + + // Create a mapping from old dimensions to new dimensions after transpose + SmallVector inversePerm(rank); + for (int64_t i = 0; i < rank; ++i) { + inversePerm[perm[i]] = i; + } + + // Permute the padding according to the inverse permutation + SmallVector newPaddingValues; + newPaddingValues.reserve(paddingMatrix.size()); + for (int64_t i = 0; i < rank; ++i) { + int32_t dim = inversePerm[i]; + newPaddingValues.push_back(paddingMatrix[dim * 2]); + newPaddingValues.push_back(paddingMatrix[dim * 2 + 1]); } - // Create new TransposeOp - auto padInputShape = - padOp.getInput().getType().cast().getShape(); - auto tranposeResultType = RankedTensorType::get( - {padInputShape[0], padInputShape[2], padInputShape[3], - padInputShape[1]}, - padOp.getInput().getType().cast().getElementType()); + // Create new TransposeOp before PadOp's input + auto newTransposeType = padOp.getInput().getType(); auto newTranspose = rewriter.create( - padOp.getLoc(), tranposeResultType, padOp.getInput(), op.getPerm()); - - // Create new padding attr with spatial dimensions - std::vector paddingValues{0, 0, 1, 1, 1, 1, 0, 0}; - auto paddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4, 2}, rewriter.getI32Type()), paddingValues); - auto paddingOp = rewriter.create( - padOp->getLoc(), RankedTensorType::get({4, 2}, rewriter.getI32Type()), - paddingAttr); - auto newPad = rewriter.create( - padOp.getLoc(), op.getOutput().getType(), newTranspose, paddingOp); - - rewriter.replaceOp(op, newPad.getOutput()); + padOp.getLoc(), newTransposeType, padOp.getInput(), op.getPerm()); + + // Create new padding constant + auto newPaddingAttr = DenseIntElementsAttr::get( + RankedTensorType::get({rank, 2}, rewriter.getI32Type()), + newPaddingValues); + auto newPaddingConst = rewriter.create( + padOp.getLoc(), newPaddingAttr.getType(), newPaddingAttr); + + // Create new PadOp after TransposeOp + auto newPadType = op.getType(); + auto newPad = rewriter.create(padOp.getLoc(), newPadType, + newTranspose, newPaddingConst); + + rewriter.replaceOp(op, newPad.getResult()); return success(); } }; + struct FoldCancellableTransposePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TFL::TransposeOp op, + LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, PatternRewriter &rewriter) const override { - // Check for invalid types and return - // Defining op must be transpose - auto transposeOp = - dyn_cast_or_null(op.getInput().getDefiningOp()); - if (!transposeOp) { + auto inputTransposeOp = + transposeOp.getInput().getDefiningOp(); + if (!inputTransposeOp) return failure(); - } - // Get transpose permutations - DenseIntElementsAttr perm0; - DenseIntElementsAttr perm1; - if (!matchPattern(op.getPerm(), m_Constant(&perm0)) || - !matchPattern(transposeOp.getPerm(), m_Constant(&perm1))) { + // Check if permutations are inverses + DenseIntElementsAttr perm1, perm2; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&perm1)) || + !matchPattern(inputTransposeOp.getPerm(), m_Constant(&perm2))) return failure(); - } - // Do permutation indices cancel each other? - if (!TF::AreCancellablePermutations(perm0, perm1)) { + if (!TF::AreCancellablePermutations(perm1, perm2)) return failure(); - } - rewriter.replaceOp(op, transposeOp.getInput()); + // Replace the outer transpose with the input of the inner transpose + rewriter.replaceOp(transposeOp, inputTransposeOp.getInput()); + + // Erase the inner transpose if it has no more uses + if (inputTransposeOp.use_empty()) + rewriter.eraseOp(inputTransposeOp); return success(); } }; + struct FoldTransposeWCHToInput : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -176,9 +392,20 @@ struct FoldTransposeWCHToInput : public OpRewritePattern { void OptimizeTranspose::runOnOperation() { auto *ctx = &getContext(); func::FuncOp func = getOperation(); + + // Try to merge transpose -> ops -> inverse transpose + RewritePatternSet mergePatterns(ctx); + mergePatterns.insert(ctx); + if (mergeTransposeOption) { + (void)applyPatternsAndFoldGreedily(func, std::move(mergePatterns)); + } + + // Other transpose optimizations RewritePatternSet patterns(ctx); - patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); if (allowInputModificationOption) { patterns.insert(ctx); diff --git a/xformer/Transforms/Options.h b/xformer/Transforms/Options.h index 3b82763c6..811f969ce 100644 --- a/xformer/Transforms/Options.h +++ b/xformer/Transforms/Options.h @@ -26,6 +26,7 @@ extern llvm::cl::list opSplitBottomOpsOption; extern llvm::cl::list opSplitTopOpsOption; extern llvm::cl::list opSplitNumSplitsOption; extern llvm::cl::opt allowInputModificationOption; +extern llvm::cl::opt mergeTransposeOption; extern llvm::cl::opt convDebugOption; extern llvm::cl::opt overlapConvOption; extern llvm::cl::opt offlineOffsetsOption; diff --git a/xformer/XCoreOptMain.cpp b/xformer/XCoreOptMain.cpp index da9297a96..48663dbb6 100644 --- a/xformer/XCoreOptMain.cpp +++ b/xformer/XCoreOptMain.cpp @@ -159,6 +159,11 @@ cl::opt allowInputModificationOption( cl::desc("Allow the compiler to modify input tensor for optimizations."), cl::init(false), cl::cat(XformerCategory), cl::Hidden); +cl::opt mergeTransposeOption( + "xcore-merge-transpose", + cl::desc("Try to merge transpose and inverse transpose together."), + cl::init(true), cl::cat(XformerCategory), cl::Hidden); + cl::opt convDebugOption("xcore-conv-debug", cl::desc("Enable conv debug prints."), cl::init(false), cl::cat(XformerCategory), From 1c50e793b3c2249ab68f757577daf85e05170370 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 16 Oct 2024 19:18:46 +0100 Subject: [PATCH 02/15] fix 1 --- xformer/Transforms/OptimizeTranspose.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 7bda60e45..a3d752485 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -104,7 +104,7 @@ struct MoveTransposeForwardOverUnaryOpPattern // Create a new Transpose operation after the unary operation auto newTransposeOp = rewriter.create( - transposeOp.getLoc(), transposeOutputType, newUnaryOpResult, perm); + transposeOp.getLoc(), newUnaryOutputType, newUnaryOpResult, perm); // Replace the original user operation's result with the new transpose // result From f2499facb15b2c3c87c81e570adf165b2e2a566d Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 17 Oct 2024 09:05:31 +0100 Subject: [PATCH 03/15] fix? --- xformer/Transforms/OptimizeTranspose.cpp | 65 ++++++++++++------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index a3d752485..6e62a077e 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -117,6 +117,39 @@ struct MoveTransposeForwardOverUnaryOpPattern } }; +struct FoldCancellableTransposePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::TransposeOp op, + PatternRewriter &rewriter) const override { + + // Check for invalid types and return + // Defining op must be transpose + auto transposeOp = + dyn_cast_or_null(op.getInput().getDefiningOp()); + if (!transposeOp) { + return failure(); + } + + // Get transpose permutations + DenseIntElementsAttr perm0; + DenseIntElementsAttr perm1; + if (!matchPattern(op.getPerm(), m_Constant(&perm0)) || + !matchPattern(transposeOp.getPerm(), m_Constant(&perm1))) { + return failure(); + } + + // Do permutation indices cancel each other? + if (!TF::AreCancellablePermutations(perm0, perm1)) { + return failure(); + } + + rewriter.replaceOp(op, transposeOp.getInput()); + + return success(); + } +}; struct MoveTransposeForwardOverConcatOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -308,38 +341,6 @@ struct HoistTransposeAbovePadPattern } }; -struct FoldCancellableTransposePattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, - PatternRewriter &rewriter) const override { - - auto inputTransposeOp = - transposeOp.getInput().getDefiningOp(); - if (!inputTransposeOp) - return failure(); - - // Check if permutations are inverses - DenseIntElementsAttr perm1, perm2; - if (!matchPattern(transposeOp.getPerm(), m_Constant(&perm1)) || - !matchPattern(inputTransposeOp.getPerm(), m_Constant(&perm2))) - return failure(); - - if (!TF::AreCancellablePermutations(perm1, perm2)) - return failure(); - - // Replace the outer transpose with the input of the inner transpose - rewriter.replaceOp(transposeOp, inputTransposeOp.getInput()); - - // Erase the inner transpose if it has no more uses - if (inputTransposeOp.use_empty()) - rewriter.eraseOp(inputTransposeOp); - - return success(); - } -}; - struct FoldTransposeWCHToInput : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From 0c07cc12502ec050f24233ae70925237131d1e13 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 17 Oct 2024 09:33:57 +0100 Subject: [PATCH 04/15] fix --- xformer/Transforms/OptimizeTranspose.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 6e62a077e..c4bc0cac2 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -104,7 +104,7 @@ struct MoveTransposeForwardOverUnaryOpPattern // Create a new Transpose operation after the unary operation auto newTransposeOp = rewriter.create( - transposeOp.getLoc(), newUnaryOutputType, newUnaryOpResult, perm); + transposeOp.getLoc(), originalUnaryOutputType, newUnaryOpResult, perm); // Replace the original user operation's result with the new transpose // result @@ -205,7 +205,7 @@ struct MoveTransposeForwardOverConcatOpPattern if (oldAxis < 0 || oldAxis >= rank) { return failure(); // Invalid axis } - int32_t newAxis = invPerm[oldAxis]; + int32_t newAxis = permVec[oldAxis]; // Collect input types and compute the new result type SmallVector inputTypes; @@ -229,14 +229,7 @@ struct MoveTransposeForwardOverConcatOpPattern inputTypes[0].getShape().end()); for (size_t i = 1; i < inputTypes.size(); ++i) { auto shape = inputTypes[i].getShape(); - for (int64_t dim = 0; dim < rank; ++dim) { - if (dim == newAxis) { - resultShape[dim] += shape[dim]; - } else if (resultShape[dim] != shape[dim]) { - // Dimensions must be equal except for the concatenation axis - return failure(); - } - } + resultShape[newAxis] += shape[newAxis]; } // Create the new ConcatenationOp with the correct result type and axis From aa7fd0aea2d4b71b28c52ac921b184d6e6038de6 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 17 Oct 2024 10:32:34 +0100 Subject: [PATCH 05/15] undo pad --- xformer/Transforms/OptimizeTranspose.cpp | 94 +++++++++++------------- 1 file changed, 42 insertions(+), 52 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index c4bc0cac2..f1a7141e3 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -257,79 +257,69 @@ struct MoveTransposeForwardOverConcatOpPattern } }; -struct HoistTransposeAbovePadPattern +struct HoistTransposeWCHAbovePadPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::TransposeOp op, PatternRewriter &rewriter) const override { - // Check if the input to TransposeOp is a PadOp + + // Check for invalid types and return + // Defining op must be pad auto padOp = dyn_cast_or_null(op.getInput().getDefiningOp()); if (!padOp) { return failure(); } - // Get the permutation attribute - DenseIntElementsAttr permAttr; - if (!matchPattern(op.getPerm(), m_Constant(&permAttr))) { + // Get transpose permutation + DenseIntElementsAttr perm; + if (!matchPattern(op.getPerm(), m_Constant(&perm))) { return failure(); } - auto perm = permAttr.getValues(); - // Get the padding attribute - DenseIntElementsAttr padAttr; - if (!matchPattern(padOp.getPadding(), m_Constant(&padAttr))) { + // Confirm transpose permutation is 0,2,3,1 i.e., NWCH + // Remnants of Pytorch to TFlite conversion + auto permVal = perm.getValues(); + if (perm.size() != 4 || permVal[0] != 0 || permVal[1] != 2 || + permVal[2] != 3 || permVal[3] != 1) { return failure(); } - auto padValues = padAttr.getValues(); - // Get the rank of the tensor - auto padInputType = padOp.getInput().getType().dyn_cast(); - if (!padInputType) { + // Get padding val + DenseIntElementsAttr pad; + if (!matchPattern(padOp.getPadding(), m_Constant(&pad))) { return failure(); } - int64_t rank = padInputType.getRank(); - - // Reshape the padding values into a matrix of shape [rank, 2] - SmallVector paddingMatrix; - paddingMatrix.reserve(padValues.size()); - for (int64_t i = 0; i < padValues.size(); ++i) { - paddingMatrix.push_back(padValues[i]); - } - - // Create a mapping from old dimensions to new dimensions after transpose - SmallVector inversePerm(rank); - for (int64_t i = 0; i < rank; ++i) { - inversePerm[perm[i]] = i; - } - // Permute the padding according to the inverse permutation - SmallVector newPaddingValues; - newPaddingValues.reserve(paddingMatrix.size()); - for (int64_t i = 0; i < rank; ++i) { - int32_t dim = inversePerm[i]; - newPaddingValues.push_back(paddingMatrix[dim * 2]); - newPaddingValues.push_back(paddingMatrix[dim * 2 + 1]); + // Confirm padding is only in last two dimensions + auto padVal = pad.getValues(); + if (padVal[{0, 0}] != 0 || padVal[{0, 1}] != 0 || padVal[{1, 0}] != 0 || + padVal[{1, 1}] != 0 || padVal[{2, 0}] != 1 || padVal[{2, 1}] != 1 || + padVal[{3, 0}] != 1 || padVal[{3, 1}] != 1) { + return failure(); } - // Create new TransposeOp before PadOp's input - auto newTransposeType = padOp.getInput().getType(); + // Create new TransposeOp + auto padInputShape = + padOp.getInput().getType().cast().getShape(); + auto tranposeResultType = RankedTensorType::get( + {padInputShape[0], padInputShape[2], padInputShape[3], + padInputShape[1]}, + padOp.getInput().getType().cast().getElementType()); auto newTranspose = rewriter.create( - padOp.getLoc(), newTransposeType, padOp.getInput(), op.getPerm()); - - // Create new padding constant - auto newPaddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({rank, 2}, rewriter.getI32Type()), - newPaddingValues); - auto newPaddingConst = rewriter.create( - padOp.getLoc(), newPaddingAttr.getType(), newPaddingAttr); - - // Create new PadOp after TransposeOp - auto newPadType = op.getType(); - auto newPad = rewriter.create(padOp.getLoc(), newPadType, - newTranspose, newPaddingConst); - - rewriter.replaceOp(op, newPad.getResult()); + padOp.getLoc(), tranposeResultType, padOp.getInput(), op.getPerm()); + + // Create new padding attr with spatial dimensions + std::vector paddingValues{0, 0, 1, 1, 1, 1, 0, 0}; + auto paddingAttr = DenseIntElementsAttr::get( + RankedTensorType::get({4, 2}, rewriter.getI32Type()), paddingValues); + auto paddingOp = rewriter.create( + padOp->getLoc(), RankedTensorType::get({4, 2}, rewriter.getI32Type()), + paddingAttr); + auto newPad = rewriter.create( + padOp.getLoc(), op.getOutput().getType(), newTranspose, paddingOp); + + rewriter.replaceOp(op, newPad.getOutput()); return success(); } }; @@ -399,7 +389,7 @@ void OptimizeTranspose::runOnOperation() { // Other transpose optimizations RewritePatternSet patterns(ctx); - patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); if (allowInputModificationOption) { patterns.insert(ctx); From 7040f840df19dbf0b972d0910cf6994f28ec932a Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 23 Oct 2024 16:05:34 +0100 Subject: [PATCH 06/15] wip --- xformer/Transforms/OptimizeTranspose.cpp | 106 +++++++++++++++++++++-- xformer/Utils/Util.cpp | 1 - 2 files changed, 99 insertions(+), 8 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index f1a7141e3..b5aff23c9 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -24,6 +24,96 @@ struct OptimizeTranspose void runOnOperation() override; }; +struct FoldTransposeIntoFullyConnectedPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + // Match the pattern: fully_connected -> reshape -> transpose + auto reshapeOp = transposeOp.getInput().getDefiningOp(); + if (!reshapeOp || !reshapeOp->getResult(0).hasOneUse()) + return failure(); + + auto fullyConnectedOp = + reshapeOp.getInput().getDefiningOp(); + if (!fullyConnectedOp || !fullyConnectedOp->getResult(0).hasOneUse()) + return failure(); + + // Get types and shapes + auto fcInputType = + fullyConnectedOp.getInput().getType().dyn_cast(); + auto fcOutputType = + fullyConnectedOp.getResult(0).getType().dyn_cast(); + auto reshapeOutputType = + reshapeOp.getResult().getType().dyn_cast(); + auto transposeOutputType = + transposeOp.getResult().getType().dyn_cast(); + + if (!fcInputType || !fcOutputType || !reshapeOutputType || + !transposeOutputType) + return failure(); + + // Check if the batch dimension (assumed to be dimension 0) remains + // unchanged + auto fcOutputShape = fcOutputType.getShape(); + auto reshapeOutputShape = reshapeOutputType.getShape(); + + if (fcOutputShape.empty() || reshapeOutputShape.empty()) + return failure(); // Expecting non-scalar tensors + + if (fcOutputShape[0] != reshapeOutputShape[0]) + return failure(); // Batch dimension changed in reshape + + // Check if transpose does not affect the batch dimension + DenseIntElementsAttr permAttr; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) + return failure(); + + SmallVector permVec; + for (auto val : permAttr.getValues()) { + permVec.push_back(static_cast(val)); + } + + // Check if batch dimension remains at position 0 after transpose + if (permVec.empty() || permVec[0] != 0) + return failure(); + + // Placeholder for transforming the filter and bias + // TODO: Adjust the filter and bias to account for the reshape and transpose + Value newFilter = fullyConnectedOp.getFilter(); + Value bias = fullyConnectedOp.getBias(); + + // match bias m_Constant and get its value + RankedTensorType biasAttr; + if (!matchPattern(bias, m_Constant(&biasAttr))) + return failure(); + llvm::outs() << "Bias: " << biasAttr << "\n"; + + Value newBias = fullyConnectedOp.getBias(); + + // Create new fully connected op with adjusted filter and bias + auto newFullyConnectedOp = rewriter.create( + fullyConnectedOp.getLoc(), + fcOutputType, // Adjusted output type if necessary + fullyConnectedOp.getInput(), newFilter, newBias, + fullyConnectedOp.getFusedActivationFunctionAttr(), + fullyConnectedOp.getWeightsFormatAttr(), + fullyConnectedOp.getKeepNumDimsAttr(), + fullyConnectedOp.getAsymmetricQuantizeInputsAttr()); + + // Create new reshape op with the output type of the original transpose op + auto newReshapeOp = rewriter.create( + reshapeOp.getLoc(), transposeOutputType, + newFullyConnectedOp.getResult(0), reshapeOp.getShape()); + + // Replace the original transpose op with the new reshape op + rewriter.replaceOp(transposeOp, newReshapeOp.getResult()); + + return success(); + } +}; + struct MoveTransposeForwardOverUnaryOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -36,7 +126,8 @@ struct MoveTransposeForwardOverUnaryOpPattern Operation *userOp = *transposeOp->getUsers().begin(); - // Check if the user operation is a unary op that can commute with transpose + // Check if the user operation is a unary op that can commute with + // transpose if (!isa(userOp)) return failure(); @@ -62,15 +153,15 @@ struct MoveTransposeForwardOverUnaryOpPattern if (!originalUnaryOutputType) return failure(); - // Create a new output type for the unary op with the same shape as 'input' - // and the same element type as the original output type + // Create a new output type for the unary op with the same shape as + // 'input' and the same element type as the original output type auto newUnaryOutputType = RankedTensorType::get( transposeInputType.getShape(), originalUnaryOutputType.getElementType(), originalUnaryOutputType.getEncoding()); if (auto quantizeOp = dyn_cast(userOp)) { - // For QuantizeOp, create new QuantizeOp with input as 'input' and output - // type adjusted + // For QuantizeOp, create new QuantizeOp with input as 'input' and + // output type adjusted // Create new QuantizeOp with adjusted output type newUnaryOpResult = rewriter.create( @@ -351,8 +442,8 @@ struct FoldTransposeWCHToInput : public OpRewritePattern { if (blockArg.hasOneUse()) { auto funcOp = cast(blockArg.getOwner()->getParentOp()); - // Set function type to the transpose output type as we are changing the - // input + // Set function type to the transpose output type as we are changing + // the input FunctionType funcType = funcOp.getFunctionType(); llvm::SmallVector newInputTypes(funcType.getInputs().begin(), funcType.getInputs().end()); @@ -391,6 +482,7 @@ void OptimizeTranspose::runOnOperation() { patterns.insert(ctx); patterns.insert(ctx); + patterns.insert(ctx); if (allowInputModificationOption) { patterns.insert(ctx); } diff --git a/xformer/Utils/Util.cpp b/xformer/Utils/Util.cpp index ac4cddea6..a1e3ae34a 100644 --- a/xformer/Utils/Util.cpp +++ b/xformer/Utils/Util.cpp @@ -2,7 +2,6 @@ // XMOS Public License: Version 1 #include "Utils/Util.h" -#include #include "mlir/Dialect/Quant/QuantTypes.h" #include "llvm/ADT/ArrayRef.h" From 5956b84453a534c0185f389e7a85421b8c8e7ad6 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 24 Oct 2024 12:08:03 +0100 Subject: [PATCH 07/15] wip --- xformer/Transforms/OptimizeTranspose.cpp | 205 +++++++++++++++++++++-- 1 file changed, 192 insertions(+), 13 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index b5aff23c9..2a862dfc4 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -79,25 +79,204 @@ struct FoldTransposeIntoFullyConnectedPattern if (permVec.empty() || permVec[0] != 0) return failure(); - // Placeholder for transforming the filter and bias - // TODO: Adjust the filter and bias to account for the reshape and transpose - Value newFilter = fullyConnectedOp.getFilter(); + // Exclude the batch dimension from permVec + SmallVector permVecExclBatch(permVec.begin() + 1, + permVec.end()); + + // Prepare to transform the filter and bias + Value filter = fullyConnectedOp.getFilter(); Value bias = fullyConnectedOp.getBias(); - // match bias m_Constant and get its value - RankedTensorType biasAttr; - if (!matchPattern(bias, m_Constant(&biasAttr))) - return failure(); - llvm::outs() << "Bias: " << biasAttr << "\n"; + // Process bias + { + // Ensure bias is produced by a TFL::QConstOp + auto biasQConstOp = bias.getDefiningOp(); + if (!biasQConstOp) + return failure(); + + // Get bias type and shape + auto biasType = bias.getType().dyn_cast(); + if (!biasType) + return failure(); + auto biasShape = biasType.getShape(); + + // Get reshape output shape excluding batch dimension + SmallVector reshapeShapeExclBatch( + reshapeOutputShape.begin() + 1, reshapeOutputShape.end()); + + // Compute total number of elements + int64_t biasNumElements = biasType.getNumElements(); + int64_t reshapeNumElements = std::accumulate( + reshapeShapeExclBatch.begin(), reshapeShapeExclBatch.end(), 1, + std::multiplies()); + + if (biasNumElements != reshapeNumElements) + return failure(); + + // Reshape bias to match reshape output shape (excluding batch dimension) + auto newShapeType = RankedTensorType::get( + {static_cast(reshapeShapeExclBatch.size())}, + rewriter.getI32Type()); + auto newShapeAttr = DenseIntElementsAttr::get( + newShapeType, llvm::makeArrayRef(reshapeShapeExclBatch)); + auto newShapeOp = rewriter.create( + bias.getLoc(), newShapeType, newShapeAttr); + + auto reshapedBiasType = RankedTensorType::get(reshapeShapeExclBatch, + biasType.getElementType()); + auto reshapedBias = rewriter.create( + bias.getLoc(), reshapedBiasType, bias, newShapeOp); + + // Create perm vector excluding batch dimension + auto permType = + RankedTensorType::get({static_cast(permVecExclBatch.size())}, + rewriter.getI32Type()); + auto permAttr = DenseIntElementsAttr::get( + permType, llvm::makeArrayRef(permVecExclBatch)); + auto permOp = rewriter.create(transposeOp.getLoc(), + permType, permAttr); + + // Compute transposed shape + SmallVector transposedShape; + for (auto idx : permVecExclBatch) { + transposedShape.push_back(reshapeShapeExclBatch[idx]); + } + auto transposedBiasType = + RankedTensorType::get(transposedShape, biasType.getElementType()); + + // Transpose the reshaped bias + auto transposedBias = rewriter.create( + bias.getLoc(), transposedBiasType, reshapedBias, permOp); + + // Reshape back to original bias shape + auto origBiasShapeType = RankedTensorType::get( + {static_cast(biasShape.size())}, rewriter.getI32Type()); + auto origBiasShapeAttr = DenseIntElementsAttr::get( + origBiasShapeType, llvm::makeArrayRef(biasShape)); + auto origBiasShapeConst = rewriter.create( + bias.getLoc(), origBiasShapeType, origBiasShapeAttr); + + auto finalBias = rewriter.create( + bias.getLoc(), biasType, transposedBias, origBiasShapeConst); + + // Update bias + bias = finalBias.getResult(); + } - Value newBias = fullyConnectedOp.getBias(); + // Process filter + { + // Ensure filter is produced by a TFL::QConstOp + auto filterQConstOp = filter.getDefiningOp(); + if (!filterQConstOp) + return failure(); + + // Get filter type and shape + auto filterType = filter.getType().dyn_cast(); + if (!filterType) + return failure(); + auto filterShape = filterType.getShape(); + + // Treat columns (first axis) as batch + SmallVector filterShapeExclBatch(filterShape.begin() + 1, + filterShape.end()); + + // Get reshape output shape excluding batch dimension + SmallVector reshapeShapeExclBatch( + reshapeOutputShape.begin() + 1, reshapeOutputShape.end()); + + // Compute total number of elements excluding batch dimension + int64_t filterNumElements = std::accumulate(filterShapeExclBatch.begin(), + filterShapeExclBatch.end(), 1, + std::multiplies()); + int64_t reshapeNumElements = std::accumulate( + reshapeShapeExclBatch.begin(), reshapeShapeExclBatch.end(), 1, + std::multiplies()); + + if (filterNumElements != reshapeNumElements) + return failure(); + + // Reshape filter to match reshape output shape (excluding batch + // dimension) + auto newShapeType = RankedTensorType::get( + {static_cast(reshapeShapeExclBatch.size())}, + rewriter.getI32Type()); + auto newShapeAttr = DenseIntElementsAttr::get( + newShapeType, llvm::makeArrayRef(reshapeShapeExclBatch)); + auto newShapeOp = rewriter.create( + filter.getLoc(), newShapeType, newShapeAttr); + + auto reshapedFilterType = RankedTensorType::get( + reshapeShapeExclBatch, filterType.getElementType()); + auto reshapedFilter = rewriter.create( + filter.getLoc(), reshapedFilterType, filter, newShapeOp); + + // Create perm vector excluding batch dimension + auto permType = + RankedTensorType::get({static_cast(permVecExclBatch.size())}, + rewriter.getI32Type()); + auto permAttr = DenseIntElementsAttr::get( + permType, llvm::makeArrayRef(permVecExclBatch)); + auto permOp = rewriter.create(transposeOp.getLoc(), + permType, permAttr); + + // Compute transposed shape + SmallVector transposedShape; + for (auto idx : permVecExclBatch) { + transposedShape.push_back(reshapeShapeExclBatch[idx]); + } + auto transposedFilterType = + RankedTensorType::get(transposedShape, filterType.getElementType()); + + // Transpose the reshaped filter + auto transposedFilter = rewriter.create( + filter.getLoc(), transposedFilterType, reshapedFilter, permOp); + + // Reshape back to original filter shape + SmallVector origFilterShapeExclBatch(filterShape.begin() + 1, + filterShape.end()); + auto origFilterShapeType = RankedTensorType::get( + {static_cast(origFilterShapeExclBatch.size())}, + rewriter.getI32Type()); + auto origFilterShapeAttr = DenseIntElementsAttr::get( + origFilterShapeType, llvm::makeArrayRef(origFilterShapeExclBatch)); + auto origFilterShapeConst = rewriter.create( + filter.getLoc(), origFilterShapeType, origFilterShapeAttr); + + // Reshape back to original filter shape (excluding batch dimension) + auto finalFilterType = RankedTensorType::get(filterShapeExclBatch, + filterType.getElementType()); + auto finalFilter = rewriter.create( + filter.getLoc(), finalFilterType, transposedFilter, + origFilterShapeConst); + + // Prepend the batch dimension back to the filter shape + SmallVector newFilterShape; + newFilterShape.push_back(filterShape[0]); // Batch dimension + newFilterShape.append(filterShapeExclBatch.begin(), + filterShapeExclBatch.end()); + + auto newFilterType = + RankedTensorType::get(newFilterShape, filterType.getElementType()); + + // Create a reshape to get back to the original filter shape + auto finalFilterShapeType = RankedTensorType::get( + {static_cast(newFilterShape.size())}, rewriter.getI32Type()); + auto finalFilterShapeAttr = DenseIntElementsAttr::get( + finalFilterShapeType, llvm::makeArrayRef(newFilterShape)); + auto finalFilterShapeConst = rewriter.create( + filter.getLoc(), finalFilterShapeType, finalFilterShapeAttr); + + auto finalFilterReshaped = rewriter.create( + filter.getLoc(), filterType, finalFilter, finalFilterShapeConst); + + // Update filter + filter = finalFilterReshaped.getResult(); + } // Create new fully connected op with adjusted filter and bias auto newFullyConnectedOp = rewriter.create( - fullyConnectedOp.getLoc(), - fcOutputType, // Adjusted output type if necessary - fullyConnectedOp.getInput(), newFilter, newBias, - fullyConnectedOp.getFusedActivationFunctionAttr(), + fullyConnectedOp.getLoc(), fcOutputType, fullyConnectedOp.getInput(), + filter, bias, fullyConnectedOp.getFusedActivationFunctionAttr(), fullyConnectedOp.getWeightsFormatAttr(), fullyConnectedOp.getKeepNumDimsAttr(), fullyConnectedOp.getAsymmetricQuantizeInputsAttr()); From c2e6e5c07cc9513207246bd1921b2f815831f2be Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 24 Oct 2024 17:08:18 +0100 Subject: [PATCH 08/15] pass works but no canonicalisation and segfault --- xformer/Transforms/OptimizeTranspose.cpp | 195 +++++------------------ xformer/Utils/Util.cpp | 78 +++++++++ xformer/Utils/Util.h | 14 ++ 3 files changed, 129 insertions(+), 158 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 2a862dfc4..07b594033 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -3,6 +3,7 @@ #include "Transforms/Options.h" +#include "Utils/Util.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -58,6 +59,12 @@ struct FoldTransposeIntoFullyConnectedPattern // unchanged auto fcOutputShape = fcOutputType.getShape(); auto reshapeOutputShape = reshapeOutputType.getShape(); + SmallVector reshapeOutputShapeVec(reshapeOutputShape.begin(), + reshapeOutputShape.end()); + + if (reshapeOutputShape[0] != 1) { + return failure(); + } if (fcOutputShape.empty() || reshapeOutputShape.empty()) return failure(); // Expecting non-scalar tensors @@ -79,10 +86,6 @@ struct FoldTransposeIntoFullyConnectedPattern if (permVec.empty() || permVec[0] != 0) return failure(); - // Exclude the batch dimension from permVec - SmallVector permVecExclBatch(permVec.begin() + 1, - permVec.end()); - // Prepare to transform the filter and bias Value filter = fullyConnectedOp.getFilter(); Value bias = fullyConnectedOp.getBias(); @@ -100,67 +103,15 @@ struct FoldTransposeIntoFullyConnectedPattern return failure(); auto biasShape = biasType.getShape(); - // Get reshape output shape excluding batch dimension - SmallVector reshapeShapeExclBatch( - reshapeOutputShape.begin() + 1, reshapeOutputShape.end()); - - // Compute total number of elements - int64_t biasNumElements = biasType.getNumElements(); - int64_t reshapeNumElements = std::accumulate( - reshapeShapeExclBatch.begin(), reshapeShapeExclBatch.end(), 1, - std::multiplies()); - - if (biasNumElements != reshapeNumElements) + SmallVector biasShapeVec(biasShape.begin(), biasShape.end()); + Value finalBias; + if (failed(utils::reshapeTransposeReshape(rewriter, bias, + reshapeOutputShapeVec, permVec, + biasShapeVec, finalBias))) return failure(); - // Reshape bias to match reshape output shape (excluding batch dimension) - auto newShapeType = RankedTensorType::get( - {static_cast(reshapeShapeExclBatch.size())}, - rewriter.getI32Type()); - auto newShapeAttr = DenseIntElementsAttr::get( - newShapeType, llvm::makeArrayRef(reshapeShapeExclBatch)); - auto newShapeOp = rewriter.create( - bias.getLoc(), newShapeType, newShapeAttr); - - auto reshapedBiasType = RankedTensorType::get(reshapeShapeExclBatch, - biasType.getElementType()); - auto reshapedBias = rewriter.create( - bias.getLoc(), reshapedBiasType, bias, newShapeOp); - - // Create perm vector excluding batch dimension - auto permType = - RankedTensorType::get({static_cast(permVecExclBatch.size())}, - rewriter.getI32Type()); - auto permAttr = DenseIntElementsAttr::get( - permType, llvm::makeArrayRef(permVecExclBatch)); - auto permOp = rewriter.create(transposeOp.getLoc(), - permType, permAttr); - - // Compute transposed shape - SmallVector transposedShape; - for (auto idx : permVecExclBatch) { - transposedShape.push_back(reshapeShapeExclBatch[idx]); - } - auto transposedBiasType = - RankedTensorType::get(transposedShape, biasType.getElementType()); - - // Transpose the reshaped bias - auto transposedBias = rewriter.create( - bias.getLoc(), transposedBiasType, reshapedBias, permOp); - - // Reshape back to original bias shape - auto origBiasShapeType = RankedTensorType::get( - {static_cast(biasShape.size())}, rewriter.getI32Type()); - auto origBiasShapeAttr = DenseIntElementsAttr::get( - origBiasShapeType, llvm::makeArrayRef(biasShape)); - auto origBiasShapeConst = rewriter.create( - bias.getLoc(), origBiasShapeType, origBiasShapeAttr); - - auto finalBias = rewriter.create( - bias.getLoc(), biasType, transposedBias, origBiasShapeConst); - // Update bias - bias = finalBias.getResult(); + bias = finalBias; } // Process filter @@ -175,102 +126,22 @@ struct FoldTransposeIntoFullyConnectedPattern if (!filterType) return failure(); auto filterShape = filterType.getShape(); - - // Treat columns (first axis) as batch - SmallVector filterShapeExclBatch(filterShape.begin() + 1, - filterShape.end()); - - // Get reshape output shape excluding batch dimension - SmallVector reshapeShapeExclBatch( - reshapeOutputShape.begin() + 1, reshapeOutputShape.end()); - - // Compute total number of elements excluding batch dimension - int64_t filterNumElements = std::accumulate(filterShapeExclBatch.begin(), - filterShapeExclBatch.end(), 1, - std::multiplies()); - int64_t reshapeNumElements = std::accumulate( - reshapeShapeExclBatch.begin(), reshapeShapeExclBatch.end(), 1, - std::multiplies()); - - if (filterNumElements != reshapeNumElements) + SmallVector filterShapeVec(filterShape.begin(), + filterShape.end()); + + // same as the shape of the reshape, except for first dimension which + // should be the first dimension of the filterShape + SmallVector filterOutShapeVec = {filterShape[1]}; + filterOutShapeVec.insert(filterOutShapeVec.end(), + reshapeOutputShapeVec.begin() + 1, + reshapeOutputShapeVec.end()); + + Value finalFilter; + if (failed(utils::reshapeTransposeReshape(rewriter, filter, + filterOutShapeVec, permVec, + filterShapeVec, finalFilter))) return failure(); - - // Reshape filter to match reshape output shape (excluding batch - // dimension) - auto newShapeType = RankedTensorType::get( - {static_cast(reshapeShapeExclBatch.size())}, - rewriter.getI32Type()); - auto newShapeAttr = DenseIntElementsAttr::get( - newShapeType, llvm::makeArrayRef(reshapeShapeExclBatch)); - auto newShapeOp = rewriter.create( - filter.getLoc(), newShapeType, newShapeAttr); - - auto reshapedFilterType = RankedTensorType::get( - reshapeShapeExclBatch, filterType.getElementType()); - auto reshapedFilter = rewriter.create( - filter.getLoc(), reshapedFilterType, filter, newShapeOp); - - // Create perm vector excluding batch dimension - auto permType = - RankedTensorType::get({static_cast(permVecExclBatch.size())}, - rewriter.getI32Type()); - auto permAttr = DenseIntElementsAttr::get( - permType, llvm::makeArrayRef(permVecExclBatch)); - auto permOp = rewriter.create(transposeOp.getLoc(), - permType, permAttr); - - // Compute transposed shape - SmallVector transposedShape; - for (auto idx : permVecExclBatch) { - transposedShape.push_back(reshapeShapeExclBatch[idx]); - } - auto transposedFilterType = - RankedTensorType::get(transposedShape, filterType.getElementType()); - - // Transpose the reshaped filter - auto transposedFilter = rewriter.create( - filter.getLoc(), transposedFilterType, reshapedFilter, permOp); - - // Reshape back to original filter shape - SmallVector origFilterShapeExclBatch(filterShape.begin() + 1, - filterShape.end()); - auto origFilterShapeType = RankedTensorType::get( - {static_cast(origFilterShapeExclBatch.size())}, - rewriter.getI32Type()); - auto origFilterShapeAttr = DenseIntElementsAttr::get( - origFilterShapeType, llvm::makeArrayRef(origFilterShapeExclBatch)); - auto origFilterShapeConst = rewriter.create( - filter.getLoc(), origFilterShapeType, origFilterShapeAttr); - - // Reshape back to original filter shape (excluding batch dimension) - auto finalFilterType = RankedTensorType::get(filterShapeExclBatch, - filterType.getElementType()); - auto finalFilter = rewriter.create( - filter.getLoc(), finalFilterType, transposedFilter, - origFilterShapeConst); - - // Prepend the batch dimension back to the filter shape - SmallVector newFilterShape; - newFilterShape.push_back(filterShape[0]); // Batch dimension - newFilterShape.append(filterShapeExclBatch.begin(), - filterShapeExclBatch.end()); - - auto newFilterType = - RankedTensorType::get(newFilterShape, filterType.getElementType()); - - // Create a reshape to get back to the original filter shape - auto finalFilterShapeType = RankedTensorType::get( - {static_cast(newFilterShape.size())}, rewriter.getI32Type()); - auto finalFilterShapeAttr = DenseIntElementsAttr::get( - finalFilterShapeType, llvm::makeArrayRef(newFilterShape)); - auto finalFilterShapeConst = rewriter.create( - filter.getLoc(), finalFilterShapeType, finalFilterShapeAttr); - - auto finalFilterReshaped = rewriter.create( - filter.getLoc(), filterType, finalFilter, finalFilterShapeConst); - - // Update filter - filter = finalFilterReshaped.getResult(); + filter = finalFilter; } // Create new fully connected op with adjusted filter and bias @@ -281,10 +152,18 @@ struct FoldTransposeIntoFullyConnectedPattern fullyConnectedOp.getKeepNumDimsAttr(), fullyConnectedOp.getAsymmetricQuantizeInputsAttr()); + // create new shape from the shape of the original transpose op + auto originalShape = + transposeOp.getResult().getType().cast().getShape(); + SmallVector originalShapeVec(originalShape.begin(), + originalShape.end()); + Value newShapeConstOp = utils::createShapeConstOp( + rewriter, transposeOp.getLoc(), originalShapeVec); + // Create new reshape op with the output type of the original transpose op auto newReshapeOp = rewriter.create( reshapeOp.getLoc(), transposeOutputType, - newFullyConnectedOp.getResult(0), reshapeOp.getShape()); + newFullyConnectedOp.getResult(0), newShapeConstOp); // Replace the original transpose op with the new reshape op rewriter.replaceOp(transposeOp, newReshapeOp.getResult()); diff --git a/xformer/Utils/Util.cpp b/xformer/Utils/Util.cpp index a1e3ae34a..39981a761 100644 --- a/xformer/Utils/Util.cpp +++ b/xformer/Utils/Util.cpp @@ -138,4 +138,82 @@ int mergeAxes(std::vector &begin, std::vector &size, return rank; } +// Converts int64_t vector to int32_t vector, returns failure if any value is +// out of int32_t range. +LogicalResult convertToI32Array(const SmallVectorImpl &input, + SmallVectorImpl &output) { + for (auto val : input) { + if (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) + return failure(); + output.push_back(static_cast(val)); + } + return success(); +} + +// Creates a constant op for a shape vector. +Value createShapeConstOp(PatternRewriter &rewriter, Location loc, + const SmallVectorImpl &shapeVec) { + SmallVector shapeVecI32; + if (failed(convertToI32Array(shapeVec, shapeVecI32))) + return nullptr; + auto shapeType = RankedTensorType::get( + {static_cast(shapeVecI32.size())}, rewriter.getI32Type()); + auto shapeAttr = DenseIntElementsAttr::get(shapeType, shapeVecI32); + return rewriter.create(loc, shapeType, shapeAttr); +} + +// Helper function for reshape-transpose-reshape pattern. +LogicalResult +reshapeTransposeReshape(PatternRewriter &rewriter, Value tensor, + const SmallVectorImpl &reshapeShape, + const SmallVectorImpl &permVec, + const SmallVectorImpl &origShape, + Value &result) { + auto loc = tensor.getLoc(); + auto tensorType = tensor.getType().cast(); + auto elementType = tensorType.getElementType(); + + // Reshape tensor to reshapeShapeExclBatch. + Value newShapeOp = createShapeConstOp(rewriter, loc, reshapeShape); + if (!newShapeOp) + return failure(); + auto reshapedType = RankedTensorType::get(reshapeShape, elementType); + auto reshapedTensor = + rewriter.create(loc, reshapedType, tensor, newShapeOp); + + // Convert permVecExclBatch to int32_t vector. + SmallVector permVecI32; + if (failed(convertToI32Array(permVec, permVecI32))) + return failure(); + + // Create perm op. + auto permType = RankedTensorType::get( + {static_cast(permVecI32.size())}, rewriter.getI32Type()); + auto permAttr = DenseIntElementsAttr::get(permType, permVecI32); + auto permOp = rewriter.create(loc, permType, permAttr); + + // Compute transposed shape. + SmallVector transposedShape; + for (auto idx : permVec) { + if (idx < 0 || idx >= reshapeShape.size()) + return failure(); + transposedShape.push_back(reshapeShape[idx]); + } + auto transposedType = RankedTensorType::get(transposedShape, elementType); + + // Transpose. + auto transposedTensor = rewriter.create( + loc, transposedType, reshapedTensor, permOp); + + // Reshape back to original shape. + Value origShapeOp = createShapeConstOp(rewriter, loc, origShape); + if (!origShapeOp) + return failure(); + auto finalTensor = rewriter.create( + loc, tensorType, transposedTensor, origShapeOp); + + result = finalTensor.getResult(); + return success(); +} } // namespace mlir::xcore::utils diff --git a/xformer/Utils/Util.h b/xformer/Utils/Util.h index 7e29f1f34..d2296616c 100644 --- a/xformer/Utils/Util.h +++ b/xformer/Utils/Util.h @@ -7,6 +7,7 @@ #include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir::xcore::utils { @@ -69,6 +70,19 @@ template bool checkBinaryCompatibility(T op) { int mergeAxes(std::vector &begin, std::vector &size, std::vector &inShape, std::vector &outShape, int rank); + +LogicalResult convertToI32Array(const SmallVectorImpl &input, + SmallVectorImpl &output); + +Value createShapeConstOp(PatternRewriter &rewriter, Location loc, + const SmallVectorImpl &shapeVec); + +LogicalResult +reshapeTransposeReshape(PatternRewriter &rewriter, Value tensor, + const SmallVectorImpl &reshapeShape, + const SmallVectorImpl &permVec, + const SmallVectorImpl &origShape, + Value &result); } // namespace mlir::xcore::utils #endif // XFORMER_UTILS_UTIL_H From b88f814d2a6ef6a564c60b80a2ff0157c920fc93 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 24 Oct 2024 18:27:10 +0100 Subject: [PATCH 09/15] transpose -> reshape -> fc pass --- xformer/Transforms/OptimizeTranspose.cpp | 139 ++++++++++++++++++++++- 1 file changed, 136 insertions(+), 3 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 07b594033..b301e47b1 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -25,8 +25,140 @@ struct OptimizeTranspose void runOnOperation() override; }; -struct FoldTransposeIntoFullyConnectedPattern - : public OpRewritePattern { +struct FoldTrReFCPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::FullyConnectedOp fcOp, + PatternRewriter &rewriter) const override { + // Match the pattern: transpose -> reshape -> fully_connected + + // Check that the input to fully_connected is a reshape + auto reshapeOp = fcOp.getInput().getDefiningOp(); + if (!reshapeOp || !reshapeOp->hasOneUse()) + return failure(); + + // Check that the input to reshape is a transpose + auto transposeOp = reshapeOp.getInput().getDefiningOp(); + if (!transposeOp || !transposeOp->hasOneUse()) + return failure(); + + // Ensure that the fully_connected op has a single use + if (!fcOp->hasOneUse()) + return failure(); + + // Get permutation attribute from transpose + DenseIntElementsAttr permAttr; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) + return failure(); + + SmallVector permVec; + for (auto val : permAttr.getValues()) + permVec.push_back(static_cast(val)); + + // Get shapes + auto reshapeInputType = + reshapeOp.getInput().getType().cast(); + auto reshapeOutputType = + reshapeOp.getResult().getType().cast(); + if (!reshapeInputType || !reshapeOutputType) + return failure(); + + SmallVector reshapeInputShapeVec( + reshapeInputType.getShape().begin(), reshapeInputType.getShape().end()); + SmallVector reshapeOutputShapeVec( + reshapeOutputType.getShape().begin(), + reshapeOutputType.getShape().end()); + + // Transpose the filter weights + Value filter = fcOp.getFilter(); + Value bias = fcOp.getBias(); + + { + // Ensure filter is produced by a TFL::QConstOp + auto filterQConstOp = filter.getDefiningOp(); + if (!filterQConstOp) { + return failure(); + } + // Get filter type and shape + auto filterType = filter.getType().dyn_cast(); + if (!filterType) { + + return failure(); + } + auto filterShape = filterType.getShape(); + SmallVector filterShapeVec(filterShape.begin(), + filterShape.end()); + + // Compute the reshape shape for the filter + // The filter is typically of shape [output_size, input_size] + int64_t outputSize = filterShapeVec[0]; + int64_t inputSize = filterShapeVec[1]; + + // Reshape the filter to [output_size] + reshapeInputShapeVec + SmallVector reshapedFilterShapeVec( + reshapeInputShapeVec.begin() + 1, reshapeInputShapeVec.end()); + reshapedFilterShapeVec.push_back(outputSize); + + // Prepare permutation vector for the filter + // Since the data is transposed before reshape, we need to adjust the + // filter accordingly. The first dimension (output_size) remains + // unchanged. + // + SmallVector filterPermVec; + for (size_t i = 1; i < permVec.size(); i++) { + filterPermVec.push_back(permVec[i] - 1); + } + filterPermVec.push_back(permVec.size() - 1); + + // Original filter shape vector + SmallVector origFilterShapeVec(filterShapeVec.begin(), + filterShapeVec.end()); + + Value finalFilter; + if (failed(utils::reshapeTransposeReshape( + rewriter, filter, reshapedFilterShapeVec, filterPermVec, + origFilterShapeVec, finalFilter))) { + + return failure(); + } + + filter = finalFilter; + } + + // Bias remains the same; no need to adjust it + + // Create new reshape op (from transpose's input to reshape's output) + Value newInput = transposeOp.getInput(); + + // Create new shape const op for reshape + Value newShapeConstOp = utils::createShapeConstOp( + rewriter, reshapeOp.getLoc(), reshapeOutputShapeVec); + + auto newReshapeOp = rewriter.create( + reshapeOp.getLoc(), reshapeOp.getResult().getType(), newInput, + newShapeConstOp); + + // Create new fully_connected op with adjusted filter + auto newFullyConnectedOp = rewriter.create( + fcOp.getLoc(), fcOp.getResult(0).getType(), newReshapeOp.getResult(), + filter, bias, fcOp.getFusedActivationFunctionAttr(), + fcOp.getWeightsFormatAttr(), fcOp.getKeepNumDimsAttr(), + fcOp.getAsymmetricQuantizeInputsAttr()); + + // Replace the original fully_connected op with the new one + rewriter.replaceOp(fcOp, newFullyConnectedOp.getResults()); + + // Erase the old reshape and transpose ops if they are no longer used + if (reshapeOp.use_empty()) + rewriter.eraseOp(reshapeOp); + if (transposeOp.use_empty()) + rewriter.eraseOp(transposeOp); + + return success(); + } +}; + +struct FoldFCReTrPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, @@ -540,7 +672,8 @@ void OptimizeTranspose::runOnOperation() { patterns.insert(ctx); patterns.insert(ctx); - patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); if (allowInputModificationOption) { patterns.insert(ctx); } From 6a580b8710d8b2c4cb399dc6666d10d8bbb3e697 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 28 Oct 2024 08:05:12 +0000 Subject: [PATCH 10/15] fix TrReFC --- xformer/Transforms/OptimizeTranspose.cpp | 45 +++++++++--------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index b301e47b1..e2cdb08cf 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -55,20 +55,18 @@ struct FoldTrReFCPattern : public OpRewritePattern { for (auto val : permAttr.getValues()) permVec.push_back(static_cast(val)); + // Compute inverse permutation vector + SmallVector invPermVec(permVec.size()); + for (size_t i = 0; i < permVec.size(); ++i) { + invPermVec[permVec[i]] = i; + } + // Get shapes auto reshapeInputType = reshapeOp.getInput().getType().cast(); - auto reshapeOutputType = - reshapeOp.getResult().getType().cast(); - if (!reshapeInputType || !reshapeOutputType) + if (!reshapeInputType) return failure(); - SmallVector reshapeInputShapeVec( - reshapeInputType.getShape().begin(), reshapeInputType.getShape().end()); - SmallVector reshapeOutputShapeVec( - reshapeOutputType.getShape().begin(), - reshapeOutputType.getShape().end()); - // Transpose the filter weights Value filter = fcOp.getFilter(); Value bias = fcOp.getBias(); @@ -82,7 +80,6 @@ struct FoldTrReFCPattern : public OpRewritePattern { // Get filter type and shape auto filterType = filter.getType().dyn_cast(); if (!filterType) { - return failure(); } auto filterShape = filterType.getShape(); @@ -90,25 +87,21 @@ struct FoldTrReFCPattern : public OpRewritePattern { filterShape.end()); // Compute the reshape shape for the filter - // The filter is typically of shape [output_size, input_size] int64_t outputSize = filterShapeVec[0]; - int64_t inputSize = filterShapeVec[1]; - // Reshape the filter to [output_size] + reshapeInputShapeVec + // Reshape the filter to reshapeInputShapeVec[1:] + [outputSize] SmallVector reshapedFilterShapeVec( - reshapeInputShapeVec.begin() + 1, reshapeInputShapeVec.end()); + reshapeInputType.getShape().begin() + 1, + reshapeInputType.getShape().end()); reshapedFilterShapeVec.push_back(outputSize); - // Prepare permutation vector for the filter - // Since the data is transposed before reshape, we need to adjust the - // filter accordingly. The first dimension (output_size) remains - // unchanged. - // + // Prepare inverse permutation vector for the filter + // Exclude the batch dimension SmallVector filterPermVec; - for (size_t i = 1; i < permVec.size(); i++) { - filterPermVec.push_back(permVec[i] - 1); + for (size_t i = 1; i < invPermVec.size(); ++i) { + filterPermVec.push_back(invPermVec[i] - 1); } - filterPermVec.push_back(permVec.size() - 1); + filterPermVec.push_back(invPermVec.size() - 1); // Original filter shape vector SmallVector origFilterShapeVec(filterShapeVec.begin(), @@ -118,7 +111,7 @@ struct FoldTrReFCPattern : public OpRewritePattern { if (failed(utils::reshapeTransposeReshape( rewriter, filter, reshapedFilterShapeVec, filterPermVec, origFilterShapeVec, finalFilter))) { - + llvm::outs() << "Failed to reshape filter\n"; return failure(); } @@ -130,13 +123,9 @@ struct FoldTrReFCPattern : public OpRewritePattern { // Create new reshape op (from transpose's input to reshape's output) Value newInput = transposeOp.getInput(); - // Create new shape const op for reshape - Value newShapeConstOp = utils::createShapeConstOp( - rewriter, reshapeOp.getLoc(), reshapeOutputShapeVec); - auto newReshapeOp = rewriter.create( reshapeOp.getLoc(), reshapeOp.getResult().getType(), newInput, - newShapeConstOp); + reshapeOp.getShape()); // Create new fully_connected op with adjusted filter auto newFullyConnectedOp = rewriter.create( From 54e3a113dc0d14af4d92f21638df9fe1bace9b96 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 28 Oct 2024 08:48:11 +0000 Subject: [PATCH 11/15] fix the other pass --- xformer/Transforms/OptimizeTranspose.cpp | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index e2cdb08cf..58551aba9 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -180,8 +180,9 @@ struct FoldFCReTrPattern : public OpRewritePattern { // unchanged auto fcOutputShape = fcOutputType.getShape(); auto reshapeOutputShape = reshapeOutputType.getShape(); - SmallVector reshapeOutputShapeVec(reshapeOutputShape.begin(), - reshapeOutputShape.end()); + auto transposeOutputShape = transposeOutputType.getShape(); + SmallVector transposeOutputShapeVec( + transposeOutputShape.begin(), transposeOutputShape.end()); if (reshapeOutputShape[0] != 1) { return failure(); @@ -198,13 +199,14 @@ struct FoldFCReTrPattern : public OpRewritePattern { if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) return failure(); - SmallVector permVec; - for (auto val : permAttr.getValues()) { - permVec.push_back(static_cast(val)); + SmallVector invPermVec; + for (int32_t val : permAttr.getValues()) { + invPermVec.push_back( + static_cast(permAttr.getValues()[val])); } // Check if batch dimension remains at position 0 after transpose - if (permVec.empty() || permVec[0] != 0) + if (invPermVec.empty() || invPermVec[0] != 0) return failure(); // Prepare to transform the filter and bias @@ -226,9 +228,9 @@ struct FoldFCReTrPattern : public OpRewritePattern { SmallVector biasShapeVec(biasShape.begin(), biasShape.end()); Value finalBias; - if (failed(utils::reshapeTransposeReshape(rewriter, bias, - reshapeOutputShapeVec, permVec, - biasShapeVec, finalBias))) + if (failed(utils::reshapeTransposeReshape( + rewriter, bias, transposeOutputShapeVec, invPermVec, biasShapeVec, + finalBias))) return failure(); // Update bias @@ -254,12 +256,12 @@ struct FoldFCReTrPattern : public OpRewritePattern { // should be the first dimension of the filterShape SmallVector filterOutShapeVec = {filterShape[1]}; filterOutShapeVec.insert(filterOutShapeVec.end(), - reshapeOutputShapeVec.begin() + 1, - reshapeOutputShapeVec.end()); + transposeOutputShapeVec.begin() + 1, + transposeOutputShapeVec.end()); Value finalFilter; if (failed(utils::reshapeTransposeReshape(rewriter, filter, - filterOutShapeVec, permVec, + filterOutShapeVec, invPermVec, filterShapeVec, finalFilter))) return failure(); filter = finalFilter; From 76035de1c04b039b4633e5e59cf33b3b71003cfa Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 28 Oct 2024 09:19:37 +0000 Subject: [PATCH 12/15] cleanup --- xformer/Transforms/OptimizeTranspose.cpp | 104 +++++++++-------------- xformer/Utils/Util.cpp | 1 + xformer/Utils/Util.h | 9 ++ 3 files changed, 51 insertions(+), 63 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 58551aba9..3907e7880 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -25,6 +25,15 @@ struct OptimizeTranspose void runOnOperation() override; }; +static SmallVector +computeInversePermutation(const SmallVector &permVec) { + SmallVector invPermVec(permVec.size()); + for (size_t i = 0; i < permVec.size(); ++i) { + invPermVec[permVec[i]] = i; + } + return invPermVec; +} + struct FoldTrReFCPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -51,19 +60,14 @@ struct FoldTrReFCPattern : public OpRewritePattern { if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) return failure(); - SmallVector permVec; - for (auto val : permAttr.getValues()) - permVec.push_back(static_cast(val)); + SmallVector permVec = utils::denseToVector(permAttr); // Compute inverse permutation vector - SmallVector invPermVec(permVec.size()); - for (size_t i = 0; i < permVec.size(); ++i) { - invPermVec[permVec[i]] = i; - } + SmallVector invPermVec = computeInversePermutation(permVec); - // Get shapes + // Get input shape auto reshapeInputType = - reshapeOp.getInput().getType().cast(); + reshapeOp.getInput().getType().dyn_cast(); if (!reshapeInputType) return failure(); @@ -72,45 +76,34 @@ struct FoldTrReFCPattern : public OpRewritePattern { Value bias = fcOp.getBias(); { - // Ensure filter is produced by a TFL::QConstOp - auto filterQConstOp = filter.getDefiningOp(); - if (!filterQConstOp) { - return failure(); - } // Get filter type and shape auto filterType = filter.getType().dyn_cast(); - if (!filterType) { + if (!filterType) return failure(); - } + auto filterShape = filterType.getShape(); SmallVector filterShapeVec(filterShape.begin(), filterShape.end()); - // Compute the reshape shape for the filter int64_t outputSize = filterShapeVec[0]; - // Reshape the filter to reshapeInputShapeVec[1:] + [outputSize] + // Reshape the filter to reshapeInputShape[1:] + [outputSize] + auto reshapeInputShape = reshapeInputType.getShape(); SmallVector reshapedFilterShapeVec( - reshapeInputType.getShape().begin() + 1, - reshapeInputType.getShape().end()); + reshapeInputShape.begin() + 1, reshapeInputShape.end()); reshapedFilterShapeVec.push_back(outputSize); - // Prepare inverse permutation vector for the filter - // Exclude the batch dimension + // Prepare permutation vector for the filter (excluding batch dimension) SmallVector filterPermVec; for (size_t i = 1; i < invPermVec.size(); ++i) { filterPermVec.push_back(invPermVec[i] - 1); } filterPermVec.push_back(invPermVec.size() - 1); - // Original filter shape vector - SmallVector origFilterShapeVec(filterShapeVec.begin(), - filterShapeVec.end()); - Value finalFilter; if (failed(utils::reshapeTransposeReshape( rewriter, filter, reshapedFilterShapeVec, filterPermVec, - origFilterShapeVec, finalFilter))) { + filterShapeVec, finalFilter))) { llvm::outs() << "Failed to reshape filter\n"; return failure(); } @@ -154,12 +147,12 @@ struct FoldFCReTrPattern : public OpRewritePattern { PatternRewriter &rewriter) const override { // Match the pattern: fully_connected -> reshape -> transpose auto reshapeOp = transposeOp.getInput().getDefiningOp(); - if (!reshapeOp || !reshapeOp->getResult(0).hasOneUse()) + if (!reshapeOp || !reshapeOp->hasOneUse()) return failure(); auto fullyConnectedOp = reshapeOp.getInput().getDefiningOp(); - if (!fullyConnectedOp || !fullyConnectedOp->getResult(0).hasOneUse()) + if (!fullyConnectedOp || !fullyConnectedOp->hasOneUse()) return failure(); // Get types and shapes @@ -181,12 +174,6 @@ struct FoldFCReTrPattern : public OpRewritePattern { auto fcOutputShape = fcOutputType.getShape(); auto reshapeOutputShape = reshapeOutputType.getShape(); auto transposeOutputShape = transposeOutputType.getShape(); - SmallVector transposeOutputShapeVec( - transposeOutputShape.begin(), transposeOutputShape.end()); - - if (reshapeOutputShape[0] != 1) { - return failure(); - } if (fcOutputShape.empty() || reshapeOutputShape.empty()) return failure(); // Expecting non-scalar tensors @@ -194,16 +181,15 @@ struct FoldFCReTrPattern : public OpRewritePattern { if (fcOutputShape[0] != reshapeOutputShape[0]) return failure(); // Batch dimension changed in reshape - // Check if transpose does not affect the batch dimension + // Get permutation attribute from transpose DenseIntElementsAttr permAttr; if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) return failure(); - SmallVector invPermVec; - for (int32_t val : permAttr.getValues()) { - invPermVec.push_back( - static_cast(permAttr.getValues()[val])); - } + SmallVector permVec = utils::denseToVector(permAttr); + + // Compute inverse permutation vector + SmallVector invPermVec = computeInversePermutation(permVec); // Check if batch dimension remains at position 0 after transpose if (invPermVec.empty() || invPermVec[0] != 0) @@ -215,55 +201,49 @@ struct FoldFCReTrPattern : public OpRewritePattern { // Process bias { - // Ensure bias is produced by a TFL::QConstOp - auto biasQConstOp = bias.getDefiningOp(); - if (!biasQConstOp) - return failure(); - // Get bias type and shape auto biasType = bias.getType().dyn_cast(); if (!biasType) return failure(); auto biasShape = biasType.getShape(); - SmallVector biasShapeVec(biasShape.begin(), biasShape.end()); + + // Use transpose output shape as reshape shape for bias + SmallVector transposeOutputShapeVec( + transposeOutputShape.begin(), transposeOutputShape.end()); + Value finalBias; if (failed(utils::reshapeTransposeReshape( rewriter, bias, transposeOutputShapeVec, invPermVec, biasShapeVec, finalBias))) return failure(); - // Update bias bias = finalBias; } // Process filter { - // Ensure filter is produced by a TFL::QConstOp - auto filterQConstOp = filter.getDefiningOp(); - if (!filterQConstOp) - return failure(); - // Get filter type and shape auto filterType = filter.getType().dyn_cast(); if (!filterType) return failure(); + auto filterShape = filterType.getShape(); SmallVector filterShapeVec(filterShape.begin(), filterShape.end()); - // same as the shape of the reshape, except for first dimension which - // should be the first dimension of the filterShape + // Compute new filter reshape shape SmallVector filterOutShapeVec = {filterShape[1]}; filterOutShapeVec.insert(filterOutShapeVec.end(), - transposeOutputShapeVec.begin() + 1, - transposeOutputShapeVec.end()); + transposeOutputShape.begin() + 1, + transposeOutputShape.end()); Value finalFilter; if (failed(utils::reshapeTransposeReshape(rewriter, filter, filterOutShapeVec, invPermVec, filterShapeVec, finalFilter))) return failure(); + filter = finalFilter; } @@ -275,15 +255,13 @@ struct FoldFCReTrPattern : public OpRewritePattern { fullyConnectedOp.getKeepNumDimsAttr(), fullyConnectedOp.getAsymmetricQuantizeInputsAttr()); - // create new shape from the shape of the original transpose op - auto originalShape = - transposeOp.getResult().getType().cast().getShape(); - SmallVector originalShapeVec(originalShape.begin(), - originalShape.end()); + // Create new reshape op with the output type of the original transpose op + SmallVector originalShapeVec(transposeOutputShape.begin(), + transposeOutputShape.end()); Value newShapeConstOp = utils::createShapeConstOp( rewriter, transposeOp.getLoc(), originalShapeVec); - // Create new reshape op with the output type of the original transpose op + // Create new reshape op auto newReshapeOp = rewriter.create( reshapeOp.getLoc(), transposeOutputType, newFullyConnectedOp.getResult(0), newShapeConstOp); diff --git a/xformer/Utils/Util.cpp b/xformer/Utils/Util.cpp index 39981a761..874c638bf 100644 --- a/xformer/Utils/Util.cpp +++ b/xformer/Utils/Util.cpp @@ -216,4 +216,5 @@ reshapeTransposeReshape(PatternRewriter &rewriter, Value tensor, result = finalTensor.getResult(); return success(); } + } // namespace mlir::xcore::utils diff --git a/xformer/Utils/Util.h b/xformer/Utils/Util.h index d2296616c..97de0b725 100644 --- a/xformer/Utils/Util.h +++ b/xformer/Utils/Util.h @@ -8,6 +8,7 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include namespace mlir::xcore::utils { @@ -83,6 +84,14 @@ reshapeTransposeReshape(PatternRewriter &rewriter, Value tensor, const SmallVectorImpl &permVec, const SmallVectorImpl &origShape, Value &result); + +template +static SmallVector denseToVector(DenseIntElementsAttr permAttr) { + SmallVector permVec; + for (auto val : permAttr.getValues()) + permVec.push_back(static_cast(val)); + return permVec; +} } // namespace mlir::xcore::utils #endif // XFORMER_UTILS_UTIL_H From b368cee96040b25ffda11016fff552d8338b3013 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Wed, 6 Nov 2024 13:31:50 +0000 Subject: [PATCH 13/15] Add canonicalization and disable transpose transformation --- xformer/Transforms/OptimizeConv2D.cpp | 3 +++ xformer/Transforms/OptimizeTranspose.cpp | 5 +++-- xformer/Transforms/Passes.cpp | 2 ++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/xformer/Transforms/OptimizeConv2D.cpp b/xformer/Transforms/OptimizeConv2D.cpp index 62d637106..4a5dfcad7 100644 --- a/xformer/Transforms/OptimizeConv2D.cpp +++ b/xformer/Transforms/OptimizeConv2D.cpp @@ -86,6 +86,9 @@ struct ChannelwiseSplitConv2DOutputPattern return splitResultType; }; + if(!llvm::isa(op.getFilter().getDefiningOp())) + return failure(); + auto filterQConstOp = dyn_cast(op.getFilter().getDefiningOp()); auto filterType = op.getFilter().getType().cast(); diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 3907e7880..8ac71be2e 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -641,8 +641,9 @@ void OptimizeTranspose::runOnOperation() { patterns.insert(ctx); patterns.insert(ctx); - patterns.insert(ctx); - patterns.insert(ctx); + // TODO - enable after transpose permutation fix + // patterns.insert(ctx); + // patterns.insert(ctx); if (allowInputModificationOption) { patterns.insert(ctx); } diff --git a/xformer/Transforms/Passes.cpp b/xformer/Transforms/Passes.cpp index 29d720de2..fc7e9e12d 100644 --- a/xformer/Transforms/Passes.cpp +++ b/xformer/Transforms/Passes.cpp @@ -22,6 +22,8 @@ void buildXCorePreOpSplitPassPipeline(OpPassManager &pm) { void buildXCoreRemainingPassPipeline(OpPassManager &pm) { // TFL passes pm.addPass(createOptimizeTransposePass()); + // Run canonicalization for constant folding Transpose, if any + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createReplaceAvgPoolWithConv2DPass()); pm.addPass(createReplaceFCWithConv2DPass()); if (opSplitTensorArenaOption) { From 654dd754344802290a348af74ade37d547ba9914 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Wed, 6 Nov 2024 13:33:35 +0000 Subject: [PATCH 14/15] Remove unnecessary printf --- python/xmos_ai_tools/xinterpreters/host_interpreter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/xmos_ai_tools/xinterpreters/host_interpreter.py b/python/xmos_ai_tools/xinterpreters/host_interpreter.py index bee4a48a0..f9924df65 100644 --- a/python/xmos_ai_tools/xinterpreters/host_interpreter.py +++ b/python/xmos_ai_tools/xinterpreters/host_interpreter.py @@ -242,7 +242,6 @@ def close(self, model_index: int = 0) -> None: if self.obj: lib.delete_interpreter(self.obj) self.obj = None - print(self.obj) def tensor_arena_size(self) -> int: """! Read the size of the tensor arena required. From 63a2f79e7db66adc9c6226d4e545497c8fa87aeb Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Wed, 6 Nov 2024 21:24:16 +0000 Subject: [PATCH 15/15] Update error printing --- python/xmos_ai_tools/xformer/__init__.py | 36 +++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/python/xmos_ai_tools/xformer/__init__.py b/python/xmos_ai_tools/xformer/__init__.py index 7c5e36421..832981ec4 100644 --- a/python/xmos_ai_tools/xformer/__init__.py +++ b/python/xmos_ai_tools/xformer/__init__.py @@ -5,7 +5,8 @@ from .flash import generate_flash import re -__compilation_output = "" +__compilation_stdout = "" +__compilation_stderr = "" __arena_size = 0 @@ -29,20 +30,22 @@ def convert( args.append(str(filename)) - process_call: subprocess.CompletedProcess = subprocess.run( - [arg for arg in args], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - check=True, - ) - - global __compilation_output, __arena_size - __compilation_output = process_call.stdout.decode("utf-8") - size_str = re.sub("((.|\n|\r)*)Tensor arena size :", "", __compilation_output) - size_str = re.sub("(\n|\r)((.|\n|\r)*)", "", size_str) - __arena_size = int(size_str.strip()) - - return process_call.returncode + try: + process_call: subprocess.CompletedProcess = subprocess.run( + [arg for arg in args], + check=True, text=True, capture_output=True, + ) + global __compilation_stdout, __compilation_stderr, __arena_size + __compilation_stdout = process_call.stdout + __compilation_stderr = process_call.stderr + size_str = re.sub("((.|\n|\r)*)Tensor arena size :", "", __compilation_stdout) + size_str = re.sub("(\n|\r)((.|\n|\r)*)", "", size_str) + __arena_size = int(size_str.strip()) + return process_call.returncode + except subprocess.CalledProcessError as e: + print(e) + print("Return code:", e.returncode) + print("Error output:", e.stderr) def tensor_arena_size() -> int: @@ -50,7 +53,8 @@ def tensor_arena_size() -> int: def print_optimization_report(): - print(__compilation_output) + print(__compilation_stderr) + print(__compilation_stdout) def print_help(show_hidden: Optional[bool] = False) -> int: