Skip to content

Commit

Permalink
Add lambda support to CHACallGraph (#1476)
Browse files Browse the repository at this point in the history
Fixes #1459 

To support lambdas, we re-use logic from the
`LambdaMethodTargetSelector` to handle calls that use `invokedynamic`.
`CHACallGraph` only includes edges for methods reachable from given
entrypoints, and our handling of lambdas adds new synthetic classes to
the class hierarchy (representing the classes generated for lambdas). So
to ensure a complete call graph, we do a second pass after adding these
synthetic classes to recompute call targets. This is probably not the
most efficient way to do things, but I'm not sure it's the bottleneck;
if we need to optimize in the future we should do so based on profiling.

Note that the handling of lambdas here is very imprecise; every call
site of a functional interface method will have an edge to every lambda
of that functional interface type. Still, it could be practically
useful, e.g., when combined with other heuristic filtering of irrelevant
code.

We add new tests, and also fix a couple other minor bugs (e.g., there
was no edge before to the synthetic method that invokes class
initializers, so it wasn't being processed properly).

---------

Co-authored-by: Ben Liblit <[email protected]>
  • Loading branch information
msridhar and liblit authored Nov 30, 2024
1 parent 6d62392 commit 9f71651
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,11 @@ public static CrossLanguageClassHierarchy make(AnalysisScope scope, ClassLoaderF
public Set<TypeReference> getUnresolvedClasses() {
return HashSetFactory.make();
}

/* END Custom change: unresolved classes */

@Override
public void clearCaches() {
hierarchies.values().forEach(IClassHierarchy::clearCaches);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
import com.ibm.wala.ipa.callgraph.impl.ExplicitPredecessorsEdgeManager;
import com.ibm.wala.ipa.callgraph.impl.FakeWorldClinitMethod;
import com.ibm.wala.ipa.cha.IClassHierarchy;
import com.ibm.wala.ipa.summaries.LambdaMethodTargetSelector;
import com.ibm.wala.ipa.summaries.LambdaSummaryClass;
import com.ibm.wala.shrike.shrikeBT.IInvokeInstruction;
import com.ibm.wala.ssa.DefUse;
import com.ibm.wala.ssa.IR;
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.CancelException;
import com.ibm.wala.util.collections.ComposedIterator;
import com.ibm.wala.util.collections.FilterIterator;
Expand Down Expand Up @@ -60,6 +63,10 @@ public class CHACallGraph extends BasicCallGraph<CHAContextInterpreter> {
*/
private final boolean applicationOnly;

/** To handle lambdas. We pass a selector that always returns null as the base selector. */
private final LambdaMethodTargetSelector lambdaMethodTargetSelector =
new LambdaMethodTargetSelector((caller, site, receiver) -> null);

private boolean isInitialized = false;

private class CHANode extends NodeImpl {
Expand Down Expand Up @@ -138,6 +145,25 @@ public void init(Iterable<Entrypoint> entrypoints) throws CancelException {
}
newNodes.push(root);
closure();
// classes simulating lambdas may have been added to the CHA via the previous closure() call.
// to update call targets to include lambdas, we clear all call target caches, iterate through
// all call sites, and re-compute the targets.
// TODO optimize if needed
targetCache.clear();
cha.clearCaches();
for (CGNode n : this) {
for (CallSiteReference site : Iterator2Iterable.make(n.iterateCallSites())) {
for (IMethod target : Iterator2Iterable.make(getOrUpdatePossibleTargets(n, site))) {
if (isRelevantMethod(target)) {
CGNode callee = getNode(target, Everywhere.EVERYWHERE);
if (callee == null) {
throw new RuntimeException("should have already created CGNode for " + target);
}
edgeManager.addEdge(n, callee);
}
}
}
}
isInitialized = true;
}

Expand All @@ -146,6 +172,13 @@ public IClassHierarchy getClassHierarchy() {
return cha;
}

/**
* Cache of possible targets for call sites.
*
* <p>In the future, this cache could be keyed on ({@link com.ibm.wala.types.MethodReference},
* {@code isDispatch}) pairs to save space and possibly time, where {@code isDispatch} indicates
* whether the call site is a virtual dispatch.
*/
private final Map<CallSiteReference, Set<IMethod>> targetCache = HashMapFactory.make();

/**
Expand All @@ -154,8 +187,29 @@ public IClassHierarchy getClassHierarchy() {
* @param site the call site
* @return an iterator of possible targets
*/
private Iterator<IMethod> getOrUpdatePossibleTargets(CallSiteReference site) {
Set<IMethod> result = targetCache.get(site);
private Iterator<IMethod> getOrUpdatePossibleTargets(CGNode caller, CallSiteReference site)
throws CancelException {
Set<IMethod> result = null;
if (isCallToLambdaMetafactoryMethod(site)) {
IMethod calleeTarget = lambdaMethodTargetSelector.getCalleeTarget(caller, site, null);
if (calleeTarget != null) {
// It's for a lambda. The result method is a synthetic method that allocates an object of
// the synthetic class generate for the lambda.
result = Collections.singleton(calleeTarget);
// we eagerly create a CGNode for the "trampoline" method that invokes the body of the
// lambda itself. This way, the new node gets added to the worklist, so we process all
// methods reachable from the lambda body immediately and don't need to do an outer fixed
// point. This does not do any wasted work assuming the call graph has at least one
// invocation of the lambda.
LambdaSummaryClass lambdaSummaryClass =
lambdaMethodTargetSelector.getLambdaSummaryClass(caller, site);
IMethod trampoline = lambdaSummaryClass.getDeclaredMethods().iterator().next();
CGNode callee = getNode(trampoline, Everywhere.EVERYWHERE);
if (callee == null) {
callee = findOrCreateNode(trampoline, Everywhere.EVERYWHERE);
}
}
}
if (result == null) {
if (site.isDispatch()) {
result = cha.getPossibleTargets(site.getDeclaredTarget());
Expand All @@ -164,7 +218,12 @@ private Iterator<IMethod> getOrUpdatePossibleTargets(CallSiteReference site) {
if (m != null) {
result = Collections.singleton(m);
} else {
result = Collections.emptySet();
IMethod fakeWorldClinitMethod = getFakeWorldClinitNode().getMethod();
if (site.getDeclaredTarget().equals(fakeWorldClinitMethod.getReference())) {
result = Collections.singleton(fakeWorldClinitMethod);
} else {
result = Collections.emptySet();
}
}
}
targetCache.put(site, result);
Expand All @@ -178,7 +237,14 @@ private Iterator<IMethod> getOrUpdatePossibleTargets(CallSiteReference site) {
* @param site the call site
* @return an iterator of possible targets
*/
private Iterator<IMethod> getPossibleTargetsFromCache(CallSiteReference site) {
private Iterator<IMethod> getPossibleTargetsFromCache(CGNode caller, CallSiteReference site) {
if (isCallToLambdaMetafactoryMethod(site)) {
IMethod calleeTarget = lambdaMethodTargetSelector.getCalleeTarget(caller, site, null);
if (calleeTarget != null) {
// it's for a lambda
return Collections.singleton(calleeTarget).iterator();
}
}
Set<IMethod> result = targetCache.get(site);
if (result == null) {
return Collections.emptyIterator();
Expand All @@ -190,7 +256,7 @@ private Iterator<IMethod> getPossibleTargetsFromCache(CallSiteReference site) {
public Set<CGNode> getPossibleTargets(CGNode node, CallSiteReference site) {
return Iterator2Collection.toSet(
new MapIterator<>(
new FilterIterator<>(getPossibleTargetsFromCache(site), this::isRelevantMethod),
new FilterIterator<>(getPossibleTargetsFromCache(node, site), this::isRelevantMethod),
object -> {
try {
return findOrCreateNode(object, Everywhere.EVERYWHERE);
Expand All @@ -203,7 +269,7 @@ public Set<CGNode> getPossibleTargets(CGNode node, CallSiteReference site) {

@Override
public int getNumberOfTargets(CGNode node, CallSiteReference site) {
return IteratorUtil.count(getPossibleTargetsFromCache(site));
return IteratorUtil.count(getPossibleTargetsFromCache(node, site));
}

@Override
Expand Down Expand Up @@ -279,9 +345,7 @@ private void closure() throws CancelException {
while (!newNodes.isEmpty()) {
CGNode n = newNodes.pop();
for (CallSiteReference site : Iterator2Iterable.make(n.iterateCallSites())) {
Iterator<IMethod> methods = getOrUpdatePossibleTargets(site);
while (methods.hasNext()) {
IMethod target = methods.next();
for (IMethod target : Iterator2Iterable.make(getOrUpdatePossibleTargets(n, site))) {
if (isRelevantMethod(target)) {
CGNode callee = getNode(target, Everywhere.EVERYWHERE);
if (callee == null) {
Expand All @@ -297,6 +361,13 @@ private void closure() throws CancelException {
}
}

private boolean isCallToLambdaMetafactoryMethod(CallSiteReference site) {
return site.getDeclaredTarget()
.getDeclaringClass()
.getName()
.equals(TypeReference.LambdaMetaFactory.getName());
}

private boolean isRelevantMethod(IMethod target) {
return !target.isAbstract()
&& (!applicationOnly
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/com/ibm/wala/ipa/cha/ClassHierarchy.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public enum MissingSuperClassHandling {
/** A mapping from IClass -&gt; Selector -&gt; Set of IMethod */
private final HashMap<IClass, Object> targetCache = HashMapFactory.make();

@Override
public void clearCaches() {
targetCache.clear();
}

/** Governing analysis scope */
private final AnalysisScope scope;

Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/com/ibm/wala/ipa/cha/IClassHierarchy.java
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,10 @@ public interface IClassHierarchy extends Iterable<IClass> {
* @throws IllegalArgumentException if c2 is null
*/
boolean isAssignableFrom(IClass c1, IClass c2);

/**
* Clear internal caches that may be invalidated by addition of new classes, e.g., a cache of the
* results of {@link #getPossibleTargets(MethodReference)}.
*/
void clearCaches();
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ public IMethod getCalleeTarget(CGNode caller, CallSiteReference site, IClass rec
return base.getCalleeTarget(caller, site, receiver);
}

/**
* Gets the summary class for a lambda factory, if it has already been created.
*
* @param caller the caller node
* @param site the call site reference
* @return the summary class for the lambda factory, or {@code null} if it has not been created
*/
public LambdaSummaryClass getLambdaSummaryClass(CGNode caller, CallSiteReference site) {
IR ir = caller.getIR();
if (ir.getCallInstructionIndices(site) != null) {
SSAAbstractInvokeInstruction call = ir.getCalls(site)[0];
if (call instanceof SSAInvokeDynamicInstruction) {
SSAInvokeDynamicInstruction invoke = (SSAInvokeDynamicInstruction) call;
BootstrapMethod bootstrap = invoke.getBootstrap();
return classSummaries.get(bootstrap);
}
}
return null;
}

/**
* Create a summary for a lambda factory, as it would be generated by the lambda metafactory. The
* lambda factory summary returns an instance of the summary anonymous class for the lambda (see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ public IClass getSuperclass() {
*/
@Override
public Collection<? extends IClass> getDirectInterfaces() {
return Collections.singleton(getClassHierarchy().lookupClass(invoke.getDeclaredResultType()));
IClass resultType = getClassHierarchy().lookupClass(invoke.getDeclaredResultType());
return resultType != null ? Collections.singleton(resultType) : Collections.emptySet();
}

/**
Expand All @@ -126,6 +127,9 @@ public Collection<? extends IClass> getDirectInterfaces() {
@Override
public Collection<IClass> getAllImplementedInterfaces() {
IClass iface = getClassHierarchy().lookupClass(invoke.getDeclaredResultType());
if (iface == null) {
return Collections.emptySet();
}
Set<IClass> result = HashSetFactory.make(iface.getAllImplementedInterfaces());
result.add(iface);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,36 @@
*/
package com.ibm.wala.core.tests.callGraph;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.ibm.wala.classLoader.CallSiteReference;
import com.ibm.wala.core.tests.util.TestConstants;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.ipa.callgraph.AnalysisScope;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.CallGraphStats;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.callgraph.cha.CHACallGraph;
import com.ibm.wala.ipa.callgraph.impl.Util;
import com.ibm.wala.ipa.callgraph.util.CallGraphSearchUtil;
import com.ibm.wala.ipa.cha.ClassHierarchyException;
import com.ibm.wala.ipa.cha.ClassHierarchyFactory;
import com.ibm.wala.ipa.cha.IClassHierarchy;
import com.ibm.wala.types.ClassLoaderReference;
import com.ibm.wala.types.Descriptor;
import com.ibm.wala.types.MethodReference;
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.CancelException;
import com.ibm.wala.util.collections.Iterator2Collection;
import com.ibm.wala.util.collections.Iterator2Iterable;
import com.ibm.wala.util.intset.IntSet;
import java.io.IOException;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.jupiter.api.Test;

Expand All @@ -41,6 +54,93 @@ public void testJava_cup()
CallGraphTestUtil.REGRESSION_EXCLUSIONS);
}

@Test
public void testLambdaAndAnonymous()
throws ClassHierarchyException, CancelException, IOException {
CallGraph cg =
testCHA(
TestConstants.WALA_TESTDATA,
"Llambda/LambdaAndAnonymous",
CallGraphTestUtil.REGRESSION_EXCLUSIONS);
CGNode mainMethod = CallGraphSearchUtil.findMainMethod(cg);
for (CallSiteReference site : Iterator2Iterable.make(mainMethod.iterateCallSites())) {
if (site.isInterface() && site.getDeclaredTarget().getName().toString().equals("target")) {
assertEquals(2, cg.getNumberOfTargets(mainMethod, site));
}
}
}

@Test
public void testLambdaParamsAndCapture()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
CallGraph cg =
testCHA(
TestConstants.WALA_TESTDATA,
"Llambda/ParamsAndCapture",
CallGraphTestUtil.REGRESSION_EXCLUSIONS);
Function<String, MethodReference> getTargetRef =
(klass) ->
MethodReference.findOrCreate(
TypeReference.findOrCreate(
ClassLoaderReference.Application, "Llambda/ParamsAndCapture$" + klass),
Atom.findOrCreateUnicodeAtom("target"),
Descriptor.findOrCreateUTF8("()V"));

Consumer<String> checkCalledFromFiveSites =
(klassName) -> {
Set<CGNode> nodes = cg.getNodes(getTargetRef.apply(klassName));
assertEquals(1, nodes.size(), "expected " + klassName + ".target() to be reachable");
CGNode node = nodes.iterator().next();
List<CGNode> predNodes = Iterator2Collection.toList(cg.getPredNodes(node));
assertEquals(
1,
predNodes.size(),
"expected " + klassName + ".target() to be invoked from one calling method");
CGNode pred = predNodes.get(0);
List<CallSiteReference> sites =
Iterator2Collection.toList(cg.getPossibleSites(pred, node));
assertEquals(
5,
sites.size(),
"expected " + klassName + ".target() to be invoked from five call sites");
};

checkCalledFromFiveSites.accept("C1");
checkCalledFromFiveSites.accept("C2");
checkCalledFromFiveSites.accept("C3");
checkCalledFromFiveSites.accept("C4");
checkCalledFromFiveSites.accept("C5");
}

@Test
public void testMethodRefs()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {

CallGraph cg =
testCHA(
TestConstants.WALA_TESTDATA,
"Llambda/MethodRefs",
CallGraphTestUtil.REGRESSION_EXCLUSIONS);

Function<String, MethodReference> getTargetRef =
(klass) ->
MethodReference.findOrCreate(
TypeReference.findOrCreate(
ClassLoaderReference.Application, "Llambda/MethodRefs$" + klass),
Atom.findOrCreateUnicodeAtom("target"),
Descriptor.findOrCreateUTF8("()V"));
assertEquals(
1, cg.getNodes(getTargetRef.apply("C1")).size(), "expected C1.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C2")).size(), "expected C2.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C3")).size(), "expected C3.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C4")).size(), "expected C4.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C5")).size(), "expected C5.target() to be reachable");
}

public static CallGraph testCHA(
String scopeFile, final String mainClass, final String exclusionsFile)
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Expand Down Expand Up @@ -69,7 +169,9 @@ public static CallGraph testCHA(
assertNotNull(
predNodeNumbers,
"no predecessors for " + succNode + " which is called by " + node);
assertTrue(predNodeNumbers.contains(nodeNum));
assertTrue(
predNodeNumbers.contains(nodeNum),
"missing predecessor " + node + " for " + succNode);
});
}
return CG;
Expand Down
Loading

0 comments on commit 9f71651

Please sign in to comment.