From d40186b95c0b3f958b3052a4b7f82801fe2e223f Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Wed, 1 May 2024 11:37:50 -0400 Subject: [PATCH] Decorator fixes. --- .../python/ml/test/TestTensorflow2Model.java | 11 +++++ .../data/tensorflow.xml | 22 ++++++++++ .../data/tf2_test_decorated_functions.py | 41 +++++++++++++++++++ com.ibm.wala.cast.python/data/pytest.xml | 23 ++++++++++- 4 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 8061d9c49..7478bf03f 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -3691,6 +3691,17 @@ public void testDecoratedMethod13() throws ClassHierarchyException, CancelExcept test("tf2_test_decorated_method13.py", "f", 0, 0); } + @Test + public void testDecoratedFunctions() + throws ClassHierarchyException, CancelException, IOException { + test("tf2_test_decorated_functions.py", "dummy_fun", 1, 1, 2); + test("tf2_test_decorated_functions.py", "dummy_test", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function2", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function3", 1, 1, 2); + test("tf2_test_decorated_functions.py", "test_function4", 1, 1, 2); + } + private void test( String filename, String functionName, diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index f49cf34cc..adf31dc52 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -252,6 +252,12 @@ + + + + + + @@ -291,6 +297,22 @@ + + + + + + + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py new file mode 100644 index 000000000..57cfd06cf --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorated_functions.py @@ -0,0 +1,41 @@ +import tensorflow +import pytest +import sys + + +@tensorflow.autograph.experimental.do_not_convert +def dummy_fun(a): + pass + + +@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)]) +def dummy_test(x, test_input, expected): + pass + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_function(x): + pass + + +@pytest.mark.skip(reason="requires python3.10 or higher") +def test_function2(x): + pass + + +@pytest.mark.skip("requires python3.10 or higher") +def test_function3(x): + pass + + +@pytest.mark.skip +def test_function4(x): + pass + + +dummy_fun(tensorflow.constant(1)) +dummy_test(tensorflow.constant(1), "1", "1") +test_function(tensorflow.constant(1)) +test_function2(tensorflow.constant(1)) +test_function3(tensorflow.constant(1)) +test_function4(tensorflow.constant(1)) diff --git a/com.ibm.wala.cast.python/data/pytest.xml b/com.ibm.wala.cast.python/data/pytest.xml index e75bb2814..60d5e27ad 100644 --- a/com.ibm.wala.cast.python/data/pytest.xml +++ b/com.ibm.wala.cast.python/data/pytest.xml @@ -12,6 +12,8 @@ + + @@ -39,10 +41,27 @@ - + - + + + + + + + + + + + + + + + + + +