From 7756b4e4f9476c20fd8fdcd74bfb55d4fec76452 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Mon, 1 Jul 2024 18:44:26 -0400 Subject: [PATCH 1/3] Fix bug with dataset iterator processing. Always use the points-to analysis. --- .../python/ml/test/TestTensorflow2Model.java | 14 +++++++++++ .../ml/client/PythonTensorAnalysisEngine.java | 18 ++++---------- .../data/tf2_test_dataset34.py | 24 +++++++++++++++++++ .../data/tf2_test_dataset35.py | 19 +++++++++++++++ 4 files changed, 62 insertions(+), 13 deletions(-) create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset34.py create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset35.py diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 0706e9b5b..141794805 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -1107,6 +1107,20 @@ public void testDataset33() test("tf2_test_dataset33.py", "f", 1, 1, 2); } + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset34() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset34.py", "add", 2, 2, 2, 3); + } + + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset35() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset35.py", "add", 2, 2, 2, 3); + } + /** * Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the * tuple returned isn't a tensor. diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index f09513ea1..7bbf28c1c 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -155,19 +155,11 @@ private static Set getDataflowSources( if (reference.equals(NEXT.getDeclaringClass())) { // it's a call to `next()`. Look up the call to `iter()`. int iterator = ni.getUse(1); - SSAInstruction iteratorDef = du.getDef(iterator); - - // Let's see if the iterator is over a tensor dataset. - if (iteratorDef != null && iteratorDef.getNumberOfUses() > 1) { - // Get the argument. - int iterArg = iteratorDef.getUse(1); - processInstructionInterprocedurally( - iteratorDef, iterArg, localPointerKeyNode, src, sources, pointerAnalysis); - } else - // Use the original instruction. NOTE: We can only do this because `iter()` is - // currently just passing-through its argument. - processInstructionInterprocedurally( - ni, iterator, localPointerKeyNode, src, sources, pointerAnalysis); + + // Use the original instruction. NOTE: We can only do this because `iter()` is + // currently just passing-through its argument. + processInstructionInterprocedurally( + ni, iterator, localPointerKeyNode, src, sources, pointerAnalysis); } } } diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset34.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset34.py new file mode 100644 index 000000000..200ee9894 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset34.py @@ -0,0 +1,24 @@ +import tensorflow as tf + + +class C: + + def __init__(self, some_iter): + self.some_iter = some_iter + + def __str__(self): + return str(self.some_iter) + + +def add(a, b): + return a + b + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) +my_iter = iter(dataset) +c = C(my_iter) +length = len(dataset) + +for _ in range(length): + element = next(c.some_iter) + add(element, element) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset35.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset35.py new file mode 100644 index 000000000..442fdde14 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset35.py @@ -0,0 +1,19 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +def gen_iter(ds): + return iter(ds) + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + +my_iter = gen_iter(dataset) +length = len(dataset) + +for _ in range(length): + element = next(my_iter) + add(element, element) From 8ad59145685d11527aedc254be0c7f7909edda64 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Tue, 2 Jul 2024 13:13:19 -0400 Subject: [PATCH 2/3] Don't consider exceptions as data sources. --- .../com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java | 2 +- .../wala/cast/python/ml/client/PythonTensorAnalysisEngine.java | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 141794805..c4c0e4621 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -961,7 +961,7 @@ public void testDataset18() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { test("tf2_test_dataset18.py", "add", 2, 2, 2, 3); test("tf2_test_dataset18.py", "f", 1, 1, 2); - test("tf2_test_dataset18.py", "g", 0, 2); + test("tf2_test_dataset18.py", "g", 0, 1); } /** Test a dataset that uses an iterator. */ diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index 7bbf28c1c..a21aea6ae 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -132,6 +132,9 @@ private static Set getDataflowSources( // We potentially have a function call that generates a tensor. SSAAbstractInvokeInstruction ni = (SSAAbstractInvokeInstruction) inst; + // don't consider exceptions as a data source. + if (ni.getException() == vn) continue; + if (ni.getCallSite() .getDeclaredTarget() .getName() From 02c2a6bd42521d24bf9552300ed8ce89a3f4fad8 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Tue, 2 Jul 2024 16:29:55 -0400 Subject: [PATCH 3/3] Actually fix the bug with dataset iterator processing. --- .../python/ml/test/TestTensorflow2Model.java | 15 ++ .../ml/client/PythonTensorAnalysisEngine.java | 138 +++++++++++++----- .../data/tf2_test_dataset36.py | 34 +++++ .../data/tf2_test_dataset37.py | 28 ++++ 4 files changed, 177 insertions(+), 38 deletions(-) create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset36.py create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset37.py diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index c4c0e4621..bf9d86cb0 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -1121,6 +1121,21 @@ public void testDataset35() test("tf2_test_dataset35.py", "add", 2, 2, 2, 3); } + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset36() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset36.py", "id1", 1, 1, 2); + // test("tf2_test_dataset36.py", "id2", 1, 1, 2); + } + + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset37() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset37.py", "add", 2, 2, 2, 3); + } + /** * Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the * tuple returned isn't a tensor. diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index a21aea6ae..597921afb 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -131,42 +131,7 @@ private static Set getDataflowSources( if (inst instanceof SSAAbstractInvokeInstruction) { // We potentially have a function call that generates a tensor. SSAAbstractInvokeInstruction ni = (SSAAbstractInvokeInstruction) inst; - - // don't consider exceptions as a data source. - if (ni.getException() == vn) continue; - - if (ni.getCallSite() - .getDeclaredTarget() - .getName() - .toString() - .equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME) - && ni.getException() != vn) { - sources.add(src); - logger.info("Added dataflow source from tensor generator: " + src + "."); - } else if (ni.getNumberOfUses() > 1) { - // Get the invoked function from the PA. - int target = ni.getUse(0); - PointerKey targetKey = - pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, target); - - for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) { - if (ik instanceof ConcreteTypeKey) { - ConcreteTypeKey ctk = (ConcreteTypeKey) ik; - IClass type = ctk.getType(); - TypeReference reference = type.getReference(); - - if (reference.equals(NEXT.getDeclaringClass())) { - // it's a call to `next()`. Look up the call to `iter()`. - int iterator = ni.getUse(1); - - // Use the original instruction. NOTE: We can only do this because `iter()` is - // currently just passing-through its argument. - processInstructionInterprocedurally( - ni, iterator, localPointerKeyNode, src, sources, pointerAnalysis); - } - } - } - } + processInstruction(ni, du, localPointerKeyNode, src, vn, sources, pointerAnalysis); } else if (inst instanceof EachElementGetInstruction) { // We are potentially pulling a tensor out of a tensor iterable. EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst; @@ -210,8 +175,24 @@ private static Set getDataflowSources( } else if (def instanceof EachElementGetInstruction || def instanceof PythonPropertyRead || def instanceof PythonInvokeInstruction) { - processInstruction( - def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis); + boolean added = false; + // we may be invoking `next()` on a dataset. + if (def instanceof SSAAbstractInvokeInstruction && def.getNumberOfUses() > 1) { + SSAAbstractInvokeInstruction invokeInstruction = (SSAAbstractInvokeInstruction) def; + added = + processInstruction( + invokeInstruction, + du, + localPointerKeyNode, + src, + vn, + sources, + pointerAnalysis); + } + + if (!added) + processInstruction( + def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis); } } } @@ -219,6 +200,87 @@ private static Set getDataflowSources( return sources; } + /** + * Processes the given {@link SSAAbstractInvokeInstruction}, adding the given {@link PointsToSetVariable} to the given {@link Set} of {@link PointsToSetVariable}s as a dataflow source if the given {@link SSAAbstractInvokeInstruction} results in a tensor value. + * + * @param instruction The {@link SSAAbstractInvokeInstruction} to consider. + * @param du The {@link DefUse} for the given {@link SSAAbstractInvokeInstruction}. + * @param node The {@link CGNode} containing the given {@link SSAAbstractInvokeInstruction}. + * @param src The {@link PointsToSetVariable} to add to the given {@link Set} of {@link PointsToSetVariable}s if there a tensor flows from the given {@link SSAAbstractInvokeInstruction. + * @param vn The value number in the given {@link CGNode} corresponding to the given {@link PointsToSetVariable}. + * @param sources The {@link Set} of {@link PointsToSetVariable}s representing tensor dataflow sources. + * @param pointerAnalysis The {@link PointerAnalysis} for the given {@link CGNode}. + * @return True iff given the source was added to the set. + */ + private static boolean processInstruction( + SSAAbstractInvokeInstruction instruction, + DefUse du, + CGNode node, + PointsToSetVariable src, + int vn, + Set sources, + PointerAnalysis pointerAnalysis) { + boolean ret = false; + + // don't consider exceptions as a data source. + if (instruction.getException() != vn) { + if (instruction + .getCallSite() + .getDeclaredTarget() + .getName() + .toString() + .equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)) { + sources.add(src); + logger.info("Added dataflow source from tensor generator: " + src + "."); + ret = true; + } else if (instruction.getNumberOfUses() > 1) { + // Get the invoked function from the PA. + int target = instruction.getUse(0); + PointerKey targetKey = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, target); + + for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) { + if (ik instanceof ConcreteTypeKey) { + ConcreteTypeKey ctk = (ConcreteTypeKey) ik; + IClass type = ctk.getType(); + TypeReference reference = type.getReference(); + + if (reference.equals(NEXT.getDeclaringClass())) { + // it's a call to `next()`. Look up the iterator definition. + int iterator = instruction.getUse(1); + SSAInstruction iteratorDef = du.getDef(iterator); + + // Let's see if the iterator is over a tensor dataset. First, check the iterator + // for a dataset source. NOTE: We can only do this because `iter()` is currently + // just passing-through its argument. + if (iteratorDef != null && iteratorDef.getNumberOfUses() > 1) { + boolean added = + processInstructionInterprocedurally( + iteratorDef, iteratorDef.getDef(), node, src, sources, pointerAnalysis); + + ret |= added; + + if (!added && iteratorDef instanceof SSAAbstractInvokeInstruction) { + // It may be a call to `iter()`. Get the argument. + int iterArg = iteratorDef.getUse(1); + ret |= + processInstructionInterprocedurally( + iteratorDef, iterArg, node, src, sources, pointerAnalysis); + } + } else + // Use the original instruction. NOTE: We can only do this because `iter()` is + // currently just passing-through its argument. + ret |= + processInstructionInterprocedurally( + instruction, iterator, node, src, sources, pointerAnalysis); + } + } + } + } + } + + return ret; + } + /** * Processes the given {@link SSAInstruction} to decide if the given {@link PointsToSetVariable} * is added to the given {@link Set} of {@link PointsToSetVariable}s as tensor dataflow sources. diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset36.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset36.py new file mode 100644 index 000000000..edd628de5 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset36.py @@ -0,0 +1,34 @@ +import tensorflow as tf + + +class C: + + def __init__(self, some_iter): + self.some_iter = some_iter + + def __str__(self): + return str(self.some_iter) + + +def id1(a): + return a + + +def id2(a): + return a + + +def gen(): + yield "42", tf.constant("43") + + +dataset = tf.data.Dataset.from_generator(gen, output_types=(tf.string, tf.string)) + +my_iter = iter(dataset) +c = C(my_iter) +length = 1 + +for _ in range(length): + x, y = next(c.some_iter) + id1(x) + id2(y) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset37.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset37.py new file mode 100644 index 000000000..436b107dc --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset37.py @@ -0,0 +1,28 @@ +import tensorflow as tf + + +class C: + + def __init__(self, some_iter): + self.some_iter = some_iter + + def __str__(self): + return str(self.some_iter) + + +def add(a, b): + return a + b + + +def gen_iter(dataset): + my_iter = iter(dataset) + return C(my_iter) + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) +c = gen_iter(dataset) +length = len(dataset) + +for _ in range(length): + element = next(c.some_iter) + add(element, element)