Skip to content

Commit

Permalink
Add dataset generator and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Jan 26, 2024
1 parent 9e6ef77 commit 0a84872
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ public void testTf2()
testTf2("tf2_test_dataset7.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset8.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset9.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset10.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 0);
testTf2("tf2_test_tensor_list3.py", "add", 0, 0);
Expand Down
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
<putfield class="LRoot" field="numpy_input_fn" fieldType="LRoot" ref="inputs" value="numpy_input_fn" />
<new def="from_tensor_slices" class="Ltensorflow/data/Dataset/from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
<new def="from_generator" class="Ltensorflow/data/Dataset/from_generator" />
<putfield class="LRoot" field="from_generator" fieldType="LRoot" ref="Dataset" value="from_generator" />
<new def="reshape" class="Ltensorflow/functions/reshape" />
<putfield class="LRoot" field="reshape" fieldType="LRoot" ref="x" value="reshape" />
<new def="conv2d" class="Ltensorflow/functions/conv2d" />
Expand Down Expand Up @@ -790,6 +792,14 @@
<return value="xx" />
</method>
</class>
<class name="from_generator" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/data/Dataset#from_generator -->
<method name="do" descriptor="()LRoot;" numArgs="6" paramNames="generator output_types output_shapes args output_signature name">
<new def="x" class="Ltensorflow/data/Dataset" />
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
<return value="xx" />
</method>
</class>
</package>
<package name="tensorflow/estimator/train">
<class name="train" allocatable="true">
Expand Down
20 changes: 20 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import tensorflow as tf


def gen():
ragged_tensor = tf.ragged.constant([[1, 2], [3]])
yield 42, ragged_tensor


def add(a, b):
return a + b


dataset = tf.data.Dataset.from_generator(
gen,
output_signature=(
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))

for element in dataset:
c = add(element, element)

0 comments on commit 0a84872

Please sign in to comment.