Skip to content

Commit

Permalink
Merge pull request #936 from soot-oss/feature/InvokesCG
Browse files Browse the repository at this point in the history
Rework the call graph algorithm to annote call edges with the corresponding invoke statement
  • Loading branch information
JonasKlauke authored Jul 25, 2024
2 parents cdff778 + 8028c0e commit 3a03fa0
Show file tree
Hide file tree
Showing 74 changed files with 2,202 additions and 703 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public abstract class AbstractJimpleBasedICFG implements BiDiInterproceduralCFG<
protected LoadingCache<Body, StmtGraph<?>> bodyToStmtGraph =
IDESolver.DEFAULT_CACHE_BUILDER.build(
new CacheLoader<Body, StmtGraph<?>>() {
@Nonnull
@Override
public StmtGraph<?> load(@Nonnull Body body) {
return makeGraph(body);
Expand All @@ -59,6 +60,7 @@ public StmtGraph<?> load(@Nonnull Body body) {
protected LoadingCache<SootMethod, List<Value>> methodToParameterRefs =
IDESolver.DEFAULT_CACHE_BUILDER.build(
new CacheLoader<SootMethod, List<Value>>() {
@Nonnull
@Override
public List<Value> load(@Nonnull SootMethod m) {
return new ArrayList<>(m.getBody().getParameterLocals());
Expand All @@ -69,6 +71,7 @@ public List<Value> load(@Nonnull SootMethod m) {
protected LoadingCache<SootMethod, Set<Stmt>> methodToCallsFromWithin =
IDESolver.DEFAULT_CACHE_BUILDER.build(
new CacheLoader<SootMethod, Set<Stmt>>() {
@Nonnull
@Override
public Set<Stmt> load(@Nonnull SootMethod m) {
return getCallsFromWithinMethod(m);
Expand Down Expand Up @@ -185,7 +188,7 @@ public boolean setOwnerStatement(Stmt u, Body b) {

@Override
public boolean isCallStmt(Stmt stmt) {
return stmt.containsInvokeExpr();
return stmt.isInvokableStmt() && stmt.asInvokableStmt().containsInvokeExpr();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ public static Set<Pair<MethodSignature, CalleeMethodSignature>> getCallEdges(
SootMethod method = view.getMethod(caller).orElse(null);
if (method != null && method.hasBody()) {
for (Stmt s : method.getBody().getStmtGraph().getNodes()) {
if (s.containsInvokeExpr()) {
if (s.isInvokableStmt() && s.asInvokableStmt().containsInvokeExpr()) {
AbstractInvokeExpr invokeExpr = s.asInvokableStmt().getInvokeExpr().get();
CalleeMethodSignature callee =
new CalleeMethodSignature(
s.getInvokeExpr().getMethodSignature(),
findCallGraphEdgeType(s.getInvokeExpr()),
s);
invokeExpr.getMethodSignature(), findCallGraphEdgeType(invokeExpr), s);
callEdges.add(new ImmutablePair<>(caller, callee));
}
}
Expand Down Expand Up @@ -116,7 +115,7 @@ public enum CallGraphEdgeType {
/** Due to call to Class.newInstance(..) when reflection log is enabled. */
REFL_CLASS_NEWINSTANCE("REFL_CLASS_NEWINSTANCE");

private String name;
private final String name;

CallGraphEdgeType(String name) {
this.name = name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ public static Map<Integer, MethodSignature> computeCalls(
for (BasicBlock<?> block : blocks) {
List<Stmt> stmts = block.getStmts();
for (Stmt stmt : stmts) {
if (stmt.containsInvokeExpr()) {
MethodSignature target = stmt.getInvokeExpr().getMethodSignature();
if (stmt.isInvokableStmt() && stmt.asInvokableStmt().containsInvokeExpr()) {
MethodSignature target =
stmt.asInvokableStmt().getInvokeExpr().get().getMethodSignature();
int hashCode = stmt.hashCode();
calls.put(hashCode, target);
// compute all the classes that are made to the subclasses as well
Expand Down Expand Up @@ -112,7 +113,7 @@ public static Set<MethodSignature> getMethodSignatureInSubClass(
if (!callGraph.containsMethod(source) || !callGraph.containsMethod(target)) {
return Collections.emptySet();
}
return callGraph.callsFrom(source).stream()
return callGraph.callTargetsFrom(source).stream()
.filter(
methodSignature ->
!methodSignature.equals(target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import sootup.callgraph.CallGraphAlgorithm;
import sootup.callgraph.ClassHierarchyAnalysisAlgorithm;
import sootup.core.graph.StmtGraph;
import sootup.core.jimple.common.expr.AbstractInvokeExpr;
import sootup.core.jimple.common.stmt.InvokableStmt;
import sootup.core.jimple.common.stmt.Stmt;
import sootup.core.model.SootMethod;
import sootup.core.signatures.MethodSignature;
Expand Down Expand Up @@ -71,7 +73,9 @@ public class JimpleBasedInterproceduralCFG extends AbstractJimpleBasedICFG {
@Override
public Collection<SootMethod> load(Stmt stmt) {
ArrayList<SootMethod> res = new ArrayList<>();
MethodSignature methodSignature = stmt.getInvokeExpr().getMethodSignature();
if (!stmt.isInvokableStmt() && !stmt.asInvokableStmt().containsInvokeExpr()) return res;
MethodSignature methodSignature =
stmt.asInvokableStmt().getInvokeExpr().get().getMethodSignature();
Optional<? extends SootMethod> smOpt = view.getMethod(methodSignature);
if (smOpt.isPresent()) {
SootMethod sm = smOpt.get();
Expand Down Expand Up @@ -99,7 +103,7 @@ public Collection<Stmt> load(SootMethod method) {
ArrayList<Stmt> res = new ArrayList<>();
// only retain callers that are explicit call sites or
// Thread.start()
Set<MethodSignature> callsToMethod = cg.callsTo(method.getSignature());
Set<MethodSignature> callsToMethod = cg.callSourcesTo(method.getSignature());
for (MethodSignature methodSignature : callsToMethod) {
Stmt stmt = filterEdgeAndGetCallerStmt(methodSignature);
if (stmt != null) {
Expand Down Expand Up @@ -179,7 +183,7 @@ private void computeAllCalls(
signatureToStmtGraph.put(methodSignature, stmtGraph);
}
}
callGraph.callsFrom(methodSignature).stream()
callGraph.callTargetsFrom(methodSignature).stream()
.filter(methodSignature1 -> !visitedMethods.contains(methodSignature1))
.forEach(
nextMethodSignature ->
Expand Down Expand Up @@ -219,12 +223,13 @@ public static Set<Pair<MethodSignature, CalleeMethodSignature>> getCallEdges(
final SootMethod method = methodOpt.get();
if (method.hasBody()) {
for (Stmt s : method.getBody().getStmtGraph().getNodes()) {
if (s.containsInvokeExpr()) {
// TODO: Consider calls to clinit methods caused by static fields
// Assignment statements without invokeExpressions
if (s instanceof InvokableStmt && ((InvokableStmt) s).containsInvokeExpr()) {
AbstractInvokeExpr expr = ((InvokableStmt) s).getInvokeExpr().get();
CalleeMethodSignature callee =
new CalleeMethodSignature(
s.getInvokeExpr().getMethodSignature(),
CGEdgeUtil.findCallGraphEdgeType(s.getInvokeExpr()),
s);
expr.getMethodSignature(), CGEdgeUtil.findCallGraphEdgeType(expr), s);
callEdges.add(new ImmutablePair<>(caller, callee));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import sootup.core.jimple.common.expr.JSpecialInvokeExpr;
import sootup.core.jimple.common.expr.JStaticInvokeExpr;
import sootup.core.jimple.common.expr.JVirtualInvokeExpr;
import sootup.core.jimple.common.stmt.InvokableStmt;
import sootup.core.jimple.common.stmt.Stmt;
import sootup.core.jimple.visitor.ExprVisitor;
import sootup.core.model.SootClass;
Expand Down Expand Up @@ -163,7 +164,9 @@ public void testGetCallEdges() {

List<Stmt> invokesStmts =
m.getBody().getStmts().stream()
.filter(Stmt::containsInvokeExpr)
.filter(Stmt::isInvokableStmt)
.map(Stmt::asInvokableStmt)
.filter(InvokableStmt::containsInvokeExpr)
.collect(Collectors.toList());
assertEquals(invokesStmts.size(), 3);
MethodSignature constructorMethodSignature =
Expand Down Expand Up @@ -222,7 +225,10 @@ private void checkPair(
assertNotNull(virtualCall);
Stmt virtualStmt =
invokesStmts.stream()
.filter(stmt -> stmt.getInvokeExpr().getClass() == invokeClass)
.filter(Stmt::isInvokableStmt)
.map(Stmt::asInvokableStmt)
.filter(InvokableStmt::containsInvokeExpr)
.filter(stmt -> stmt.getInvokeExpr().get().getClass() == invokeClass)
.findAny()
.orElse(null);
assertNotNull(virtualStmt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
public class IFDSTaintAnalysisProblem
extends DefaultJimpleIFDSTabulationProblem<Value, InterproceduralCFG<Stmt, SootMethod>> {

private SootMethod entryMethod;
private final SootMethod entryMethod;

protected InterproceduralCFG<Stmt, SootMethod> icfg;

Expand Down Expand Up @@ -74,18 +74,18 @@ public FlowFunction<Value> getNormalFlowFunction(Stmt curr, Stmt succ) {

@Override
public FlowFunction<Value> getCallFlowFunction(Stmt callStmt, SootMethod destinationMethod) {
return getCallFlow(callStmt, destinationMethod);
return getCallFlow(callStmt.asInvokableStmt(), destinationMethod);
}

@Override
public FlowFunction<Value> getReturnFlowFunction(
Stmt callSite, SootMethod calleeMethod, Stmt exitStmt, Stmt returnSite) {
return getReturnFlow(callSite, calleeMethod, exitStmt, returnSite);
return getReturnFlow(callSite.asInvokableStmt(), calleeMethod, exitStmt, returnSite);
}

@Override
public FlowFunction<Value> getCallToReturnFlowFunction(Stmt callSite, Stmt returnSite) {
return getCallToReturnFlow(callSite, returnSite);
return getCallToReturnFlow(callSite.asInvokableStmt(), returnSite);
}
};
}
Expand All @@ -107,36 +107,33 @@ FlowFunction<Value> getNormalFlow(Stmt curr, Stmt succ) {
return new Gen<>(leftOp, zeroValue());
}
}
return new FlowFunction<Value>() {
@Override
public Set<Value> computeTargets(Value source) {
// source = {v.f*} some local and all its fields
// Kill T = ...
if (source == leftOp) {
return Collections.emptySet();
}
Set<Value> res = new HashSet<Value>();
res.add(source);
// x = T
if (source == rightOp) {
res.add(leftOp);
}
return res;
return source -> {
// source = {v.f*} some local and all its fields
// Kill T = ...
if (source == leftOp) {
return Collections.emptySet();
}
Set<Value> res = new HashSet<>();
res.add(source);
// x = T
if (source == rightOp) {
res.add(leftOp);
}
return res;
};
}
return Identity.v();
}

FlowFunction<Value> getCallFlow(Stmt callStmt, final SootMethod destinationMethod) {
FlowFunction<Value> getCallFlow(InvokableStmt callStmt, final SootMethod destinationMethod) {
if ("<clinit>".equals(destinationMethod.getName())) {
return KillAll.v();
}

AbstractInvokeExpr ie = callStmt.getInvokeExpr();
AbstractInvokeExpr ie = callStmt.getInvokeExpr().get();

final List<Immediate> callArgs = ie.getArgs();
final List<Value> paramLocals = new ArrayList<Value>();
final List<Value> paramLocals = new ArrayList<>();
for (int i = 0; i < destinationMethod.getParameterCount(); i++) {
paramLocals.add(destinationMethod.getBody().getParameterLocal(i));
}
Expand All @@ -154,28 +151,25 @@ FlowFunction<Value> getCallFlow(Stmt callStmt, final SootMethod destinationMetho
}
final Value baseF = base;

return new FlowFunction<Value>() {
@Override
public Set<Value> computeTargets(Value source) {
Set<Value> ret = new HashSet<>();
if (source instanceof JStaticFieldRef) {
ret.add(source);
}
// Tainted func parameters
for (int i = 0; i < callArgs.size(); i++) {
if (callArgs.get(i).equivTo(source) && i < paramLocals.size()) {
ret.add(paramLocals.get(i));
}
return source -> {
Set<Value> ret = new HashSet<>();
if (source instanceof JStaticFieldRef) {
ret.add(source);
}
// Tainted func parameters
for (int i = 0; i < callArgs.size(); i++) {
if (callArgs.get(i).equivTo(source) && i < paramLocals.size()) {
ret.add(paramLocals.get(i));
}
return ret;
}
return ret;
};
}

FlowFunction<Value> getReturnFlow(
final Stmt callSite, final SootMethod calleeMethod, Stmt exitStmt, Stmt returnSite) {
final InvokableStmt callSite, final SootMethod calleeMethod, Stmt exitStmt, Stmt returnSite) {

AbstractInvokeExpr ie = callSite.getInvokeExpr();
AbstractInvokeExpr ie = callSite.getInvokeExpr().get();

Value base = null;
if (ie instanceof JVirtualInvokeExpr) {
Expand Down Expand Up @@ -203,44 +197,38 @@ FlowFunction<Value> getReturnFlow(
}
}
}
return new FlowFunction<Value>() {
@Override
public Set<Value> computeTargets(Value source) {
Set<Value> ret = new HashSet<>();
if (source instanceof JStaticFieldRef) {
ret.add(source);
}
if (callSite instanceof AbstractDefinitionStmt && source == retOp) {
AbstractDefinitionStmt defnStmt = (AbstractDefinitionStmt) callSite;
ret.add(defnStmt.getLeftOp());
}
if (baseF != null && source.equals(calleeMethod.getBody().getThisLocal())) {
ret.add(baseF);
}
return ret;
return source -> {
Set<Value> ret = new HashSet<>();
if (source instanceof JStaticFieldRef) {
ret.add(source);
}
if (callSite instanceof AbstractDefinitionStmt && source == retOp) {
AbstractDefinitionStmt defnStmt = (AbstractDefinitionStmt) callSite;
ret.add(defnStmt.getLeftOp());
}
if (baseF != null && source.equals(calleeMethod.getBody().getThisLocal())) {
ret.add(baseF);
}
return ret;
};
}
if (exitStmt instanceof JReturnVoidStmt) {
return new FlowFunction<Value>() {
@Override
public Set<Value> computeTargets(Value source) {
Set<Value> ret = new HashSet<Value>();
if (source instanceof JStaticFieldRef) {
ret.add(source);
}
if (baseF != null && source.equals(calleeMethod.getBody().getThisLocal())) {
ret.add(baseF);
}
return ret;
return source -> {
Set<Value> ret = new HashSet<>();
if (source instanceof JStaticFieldRef) {
ret.add(source);
}
if (baseF != null && source.equals(calleeMethod.getBody().getThisLocal())) {
ret.add(baseF);
}
return ret;
};
}
return KillAll.v();
}

FlowFunction<Value> getCallToReturnFlow(final Stmt callSite, Stmt returnSite) {
AbstractInvokeExpr ie = callSite.getInvokeExpr();
FlowFunction<Value> getCallToReturnFlow(final InvokableStmt callSite, Stmt returnSite) {
AbstractInvokeExpr ie = callSite.getInvokeExpr().get();
final List<Immediate> callArgs = ie.getArgs();

Value base = null;
Expand All @@ -267,25 +255,22 @@ FlowFunction<Value> getCallToReturnFlow(final Stmt callSite, Stmt returnSite) {

// use assumption if no callees to analyze
if (icfg.getCalleesOfCallAt(callSite).isEmpty()) {
return new FlowFunction<Value>() {
@Override
public Set<Value> computeTargets(Value source) {
Set<Value> ret = new HashSet<Value>();
ret.add(source);
// taint leftOp if base is tainted
if (baseF != null && leftOpF != null && source == baseF) {
ret.add(leftOpF);
}
// taint leftOp if one of the args is tainted
if (leftOpF != null && callArgs.contains(source)) {
ret.add(leftOpF);
}
// taint base if one of the args is tainted and has no callee in known methods
if (baseF != null && callArgs.contains(source)) {
ret.add(baseF);
}
return ret;
return source -> {
Set<Value> ret = new HashSet<>();
ret.add(source);
// taint leftOp if base is tainted
if (baseF != null && leftOpF != null && source == baseF) {
ret.add(leftOpF);
}
// taint leftOp if one of the args is tainted
if (leftOpF != null && callArgs.contains(source)) {
ret.add(leftOpF);
}
// taint base if one of the args is tainted and has no callee in known methods
if (baseF != null && callArgs.contains(source)) {
ret.add(baseF);
}
return ret;
};
}
return Identity.v();
Expand Down
Loading

0 comments on commit 3a03fa0

Please sign in to comment.