From 40541dbabefe743aaee16c3f3f650c3b20845375 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Tue, 27 Feb 2024 17:33:39 -0500 Subject: [PATCH] Test Pytest entrypoints with TF code (#154) Test #153. Should work fine, and we should close that issue. I am seeing in a different program being analyzed this combination not working, but there must be some other problem. --- .../python/ml/test/TestTensorflow2Model.java | 9 +++++ .../data/test_tf_range.py | 17 +++++++++ .../ibm/wala/cast/python/test/TestCalls.java | 22 +----------- .../com/ibm/wala/cast/python/util/Util.java | 36 +++++++++++++++++++ 4 files changed, 63 insertions(+), 21 deletions(-) create mode 100644 com.ibm.wala.cast.python.test/data/test_tf_range.py create mode 100644 com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/util/Util.java diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 191c19335..f87dbdbcb 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -1,5 +1,6 @@ package com.ibm.wala.cast.python.ml.test; +import static com.ibm.wala.cast.python.util.Util.addPytestEntrypoints; import static java.util.Collections.emptySet; import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.toSet; @@ -1210,6 +1211,12 @@ public void testTFRange2() test("tf2_test_tf_range2.py", "f", 1, 1, 2); } + @Test + public void testTFRange3() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("test_tf_range.py", "f", 1, 1, 2); + } + private void test( String filename, String functionName, @@ -1220,6 +1227,8 @@ private void test( PythonAnalysisEngine E = makeEngine(filename); PythonSSAPropagationCallGraphBuilder builder = E.defaultCallGraphBuilder(); + addPytestEntrypoints(builder); + CallGraph CG = builder.makeCallGraph(builder.getOptions()); assertNotNull(CG); diff --git a/com.ibm.wala.cast.python.test/data/test_tf_range.py b/com.ibm.wala.cast.python.test/data/test_tf_range.py new file mode 100644 index 000000000..0e5651f62 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/test_tf_range.py @@ -0,0 +1,17 @@ +# From: https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/range#for_example +import tensorflow as tf + + +def f(a): + pass + + +def test_tf_range(): + start = 3 + limit = 18 + delta = 3 + + r = tf.range(start, limit, delta) + + for i in r: + f(i) diff --git a/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestCalls.java b/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestCalls.java index 81cc441cb..c6e6b3645 100644 --- a/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestCalls.java +++ b/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestCalls.java @@ -1,26 +1,21 @@ package com.ibm.wala.cast.python.test; -import static com.google.common.collect.Iterables.concat; +import static com.ibm.wala.cast.python.util.Util.addPytestEntrypoints; import static java.util.Collections.singleton; import com.ibm.wala.cast.ipa.callgraph.CAstCallGraphUtil; import com.ibm.wala.cast.python.client.PythonAnalysisEngine; -import com.ibm.wala.cast.python.ipa.callgraph.PytestEntrypointBuilder; import com.ibm.wala.ipa.callgraph.CallGraph; -import com.ibm.wala.ipa.callgraph.Entrypoint; import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; import com.ibm.wala.ipa.callgraph.propagation.SSAContextInterpreter; import com.ibm.wala.ipa.cha.ClassHierarchyException; import com.ibm.wala.util.CancelException; import java.io.IOException; import java.util.Collections; -import java.util.logging.Logger; import org.junit.Test; public class TestCalls extends TestPythonCallGraphShape { - private static final Logger LOGGER = Logger.getLogger(TestCalls.class.getName()); - protected static final Object[][] assertionsCalls1 = new Object[][] { new Object[] {ROOT, new String[] {"script calls1.py"}}, @@ -282,21 +277,6 @@ public Void performAnalysis(PropagationCallGraphBuilder builder) throws CancelEx verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS); } - private static void addPytestEntrypoints(PropagationCallGraphBuilder callGraphBuilder) { - Iterable defaultEntrypoints = - callGraphBuilder.getOptions().getEntrypoints(); - - Iterable pytestEntrypoints = - new PytestEntrypointBuilder().createEntrypoints(callGraphBuilder.getClassHierarchy()); - - Iterable entrypoints = concat(defaultEntrypoints, pytestEntrypoints); - - callGraphBuilder.getOptions().setEntrypoints(entrypoints); - - for (Entrypoint ep : callGraphBuilder.getOptions().getEntrypoints()) - LOGGER.info(() -> "Using entrypoint: " + ep.getMethod().getDeclaringClass().getName() + "."); - } - protected static final Object[][] PYTEST_ASSERTIONS2 = new Object[][] { new Object[] { diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/util/Util.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/util/Util.java new file mode 100644 index 000000000..bcf51d77e --- /dev/null +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/util/Util.java @@ -0,0 +1,36 @@ +package com.ibm.wala.cast.python.util; + +import static com.google.common.collect.Iterables.concat; + +import com.ibm.wala.cast.python.ipa.callgraph.PytestEntrypointBuilder; +import com.ibm.wala.ipa.callgraph.Entrypoint; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import java.util.logging.Logger; + +public class Util { + + private static final Logger LOGGER = Logger.getLogger(Util.class.getName()); + + /** + * Add Pytest entrypoints to the given {@link PropagationCallGraphBuilder}. + * + * @param callGraphBuilder The {@link PropagationCallGraphBuilder} for which to add Pytest + * entrypoints. + */ + public static void addPytestEntrypoints(PropagationCallGraphBuilder callGraphBuilder) { + Iterable defaultEntrypoints = + callGraphBuilder.getOptions().getEntrypoints(); + + Iterable pytestEntrypoints = + new PytestEntrypointBuilder().createEntrypoints(callGraphBuilder.getClassHierarchy()); + + Iterable entrypoints = concat(defaultEntrypoints, pytestEntrypoints); + + callGraphBuilder.getOptions().setEntrypoints(entrypoints); + + for (Entrypoint ep : callGraphBuilder.getOptions().getEntrypoints()) + LOGGER.info(() -> "Using entrypoint: " + ep.getMethod().getDeclaringClass().getName() + "."); + } + + private Util() {} +}