From 380aed5d3de4fbe3dd184c2e8032e1cfd87c4cb8 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Tue, 2 Jan 2024 12:48:53 -0500 Subject: [PATCH] Add dataset batch API. --- .../cast/python/ml/test/TestTensorflowModel.java | 1 + com.ibm.wala.cast.python.ml/data/tensorflow.xml | 12 ++++++++++++ .../data/tf2_test_dataset5.py | 11 +++++++++++ 3 files changed, 24 insertions(+) create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset5.py diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index 5b301a25e..a81193212 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -206,6 +206,7 @@ public void testTf2() testTf2("tf2_test_dataset2.py", "add", 2, 2, 2, 3); testTf2("tf2_test_dataset3.py", "add", 2, 2, 2, 3); testTf2("tf2_test_dataset4.py", "add", 2, 2, 2, 3); + testTf2("tf2_test_dataset5.py", "add", 2, 2, 2, 3); testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3); testTf2("tf2_test_tensor_list2.py", "add", 0, 2); testTf2("tf2_test_tensor_list3.py", "add", 0, 2); diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index 3b068ca3c..30ae5c8fb 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -796,6 +796,8 @@ + + @@ -813,6 +815,16 @@ + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset5.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset5.py new file mode 100644 index 000000000..ceeb54599 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset5.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2) + +for element in dataset: + c = add(element, element)