Skip to content

Commit

Permalink
Test Pytest entrypoints with TF code (#154)
Browse files Browse the repository at this point in the history
Test #153. Should work fine, and we should close that issue.

I am seeing in a different program being analyzed this combination not
working, but there must be some other problem.
  • Loading branch information
khatchad authored Feb 27, 2024
1 parent ce82d05 commit 40541db
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.ibm.wala.cast.python.ml.test;

import static com.ibm.wala.cast.python.util.Util.addPytestEntrypoints;
import static java.util.Collections.emptySet;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toSet;
Expand Down Expand Up @@ -1210,6 +1211,12 @@ public void testTFRange2()
test("tf2_test_tf_range2.py", "f", 1, 1, 2);
}

@Test
public void testTFRange3()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("test_tf_range.py", "f", 1, 1, 2);
}

private void test(
String filename,
String functionName,
Expand All @@ -1220,6 +1227,8 @@ private void test(
PythonAnalysisEngine<TensorTypeAnalysis> E = makeEngine(filename);
PythonSSAPropagationCallGraphBuilder builder = E.defaultCallGraphBuilder();

addPytestEntrypoints(builder);

CallGraph CG = builder.makeCallGraph(builder.getOptions());
assertNotNull(CG);

Expand Down
17 changes: 17 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_tf_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# From: https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/range#for_example
import tensorflow as tf


def f(a):
pass


def test_tf_range():
start = 3
limit = 18
delta = 3

r = tf.range(start, limit, delta)

for i in r:
f(i)
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
package com.ibm.wala.cast.python.test;

import static com.google.common.collect.Iterables.concat;
import static com.ibm.wala.cast.python.util.Util.addPytestEntrypoints;
import static java.util.Collections.singleton;

import com.ibm.wala.cast.ipa.callgraph.CAstCallGraphUtil;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ipa.callgraph.PytestEntrypointBuilder;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import com.ibm.wala.ipa.callgraph.propagation.SSAContextInterpreter;
import com.ibm.wala.ipa.cha.ClassHierarchyException;
import com.ibm.wala.util.CancelException;
import java.io.IOException;
import java.util.Collections;
import java.util.logging.Logger;
import org.junit.Test;

public class TestCalls extends TestPythonCallGraphShape {

private static final Logger LOGGER = Logger.getLogger(TestCalls.class.getName());

protected static final Object[][] assertionsCalls1 =
new Object[][] {
new Object[] {ROOT, new String[] {"script calls1.py"}},
Expand Down Expand Up @@ -282,21 +277,6 @@ public Void performAnalysis(PropagationCallGraphBuilder builder) throws CancelEx
verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS);
}

private static void addPytestEntrypoints(PropagationCallGraphBuilder callGraphBuilder) {
Iterable<? extends Entrypoint> defaultEntrypoints =
callGraphBuilder.getOptions().getEntrypoints();

Iterable<Entrypoint> pytestEntrypoints =
new PytestEntrypointBuilder().createEntrypoints(callGraphBuilder.getClassHierarchy());

Iterable<Entrypoint> entrypoints = concat(defaultEntrypoints, pytestEntrypoints);

callGraphBuilder.getOptions().setEntrypoints(entrypoints);

for (Entrypoint ep : callGraphBuilder.getOptions().getEntrypoints())
LOGGER.info(() -> "Using entrypoint: " + ep.getMethod().getDeclaringClass().getName() + ".");
}

protected static final Object[][] PYTEST_ASSERTIONS2 =
new Object[][] {
new Object[] {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.ibm.wala.cast.python.util;

import static com.google.common.collect.Iterables.concat;

import com.ibm.wala.cast.python.ipa.callgraph.PytestEntrypointBuilder;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import java.util.logging.Logger;

public class Util {

private static final Logger LOGGER = Logger.getLogger(Util.class.getName());

/**
* Add Pytest entrypoints to the given {@link PropagationCallGraphBuilder}.
*
* @param callGraphBuilder The {@link PropagationCallGraphBuilder} for which to add Pytest
* entrypoints.
*/
public static void addPytestEntrypoints(PropagationCallGraphBuilder callGraphBuilder) {
Iterable<? extends Entrypoint> defaultEntrypoints =
callGraphBuilder.getOptions().getEntrypoints();

Iterable<Entrypoint> pytestEntrypoints =
new PytestEntrypointBuilder().createEntrypoints(callGraphBuilder.getClassHierarchy());

Iterable<Entrypoint> entrypoints = concat(defaultEntrypoints, pytestEntrypoints);

callGraphBuilder.getOptions().setEntrypoints(entrypoints);

for (Entrypoint ep : callGraphBuilder.getOptions().getEntrypoints())
LOGGER.info(() -> "Using entrypoint: " + ep.getMethod().getDeclaringClass().getName() + ".");
}

private Util() {}
}

0 comments on commit 40541db

Please sign in to comment.