diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 8061d9c49..7478bf03f 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -3691,6 +3691,17 @@ public void testDecoratedMethod13() throws ClassHierarchyException, CancelExcept test("tf2_test_decorated_method13.py", "f", 0, 0); } + @Test + public void testDecoratedFunctions() + throws ClassHierarchyException, CancelException, IOException { + test("tf2_test_decorated_functions.py", "dummy_fun", 1, 1, 2); + test("tf2_test_decorated_functions.py", "dummy_test", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function2", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function3", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function4", 1, 1, 2); + } + private void test( String filename, String functionName, diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index f49cf34cc..adf31dc52 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -252,6 +252,12 @@ + + + + + + @@ -291,6 +297,22 @@ + + + + + + + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py new file mode 100644 index 000000000..57cfd06cf --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py @@ -0,0 +1,41 @@ +import tensorflow +import pytest +import sys + + +@tensorflow.autograph.experimental.do_not_convert +def dummy_fun(a): + pass + + +@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)]) +def dummy_test(x, test_input, expected): + pass + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_function(x): + pass + + +@pytest.mark.skip(reason="requires python3.10 or higher") +def test_function2(x): + pass + + +@pytest.mark.skip("requires python3.10 or higher") +def test_function3(x): + pass + + +@pytest.mark.skip +def test_function4(x): + pass + + +dummy_fun(tensorflow.constant(1)) +dummy_test(tensorflow.constant(1), "1", "1") +test_function(tensorflow.constant(1)) +test_function2(tensorflow.constant(1)) +test_function3(tensorflow.constant(1)) +test_function4(tensorflow.constant(1)) diff --git a/com.ibm.wala.cast.python/data/pytest.xml b/com.ibm.wala.cast.python/data/pytest.xml index e75bb2814..60d5e27ad 100644 --- a/com.ibm.wala.cast.python/data/pytest.xml +++ b/com.ibm.wala.cast.python/data/pytest.xml @@ -12,6 +12,8 @@ + + @@ -39,10 +41,27 @@ - + - + + + + + + + + + + + + + + + + + +