Skip to content

Commit

Permalink
Merge branch 'master' into contrib_upgrade_java
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad authored May 24, 2024
2 parents 2e967d0 + 859e413 commit b3bc66e
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3609,6 +3609,18 @@ public void testClassMethod5() throws ClassHierarchyException, CancelException,
expectedTensorParameterValueNumbers);
}

@Test
public void testAbstractMethod() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_abstract_method.py", "D.f", 1, 1, 3);
test("tf2_test_abstract_method.py", "C.f", 1, 1, 3);
}

@Test
public void testAbstractMethod2() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_abstract_method2.py", "D.f", 1, 1, 3);
test("tf2_test_abstract_method2.py", "C.f", 1, 1, 3);
}

private void test(
String filename,
String functionName,
Expand Down
20 changes: 20 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_abstract_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# From https://blog.teclado.com/python-abc-abstract-base-classes/#introducing-abstract-classes.
import tensorflow as tf
from abc import ABC, abstractmethod


class C:

def f(self, x):
assert isinstance(x, tf.Tensor)


class D(C):

def f(self, x):
super(D, self).f(x)


c = D()
c.f(tf.constant(1))

21 changes: 21 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_abstract_method2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# From https://blog.teclado.com/python-abc-abstract-base-classes/#introducing-abstract-classes.
import tensorflow as tf
from abc import ABC, abstractmethod


class C(ABC):

@abstractmethod
def f(self, x):
assert isinstance(x, tf.Tensor)


class D(C):

def f(self, x):
super(D, self).f(x)


c = D()
c.f(tf.constant(1))

2 changes: 1 addition & 1 deletion com.ibm.wala.cast.python/data/flask.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" ?>
<!DOCTYPE summary-spec>
<!-- Pandas model -->
<!-- Flask model -->
<summary-spec>
<classloader name="PythonLoader">
<class name="flask" allocatable="true">
Expand Down
2 changes: 1 addition & 1 deletion com.ibm.wala.cast.python/data/pytest.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" ?>
<!DOCTYPE summary-spec>
<!-- Pandas model -->
<!-- Pytest model -->
<summary-spec>
<classloader name="PythonLoader">
<class name="pytest" allocatable="true">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
public class PythonInstanceMethodTrampolineTargetSelector<T>
extends PythonMethodTrampolineTargetSelector<T> {

private static final Logger logger =
private static final Logger LOGGER =
Logger.getLogger(PythonInstanceMethodTrampolineTargetSelector.class.getName());

/**
Expand All @@ -73,9 +73,9 @@ public class PythonInstanceMethodTrampolineTargetSelector<T>
private PythonAnalysisEngine<T> engine;

public PythonInstanceMethodTrampolineTargetSelector(
MethodTargetSelector base, PythonAnalysisEngine<T> pythonAnalysisEngine) {
MethodTargetSelector base, PythonAnalysisEngine<T> engine) {
super(base);
this.engine = pythonAnalysisEngine;
this.engine = engine;
}

@Override
Expand All @@ -88,15 +88,15 @@ protected boolean shouldProcess(CGNode caller, CallSiteReference site, IClass re
@Override
public IMethod getCalleeTarget(CGNode caller, CallSiteReference site, IClass receiver) {
if (isCallable(receiver)) {
logger.fine("Encountered callable.");
LOGGER.fine("Encountered callable.");

PythonInvokeInstruction call = this.getCall(caller, site);

// It's a callable. Change the receiver.
receiver = getCallable(caller, receiver.getClassHierarchy(), call);

if (receiver == null) return null; // not found.
else logger.fine("Substituting the receiver with one derived from a callable.");
else LOGGER.fine("Substituting the receiver with one derived from a callable.");
}

return super.getCalleeTarget(caller, site, receiver);
Expand Down Expand Up @@ -243,15 +243,15 @@ private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstr
if (callable == null) {
// try the workaround for https://github.com/wala/ML/issues/106. NOTE: We cannot verify
// that the super class is tf.keras.Model due to https://github.com/wala/ML/issues/118.
logger.fine("Attempting callable workaround for https://github.com/wala/ML/issues/118.");
LOGGER.fine("Attempting callable workaround for https://github.com/wala/ML/issues/118.");

callable =
cha.lookupClass(
TypeReference.findOrCreateClass(
classLoaderReference, packageName, CALLABLE_METHOD_NAME_FOR_KERAS_MODELS));

if (callable != null)
logger.info("Applying callable workaround for https://github.com/wala/ML/issues/118.");
LOGGER.info("Applying callable workaround for https://github.com/wala/ML/issues/118.");
}

if (callable != null) return callable;
Expand Down Expand Up @@ -315,7 +315,7 @@ private static AllocationSiteInNode getAllocationSiteInNode(ConstantKey<?> const
Object value = constantKey.getValue();

if (value == null) {
logger.warning("Can't extract AllocationSiteInNode from: " + constantKey + ".");
LOGGER.warning("Can't extract AllocationSiteInNode from: " + constantKey + ".");
return null;
} else
throw new IllegalArgumentException(
Expand All @@ -332,7 +332,7 @@ public PythonAnalysisEngine<T> getEngine() {

@Override
protected Logger getLogger() {
return logger;
return LOGGER;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public abstract class PythonMethodTrampolineTargetSelector<T> implements MethodT
protected final Map<Pair<IClass, Integer>, IMethod> codeBodies = HashMapFactory.make();

public PythonMethodTrampolineTargetSelector(MethodTargetSelector base) {
super();
this.base = base;
}

Expand Down

0 comments on commit b3bc66e

Please sign in to comment.