Skip to content

Commit

Permalink
Callsite override (#398)
Browse files Browse the repository at this point in the history
* implement call site override

* simplify override logic

* address c++ nits

* add missing const

* a bb context should always exist given a valid bb address
  • Loading branch information
Ninja3047 authored Oct 4, 2023
1 parent c3f8be4 commit 9722e02
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 170 deletions.
41 changes: 0 additions & 41 deletions include/anvill/Providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ class TypeProvider {
std::optional<FunctionDecl>
TryGetFunctionTypeOrDefault(uint64_t address) const;

std::optional<CallableDecl>
TryGetCalledFunctionTypeOrDefault(uint64_t function_address,
const remill::Instruction &from_inst,
uint64_t to_address) const;

std::optional<VariableDecl>
TryGetVariableTypeOrDefault(uint64_t address,
llvm::Type *hinted_value_type = nullptr) const;
Expand All @@ -60,19 +55,6 @@ class TypeProvider {
virtual std::optional<FunctionDecl>
TryGetFunctionType(uint64_t address) const = 0;

// Try to return the type of a function that has been called from `from_isnt`.
virtual std::optional<CallableDecl>
TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &from_inst) const;

// Try to return the type of a function starting at address `to_address`. This
// type is the prototype of the function. The type can be call site specific,
// where the call site is `from_inst`.
virtual std::optional<CallableDecl>
TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &from_inst,
uint64_t to_address) const;

// Try to return the variable at given address or containing the address
virtual std::optional<VariableDecl>
TryGetVariableType(uint64_t address,
Expand Down Expand Up @@ -153,19 +135,6 @@ class ProxyTypeProvider : public TypeProvider {
std::optional<FunctionDecl>
TryGetFunctionType(uint64_t address) const override;

// Try to return the type of a function that has been called from `from_isnt`.
std::optional<CallableDecl>
TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &from_inst) const override;

// Try to return the type of a function starting at address `to_address`. This
// type is the prototype of the function. The type can be call site specific,
// where the call site is `from_inst`.
std::optional<CallableDecl>
TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &from_inst,
uint64_t to_address) const override;

// Try to return the variable at given address or containing the address
std::optional<VariableDecl>
TryGetVariableType(uint64_t address,
Expand Down Expand Up @@ -206,11 +175,6 @@ class DefaultCallableTypeProvider : public ProxyTypeProvider {
// Set `decl` to the default callable type for `arch`.
void SetDefault(remill::ArchName arch, CallableDecl decl);

// Try to return the type of a function that has been called from `from_isnt`.
std::optional<CallableDecl>
TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &from_inst) const override;

std::optional<anvill::FunctionDecl>
TryGetFunctionType(uint64_t address) const override;
};
Expand All @@ -225,11 +189,6 @@ class SpecificationTypeProvider : public BaseTypeProvider {

explicit SpecificationTypeProvider(const Specification &spec);

// Try to return the type of a function that has been called from `from_isnt`.
std::optional<CallableDecl>
TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &from_inst) const override;

// Try to return the type of a function starting at address `address`. This
// type is the prototype of the function.
std::optional<anvill::FunctionDecl>
Expand Down
4 changes: 4 additions & 0 deletions include/anvill/Specification.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ class Specification {
static anvill::Result<Specification, std::string>
DecodeFromPB(llvm::LLVMContext &context, std::istream &pb);

// Return the call site at a given function address, instruction address pair, or an empty `shared_ptr`.
std::shared_ptr<const CallSiteDecl>
CallSiteAt(const std::pair<std::uint64_t, std::uint64_t> &loc) const;

// Return the function beginning at `address`, or an empty `shared_ptr`.
std::shared_ptr<const FunctionDecl> FunctionAt(std::uint64_t address) const;

Expand Down
8 changes: 8 additions & 0 deletions include/anvill/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ std::string CreateFunctionName(std::uint64_t addr);
// Creates a `data_<address>` name from an address
std::string CreateVariableName(std::uint64_t addr);

// Get metadata for an instruction
std::optional<std::uint64_t> GetMetadata(llvm::StringRef tag,
const llvm::Instruction &instr);

// Set metadata for an instruction
void SetMetadata(llvm::StringRef tag, llvm::Instruction &insn,
std::uint64_t pc_val);

// Looks for any constant expressions in the operands of `inst` and unfolds
// them into other instructions in the same block.
void UnfoldConstantExpressions(llvm::Instruction *inst);
Expand Down
2 changes: 2 additions & 0 deletions lib/Lifters/BasicBlockLifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "Lifters/FunctionLifter.h"
#include "anvill/Declarations.h"
#include "anvill/Optimize.h"
#include "anvill/Utils.h"

namespace anvill {

Expand Down Expand Up @@ -183,6 +184,7 @@ bool BasicBlockLifter::DoInterProceduralControlFlow(
call = this->AddCallFromBasicBlockFunctionToLifted(
block, this->intrinsics.function_call, this->intrinsics);
}
SetMetadata(options.pc_metadata_name, *call, insn.pc);
if (!cc.stop) {
auto [_, raddr] = this->LoadFunctionReturnAddress(insn, block);
auto npc = remill::LoadNextProgramCounterRef(block);
Expand Down
12 changes: 0 additions & 12 deletions lib/Lifters/FunctionLifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,6 @@ void FunctionLifter::InsertError(llvm::BasicBlock *block) {
AnnotateInstruction(tail, pc_annotation_id, pc_annotation);
}


std::optional<CallableDecl>
FunctionLifter::TryGetTargetFunctionType(const remill::Instruction &from_inst,
std::uint64_t address) {
std::optional<CallableDecl> opt_callable_decl =
type_provider.TryGetCalledFunctionTypeOrDefault(func_address, from_inst,
address);

return opt_callable_decl;
}


// Get the annotation for the program counter `pc`, or `nullptr` if we're
// not doing annotations.
llvm::MDNode *FunctionLifter::GetPCAnnotation(uint64_t pc) const {
Expand Down
7 changes: 0 additions & 7 deletions lib/Lifters/FunctionLifter.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,6 @@ class FunctionLifter : public CodeLifter {
bool CallFunction(const remill::Instruction &inst, llvm::BasicBlock *block,
std::optional<std::uint64_t> target_pc);

// A wrapper around the type provider's TryGetFunctionType that makes use
// of the control flow provider to handle control flow redirections for
// thunks
std::optional<CallableDecl>
TryGetTargetFunctionType(const remill::Instruction &inst,
std::uint64_t address);

// Visit a direct function call control-flow instruction. The target is known
// at decode time, and its realized address is stored in
// `inst.branch_taken_pc`. In practice, what we do in this situation is try
Expand Down
38 changes: 24 additions & 14 deletions lib/Passes/RemoveCallIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ llvm::PreservedAnalyses
RemoveCallIntrinsics::runOnIntrinsic(llvm::CallInst *remillFunctionCall,
llvm::FunctionAnalysisManager &am,
llvm::PreservedAnalyses prev) {
// remillFunctionCall->getFunction()->dump();
// if (remillFunctionCall->getFunction()->getName().endswith(
// "basic_block_func4201200")) {
// LOG(FATAL) << "done";
// }
CHECK(remillFunctionCall->getNumOperands() == 4);
auto target_func = remillFunctionCall->getArgOperand(1);
auto state_ptr = remillFunctionCall->getArgOperand(0);
Expand All @@ -43,23 +38,38 @@ RemoveCallIntrinsics::runOnIntrinsic(llvm::CallInst *remillFunctionCall,
ra.references_global_value || // Related to a global var/func.
ra.references_program_counter) { // Related to `__anvill_pc`.

// TODO(Ian): ignoring callsite decls for now
auto fdecl = spec.FunctionAt(ra.u.address);
auto entity = this->xref_resolver.EntityAtAddress(ra.u.address);
if (fdecl && entity) {
std::shared_ptr<const CallableDecl> callable_decl =
spec.FunctionAt(ra.u.address);

if (auto pc_val =
GetMetadata(lifter.Options().pc_metadata_name, *remillFunctionCall);
pc_val.has_value()) {
if (auto bb_addr = GetBasicBlockAddr(f); bb_addr.has_value()) {
auto block_contexts = spec.GetBlockContexts();
const auto &bb_ctx = block_contexts.GetBasicBlockContextForAddr(*bb_addr)->get();
auto func = bb_ctx.GetParentFunctionAddress();
if (auto override_decl = spec.CallSiteAt({func, *pc_val})) {
DLOG(INFO) << "Overriding call site at " << std::hex << *pc_val
<< " in " << std::hex << func;
callable_decl = std::move(override_decl);
}
}
}

auto *entity = this->xref_resolver.EntityAtAddress(ra.u.address);
if (callable_decl && entity) {
llvm::IRBuilder<> ir(remillFunctionCall->getParent());
ir.SetInsertPoint(remillFunctionCall);


const remill::IntrinsicTable table(
remillFunctionCall->getFunction()->getParent());
const remill::IntrinsicTable table(f->getParent());
DLOG(INFO) << "Replacing call from: "
<< remill::LLVMThingToString(remillFunctionCall)
<< " with call to " << std::hex << ra.u.address
<< " d has: " << std::string(entity->getName());
auto new_mem =
fdecl->CallFromLiftedBlock(entity, lifter.Options().TypeDictionary(),
table, ir, state_ptr, mem_ptr);
auto *new_mem = callable_decl->CallFromLiftedBlock(
entity, lifter.Options().TypeDictionary(), table, ir, state_ptr,
mem_ptr);

remillFunctionCall->replaceAllUsesWith(new_mem);
remillFunctionCall->eraseFromParent();
Expand Down
95 changes: 0 additions & 95 deletions lib/Providers/TypeProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,6 @@ NullTypeProvider::TryGetVariableType(uint64_t, llvm::Type *) const {
return std::nullopt;
}

// Try to return the type of a function starting at address `to_address`. This
// type is the prototype of the function. The type can be call site specific,
// where the call site is `from_inst`.
std::optional<CallableDecl>
TypeProvider::TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &from_inst,
uint64_t to_address) const {
if (auto decl = TryGetCalledFunctionType(function_address, from_inst)) {
return decl;
} else if (auto func_decl = TryGetFunctionType(to_address)) {
return static_cast<CallableDecl &>(func_decl.value());
} else {
return std::nullopt;
}
}

// Try to return the type of a function that has been called from `from_isnt`.
std::optional<CallableDecl>
TypeProvider::TryGetCalledFunctionType(uint64_t function_address,
const remill::Instruction &) const {
return std::nullopt;
}

BaseTypeProvider::~BaseTypeProvider() {}

const ::anvill::TypeDictionary &BaseTypeProvider::Dictionary(void) const {
Expand All @@ -104,19 +81,6 @@ SpecificationTypeProvider::SpecificationTypeProvider(const Specification &spec)
: BaseTypeProvider(spec.impl->type_translator),
impl(spec.impl) {}

// Try to return the type of a function that has been called from `from_isnt`.
std::optional<CallableDecl> SpecificationTypeProvider::TryGetCalledFunctionType(
uint64_t function_address, const remill::Instruction &from_inst) const {
std::pair<std::uint64_t, std::uint64_t> loc{function_address, from_inst.pc};

auto cs_it = impl->loc_to_call_site.find(loc);
if (cs_it == impl->loc_to_call_site.end()) {
return std::nullopt;
} else {
return *(cs_it->second);
}
}

// Try to return the type of a function starting at address `address`. This
// type is the prototype of the function.
std::optional<anvill::FunctionDecl>
Expand All @@ -140,36 +104,6 @@ SpecificationTypeProvider::TryGetVariableType(uint64_t address,
}
}

// Try to return the type of a function that has been called from `from_isnt`.
std::optional<CallableDecl>
DefaultCallableTypeProvider::TryGetCalledFunctionType(
uint64_t function_address, const remill::Instruction &from_inst) const {
auto maybe_res =
ProxyTypeProvider::TryGetCalledFunctionType(function_address, from_inst);
if (maybe_res.has_value()) {
return maybe_res;
}


auto maybe_func_type =
ProxyTypeProvider::TryGetFunctionType(function_address);
if (maybe_func_type.has_value()) {
return maybe_func_type;
}

if (auto arch_decl = impl->TryGetDeclForArch(from_inst.arch_name)) {
return *arch_decl;
}

if (from_inst.arch_name != from_inst.sub_arch_name) {
if (auto sub_arch_decl = impl->TryGetDeclForArch(from_inst.sub_arch_name)) {
return *sub_arch_decl;
}
}

return std::nullopt;
}

std::optional<anvill::FunctionDecl>
DefaultCallableTypeProvider::TryGetFunctionType(uint64_t address) const {
auto maybe_res = ProxyTypeProvider::TryGetFunctionType(address);
Expand Down Expand Up @@ -211,22 +145,6 @@ ProxyTypeProvider::TryGetFunctionType(uint64_t address) const {
return this->deleg.TryGetFunctionType(address);
}

// Try to return the type of a function that has been called from `from_isnt`.
std::optional<CallableDecl> ProxyTypeProvider::TryGetCalledFunctionType(
uint64_t function_address, const remill::Instruction &from_inst) const {
return this->deleg.TryGetCalledFunctionType(function_address, from_inst);
}

// Try to return the type of a function starting at address `to_address`. This
// type is the prototype of the function. The type can be call site specific,
// where the call site is `from_inst`.
std::optional<CallableDecl> ProxyTypeProvider::TryGetCalledFunctionType(
uint64_t function_address, const remill::Instruction &from_inst,
uint64_t to_address) const {
return this->deleg.TryGetCalledFunctionType(function_address, from_inst,
to_address);
}

// Try to return the variable at given address or containing the address
std::optional<VariableDecl>
ProxyTypeProvider::TryGetVariableType(uint64_t address,
Expand Down Expand Up @@ -274,19 +192,6 @@ TypeProvider::TryGetFunctionTypeOrDefault(uint64_t address) const {
return this->GetDefaultFunctionType(address);
}


std::optional<CallableDecl> TypeProvider::TryGetCalledFunctionTypeOrDefault(
uint64_t function_address, const remill::Instruction &from_inst,
uint64_t to_address) const {
auto res =
this->TryGetCalledFunctionType(function_address, from_inst, to_address);
if (res.has_value()) {
return res;
}

return this->GetDefaultFunctionType(to_address);
}

std::optional<VariableDecl>
TypeProvider::TryGetVariableTypeOrDefault(uint64_t address,
llvm::Type *hinted_value_type) const {
Expand Down
13 changes: 12 additions & 1 deletion lib/Specification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ SpecificationImpl::ParseSpecification(
continue;
}
auto cs_obj = maybe_cs.Value();
std::pair<uint64_t, uint64_t> loc{cs_obj.function_address, cs_obj.address};
std::pair<std::uint64_t, std::uint64_t> loc{cs_obj.function_address,
cs_obj.address};

if (loc_to_call_site.count(loc)) {
std::stringstream ss;
Expand Down Expand Up @@ -394,6 +395,16 @@ Specification::DecodeFromPB(llvm::LLVMContext &context, std::istream &pb) {
return Specification(std::move(pimpl));
}

// Return the call site at a given function address, instruction address pair, or an empty `shared_ptr`.
std::shared_ptr<const CallSiteDecl> Specification::CallSiteAt(
const std::pair<std::uint64_t, std::uint64_t> &loc) const {
auto it = impl->loc_to_call_site.find(loc);
if (it != impl->loc_to_call_site.end()) {
return {impl, it->second};
}
return {};
}

// Return the function beginning at `address`, or an empty `shared_ptr`.
std::shared_ptr<const FunctionDecl>
Specification::FunctionAt(std::uint64_t address) const {
Expand Down
Loading

0 comments on commit 9722e02

Please sign in to comment.