Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lambda support to CHACallGraph #1476

Merged
merged 27 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,11 @@ public static CrossLanguageClassHierarchy make(AnalysisScope scope, ClassLoaderF
public Set<TypeReference> getUnresolvedClasses() {
return HashSetFactory.make();
}

/* END Custom change: unresolved classes */

@Override
public void clearCaches() {
hierarchies.values().forEach(IClassHierarchy::clearCaches);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
import com.ibm.wala.ipa.callgraph.impl.ExplicitPredecessorsEdgeManager;
import com.ibm.wala.ipa.callgraph.impl.FakeWorldClinitMethod;
import com.ibm.wala.ipa.cha.IClassHierarchy;
import com.ibm.wala.ipa.summaries.LambdaMethodTargetSelector;
import com.ibm.wala.ipa.summaries.LambdaSummaryClass;
import com.ibm.wala.shrike.shrikeBT.IInvokeInstruction;
import com.ibm.wala.ssa.DefUse;
import com.ibm.wala.ssa.IR;
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.CancelException;
import com.ibm.wala.util.collections.ComposedIterator;
import com.ibm.wala.util.collections.FilterIterator;
Expand Down Expand Up @@ -60,6 +63,10 @@ public class CHACallGraph extends BasicCallGraph<CHAContextInterpreter> {
*/
private final boolean applicationOnly;

/** To handle lambdas. We pass a selector that always returns null as the base selector. */
private final LambdaMethodTargetSelector lambdaMethodTargetSelector =
new LambdaMethodTargetSelector((caller, site, receiver) -> null);

private boolean isInitialized = false;

private class CHANode extends NodeImpl {
Expand Down Expand Up @@ -138,6 +145,25 @@ public void init(Iterable<Entrypoint> entrypoints) throws CancelException {
}
newNodes.push(root);
closure();
// classes simulating lambdas may have been added to the CHA via the previous closure() call.
// to update call targets to include lambdas, we clear all call target caches, iterate through
// all call sites, and re-compute the targets.
// TODO optimize if needed
targetCache.clear();
cha.clearCaches();
for (CGNode n : this) {
for (CallSiteReference site : Iterator2Iterable.make(n.iterateCallSites())) {
for (IMethod target : Iterator2Iterable.make(getOrUpdatePossibleTargets(n, site))) {
if (isRelevantMethod(target)) {
CGNode callee = getNode(target, Everywhere.EVERYWHERE);
if (callee == null) {
throw new RuntimeException("should have already created CGNode for " + target);
}
edgeManager.addEdge(n, callee);
}
}
}
}
isInitialized = true;
}

Expand All @@ -146,6 +172,13 @@ public IClassHierarchy getClassHierarchy() {
return cha;
}

/**
* Cache of possible targets for call sites.
*
* <p>In the future, this cache could be keyed on ({@link com.ibm.wala.types.MethodReference},
* {@code isDispatch}) pairs to save space and possibly time, where {@code isDispatch} indicates
* whether the call site is a virtual dispatch.
*/
private final Map<CallSiteReference, Set<IMethod>> targetCache = HashMapFactory.make();

/**
Expand All @@ -154,8 +187,29 @@ public IClassHierarchy getClassHierarchy() {
* @param site the call site
* @return an iterator of possible targets
*/
private Iterator<IMethod> getOrUpdatePossibleTargets(CallSiteReference site) {
Set<IMethod> result = targetCache.get(site);
private Iterator<IMethod> getOrUpdatePossibleTargets(CGNode caller, CallSiteReference site)
throws CancelException {
Set<IMethod> result = null;
if (isCallToLambdaMetafactoryMethod(site)) {
IMethod calleeTarget = lambdaMethodTargetSelector.getCalleeTarget(caller, site, null);
if (calleeTarget != null) {
// It's for a lambda. The result method is a synthetic method that allocates an object of
// the synthetic class generate for the lambda.
result = Collections.singleton(calleeTarget);
// we eagerly create a CGNode for the "trampoline" method that invokes the body of the
// lambda itself. This way, the new node gets added to the worklist, so we process all
// methods reachable from the lambda body immediately and don't need to do an outer fixed
// point. This does not do any wasted work assuming the call graph has at least one
// invocation of the lambda.
LambdaSummaryClass lambdaSummaryClass =
lambdaMethodTargetSelector.getLambdaSummaryClass(caller, site);
IMethod trampoline = lambdaSummaryClass.getDeclaredMethods().iterator().next();
CGNode callee = getNode(trampoline, Everywhere.EVERYWHERE);
if (callee == null) {
callee = findOrCreateNode(trampoline, Everywhere.EVERYWHERE);
}
}
}
if (result == null) {
if (site.isDispatch()) {
result = cha.getPossibleTargets(site.getDeclaredTarget());
Expand All @@ -164,7 +218,12 @@ private Iterator<IMethod> getOrUpdatePossibleTargets(CallSiteReference site) {
if (m != null) {
result = Collections.singleton(m);
} else {
result = Collections.emptySet();
IMethod fakeWorldClinitMethod = getFakeWorldClinitNode().getMethod();
if (site.getDeclaredTarget().equals(fakeWorldClinitMethod.getReference())) {
result = Collections.singleton(fakeWorldClinitMethod);
} else {
result = Collections.emptySet();
}
}
}
targetCache.put(site, result);
Expand All @@ -178,7 +237,14 @@ private Iterator<IMethod> getOrUpdatePossibleTargets(CallSiteReference site) {
* @param site the call site
* @return an iterator of possible targets
*/
private Iterator<IMethod> getPossibleTargetsFromCache(CallSiteReference site) {
private Iterator<IMethod> getPossibleTargetsFromCache(CGNode caller, CallSiteReference site) {
if (isCallToLambdaMetafactoryMethod(site)) {
IMethod calleeTarget = lambdaMethodTargetSelector.getCalleeTarget(caller, site, null);
if (calleeTarget != null) {
// it's for a lambda
return Collections.singleton(calleeTarget).iterator();
}
}
Set<IMethod> result = targetCache.get(site);
if (result == null) {
return Collections.emptyIterator();
Expand All @@ -190,7 +256,7 @@ private Iterator<IMethod> getPossibleTargetsFromCache(CallSiteReference site) {
public Set<CGNode> getPossibleTargets(CGNode node, CallSiteReference site) {
return Iterator2Collection.toSet(
new MapIterator<>(
new FilterIterator<>(getPossibleTargetsFromCache(site), this::isRelevantMethod),
new FilterIterator<>(getPossibleTargetsFromCache(node, site), this::isRelevantMethod),
object -> {
try {
return findOrCreateNode(object, Everywhere.EVERYWHERE);
Expand All @@ -203,7 +269,7 @@ public Set<CGNode> getPossibleTargets(CGNode node, CallSiteReference site) {

@Override
public int getNumberOfTargets(CGNode node, CallSiteReference site) {
return IteratorUtil.count(getPossibleTargetsFromCache(site));
return IteratorUtil.count(getPossibleTargetsFromCache(node, site));
}

@Override
Expand Down Expand Up @@ -279,9 +345,7 @@ private void closure() throws CancelException {
while (!newNodes.isEmpty()) {
CGNode n = newNodes.pop();
for (CallSiteReference site : Iterator2Iterable.make(n.iterateCallSites())) {
Iterator<IMethod> methods = getOrUpdatePossibleTargets(site);
while (methods.hasNext()) {
IMethod target = methods.next();
for (IMethod target : Iterator2Iterable.make(getOrUpdatePossibleTargets(n, site))) {
if (isRelevantMethod(target)) {
CGNode callee = getNode(target, Everywhere.EVERYWHERE);
if (callee == null) {
Expand All @@ -297,6 +361,13 @@ private void closure() throws CancelException {
}
}

private boolean isCallToLambdaMetafactoryMethod(CallSiteReference site) {
return site.getDeclaredTarget()
.getDeclaringClass()
.getName()
.equals(TypeReference.LambdaMetaFactory.getName());
}

private boolean isRelevantMethod(IMethod target) {
return !target.isAbstract()
&& (!applicationOnly
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/com/ibm/wala/ipa/cha/ClassHierarchy.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public enum MissingSuperClassHandling {
/** A mapping from IClass -&gt; Selector -&gt; Set of IMethod */
private final HashMap<IClass, Object> targetCache = HashMapFactory.make();

@Override
public void clearCaches() {
targetCache.clear();
}

/** Governing analysis scope */
private final AnalysisScope scope;

Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/com/ibm/wala/ipa/cha/IClassHierarchy.java
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,10 @@ public interface IClassHierarchy extends Iterable<IClass> {
* @throws IllegalArgumentException if c2 is null
*/
boolean isAssignableFrom(IClass c1, IClass c2);

/**
* Clear internal caches that may be invalidated by addition of new classes, e.g., a cache of the
* results of {@link #getPossibleTargets(MethodReference)}.
*/
void clearCaches();
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ public IMethod getCalleeTarget(CGNode caller, CallSiteReference site, IClass rec
return base.getCalleeTarget(caller, site, receiver);
}

/**
* Gets the summary class for a lambda factory, if it has already been created.
*
* @param caller the caller node
* @param site the call site reference
* @return the summary class for the lambda factory, or {@code null} if it has not been created
*/
public LambdaSummaryClass getLambdaSummaryClass(CGNode caller, CallSiteReference site) {
IR ir = caller.getIR();
if (ir.getCallInstructionIndices(site) != null) {
SSAAbstractInvokeInstruction call = ir.getCalls(site)[0];
if (call instanceof SSAInvokeDynamicInstruction) {
SSAInvokeDynamicInstruction invoke = (SSAInvokeDynamicInstruction) call;
BootstrapMethod bootstrap = invoke.getBootstrap();
return classSummaries.get(bootstrap);
}
}
return null;
}

/**
* Create a summary for a lambda factory, as it would be generated by the lambda metafactory. The
* lambda factory summary returns an instance of the summary anonymous class for the lambda (see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ public IClass getSuperclass() {
*/
@Override
public Collection<? extends IClass> getDirectInterfaces() {
return Collections.singleton(getClassHierarchy().lookupClass(invoke.getDeclaredResultType()));
IClass resultType = getClassHierarchy().lookupClass(invoke.getDeclaredResultType());
return resultType != null ? Collections.singleton(resultType) : Collections.emptySet();
}

/**
Expand All @@ -126,6 +127,9 @@ public Collection<? extends IClass> getDirectInterfaces() {
@Override
public Collection<IClass> getAllImplementedInterfaces() {
IClass iface = getClassHierarchy().lookupClass(invoke.getDeclaredResultType());
if (iface == null) {
return Collections.emptySet();
}
Set<IClass> result = HashSetFactory.make(iface.getAllImplementedInterfaces());
result.add(iface);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,36 @@
*/
package com.ibm.wala.core.tests.callGraph;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.ibm.wala.classLoader.CallSiteReference;
import com.ibm.wala.core.tests.util.TestConstants;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.ipa.callgraph.AnalysisScope;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.CallGraphStats;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.callgraph.cha.CHACallGraph;
import com.ibm.wala.ipa.callgraph.impl.Util;
import com.ibm.wala.ipa.callgraph.util.CallGraphSearchUtil;
import com.ibm.wala.ipa.cha.ClassHierarchyException;
import com.ibm.wala.ipa.cha.ClassHierarchyFactory;
import com.ibm.wala.ipa.cha.IClassHierarchy;
import com.ibm.wala.types.ClassLoaderReference;
import com.ibm.wala.types.Descriptor;
import com.ibm.wala.types.MethodReference;
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.CancelException;
import com.ibm.wala.util.collections.Iterator2Collection;
import com.ibm.wala.util.collections.Iterator2Iterable;
import com.ibm.wala.util.intset.IntSet;
import java.io.IOException;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.jupiter.api.Test;

Expand All @@ -41,6 +54,93 @@ public void testJava_cup()
CallGraphTestUtil.REGRESSION_EXCLUSIONS);
}

@Test
public void testLambdaAndAnonymous()
throws ClassHierarchyException, CancelException, IOException {
CallGraph cg =
testCHA(
TestConstants.WALA_TESTDATA,
"Llambda/LambdaAndAnonymous",
CallGraphTestUtil.REGRESSION_EXCLUSIONS);
CGNode mainMethod = CallGraphSearchUtil.findMainMethod(cg);
for (CallSiteReference site : Iterator2Iterable.make(mainMethod.iterateCallSites())) {
if (site.isInterface() && site.getDeclaredTarget().getName().toString().equals("target")) {
assertEquals(2, cg.getNumberOfTargets(mainMethod, site));
}
}
}

@Test
public void testLambdaParamsAndCapture()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
CallGraph cg =
testCHA(
TestConstants.WALA_TESTDATA,
"Llambda/ParamsAndCapture",
CallGraphTestUtil.REGRESSION_EXCLUSIONS);
Function<String, MethodReference> getTargetRef =
(klass) ->
MethodReference.findOrCreate(
TypeReference.findOrCreate(
ClassLoaderReference.Application, "Llambda/ParamsAndCapture$" + klass),
Atom.findOrCreateUnicodeAtom("target"),
Descriptor.findOrCreateUTF8("()V"));

Consumer<String> checkCalledFromFiveSites =
(klassName) -> {
Set<CGNode> nodes = cg.getNodes(getTargetRef.apply(klassName));
assertEquals(1, nodes.size(), "expected " + klassName + ".target() to be reachable");
CGNode node = nodes.iterator().next();
List<CGNode> predNodes = Iterator2Collection.toList(cg.getPredNodes(node));
assertEquals(
1,
predNodes.size(),
"expected " + klassName + ".target() to be invoked from one calling method");
CGNode pred = predNodes.get(0);
List<CallSiteReference> sites =
Iterator2Collection.toList(cg.getPossibleSites(pred, node));
assertEquals(
5,
sites.size(),
"expected " + klassName + ".target() to be invoked from five call sites");
};

checkCalledFromFiveSites.accept("C1");
checkCalledFromFiveSites.accept("C2");
checkCalledFromFiveSites.accept("C3");
checkCalledFromFiveSites.accept("C4");
checkCalledFromFiveSites.accept("C5");
}

@Test
public void testMethodRefs()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {

CallGraph cg =
testCHA(
TestConstants.WALA_TESTDATA,
"Llambda/MethodRefs",
CallGraphTestUtil.REGRESSION_EXCLUSIONS);

Function<String, MethodReference> getTargetRef =
(klass) ->
MethodReference.findOrCreate(
TypeReference.findOrCreate(
ClassLoaderReference.Application, "Llambda/MethodRefs$" + klass),
Atom.findOrCreateUnicodeAtom("target"),
Descriptor.findOrCreateUTF8("()V"));
assertEquals(
1, cg.getNodes(getTargetRef.apply("C1")).size(), "expected C1.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C2")).size(), "expected C2.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C3")).size(), "expected C3.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C4")).size(), "expected C4.target() to be reachable");
assertEquals(
1, cg.getNodes(getTargetRef.apply("C5")).size(), "expected C5.target() to be reachable");
}

public static CallGraph testCHA(
String scopeFile, final String mainClass, final String exclusionsFile)
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Expand Down Expand Up @@ -69,7 +169,9 @@ public static CallGraph testCHA(
assertNotNull(
predNodeNumbers,
"no predecessors for " + succNode + " which is called by " + node);
assertTrue(predNodeNumbers.contains(nodeNum));
assertTrue(
predNodeNumbers.contains(nodeNum),
"missing predecessor " + node + " for " + succNode);
});
}
return CG;
Expand Down
Loading
Loading