Skip to content

Commit

Permalink
Fix pytest entrypoints with PYTHONPATH (#107)
Browse files Browse the repository at this point in the history
Ignore test for now. Blocked on wala#198.
  • Loading branch information
khatchad authored Jun 14, 2024
1 parent 3750af1 commit 749aad2
Show file tree
Hide file tree
Showing 14 changed files with 172 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,68 @@ public void testDecoratedFunctions2()
test("test_decorated_functions.py", "test_dummy", 0, 0);
}

/**
* Test a pytest without decorators that needs a PYTHONPATH. This is a "control" case. We'll add a
* decorator in the next case.
*
* @see TestTensorflow2Model#testModule11().
*/
@Test
public void testDecoratedFunctions3()
throws ClassHierarchyException, CancelException, IOException {
test(
new String[] {
"proj48/src/__init__.py",
"proj48/src/tf2_test_module9a.py",
"proj48/src/tf2_test_module9b.py",
"proj48/src/test_module10.py"
},
"src/tf2_test_module9b.py",
"D.f",
"proj48",
1,
1,
new int[] {3});
}

/** Test a pytest without decorators. This is a "control." */
@Test
public void testDecoratedFunctions4()
throws ClassHierarchyException, CancelException, IOException {
test("test_decorated_functions2.py", "f", 1, 1, 2);
}

/** Test a pytest with a decorator. */
@Test
public void testDecoratedFunctions5()
throws ClassHierarchyException, CancelException, IOException {
test("test_decorated_functions3.py", "f", 1, 1, 2);
}

/**
* Test a pytest with a decorator that needs a PYTHONPATH.
*
* @see TestTensorflow2Model#testModule11().
*/
@Test
@Ignore("Blocked on https://github.com/wala/ML/issues/198.")
public void testDecoratedFunctions6()
throws ClassHierarchyException, CancelException, IOException {
test(
new String[] {
"proj49/src/__init__.py",
"proj49/src/tf2_test_module9a.py",
"proj49/src/tf2_test_module9b.py",
"proj49/src/test_module10.py"
},
"src/tf2_test_module9b.py",
"D.f",
"proj49",
1,
1,
new int[] {3});
}

@Test
public void testReshape() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_reshape.py", "f", 1, 1, 2);
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import ones
from src.tf2_test_module9b import D


def test_dummy(x, test_input, expected):
D().f(ones([1, 2]))
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor


class C:

def f(self, a):
assert isinstance(a, Tensor)
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj48/src/tf2_test_module9b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor
from src.tf2_test_module9a import C


class D(C):

def f(self, a):
assert isinstance(a, Tensor)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import ones
from src.tf2_test_module9b import D


@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)])
def test_dummy(x, test_input, expected):
D().f(ones([1, 2]))
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor


class C:

def f(self, a):
assert isinstance(a, Tensor)
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj49/src/tf2_test_module9b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor
from src.tf2_test_module9a import C


class D(C):

def f(self, a):
assert isinstance(a, Tensor)
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_decorated_functions2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import tensorflow as tf
import pytest


def f(a):
assert isinstance(a, tf.Tensor)


def test_dummy(x, test_input, expected):
f(tf.constant(1))
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_decorated_functions3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf
import pytest


def f(a):
assert isinstance(a, tf.Tensor)


@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)])
def test_dummy(x, test_input, expected):
f(tf.constant(1))
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.ibm.wala.cast.python.ipa.callgraph;

import static com.ibm.wala.cast.python.types.Util.getFilename;
import static com.ibm.wala.cast.python.types.Util.getRelativeFilename;
import static com.ibm.wala.cast.python.util.Util.PYTHON_FILE_EXTENSION;
import static java.util.Objects.requireNonNull;

import com.ibm.wala.cast.python.loader.PythonLoader.DynamicMethodBody;
Expand Down Expand Up @@ -47,7 +48,11 @@ public Iterable<Entrypoint> createEntrypoints(IClassHierarchy cha) {

result.add(new PytesttEntrypoint(methodReference, cha));

logger.fine(() -> "Adding test method as entry point: " + methodReference.getName() + ".");
logger.fine(
() ->
"Adding test method as entry point: "
+ methodReference.getDeclaringClass().getName()
+ ".");
}
}

Expand All @@ -66,11 +71,11 @@ public static boolean isPytestCase(IClass klass) {
final TypeName typeName = klass.getReference().getName();

if (typeName.toString().startsWith("Lscript ")) {
final String fileName = getFilename(typeName);
final String fileName = getRelativeFilename(typeName);
final Atom className = typeName.getClassName();

// In Ariadne, a script is an invokable entity like a function.
final boolean script = className.toString().endsWith(".py");
final boolean script = className.toString().endsWith(PYTHON_FILE_EXTENSION);

if (!script // it's not an invokable script.
&& (fileName.startsWith("test_")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,46 @@
package com.ibm.wala.cast.python.types;

import static com.ibm.wala.cast.python.util.Util.PYTHON_FILE_EXTENSION;

import com.ibm.wala.cast.types.AstTypeReference;
import com.ibm.wala.classLoader.IClassLoader;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.types.FieldReference;
import com.ibm.wala.types.MethodReference;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.types.TypeReference;
import java.util.Arrays;

public class Util {

private static final String GLOBAL_KEYWORD = "global";

/**
* Returns the filename portion of the given {@link TypeName} representing a Python type.
* Returns the relative filename portion of the given {@link TypeName} representing a Python type.
*
* @param typeName A {@link TypeName} of a Python type.
* @return The filename portion of the given {@link TypeName}.
* @return The relative filename portion of the given {@link TypeName}.
* @apiNote Python types include a file in their {@link TypeName}s in Ariadne.
*/
public static String getFilename(final TypeName typeName) {
String ret = typeName.toString();
ret = ret.substring("Lscript ".length());
public static String getRelativeFilename(final TypeName typeName) {
String typeNameString = typeName.toString();

// Remove the script prefix.
typeNameString = typeNameString.substring("Lscript ".length());

// Extract the filename.
String[] segments = typeNameString.split("/");

String filename =
Arrays.stream(segments)
.filter(s -> s.endsWith(PYTHON_FILE_EXTENSION))
.findFirst()
.orElseThrow();

if (ret.indexOf('/') != -1) ret = ret.substring(0, ret.indexOf('/'));
assert filename.endsWith("." + PYTHON_FILE_EXTENSION)
: "Python files must have a \"" + PYTHON_FILE_EXTENSION + "\" extension.";

return ret;
return filename;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public class Util {
/** Name of the annotation (decorator) that marks methods as a class method. */
public static final String CLASS_METHOD_ANNOTATION_NAME = "classmethod";

/** All Python files must have this extension. */
public static final String PYTHON_FILE_EXTENSION = "py";

/**
* Add Pytest entrypoints to the given {@link PropagationCallGraphBuilder}.
*
Expand Down

0 comments on commit 749aad2

Please sign in to comment.