From 872dbc2bf220ff510a636c68a7a9a4cd930852f4 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Fri, 15 Dec 2023 20:28:56 -0500 Subject: [PATCH] 89 losing tensors in datasets (#63) * Add `tf.data.Dataset.from_tensor_slices` to TensorFlow summary. * Make constant. * Initial support for `tf.data.Dataset`s. --- .../python/ml/test/TestTensorflowModel.java | 10 +-- .../data/tensorflow.xml | 23 ++++++ .../ml/client/PythonTensorAnalysisEngine.java | 81 ++++++++++++++++++- 3 files changed, 104 insertions(+), 10 deletions(-) diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index 95ed16e89..c30a8e189 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -199,12 +199,10 @@ public void testTf2() testTf2("tf2_testing_decorator8.py", "returned", 1, 3, 2); testTf2("tf2_testing_decorator9.py", "returned", 1, 3, 2); testTf2("tf2_testing_decorator10.py", "returned", 1, 3, 2); - testTf2( - "tf2_test_dataset.py", - "add", - 0, - 0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once - // https://github.com/wala/ML/issues/89 is fixed. + // FIXME: Test tf2_test_dataset.py really has three tensors in its dataset. We are currently + // treating it as one. But, in the literal case, it should be possible to model it like the list + // tests below. + testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3); testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3); testTf2("tf2_test_tensor_list2.py", "add", 0, 2); testTf2("tf2_test_tensor_list3.py", "add", 0, 2); diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index 1d2ccbc0d..b5860ff3a 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -41,6 +41,10 @@ + + + + @@ -122,6 +126,9 @@ + + + @@ -167,6 +174,10 @@ + + + + @@ -399,6 +410,18 @@ + + + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index 97901bcef..bbca9c2e5 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -1,5 +1,8 @@ package com.ibm.wala.cast.python.ml.client; +import static com.ibm.wala.cast.types.AstMethodReference.fnReference; + +import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction; import com.ibm.wala.cast.lsp.AnalysisError; import com.ibm.wala.cast.python.client.PythonAnalysisEngine; import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis; @@ -7,9 +10,15 @@ import com.ibm.wala.cast.python.types.PythonTypes; import com.ibm.wala.cast.types.AstMethodReference; import com.ibm.wala.classLoader.CallSiteReference; +import com.ibm.wala.classLoader.IClass; +import com.ibm.wala.classLoader.IMethod; import com.ibm.wala.ipa.callgraph.AnalysisOptions; import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.CallGraph; +import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; import com.ibm.wala.ipa.callgraph.propagation.PointerKey; import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; @@ -26,6 +35,7 @@ import com.ibm.wala.util.collections.HashSetFactory; import com.ibm.wala.util.graph.Graph; import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph; +import com.ibm.wala.util.intset.OrdinalSet; import java.util.Iterator; import java.util.Map; import java.util.Set; @@ -33,6 +43,14 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine { + /** A "fake" function name in the summaries that indicates that an API produces a new tensor. */ + private static final String TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME = "read_data"; + + /** + * A "fake" function name in the summaries that indicates that an API produces a tensor iterable. + */ + private static final String TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME = "read_dataset"; + private static final Logger logger = Logger.getLogger(PythonTensorAnalysisEngine.class.getName()); private static final MethodReference conv2d = @@ -69,7 +87,10 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine errorLog = HashMapFactory.make(); - private static Set getDataflowSources(Graph dataflow) { + private static Set getDataflowSources( + Graph dataflow, + CallGraph callGraph, + PointerAnalysis pointerAnalysis) { Set sources = HashSetFactory.make(); for (PointsToSetVariable src : dataflow) { PointerKey k = src.getPointerKey(); @@ -81,12 +102,63 @@ private static Set getDataflowSources(Graph pointsToSet = + pointerAnalysis.getPointsToSet(pointerKeyForLocal); + + for (InstanceKey ik : pointsToSet) { + if (ik instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = (AllocationSiteInNode) ik; + IClass concreteType = asin.getConcreteType(); + TypeReference reference = concreteType.getReference(); + MethodReference methodReference = fnReference(reference); + + // Get the nodes this method calls. + Set iterableNodes = callGraph.getNodes(methodReference); + + for (CGNode itNode : iterableNodes) + for (Iterator succNodes = callGraph.getSuccNodes(itNode); + succNodes.hasNext(); ) { + CGNode callee = succNodes.next(); + IMethod calledMethod = callee.getMethod(); + + // Does this method call the sythetic "marker?" + if (calledMethod + .getName() + .toString() + .equals(TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME)) { + sources.add(src); + logger.info("Added dataflow source from tensor iterable: " + src + "."); + } + } + } + } } } } @@ -165,7 +237,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) SlowSparseNumberedGraph.duplicate( builder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints()); - Set sources = getDataflowSources(dataflow); + Set sources = + getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis()); TensorType mnistData = TensorType.mnistInput(); Map init = HashMapFactory.make();