diff --git a/llvm/lib/Transforms/Yk/ModuleClone.cpp b/llvm/lib/Transforms/Yk/ModuleClone.cpp index de32f8653eea17..765091f18b0bd0 100644 --- a/llvm/lib/Transforms/Yk/ModuleClone.cpp +++ b/llvm/lib/Transforms/Yk/ModuleClone.cpp @@ -58,8 +58,8 @@ struct YkModuleClone : public ModulePass { YkModuleClone() : ModulePass(ID) { initializeYkModuleClonePass(*PassRegistry::getPassRegistry()); } - void updateClonedFunctions(Module &M) { - for (llvm::Function &F : M) { + void renameFunctions(Module &ClonedModule) { + for (llvm::Function &F : ClonedModule) { if (F.hasExternalLinkage() && F.isDeclaration()) { continue; } @@ -70,13 +70,54 @@ struct YkModuleClone : public ModulePass { } } + /** + * This function iterates over all functions in the `FinalModule`. + * If cloned function calls are identified within the original function + * instructions, they are redirected to the original function instead. + * + * **Example Scenario:** + * - Function `f` calls function `g`. + * - Function `g` is cloned as `__yk_clone_g`. + * - Function `f` is not cloned because its address is taken. + * - As a result, function `f` calls `__yk_clone_g` instead of `g`. + * + * **Reasoning:** + * In `YkIRWriter` we only serialise non-cloned functions. + * + * @param FinalModule The module containing both original and cloned + * functions. + */ + void updateFunctionCalls(Module &FinalModule) { + for (Function &F : FinalModule) { + if (F.getName().startswith(YK_CLONE_PREFIX)) { + continue; + } + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (CallInst *CI = dyn_cast(&I)) { + Function *CalledFunc = CI->getCalledFunction(); + if (CalledFunc && + CalledFunc->getName().startswith(YK_CLONE_PREFIX)) { + std::string OriginalName = + CalledFunc->getName().str().substr(strlen(YK_CLONE_PREFIX)); + Function *OriginalFunc = FinalModule.getFunction(OriginalName); + if (OriginalFunc) { + CI->setCalledFunction(OriginalFunc); + } + } + } + } + } + } + } + bool runOnModule(Module &M) override { std::unique_ptr Cloned = CloneModule(M); if (!Cloned) { llvm::report_fatal_error("Error cloning the module"); return false; } - updateClonedFunctions(*Cloned); + renameFunctions(*Cloned); // The `OverrideFromSrc` flag instructs the linker to prioritise // definitions from the source module (the second argument) when @@ -88,6 +129,7 @@ struct YkModuleClone : public ModulePass { llvm::report_fatal_error("Error linking the modules"); return false; } + updateFunctionCalls(M); if (verifyModule(M, &errs())) { errs() << "Module verification failed!";