Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytest entrypoints #151

Merged
merged 4 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ gradle-app.setting

# Cache of project
.gradletasknamecache
/.pytest_cache/

# Eclipse Gradle plugin generated files
# Eclipse Core
Expand Down
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# From: https://docs.pytest.org/en/8.0.x/getting-started.html#group-multiple-tests-in-a-class.


# content of test_class.py
class TestClass:

def test_one(self):
x = "this"
assert "h" in x

def test_two(self):
x = "hello"
assert hasattr(x, "check")
16 changes: 16 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_class2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# From: https://docs.pytest.org/en/8.0.x/getting-started.html#group-multiple-tests-in-a-class.


# content of test_class.py
class TestClass:

def __init__(self):
pass

def test_one(self):
x = "this"
assert "h" in x

def test_two(self):
x = "hello"
assert hasattr(x, "check")
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# From https://docs.pytest.org/en/8.0.x/getting-started.html#create-your-first-test.


# content of test_sample.py
def func(x):
return x + 1


def test_answer():
assert func(3) == 5
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
package com.ibm.wala.cast.python.test;

import static com.google.common.collect.Iterables.concat;
import static java.util.Collections.singleton;

import com.ibm.wala.cast.ipa.callgraph.CAstCallGraphUtil;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ipa.callgraph.PytestEntrypointBuilder;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import com.ibm.wala.ipa.callgraph.propagation.SSAContextInterpreter;
import com.ibm.wala.ipa.cha.ClassHierarchyException;
import com.ibm.wala.util.CancelException;
import java.io.IOException;
import java.util.Collections;
import java.util.logging.Logger;
import org.junit.Test;

public class TestCalls extends TestPythonCallGraphShape {

private static final Logger LOGGER = Logger.getLogger(TestCalls.class.getName());

protected static final Object[][] assertionsCalls1 =
new Object[][] {
new Object[] {ROOT, new String[] {"script calls1.py"}},
Expand Down Expand Up @@ -232,4 +240,112 @@ public Void performAnalysis(PropagationCallGraphBuilder builder) throws CancelEx
CG);
verifyGraphAssertions(CG, assertionsDefaultValues);
}

protected static final Object[][] PYTEST_ASSERTIONS =
new Object[][] {
new Object[] {
ROOT, new String[] {"script test_sample.py", "script test_sample.py/test_answer"}
},
new Object[] {
"script test_sample.py/test_answer", new String[] {"script test_sample.py/func"}
},
};

@Test
public void testPytestCalls()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {

PythonAnalysisEngine<?> engine =
new PythonAnalysisEngine<Void>() {
@Override
public Void performAnalysis(PropagationCallGraphBuilder builder) throws CancelException {
assert false;
return null;
}
};

engine.setModuleFiles(singleton(getScript("test_sample.py")));

PropagationCallGraphBuilder callGraphBuilder =
(PropagationCallGraphBuilder) engine.defaultCallGraphBuilder();

addPytestEntrypoints(callGraphBuilder);

CallGraph callGraph = callGraphBuilder.makeCallGraph(callGraphBuilder.getOptions());

CAstCallGraphUtil.AVOID_DUMP = false;
CAstCallGraphUtil.dumpCG(
(SSAContextInterpreter) callGraphBuilder.getContextInterpreter(),
callGraphBuilder.getPointerAnalysis(),
callGraph);

verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS);
}

private static void addPytestEntrypoints(PropagationCallGraphBuilder callGraphBuilder) {
Iterable<? extends Entrypoint> defaultEntrypoints =
callGraphBuilder.getOptions().getEntrypoints();

Iterable<Entrypoint> pytestEntrypoints =
new PytestEntrypointBuilder().createEntrypoints(callGraphBuilder.getClassHierarchy());

Iterable<Entrypoint> entrypoints = concat(defaultEntrypoints, pytestEntrypoints);

callGraphBuilder.getOptions().setEntrypoints(entrypoints);

for (Entrypoint ep : callGraphBuilder.getOptions().getEntrypoints())
LOGGER.info(() -> "Using entrypoint: " + ep.getMethod().getDeclaringClass().getName() + ".");
}

protected static final Object[][] PYTEST_ASSERTIONS2 =
new Object[][] {
new Object[] {
ROOT,
new String[] {
"script test_class.py",
"script test_class.py/TestClass",
"$script test_class.py/TestClass/test_one:trampoline2",
"$script test_class.py/TestClass/test_two:trampoline2"
}
},
};

@Test
public void testPytestCalls2()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
PythonAnalysisEngine<?> engine = this.makeEngine("test_class.py");
PropagationCallGraphBuilder callGraphBuilder = engine.defaultCallGraphBuilder();

addPytestEntrypoints(callGraphBuilder);

CallGraph callGraph = callGraphBuilder.makeCallGraph(callGraphBuilder.getOptions());

CAstCallGraphUtil.AVOID_DUMP = false;
CAstCallGraphUtil.dumpCG(
(SSAContextInterpreter) callGraphBuilder.getContextInterpreter(),
callGraphBuilder.getPointerAnalysis(),
callGraph);

verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS2);
}

protected static final Object[][] PYTEST_ASSERTIONS3 =
new Object[][] {
new Object[] {ROOT, new String[] {"script test_class2.py"}},
};

@Test
public void testPytestCalls3()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
PythonAnalysisEngine<?> engine = this.makeEngine("test_class2.py");
PropagationCallGraphBuilder callGraphBuilder = engine.defaultCallGraphBuilder();
addPytestEntrypoints(callGraphBuilder);
CallGraph callGraph = callGraphBuilder.makeCallGraph(callGraphBuilder.getOptions());
CAstCallGraphUtil.AVOID_DUMP = false;
CAstCallGraphUtil.dumpCG(
(SSAContextInterpreter) callGraphBuilder.getContextInterpreter(),
callGraphBuilder.getPointerAnalysis(),
callGraph);
verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS3);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package com.ibm.wala.cast.python.ipa.callgraph;

import static com.ibm.wala.cast.python.types.Util.getFilename;
import static java.util.Objects.requireNonNull;

import com.ibm.wala.cast.python.loader.PythonLoader.DynamicMethodBody;
import com.ibm.wala.cast.python.loader.PythonLoader.PythonClass;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.client.AbstractAnalysisEngine.EntrypointBuilder;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.cha.IClassHierarchy;
import com.ibm.wala.types.MethodReference;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.util.collections.HashSetFactory;
import java.util.HashSet;
import java.util.logging.Logger;

/**
* This class represents entry points ({@link Entrypoint})s of Pytest test functions. Pytest test
* functions are those invoked by the pytest framework reflectively. The entry points can be used to
* specify entry points of a call graph.
*/
public class PytestEntrypointBuilder implements EntrypointBuilder {

private static final Logger logger = Logger.getLogger(PytestEntrypointBuilder.class.getName());

/**
* Construct pytest entrypoints for all the pytest test functions in the given scope.
*
* @throws NullPointerException If the given {@link IClassHierarchy} is null.
*/
@Override
public Iterable<Entrypoint> createEntrypoints(IClassHierarchy cha) {
requireNonNull(cha);

final HashSet<Entrypoint> result = HashSetFactory.make();

for (IClass klass : cha) {
// if the class is a pytest test case,
if (isPytestCase(klass)) {
logger.fine(() -> "Pytest case: " + klass + ".");

MethodReference methodReference =
MethodReference.findOrCreate(klass.getReference(), AstMethodReference.fnSelector);

result.add(new PytesttEntrypoint(methodReference, cha));

logger.fine(() -> "Adding test method as entry point: " + methodReference.getName() + ".");
}
}

return result::iterator;
}

/**
* Check if the given class is a Pytest test class according to: https://bit.ly/3wj8nPY.
*
* @throws NullPointerException If the given {@link IClass} is null.
* @see https://bit.ly/3wj8nPY.
*/
public static boolean isPytestCase(IClass klass) {
requireNonNull(klass);

final TypeName typeName = klass.getReference().getName();

if (typeName.toString().startsWith("Lscript ")) {
final String fileName = getFilename(typeName);
final Atom className = typeName.getClassName();

// In Ariadne, a script is an invokable entity like a function.
final boolean script = className.toString().endsWith(".py");

if (!script // it's not an invokable script.
&& (fileName.startsWith("test_")
|| fileName.endsWith("_test")) // we're inside of a "test" file,
&& !(klass instanceof PythonClass)) { // classes aren't entrypoints.
if (klass instanceof DynamicMethodBody) {
// It's a method. In Ariadne, functions are also classes.
DynamicMethodBody dmb = (DynamicMethodBody) klass;
IClass container = dmb.getContainer();
String containerName = container.getReference().getName().getClassName().toString();

if (containerName.startsWith("Test") && container instanceof PythonClass) {
// It's a test class.
PythonClass containerClass = (PythonClass) container;

final boolean hasCtor =
containerClass.getMethodReferences().stream()
.anyMatch(
mr -> {
return mr.getName().toString().equals("__init__");
});

// Test classes can't have constructors.
if (!hasCtor) {
// In Ariadne, methods are modeled as classes. Thus, a class name in this case is the
// method name.
String methodName = className.toString();

// If the method starts with "test."
if (methodName.startsWith("test")) return true;
}
}
} else if (className.toString().startsWith("test")) return true; // It's a function.
}
}

return false;
}
}
Loading
Loading