diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java
index 8061d9c49..7478bf03f 100644
--- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java
+++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java
@@ -3691,6 +3691,17 @@ public void testDecoratedMethod13() throws ClassHierarchyException, CancelExcept
test("tf2_test_decorated_method13.py", "f", 0, 0);
}
+ @Test
+ public void testDecoratedFunctions()
+ throws ClassHierarchyException, CancelException, IOException {
+ test("tf2_test_decorated_functions.py", "dummy_fun", 1, 1, 2);
+ test("tf2_test_decorated_functions.py", "dummy_test", 1, 1, 2);
+ test("tf2_test_decorated_functions.py", "test_function", 1, 1, 2);
+ test("tf2_test_decorated_functions.py", "test_function2", 1, 1, 2);
+ test("tf2_test_decorated_functions.py", "test_function3", 1, 1, 2);
+ test("tf2_test_decorated_functions.py", "test_function4", 1, 1, 2);
+ }
+
private void test(
String filename,
String functionName,
diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml
index f49cf34cc..adf31dc52 100644
--- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml
+++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml
@@ -252,6 +252,12 @@
+
+
+
+
+
+
@@ -291,6 +297,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py
new file mode 100644
index 000000000..57cfd06cf
--- /dev/null
+++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py
@@ -0,0 +1,41 @@
+import tensorflow
+import pytest
+import sys
+
+
+@tensorflow.autograph.experimental.do_not_convert
+def dummy_fun(a):
+ pass
+
+
+@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)])
+def dummy_test(x, test_input, expected):
+ pass
+
+
+@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
+def test_function(x):
+ pass
+
+
+@pytest.mark.skip(reason="requires python3.10 or higher")
+def test_function2(x):
+ pass
+
+
+@pytest.mark.skip("requires python3.10 or higher")
+def test_function3(x):
+ pass
+
+
+@pytest.mark.skip
+def test_function4(x):
+ pass
+
+
+dummy_fun(tensorflow.constant(1))
+dummy_test(tensorflow.constant(1), "1", "1")
+test_function(tensorflow.constant(1))
+test_function2(tensorflow.constant(1))
+test_function3(tensorflow.constant(1))
+test_function4(tensorflow.constant(1))
diff --git a/com.ibm.wala.cast.python/data/pytest.xml b/com.ibm.wala.cast.python/data/pytest.xml
index e75bb2814..60d5e27ad 100644
--- a/com.ibm.wala.cast.python/data/pytest.xml
+++ b/com.ibm.wala.cast.python/data/pytest.xml
@@ -12,6 +12,8 @@
+
+
@@ -39,10 +41,27 @@
-
+
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+