Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pytest entrypoints with PYTHONPATH #107

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,68 @@ public void testDecoratedFunctions2()
test("test_decorated_functions.py", "test_dummy", 0, 0);
}

/**
* Test a pytest without decorators that needs a PYTHONPATH. This is a "control" case. We'll add a
* decorator in the next case.
*
* @see TestTensorflow2Model#testModule11().
*/
@Test
public void testDecoratedFunctions3()
throws ClassHierarchyException, CancelException, IOException {
test(
new String[] {
"proj48/src/__init__.py",
"proj48/src/tf2_test_module9a.py",
"proj48/src/tf2_test_module9b.py",
"proj48/src/test_module10.py"
},
"src/tf2_test_module9b.py",
"D.f",
"proj48",
1,
1,
new int[] {3});
}

/** Test a pytest without decorators. This is a "control." */
@Test
public void testDecoratedFunctions4()
throws ClassHierarchyException, CancelException, IOException {
test("test_decorated_functions2.py", "f", 1, 1, 2);
}

/** Test a pytest with a decorator. */
@Test
public void testDecoratedFunctions5()
throws ClassHierarchyException, CancelException, IOException {
test("test_decorated_functions3.py", "f", 1, 1, 2);
}

/**
* Test a pytest with a decorator that needs a PYTHONPATH.
*
* @see TestTensorflow2Model#testModule11().
*/
@Test
@Ignore("Blocked on https://github.com/wala/ML/issues/198.")
public void testDecoratedFunctions6()
throws ClassHierarchyException, CancelException, IOException {
test(
new String[] {
"proj49/src/__init__.py",
"proj49/src/tf2_test_module9a.py",
"proj49/src/tf2_test_module9b.py",
"proj49/src/test_module10.py"
},
"src/tf2_test_module9b.py",
"D.f",
"proj49",
1,
1,
new int[] {3});
}

@Test
public void testReshape() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_reshape.py", "f", 1, 1, 2);
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import ones
from src.tf2_test_module9b import D


def test_dummy(x, test_input, expected):
D().f(ones([1, 2]))
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor


class C:

def f(self, a):
assert isinstance(a, Tensor)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor
from src.tf2_test_module9a import C


class D(C):

def f(self, a):
assert isinstance(a, Tensor)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import ones
from src.tf2_test_module9b import D


@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)])
def test_dummy(x, test_input, expected):
D().f(ones([1, 2]))
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor


class C:

def f(self, a):
assert isinstance(a, Tensor)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor
from src.tf2_test_module9a import C


class D(C):

def f(self, a):
assert isinstance(a, Tensor)
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_decorated_functions2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import tensorflow as tf
import pytest


def f(a):
assert isinstance(a, tf.Tensor)


def test_dummy(x, test_input, expected):
f(tf.constant(1))
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_decorated_functions3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf
import pytest


def f(a):
assert isinstance(a, tf.Tensor)


@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)])
def test_dummy(x, test_input, expected):
f(tf.constant(1))
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.ibm.wala.cast.python.ipa.callgraph;

import static com.ibm.wala.cast.python.types.Util.getFilename;
import static com.ibm.wala.cast.python.types.Util.getRelativeFilename;
import static com.ibm.wala.cast.python.util.Util.PYTHON_FILE_EXTENSION;
import static java.util.Objects.requireNonNull;

import com.ibm.wala.cast.python.loader.PythonLoader.DynamicMethodBody;
Expand Down Expand Up @@ -47,7 +48,11 @@ public Iterable<Entrypoint> createEntrypoints(IClassHierarchy cha) {

result.add(new PytesttEntrypoint(methodReference, cha));

logger.fine(() -> "Adding test method as entry point: " + methodReference.getName() + ".");
logger.fine(
() ->
"Adding test method as entry point: "
+ methodReference.getDeclaringClass().getName()
+ ".");
}
}

Expand All @@ -66,11 +71,11 @@ public static boolean isPytestCase(IClass klass) {
final TypeName typeName = klass.getReference().getName();

if (typeName.toString().startsWith("Lscript ")) {
final String fileName = getFilename(typeName);
final String fileName = getRelativeFilename(typeName);
final Atom className = typeName.getClassName();

// In Ariadne, a script is an invokable entity like a function.
final boolean script = className.toString().endsWith(".py");
final boolean script = className.toString().endsWith(PYTHON_FILE_EXTENSION);

if (!script // it's not an invokable script.
&& (fileName.startsWith("test_")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,46 @@
package com.ibm.wala.cast.python.types;

import static com.ibm.wala.cast.python.util.Util.PYTHON_FILE_EXTENSION;

import com.ibm.wala.cast.types.AstTypeReference;
import com.ibm.wala.classLoader.IClassLoader;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.types.FieldReference;
import com.ibm.wala.types.MethodReference;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.types.TypeReference;
import java.util.Arrays;

public class Util {

private static final String GLOBAL_KEYWORD = "global";

/**
* Returns the filename portion of the given {@link TypeName} representing a Python type.
* Returns the relative filename portion of the given {@link TypeName} representing a Python type.
*
* @param typeName A {@link TypeName} of a Python type.
* @return The filename portion of the given {@link TypeName}.
* @return The relative filename portion of the given {@link TypeName}.
* @apiNote Python types include a file in their {@link TypeName}s in Ariadne.
*/
public static String getFilename(final TypeName typeName) {
String ret = typeName.toString();
ret = ret.substring("Lscript ".length());
public static String getRelativeFilename(final TypeName typeName) {
String typeNameString = typeName.toString();

// Remove the script prefix.
typeNameString = typeNameString.substring("Lscript ".length());

// Extract the filename.
String[] segments = typeNameString.split("/");

String filename =
Arrays.stream(segments)
.filter(s -> s.endsWith(PYTHON_FILE_EXTENSION))
.findFirst()
.orElseThrow();

if (ret.indexOf('/') != -1) ret = ret.substring(0, ret.indexOf('/'));
assert filename.endsWith("." + PYTHON_FILE_EXTENSION)
: "Python files must have a \"" + PYTHON_FILE_EXTENSION + "\" extension.";

return ret;
return filename;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public class Util {
/** Name of the annotation (decorator) that marks methods as a class method. */
public static final String CLASS_METHOD_ANNOTATION_NAME = "classmethod";

/** All Python files must have this extension. */
public static final String PYTHON_FILE_EXTENSION = "py";

/**
* Add Pytest entrypoints to the given {@link PropagationCallGraphBuilder}.
*
Expand Down
Loading