diff --git a/include/anvill/Providers.h b/include/anvill/Providers.h index 58edb809a..b8251f021 100644 --- a/include/anvill/Providers.h +++ b/include/anvill/Providers.h @@ -45,11 +45,6 @@ class TypeProvider { std::optional TryGetFunctionTypeOrDefault(uint64_t address) const; - std::optional - TryGetCalledFunctionTypeOrDefault(uint64_t function_address, - const remill::Instruction &from_inst, - uint64_t to_address) const; - std::optional TryGetVariableTypeOrDefault(uint64_t address, llvm::Type *hinted_value_type = nullptr) const; @@ -60,19 +55,6 @@ class TypeProvider { virtual std::optional TryGetFunctionType(uint64_t address) const = 0; - // Try to return the type of a function that has been called from `from_isnt`. - virtual std::optional - TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &from_inst) const; - - // Try to return the type of a function starting at address `to_address`. This - // type is the prototype of the function. The type can be call site specific, - // where the call site is `from_inst`. - virtual std::optional - TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &from_inst, - uint64_t to_address) const; - // Try to return the variable at given address or containing the address virtual std::optional TryGetVariableType(uint64_t address, @@ -153,19 +135,6 @@ class ProxyTypeProvider : public TypeProvider { std::optional TryGetFunctionType(uint64_t address) const override; - // Try to return the type of a function that has been called from `from_isnt`. - std::optional - TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &from_inst) const override; - - // Try to return the type of a function starting at address `to_address`. This - // type is the prototype of the function. The type can be call site specific, - // where the call site is `from_inst`. - std::optional - TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &from_inst, - uint64_t to_address) const override; - // Try to return the variable at given address or containing the address std::optional TryGetVariableType(uint64_t address, @@ -206,11 +175,6 @@ class DefaultCallableTypeProvider : public ProxyTypeProvider { // Set `decl` to the default callable type for `arch`. void SetDefault(remill::ArchName arch, CallableDecl decl); - // Try to return the type of a function that has been called from `from_isnt`. - std::optional - TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &from_inst) const override; - std::optional TryGetFunctionType(uint64_t address) const override; }; @@ -225,11 +189,6 @@ class SpecificationTypeProvider : public BaseTypeProvider { explicit SpecificationTypeProvider(const Specification &spec); - // Try to return the type of a function that has been called from `from_isnt`. - std::optional - TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &from_inst) const override; - // Try to return the type of a function starting at address `address`. This // type is the prototype of the function. std::optional diff --git a/include/anvill/Specification.h b/include/anvill/Specification.h index 18d7fed65..393ac196b 100644 --- a/include/anvill/Specification.h +++ b/include/anvill/Specification.h @@ -146,6 +146,10 @@ class Specification { static anvill::Result DecodeFromPB(llvm::LLVMContext &context, std::istream &pb); + // Return the call site at a given function address, instruction address pair, or an empty `shared_ptr`. + std::shared_ptr + CallSiteAt(const std::pair &loc) const; + // Return the function beginning at `address`, or an empty `shared_ptr`. std::shared_ptr FunctionAt(std::uint64_t address) const; diff --git a/include/anvill/Utils.h b/include/anvill/Utils.h index ef27f3a85..05b54ad0e 100644 --- a/include/anvill/Utils.h +++ b/include/anvill/Utils.h @@ -49,6 +49,14 @@ std::string CreateFunctionName(std::uint64_t addr); // Creates a `data_
` name from an address std::string CreateVariableName(std::uint64_t addr); +// Get metadata for an instruction +std::optional GetMetadata(llvm::StringRef tag, + const llvm::Instruction &instr); + +// Set metadata for an instruction +void SetMetadata(llvm::StringRef tag, llvm::Instruction &insn, + std::uint64_t pc_val); + // Looks for any constant expressions in the operands of `inst` and unfolds // them into other instructions in the same block. void UnfoldConstantExpressions(llvm::Instruction *inst); diff --git a/lib/Lifters/BasicBlockLifter.cpp b/lib/Lifters/BasicBlockLifter.cpp index e55748679..e3b433898 100644 --- a/lib/Lifters/BasicBlockLifter.cpp +++ b/lib/Lifters/BasicBlockLifter.cpp @@ -28,6 +28,7 @@ #include "Lifters/FunctionLifter.h" #include "anvill/Declarations.h" #include "anvill/Optimize.h" +#include "anvill/Utils.h" namespace anvill { @@ -183,6 +184,7 @@ bool BasicBlockLifter::DoInterProceduralControlFlow( call = this->AddCallFromBasicBlockFunctionToLifted( block, this->intrinsics.function_call, this->intrinsics); } + SetMetadata(options.pc_metadata_name, *call, insn.pc); if (!cc.stop) { auto [_, raddr] = this->LoadFunctionReturnAddress(insn, block); auto npc = remill::LoadNextProgramCounterRef(block); diff --git a/lib/Lifters/FunctionLifter.cpp b/lib/Lifters/FunctionLifter.cpp index 6d74702ee..673149617 100644 --- a/lib/Lifters/FunctionLifter.cpp +++ b/lib/Lifters/FunctionLifter.cpp @@ -143,18 +143,6 @@ void FunctionLifter::InsertError(llvm::BasicBlock *block) { AnnotateInstruction(tail, pc_annotation_id, pc_annotation); } - -std::optional -FunctionLifter::TryGetTargetFunctionType(const remill::Instruction &from_inst, - std::uint64_t address) { - std::optional opt_callable_decl = - type_provider.TryGetCalledFunctionTypeOrDefault(func_address, from_inst, - address); - - return opt_callable_decl; -} - - // Get the annotation for the program counter `pc`, or `nullptr` if we're // not doing annotations. llvm::MDNode *FunctionLifter::GetPCAnnotation(uint64_t pc) const { diff --git a/lib/Lifters/FunctionLifter.h b/lib/Lifters/FunctionLifter.h index 2fe5ebc5e..9a2670f0e 100644 --- a/lib/Lifters/FunctionLifter.h +++ b/lib/Lifters/FunctionLifter.h @@ -222,13 +222,6 @@ class FunctionLifter : public CodeLifter { bool CallFunction(const remill::Instruction &inst, llvm::BasicBlock *block, std::optional target_pc); - // A wrapper around the type provider's TryGetFunctionType that makes use - // of the control flow provider to handle control flow redirections for - // thunks - std::optional - TryGetTargetFunctionType(const remill::Instruction &inst, - std::uint64_t address); - // Visit a direct function call control-flow instruction. The target is known // at decode time, and its realized address is stored in // `inst.branch_taken_pc`. In practice, what we do in this situation is try diff --git a/lib/Passes/RemoveCallIntrinsics.cpp b/lib/Passes/RemoveCallIntrinsics.cpp index 8a58943a1..ee73227dc 100644 --- a/lib/Passes/RemoveCallIntrinsics.cpp +++ b/lib/Passes/RemoveCallIntrinsics.cpp @@ -22,11 +22,6 @@ llvm::PreservedAnalyses RemoveCallIntrinsics::runOnIntrinsic(llvm::CallInst *remillFunctionCall, llvm::FunctionAnalysisManager &am, llvm::PreservedAnalyses prev) { - // remillFunctionCall->getFunction()->dump(); - // if (remillFunctionCall->getFunction()->getName().endswith( - // "basic_block_func4201200")) { - // LOG(FATAL) << "done"; - // } CHECK(remillFunctionCall->getNumOperands() == 4); auto target_func = remillFunctionCall->getArgOperand(1); auto state_ptr = remillFunctionCall->getArgOperand(0); @@ -43,23 +38,38 @@ RemoveCallIntrinsics::runOnIntrinsic(llvm::CallInst *remillFunctionCall, ra.references_global_value || // Related to a global var/func. ra.references_program_counter) { // Related to `__anvill_pc`. - // TODO(Ian): ignoring callsite decls for now - auto fdecl = spec.FunctionAt(ra.u.address); - auto entity = this->xref_resolver.EntityAtAddress(ra.u.address); - if (fdecl && entity) { + std::shared_ptr callable_decl = + spec.FunctionAt(ra.u.address); + + if (auto pc_val = + GetMetadata(lifter.Options().pc_metadata_name, *remillFunctionCall); + pc_val.has_value()) { + if (auto bb_addr = GetBasicBlockAddr(f); bb_addr.has_value()) { + auto block_contexts = spec.GetBlockContexts(); + const auto &bb_ctx = block_contexts.GetBasicBlockContextForAddr(*bb_addr)->get(); + auto func = bb_ctx.GetParentFunctionAddress(); + if (auto override_decl = spec.CallSiteAt({func, *pc_val})) { + DLOG(INFO) << "Overriding call site at " << std::hex << *pc_val + << " in " << std::hex << func; + callable_decl = std::move(override_decl); + } + } + } + + auto *entity = this->xref_resolver.EntityAtAddress(ra.u.address); + if (callable_decl && entity) { llvm::IRBuilder<> ir(remillFunctionCall->getParent()); ir.SetInsertPoint(remillFunctionCall); - const remill::IntrinsicTable table( - remillFunctionCall->getFunction()->getParent()); + const remill::IntrinsicTable table(f->getParent()); DLOG(INFO) << "Replacing call from: " << remill::LLVMThingToString(remillFunctionCall) << " with call to " << std::hex << ra.u.address << " d has: " << std::string(entity->getName()); - auto new_mem = - fdecl->CallFromLiftedBlock(entity, lifter.Options().TypeDictionary(), - table, ir, state_ptr, mem_ptr); + auto *new_mem = callable_decl->CallFromLiftedBlock( + entity, lifter.Options().TypeDictionary(), table, ir, state_ptr, + mem_ptr); remillFunctionCall->replaceAllUsesWith(new_mem); remillFunctionCall->eraseFromParent(); diff --git a/lib/Providers/TypeProvider.cpp b/lib/Providers/TypeProvider.cpp index 739958be6..ebe51be6f 100644 --- a/lib/Providers/TypeProvider.cpp +++ b/lib/Providers/TypeProvider.cpp @@ -56,29 +56,6 @@ NullTypeProvider::TryGetVariableType(uint64_t, llvm::Type *) const { return std::nullopt; } -// Try to return the type of a function starting at address `to_address`. This -// type is the prototype of the function. The type can be call site specific, -// where the call site is `from_inst`. -std::optional -TypeProvider::TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &from_inst, - uint64_t to_address) const { - if (auto decl = TryGetCalledFunctionType(function_address, from_inst)) { - return decl; - } else if (auto func_decl = TryGetFunctionType(to_address)) { - return static_cast(func_decl.value()); - } else { - return std::nullopt; - } -} - -// Try to return the type of a function that has been called from `from_isnt`. -std::optional -TypeProvider::TryGetCalledFunctionType(uint64_t function_address, - const remill::Instruction &) const { - return std::nullopt; -} - BaseTypeProvider::~BaseTypeProvider() {} const ::anvill::TypeDictionary &BaseTypeProvider::Dictionary(void) const { @@ -104,19 +81,6 @@ SpecificationTypeProvider::SpecificationTypeProvider(const Specification &spec) : BaseTypeProvider(spec.impl->type_translator), impl(spec.impl) {} -// Try to return the type of a function that has been called from `from_isnt`. -std::optional SpecificationTypeProvider::TryGetCalledFunctionType( - uint64_t function_address, const remill::Instruction &from_inst) const { - std::pair loc{function_address, from_inst.pc}; - - auto cs_it = impl->loc_to_call_site.find(loc); - if (cs_it == impl->loc_to_call_site.end()) { - return std::nullopt; - } else { - return *(cs_it->second); - } -} - // Try to return the type of a function starting at address `address`. This // type is the prototype of the function. std::optional @@ -140,36 +104,6 @@ SpecificationTypeProvider::TryGetVariableType(uint64_t address, } } -// Try to return the type of a function that has been called from `from_isnt`. -std::optional -DefaultCallableTypeProvider::TryGetCalledFunctionType( - uint64_t function_address, const remill::Instruction &from_inst) const { - auto maybe_res = - ProxyTypeProvider::TryGetCalledFunctionType(function_address, from_inst); - if (maybe_res.has_value()) { - return maybe_res; - } - - - auto maybe_func_type = - ProxyTypeProvider::TryGetFunctionType(function_address); - if (maybe_func_type.has_value()) { - return maybe_func_type; - } - - if (auto arch_decl = impl->TryGetDeclForArch(from_inst.arch_name)) { - return *arch_decl; - } - - if (from_inst.arch_name != from_inst.sub_arch_name) { - if (auto sub_arch_decl = impl->TryGetDeclForArch(from_inst.sub_arch_name)) { - return *sub_arch_decl; - } - } - - return std::nullopt; -} - std::optional DefaultCallableTypeProvider::TryGetFunctionType(uint64_t address) const { auto maybe_res = ProxyTypeProvider::TryGetFunctionType(address); @@ -211,22 +145,6 @@ ProxyTypeProvider::TryGetFunctionType(uint64_t address) const { return this->deleg.TryGetFunctionType(address); } -// Try to return the type of a function that has been called from `from_isnt`. -std::optional ProxyTypeProvider::TryGetCalledFunctionType( - uint64_t function_address, const remill::Instruction &from_inst) const { - return this->deleg.TryGetCalledFunctionType(function_address, from_inst); -} - -// Try to return the type of a function starting at address `to_address`. This -// type is the prototype of the function. The type can be call site specific, -// where the call site is `from_inst`. -std::optional ProxyTypeProvider::TryGetCalledFunctionType( - uint64_t function_address, const remill::Instruction &from_inst, - uint64_t to_address) const { - return this->deleg.TryGetCalledFunctionType(function_address, from_inst, - to_address); -} - // Try to return the variable at given address or containing the address std::optional ProxyTypeProvider::TryGetVariableType(uint64_t address, @@ -274,19 +192,6 @@ TypeProvider::TryGetFunctionTypeOrDefault(uint64_t address) const { return this->GetDefaultFunctionType(address); } - -std::optional TypeProvider::TryGetCalledFunctionTypeOrDefault( - uint64_t function_address, const remill::Instruction &from_inst, - uint64_t to_address) const { - auto res = - this->TryGetCalledFunctionType(function_address, from_inst, to_address); - if (res.has_value()) { - return res; - } - - return this->GetDefaultFunctionType(to_address); -} - std::optional TypeProvider::TryGetVariableTypeOrDefault(uint64_t address, llvm::Type *hinted_value_type) const { diff --git a/lib/Specification.cpp b/lib/Specification.cpp index dbf8fc5c1..b4d20fe15 100644 --- a/lib/Specification.cpp +++ b/lib/Specification.cpp @@ -84,7 +84,8 @@ SpecificationImpl::ParseSpecification( continue; } auto cs_obj = maybe_cs.Value(); - std::pair loc{cs_obj.function_address, cs_obj.address}; + std::pair loc{cs_obj.function_address, + cs_obj.address}; if (loc_to_call_site.count(loc)) { std::stringstream ss; @@ -394,6 +395,16 @@ Specification::DecodeFromPB(llvm::LLVMContext &context, std::istream &pb) { return Specification(std::move(pimpl)); } +// Return the call site at a given function address, instruction address pair, or an empty `shared_ptr`. +std::shared_ptr Specification::CallSiteAt( + const std::pair &loc) const { + auto it = impl->loc_to_call_site.find(loc); + if (it != impl->loc_to_call_site.end()) { + return {impl, it->second}; + } + return {}; +} + // Return the function beginning at `address`, or an empty `shared_ptr`. std::shared_ptr Specification::FunctionAt(std::uint64_t address) const { diff --git a/lib/Utils.cpp b/lib/Utils.cpp index 5f33dced1..b36993afd 100644 --- a/lib/Utils.cpp +++ b/lib/Utils.cpp @@ -244,6 +244,34 @@ std::string CreateVariableName(std::uint64_t addr) { return ss.str(); } +std::optional GetMetadata(llvm::StringRef tag, + const llvm::Instruction &instr) { + if (auto *metadata = instr.getMetadata(tag)) { + for (const auto &op : metadata->operands()) { + if (auto *md = dyn_cast(op.get())) { + if (auto c = dyn_cast(md->getValue())) { + auto pc_val = c->getValue().getZExtValue(); + return pc_val; + } + } + } + } + + return {}; +} + +void SetMetadata(llvm::StringRef tag, llvm::Instruction &insn, + std::uint64_t pc_val) { + auto &context = insn.getContext(); + auto &dl = insn.getModule()->getDataLayout(); + auto *address_type = + llvm::Type::getIntNTy(context, dl.getPointerSizeInBits(0)); + auto *cam = llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(address_type, pc_val)); + auto *node = llvm::MDNode::get(insn.getContext(), cam); + insn.setMetadata(tag, node); +} + void CopyMetadataTo(llvm::Value *src, llvm::Value *dst) { if (src == dst) { return;