Skip to content

Commit

Permalink
Add support and tests for callables (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad authored Dec 4, 2023
1 parent eec6f81 commit 3b9be89
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.ibm.wala.cast.python.client;

import com.ibm.wala.cast.python.ipa.callgraph.PythonSSAPropagationCallGraphBuilder;
import com.ibm.wala.cast.python.loader.PytestLoader;
import com.ibm.wala.cast.python.loader.PytestLoaderFactory;
import com.ibm.wala.classLoader.CallSiteReference;
Expand All @@ -10,7 +9,6 @@
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.IAnalysisCacheView;
import com.ibm.wala.ipa.callgraph.MethodTargetSelector;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
Expand All @@ -23,8 +21,6 @@

public class PytestAnalysisEngine<T> extends PythonAnalysisEngine<T> {

private PythonSSAPropagationCallGraphBuilder builder;

private class PytestTargetSelector implements MethodTargetSelector {
private final MethodTargetSelector base;

Expand Down Expand Up @@ -88,12 +84,6 @@ protected void addBypassLogic(IClassHierarchy cha, AnalysisOptions options) {
addSummaryBypassLogic(options, "pytest.xml");
}

@Override
protected PythonSSAPropagationCallGraphBuilder getCallGraphBuilder(
IClassHierarchy cha, AnalysisOptions options, IAnalysisCacheView cache) {
return builder = super.getCallGraphBuilder(cha, options, cache);
}

@Override
public T performAnalysis(PropagationCallGraphBuilder arg0) throws CancelException {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,13 @@ public void testTf2()
0,
0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once
// https://github.com/wala/ML/issues/89 is fixed.
testTf2(
"tf2_test_model_call.py",
"SequentialModel.__call__",
0,
2); // NOTE: Change to testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4,
// 2) once
// https://github.com/wala/ML/issues/24 is fixed.
testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 2);
testTf2(
"tf2_test_model_call2.py",
"SequentialModel.call",
0,
2); // NOTE: Change to testTf2("tf2_test_model_call2.py", "SequentialModel.call", 1, 4, 2)
// once
// https://github.com/wala/ML/issues/24 is fixed.
// once https://github.com/wala/ML/issues/106 is fixed.
testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 2);
testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 2);
}
Expand Down
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C:

def __call__(self, x):
return x * x


c = C()
a = c.__call__(5)
assert a == 25
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C:

def __call__(self, x):
return x * x


c = C()
a = c(5)
assert a == 25
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C(object):

def __call__(self, x):
return x * x


c = C()
a = c.__call__(5)
assert a == 25
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C(object):

def __call__(self, x):
return x * x


c = C()
a = c(5)
assert a == 25
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class D:
pass


class C(D):

def __call__(self, x):
return x * x


c = C()
a = c.__call__(5)
assert a == 25
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class D:
pass


class C(D):

def __call__(self, x):
return x * x


c = C()
a = c(5)
assert a == 25
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.ibm.wala.cast.python.test;

import static org.junit.Assert.assertTrue;

import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.cha.ClassHierarchyException;
import com.ibm.wala.util.CancelException;
import java.io.IOException;
import java.util.Iterator;
import java.util.logging.Logger;
import org.junit.Test;

public class TestCallables extends TestPythonCallGraphShape {

private static Logger logger = Logger.getLogger(TestCallables.class.getName());

@Test
public void testCallables()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
final String[] testFileNames = {
"callables.py",
"callables2.py",
"callables3.py",
"callables4.py",
"callables5.py",
"callables6.py"
};

for (String fileName : testFileNames) {
CallGraph CG = process(fileName);
boolean found = false;

for (CGNode node : CG) {
if (node.getMethod()
.getDeclaringClass()
.getName()
.toString()
.equals("Lscript " + fileName)) {

for (Iterator<CGNode> it = CG.getSuccNodes(node); it.hasNext(); ) {
CGNode callee = it.next();

logger.info("Found callee: " + callee.getMethod().getSignature());

if (callee
.getMethod()
.getDeclaringClass()
.getName()
.toString()
.equals("L$script " + fileName + "/C/__call__")) found = true;
}
}
}

assertTrue("Expecting to find __call__ method trampoline in: " + fileName + ".", found);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,15 @@
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

public abstract class PythonAnalysisEngine<T>
extends AbstractAnalysisEngine<InstanceKey, PythonSSAPropagationCallGraphBuilder, T> {

private static final Logger logger = Logger.getLogger(PythonAnalysisEngine.class.getName());

protected PythonSSAPropagationCallGraphBuilder builder;

static {
try {
Class<?> j3 = Class.forName("com.ibm.wala.cast.python.loader.Python3LoaderFactory");
Expand Down Expand Up @@ -274,9 +279,10 @@ public boolean isReferenceType() {

protected void addBypassLogic(IClassHierarchy cha, AnalysisOptions options) {
options.setSelector(
new PythonTrampolineTargetSelector(
new PythonTrampolineTargetSelector<T>(
new PythonConstructorTargetSelector(
new PythonComprehensionTrampolines(options.getMethodTargetSelector()))));
new PythonComprehensionTrampolines(options.getMethodTargetSelector())),
this));

BuiltinFunctions builtins = new BuiltinFunctions(cha);
options.setSelector(builtins.builtinClassTargetSelector(options.getClassTargetSelector()));
Expand Down Expand Up @@ -337,7 +343,11 @@ public int getDefaultValue(SymbolTable symtab, int valueNumber) {

new PythonSuper(cha).handleSuperCalls(builder, options);

return builder;
return this.builder = builder;
}

public PythonSSAPropagationCallGraphBuilder getCachedCallGraphBuilder() {
return this.builder;
}

protected PythonSSAPropagationCallGraphBuilder makeBuilder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IField;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.fixpoint.AbstractOperator;
Expand All @@ -25,6 +26,7 @@
import com.ibm.wala.ipa.callgraph.IAnalysisCacheView;
import com.ibm.wala.ipa.callgraph.propagation.AbstractFieldPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerKeyFactory;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
Expand All @@ -40,8 +42,11 @@
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.collections.HashMapFactory;
import com.ibm.wala.util.collections.Pair;
import com.ibm.wala.util.intset.IntIterator;
import com.ibm.wala.util.intset.IntSet;
import com.ibm.wala.util.intset.IntSetUtil;
import com.ibm.wala.util.intset.MutableIntSet;
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
Expand Down Expand Up @@ -211,7 +216,21 @@ protected void processCallingConstraints(
}
} else {
PointerKey rval = getPointerKeyForLocal(caller, call.getUse(i));
getSystem().newConstraint(lval, assignOperator, rval);

// If we are looking at the implicit parameter of a callable.
if (call.getCallSite().isDispatch() && i == 0 && refersToAnObject(rval)) {
// Ensure that lval's variable refers to the callable method instead of callable object.
IClass callable = target.getMethod().getDeclaringClass();
IntSet instanceKeysForCallable = this.getSystem().getInstanceKeysForClass(callable);

for (IntIterator it = instanceKeysForCallable.intIterator(); it.hasNext(); ) {
int instanceKeyIndex = it.next();
InstanceKey instanceKey = this.getSystem().getInstanceKey(instanceKeyIndex);
this.getSystem().newConstraint(lval, instanceKey);
}
} else {
getSystem().newConstraint(lval, assignOperator, rval);
}
}
}

Expand Down Expand Up @@ -271,6 +290,29 @@ protected void processCallingConstraints(
}
}

/**
* Returns true iff the given {@link PointerKey} points to at least one instance whose concrete
* type equals {@link PythonTypes#object}.
*
* @param pointerKey The {@link PointerKey} in question.
* @return True iff the given {@link PointerKey} points to at least one object whose concrete type
* equals {@link PythonTypes#object}.,
*/
protected boolean refersToAnObject(PointerKey pointerKey) {
PointerAnalysis<InstanceKey> pointerAnalysis = this.getPointerAnalysis();
OrdinalSet<InstanceKey> pointsToSet = pointerAnalysis.getPointsToSet(pointerKey);

for (InstanceKey instanceKey : pointsToSet) {
IClass concreteType = instanceKey.getConcreteType();
TypeReference reference = concreteType.getReference();

// If it's an "object" method.
if (reference.equals(PythonTypes.object)) return true;
}

return false;
}

@Override
public PythonConstraintVisitor makeVisitor(CGNode node) {
return new PythonConstraintVisitor(this, node);
Expand Down
Loading

0 comments on commit 3b9be89

Please sign in to comment.