diff --git a/python/xmos_ai_tools/xformer/__init__.py b/python/xmos_ai_tools/xformer/__init__.py index 7c5e36421..832981ec4 100644 --- a/python/xmos_ai_tools/xformer/__init__.py +++ b/python/xmos_ai_tools/xformer/__init__.py @@ -5,7 +5,8 @@ from .flash import generate_flash import re -__compilation_output = "" +__compilation_stdout = "" +__compilation_stderr = "" __arena_size = 0 @@ -29,20 +30,22 @@ def convert( args.append(str(filename)) - process_call: subprocess.CompletedProcess = subprocess.run( - [arg for arg in args], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - check=True, - ) - - global __compilation_output, __arena_size - __compilation_output = process_call.stdout.decode("utf-8") - size_str = re.sub("((.|\n|\r)*)Tensor arena size :", "", __compilation_output) - size_str = re.sub("(\n|\r)((.|\n|\r)*)", "", size_str) - __arena_size = int(size_str.strip()) - - return process_call.returncode + try: + process_call: subprocess.CompletedProcess = subprocess.run( + [arg for arg in args], + check=True, text=True, capture_output=True, + ) + global __compilation_stdout, __compilation_stderr, __arena_size + __compilation_stdout = process_call.stdout + __compilation_stderr = process_call.stderr + size_str = re.sub("((.|\n|\r)*)Tensor arena size :", "", __compilation_stdout) + size_str = re.sub("(\n|\r)((.|\n|\r)*)", "", size_str) + __arena_size = int(size_str.strip()) + return process_call.returncode + except subprocess.CalledProcessError as e: + print(e) + print("Return code:", e.returncode) + print("Error output:", e.stderr) def tensor_arena_size() -> int: @@ -50,7 +53,8 @@ def tensor_arena_size() -> int: def print_optimization_report(): - print(__compilation_output) + print(__compilation_stderr) + print(__compilation_stdout) def print_help(show_hidden: Optional[bool] = False) -> int: diff --git a/python/xmos_ai_tools/xinterpreters/host_interpreter.py b/python/xmos_ai_tools/xinterpreters/host_interpreter.py index bee4a48a0..f9924df65 100644 --- a/python/xmos_ai_tools/xinterpreters/host_interpreter.py +++ b/python/xmos_ai_tools/xinterpreters/host_interpreter.py @@ -242,7 +242,6 @@ def close(self, model_index: int = 0) -> None: if self.obj: lib.delete_interpreter(self.obj) self.obj = None - print(self.obj) def tensor_arena_size(self) -> int: """! Read the size of the tensor arena required. diff --git a/xformer/Transforms/OptimizeConv2D.cpp b/xformer/Transforms/OptimizeConv2D.cpp index 62d637106..4a5dfcad7 100644 --- a/xformer/Transforms/OptimizeConv2D.cpp +++ b/xformer/Transforms/OptimizeConv2D.cpp @@ -86,6 +86,9 @@ struct ChannelwiseSplitConv2DOutputPattern return splitResultType; }; + if(!llvm::isa(op.getFilter().getDefiningOp())) + return failure(); + auto filterQConstOp = dyn_cast(op.getFilter().getDefiningOp()); auto filterType = op.getFilter().getType().cast(); diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 8223cd3f8..8ac71be2e 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -1,9 +1,9 @@ // Copyright 2023 XMOS LIMITED. This Software is subject to the terms of the // XMOS Public License: Version 1 -#include "IR/XCoreOps.h" #include "Transforms/Options.h" +#include "Utils/Util.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -25,6 +25,488 @@ struct OptimizeTranspose void runOnOperation() override; }; +static SmallVector +computeInversePermutation(const SmallVector &permVec) { + SmallVector invPermVec(permVec.size()); + for (size_t i = 0; i < permVec.size(); ++i) { + invPermVec[permVec[i]] = i; + } + return invPermVec; +} + +struct FoldTrReFCPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::FullyConnectedOp fcOp, + PatternRewriter &rewriter) const override { + // Match the pattern: transpose -> reshape -> fully_connected + + // Check that the input to fully_connected is a reshape + auto reshapeOp = fcOp.getInput().getDefiningOp(); + if (!reshapeOp || !reshapeOp->hasOneUse()) + return failure(); + + // Check that the input to reshape is a transpose + auto transposeOp = reshapeOp.getInput().getDefiningOp(); + if (!transposeOp || !transposeOp->hasOneUse()) + return failure(); + + // Ensure that the fully_connected op has a single use + if (!fcOp->hasOneUse()) + return failure(); + + // Get permutation attribute from transpose + DenseIntElementsAttr permAttr; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) + return failure(); + + SmallVector permVec = utils::denseToVector(permAttr); + + // Compute inverse permutation vector + SmallVector invPermVec = computeInversePermutation(permVec); + + // Get input shape + auto reshapeInputType = + reshapeOp.getInput().getType().dyn_cast(); + if (!reshapeInputType) + return failure(); + + // Transpose the filter weights + Value filter = fcOp.getFilter(); + Value bias = fcOp.getBias(); + + { + // Get filter type and shape + auto filterType = filter.getType().dyn_cast(); + if (!filterType) + return failure(); + + auto filterShape = filterType.getShape(); + SmallVector filterShapeVec(filterShape.begin(), + filterShape.end()); + + int64_t outputSize = filterShapeVec[0]; + + // Reshape the filter to reshapeInputShape[1:] + [outputSize] + auto reshapeInputShape = reshapeInputType.getShape(); + SmallVector reshapedFilterShapeVec( + reshapeInputShape.begin() + 1, reshapeInputShape.end()); + reshapedFilterShapeVec.push_back(outputSize); + + // Prepare permutation vector for the filter (excluding batch dimension) + SmallVector filterPermVec; + for (size_t i = 1; i < invPermVec.size(); ++i) { + filterPermVec.push_back(invPermVec[i] - 1); + } + filterPermVec.push_back(invPermVec.size() - 1); + + Value finalFilter; + if (failed(utils::reshapeTransposeReshape( + rewriter, filter, reshapedFilterShapeVec, filterPermVec, + filterShapeVec, finalFilter))) { + llvm::outs() << "Failed to reshape filter\n"; + return failure(); + } + + filter = finalFilter; + } + + // Bias remains the same; no need to adjust it + + // Create new reshape op (from transpose's input to reshape's output) + Value newInput = transposeOp.getInput(); + + auto newReshapeOp = rewriter.create( + reshapeOp.getLoc(), reshapeOp.getResult().getType(), newInput, + reshapeOp.getShape()); + + // Create new fully_connected op with adjusted filter + auto newFullyConnectedOp = rewriter.create( + fcOp.getLoc(), fcOp.getResult(0).getType(), newReshapeOp.getResult(), + filter, bias, fcOp.getFusedActivationFunctionAttr(), + fcOp.getWeightsFormatAttr(), fcOp.getKeepNumDimsAttr(), + fcOp.getAsymmetricQuantizeInputsAttr()); + + // Replace the original fully_connected op with the new one + rewriter.replaceOp(fcOp, newFullyConnectedOp.getResults()); + + // Erase the old reshape and transpose ops if they are no longer used + if (reshapeOp.use_empty()) + rewriter.eraseOp(reshapeOp); + if (transposeOp.use_empty()) + rewriter.eraseOp(transposeOp); + + return success(); + } +}; + +struct FoldFCReTrPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + // Match the pattern: fully_connected -> reshape -> transpose + auto reshapeOp = transposeOp.getInput().getDefiningOp(); + if (!reshapeOp || !reshapeOp->hasOneUse()) + return failure(); + + auto fullyConnectedOp = + reshapeOp.getInput().getDefiningOp(); + if (!fullyConnectedOp || !fullyConnectedOp->hasOneUse()) + return failure(); + + // Get types and shapes + auto fcInputType = + fullyConnectedOp.getInput().getType().dyn_cast(); + auto fcOutputType = + fullyConnectedOp.getResult(0).getType().dyn_cast(); + auto reshapeOutputType = + reshapeOp.getResult().getType().dyn_cast(); + auto transposeOutputType = + transposeOp.getResult().getType().dyn_cast(); + + if (!fcInputType || !fcOutputType || !reshapeOutputType || + !transposeOutputType) + return failure(); + + // Check if the batch dimension (assumed to be dimension 0) remains + // unchanged + auto fcOutputShape = fcOutputType.getShape(); + auto reshapeOutputShape = reshapeOutputType.getShape(); + auto transposeOutputShape = transposeOutputType.getShape(); + + if (fcOutputShape.empty() || reshapeOutputShape.empty()) + return failure(); // Expecting non-scalar tensors + + if (fcOutputShape[0] != reshapeOutputShape[0]) + return failure(); // Batch dimension changed in reshape + + // Get permutation attribute from transpose + DenseIntElementsAttr permAttr; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) + return failure(); + + SmallVector permVec = utils::denseToVector(permAttr); + + // Compute inverse permutation vector + SmallVector invPermVec = computeInversePermutation(permVec); + + // Check if batch dimension remains at position 0 after transpose + if (invPermVec.empty() || invPermVec[0] != 0) + return failure(); + + // Prepare to transform the filter and bias + Value filter = fullyConnectedOp.getFilter(); + Value bias = fullyConnectedOp.getBias(); + + // Process bias + { + // Get bias type and shape + auto biasType = bias.getType().dyn_cast(); + if (!biasType) + return failure(); + auto biasShape = biasType.getShape(); + SmallVector biasShapeVec(biasShape.begin(), biasShape.end()); + + // Use transpose output shape as reshape shape for bias + SmallVector transposeOutputShapeVec( + transposeOutputShape.begin(), transposeOutputShape.end()); + + Value finalBias; + if (failed(utils::reshapeTransposeReshape( + rewriter, bias, transposeOutputShapeVec, invPermVec, biasShapeVec, + finalBias))) + return failure(); + + bias = finalBias; + } + + // Process filter + { + // Get filter type and shape + auto filterType = filter.getType().dyn_cast(); + if (!filterType) + return failure(); + + auto filterShape = filterType.getShape(); + SmallVector filterShapeVec(filterShape.begin(), + filterShape.end()); + + // Compute new filter reshape shape + SmallVector filterOutShapeVec = {filterShape[1]}; + filterOutShapeVec.insert(filterOutShapeVec.end(), + transposeOutputShape.begin() + 1, + transposeOutputShape.end()); + + Value finalFilter; + if (failed(utils::reshapeTransposeReshape(rewriter, filter, + filterOutShapeVec, invPermVec, + filterShapeVec, finalFilter))) + return failure(); + + filter = finalFilter; + } + + // Create new fully connected op with adjusted filter and bias + auto newFullyConnectedOp = rewriter.create( + fullyConnectedOp.getLoc(), fcOutputType, fullyConnectedOp.getInput(), + filter, bias, fullyConnectedOp.getFusedActivationFunctionAttr(), + fullyConnectedOp.getWeightsFormatAttr(), + fullyConnectedOp.getKeepNumDimsAttr(), + fullyConnectedOp.getAsymmetricQuantizeInputsAttr()); + + // Create new reshape op with the output type of the original transpose op + SmallVector originalShapeVec(transposeOutputShape.begin(), + transposeOutputShape.end()); + Value newShapeConstOp = utils::createShapeConstOp( + rewriter, transposeOp.getLoc(), originalShapeVec); + + // Create new reshape op + auto newReshapeOp = rewriter.create( + reshapeOp.getLoc(), transposeOutputType, + newFullyConnectedOp.getResult(0), newShapeConstOp); + + // Replace the original transpose op with the new reshape op + rewriter.replaceOp(transposeOp, newReshapeOp.getResult()); + + return success(); + } +}; + +struct MoveTransposeForwardOverUnaryOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + // Ensure the TransposeOp has a single use + if (!transposeOp->hasOneUse()) + return failure(); + + Operation *userOp = *transposeOp->getUsers().begin(); + + // Check if the user operation is a unary op that can commute with + // transpose + if (!isa(userOp)) + return failure(); + + // Get the types of the input and output tensors + auto transposeInputType = + transposeOp.getInput().getType().dyn_cast(); + auto transposeOutputType = + transposeOp.getType().dyn_cast(); + if (!transposeInputType || !transposeOutputType) + return failure(); + + // Get the permutation used in the transpose + Value perm = transposeOp.getPerm(); + + Value newUnaryOpResult; + auto loc = userOp->getLoc(); + auto input = transposeOp.getInput(); + + // Retrieve the original unary operation's output type + auto originalUnaryOutputType = + userOp->getResult(0).getType().dyn_cast(); + if (!originalUnaryOutputType) + return failure(); + + // Create a new output type for the unary op with the same shape as + // 'input' and the same element type as the original output type + auto newUnaryOutputType = RankedTensorType::get( + transposeInputType.getShape(), originalUnaryOutputType.getElementType(), + originalUnaryOutputType.getEncoding()); + + if (auto quantizeOp = dyn_cast(userOp)) { + // For QuantizeOp, create new QuantizeOp with input as 'input' and + // output type adjusted + + // Create new QuantizeOp with adjusted output type + newUnaryOpResult = rewriter.create( + loc, newUnaryOutputType, input, quantizeOp.getQtypeAttr()); + + } else if (auto absOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto negOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto reluOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto relu6Op = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto leakyReluOp = dyn_cast(userOp)) { + newUnaryOpResult = rewriter.create( + loc, newUnaryOutputType, input, leakyReluOp.getAlphaAttr()); + } else if (auto tanhOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else if (auto logisticOp = dyn_cast(userOp)) { + newUnaryOpResult = + rewriter.create(loc, newUnaryOutputType, input); + } else { + // This should not happen as we checked the op type earlier + return failure(); + } + + // Create a new Transpose operation after the unary operation + auto newTransposeOp = rewriter.create( + transposeOp.getLoc(), originalUnaryOutputType, newUnaryOpResult, perm); + + // Replace the original user operation's result with the new transpose + // result + rewriter.replaceOp(userOp, newTransposeOp.getResult()); + + // Remove the original TransposeOp + rewriter.eraseOp(transposeOp); + + return success(); + } +}; + +struct FoldCancellableTransposePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::TransposeOp op, + PatternRewriter &rewriter) const override { + + // Check for invalid types and return + // Defining op must be transpose + auto transposeOp = + dyn_cast_or_null(op.getInput().getDefiningOp()); + if (!transposeOp) { + return failure(); + } + + // Get transpose permutations + DenseIntElementsAttr perm0; + DenseIntElementsAttr perm1; + if (!matchPattern(op.getPerm(), m_Constant(&perm0)) || + !matchPattern(transposeOp.getPerm(), m_Constant(&perm1))) { + return failure(); + } + + // Do permutation indices cancel each other? + if (!TF::AreCancellablePermutations(perm0, perm1)) { + return failure(); + } + + rewriter.replaceOp(op, transposeOp.getInput()); + + return success(); + } +}; +struct MoveTransposeForwardOverConcatOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::ConcatenationOp concatOp, + PatternRewriter &rewriter) const override { + // Get all input operands + auto inputs = concatOp.getValues(); + + // Check that all inputs are TransposeOps with the same permutation + SmallVector newInputs; + DenseIntElementsAttr commonPermAttr; + for (auto input : inputs) { + auto transposeOp = input.getDefiningOp(); + if (!transposeOp) + return failure(); + + // Get permutation attribute + DenseIntElementsAttr permAttr; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&permAttr))) + return failure(); + + // Check if the permutation is the same as others + if (commonPermAttr) { + if (permAttr != commonPermAttr) + return failure(); + } else { + commonPermAttr = permAttr; + } + + // Collect the inputs to the transpose ops + newInputs.push_back(transposeOp.getInput()); + } + + // Get the permutation vector + SmallVector permVec; + for (auto val : commonPermAttr.getValues()) { + permVec.push_back(val); + } + + // Compute the inverse permutation + SmallVector invPerm(permVec.size()); + for (size_t i = 0; i < permVec.size(); ++i) { + invPerm[permVec[i]] = i; + } + + // Adjust the axis according to the inverse permutation + int32_t oldAxis = concatOp.getAxis(); + int64_t rank = permVec.size(); + if (oldAxis < 0) { + oldAxis += rank; + } + if (oldAxis < 0 || oldAxis >= rank) { + return failure(); // Invalid axis + } + int32_t newAxis = permVec[oldAxis]; + + // Collect input types and compute the new result type + SmallVector inputTypes; + for (auto input : newInputs) { + auto inputType = input.getType().dyn_cast(); + if (!inputType) { + return failure(); + } + inputTypes.push_back(inputType); + } + + // Ensure all input types have the same rank + for (auto type : inputTypes) { + if (type.getRank() != rank) { + return failure(); + } + } + + // Compute the shape of the concatenated tensor + SmallVector resultShape(inputTypes[0].getShape().begin(), + inputTypes[0].getShape().end()); + for (size_t i = 1; i < inputTypes.size(); ++i) { + auto shape = inputTypes[i].getShape(); + resultShape[newAxis] += shape[newAxis]; + } + + // Create the new ConcatenationOp with the correct result type and axis + auto elementType = inputTypes[0].getElementType(); + auto newConcatType = RankedTensorType::get(resultShape, elementType); + auto newConcatOp = rewriter.create( + concatOp.getLoc(), newConcatType, newInputs, + rewriter.getI32IntegerAttr(newAxis), + concatOp.getFusedActivationFunctionAttr()); + + // Create the permutation constant with correct data types + auto permType = RankedTensorType::get( + {static_cast(permVec.size())}, rewriter.getIntegerType(32)); + auto permAttr = DenseIntElementsAttr::get(permType, permVec); + auto permConstOp = rewriter.create(concatOp.getLoc(), + permType, permAttr); + + // Create the new TransposeOp with the original output type + auto newTransposeOp = rewriter.create( + concatOp.getLoc(), concatOp.getType(), newConcatOp.getResult(), + permConstOp.getResult()); + + rewriter.replaceOp(concatOp, newTransposeOp.getResult()); + return success(); + } +}; + struct HoistTransposeWCHAbovePadPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -91,39 +573,7 @@ struct HoistTransposeWCHAbovePadPattern return success(); } }; -struct FoldCancellableTransposePattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::TransposeOp op, - PatternRewriter &rewriter) const override { - // Check for invalid types and return - // Defining op must be transpose - auto transposeOp = - dyn_cast_or_null(op.getInput().getDefiningOp()); - if (!transposeOp) { - return failure(); - } - - // Get transpose permutations - DenseIntElementsAttr perm0; - DenseIntElementsAttr perm1; - if (!matchPattern(op.getPerm(), m_Constant(&perm0)) || - !matchPattern(transposeOp.getPerm(), m_Constant(&perm1))) { - return failure(); - } - - // Do permutation indices cancel each other? - if (!TF::AreCancellablePermutations(perm0, perm1)) { - return failure(); - } - - rewriter.replaceOp(op, transposeOp.getInput()); - - return success(); - } -}; struct FoldTransposeWCHToInput : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -151,8 +601,8 @@ struct FoldTransposeWCHToInput : public OpRewritePattern { if (blockArg.hasOneUse()) { auto funcOp = cast(blockArg.getOwner()->getParentOp()); - // Set function type to the transpose output type as we are changing the - // input + // Set function type to the transpose output type as we are changing + // the input FunctionType funcType = funcOp.getFunctionType(); llvm::SmallVector newInputTypes(funcType.getInputs().begin(), funcType.getInputs().end()); @@ -176,10 +626,24 @@ struct FoldTransposeWCHToInput : public OpRewritePattern { void OptimizeTranspose::runOnOperation() { auto *ctx = &getContext(); func::FuncOp func = getOperation(); + + // Try to merge transpose -> ops -> inverse transpose + RewritePatternSet mergePatterns(ctx); + mergePatterns.insert(ctx); + if (mergeTransposeOption) { + (void)applyPatternsAndFoldGreedily(func, std::move(mergePatterns)); + } + + // Other transpose optimizations RewritePatternSet patterns(ctx); patterns.insert(ctx); patterns.insert(ctx); + // TODO - enable after transpose permutation fix + // patterns.insert(ctx); + // patterns.insert(ctx); if (allowInputModificationOption) { patterns.insert(ctx); } diff --git a/xformer/Transforms/Options.h b/xformer/Transforms/Options.h index 3b82763c6..811f969ce 100644 --- a/xformer/Transforms/Options.h +++ b/xformer/Transforms/Options.h @@ -26,6 +26,7 @@ extern llvm::cl::list opSplitBottomOpsOption; extern llvm::cl::list opSplitTopOpsOption; extern llvm::cl::list opSplitNumSplitsOption; extern llvm::cl::opt allowInputModificationOption; +extern llvm::cl::opt mergeTransposeOption; extern llvm::cl::opt convDebugOption; extern llvm::cl::opt overlapConvOption; extern llvm::cl::opt offlineOffsetsOption; diff --git a/xformer/Transforms/Passes.cpp b/xformer/Transforms/Passes.cpp index 29d720de2..fc7e9e12d 100644 --- a/xformer/Transforms/Passes.cpp +++ b/xformer/Transforms/Passes.cpp @@ -22,6 +22,8 @@ void buildXCorePreOpSplitPassPipeline(OpPassManager &pm) { void buildXCoreRemainingPassPipeline(OpPassManager &pm) { // TFL passes pm.addPass(createOptimizeTransposePass()); + // Run canonicalization for constant folding Transpose, if any + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createReplaceAvgPoolWithConv2DPass()); pm.addPass(createReplaceFCWithConv2DPass()); if (opSplitTensorArenaOption) { diff --git a/xformer/Utils/Util.cpp b/xformer/Utils/Util.cpp index ac4cddea6..874c638bf 100644 --- a/xformer/Utils/Util.cpp +++ b/xformer/Utils/Util.cpp @@ -2,7 +2,6 @@ // XMOS Public License: Version 1 #include "Utils/Util.h" -#include #include "mlir/Dialect/Quant/QuantTypes.h" #include "llvm/ADT/ArrayRef.h" @@ -139,4 +138,83 @@ int mergeAxes(std::vector &begin, std::vector &size, return rank; } +// Converts int64_t vector to int32_t vector, returns failure if any value is +// out of int32_t range. +LogicalResult convertToI32Array(const SmallVectorImpl &input, + SmallVectorImpl &output) { + for (auto val : input) { + if (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) + return failure(); + output.push_back(static_cast(val)); + } + return success(); +} + +// Creates a constant op for a shape vector. +Value createShapeConstOp(PatternRewriter &rewriter, Location loc, + const SmallVectorImpl &shapeVec) { + SmallVector shapeVecI32; + if (failed(convertToI32Array(shapeVec, shapeVecI32))) + return nullptr; + auto shapeType = RankedTensorType::get( + {static_cast(shapeVecI32.size())}, rewriter.getI32Type()); + auto shapeAttr = DenseIntElementsAttr::get(shapeType, shapeVecI32); + return rewriter.create(loc, shapeType, shapeAttr); +} + +// Helper function for reshape-transpose-reshape pattern. +LogicalResult +reshapeTransposeReshape(PatternRewriter &rewriter, Value tensor, + const SmallVectorImpl &reshapeShape, + const SmallVectorImpl &permVec, + const SmallVectorImpl &origShape, + Value &result) { + auto loc = tensor.getLoc(); + auto tensorType = tensor.getType().cast(); + auto elementType = tensorType.getElementType(); + + // Reshape tensor to reshapeShapeExclBatch. + Value newShapeOp = createShapeConstOp(rewriter, loc, reshapeShape); + if (!newShapeOp) + return failure(); + auto reshapedType = RankedTensorType::get(reshapeShape, elementType); + auto reshapedTensor = + rewriter.create(loc, reshapedType, tensor, newShapeOp); + + // Convert permVecExclBatch to int32_t vector. + SmallVector permVecI32; + if (failed(convertToI32Array(permVec, permVecI32))) + return failure(); + + // Create perm op. + auto permType = RankedTensorType::get( + {static_cast(permVecI32.size())}, rewriter.getI32Type()); + auto permAttr = DenseIntElementsAttr::get(permType, permVecI32); + auto permOp = rewriter.create(loc, permType, permAttr); + + // Compute transposed shape. + SmallVector transposedShape; + for (auto idx : permVec) { + if (idx < 0 || idx >= reshapeShape.size()) + return failure(); + transposedShape.push_back(reshapeShape[idx]); + } + auto transposedType = RankedTensorType::get(transposedShape, elementType); + + // Transpose. + auto transposedTensor = rewriter.create( + loc, transposedType, reshapedTensor, permOp); + + // Reshape back to original shape. + Value origShapeOp = createShapeConstOp(rewriter, loc, origShape); + if (!origShapeOp) + return failure(); + auto finalTensor = rewriter.create( + loc, tensorType, transposedTensor, origShapeOp); + + result = finalTensor.getResult(); + return success(); +} + } // namespace mlir::xcore::utils diff --git a/xformer/Utils/Util.h b/xformer/Utils/Util.h index 7e29f1f34..97de0b725 100644 --- a/xformer/Utils/Util.h +++ b/xformer/Utils/Util.h @@ -7,6 +7,8 @@ #include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include namespace mlir::xcore::utils { @@ -69,6 +71,27 @@ template bool checkBinaryCompatibility(T op) { int mergeAxes(std::vector &begin, std::vector &size, std::vector &inShape, std::vector &outShape, int rank); + +LogicalResult convertToI32Array(const SmallVectorImpl &input, + SmallVectorImpl &output); + +Value createShapeConstOp(PatternRewriter &rewriter, Location loc, + const SmallVectorImpl &shapeVec); + +LogicalResult +reshapeTransposeReshape(PatternRewriter &rewriter, Value tensor, + const SmallVectorImpl &reshapeShape, + const SmallVectorImpl &permVec, + const SmallVectorImpl &origShape, + Value &result); + +template +static SmallVector denseToVector(DenseIntElementsAttr permAttr) { + SmallVector permVec; + for (auto val : permAttr.getValues()) + permVec.push_back(static_cast(val)); + return permVec; +} } // namespace mlir::xcore::utils #endif // XFORMER_UTILS_UTIL_H diff --git a/xformer/XCoreOptMain.cpp b/xformer/XCoreOptMain.cpp index da9297a96..48663dbb6 100644 --- a/xformer/XCoreOptMain.cpp +++ b/xformer/XCoreOptMain.cpp @@ -159,6 +159,11 @@ cl::opt allowInputModificationOption( cl::desc("Allow the compiler to modify input tensor for optimizations."), cl::init(false), cl::cat(XformerCategory), cl::Hidden); +cl::opt mergeTransposeOption( + "xcore-merge-transpose", + cl::desc("Try to merge transpose and inverse transpose together."), + cl::init(true), cl::cat(XformerCategory), cl::Hidden); + cl::opt convDebugOption("xcore-conv-debug", cl::desc("Enable conv debug prints."), cl::init(false), cl::cat(XformerCategory),