Skip to content

Commit

Permalink
89 losing tensors in datasets (#63)
Browse files Browse the repository at this point in the history
* Add `tf.data.Dataset.from_tensor_slices` to TensorFlow summary.

* Make constant.

* Initial support for `tf.data.Dataset`s.
  • Loading branch information
khatchad authored Dec 16, 2023
1 parent 243cae2 commit 872dbc2
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,10 @@ public void testTf2()
testTf2("tf2_testing_decorator8.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator9.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator10.py", "returned", 1, 3, 2);
testTf2(
"tf2_test_dataset.py",
"add",
0,
0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once
// https://github.com/wala/ML/issues/89 is fixed.
// FIXME: Test tf2_test_dataset.py really has three tensors in its dataset. We are currently
// treating it as one. But, in the literal case, it should be possible to model it like the list
// tests below.
testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
Expand Down
23 changes: 23 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@

<new def="nn" class="Lobject" />
<putfield class="LRoot" field="nn" fieldType="LRoot" ref="x" value="nn" />
<new def="data" class="Lobject" />
<putfield class="LRoot" field="data" fieldType="LRoot" ref="x" value="data" />
<new def="Dataset" class="Lobject" />
<putfield class="LRoot" field="Dataset" fieldType="LRoot" ref="data" value="Dataset" />
<new def="random" class="Lobject" />
<putfield class="LRoot" field="random" fieldType="LRoot" ref="x" value="random" />
<new def="sparse" class="Lobject" />
Expand Down Expand Up @@ -122,6 +126,9 @@
<new def="array_ops" class="Lobject" />
<putfield class="LRoot" field="array_ops" fieldType="LRoot" ref="ops" value="array_ops" />

<new def="data_ops" class="Lobject" />
<putfield class="LRoot" field="data_ops" fieldType="LRoot" ref="ops" value="data_ops" />

<new def="random_ops" class="Lobject" />
<putfield class="LRoot" field="random_ops" fieldType="LRoot" ref="ops" value="random_ops" />

Expand Down Expand Up @@ -167,6 +174,10 @@
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="x" value="ones" />
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="array_ops" value="ones" />

<new def="from_tensor_slices" class="Ltensorflow/functions/from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="data_ops" value="from_tensor_slices" />

<new def="zeros" class="Ltensorflow/functions/zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="x" value="zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="array_ops" value="zeros" />
Expand Down Expand Up @@ -399,6 +410,18 @@
</method>
</class>

<class name="from_tensor_slices" allocatable="true">
<!-- "read_dataset" means that this function reads a tensor iterable. -->
<method name="read_dataset" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/data_ops/from_tensor_slices" />
<return value="x" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="tensors name">
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<return value="x" />
</method>
</class>

<class name="Variable" allocatable="true">
<method name="read_data" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/variables/Variable" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
package com.ibm.wala.cast.python.ml.client;

import static com.ibm.wala.cast.types.AstMethodReference.fnReference;

import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
import com.ibm.wala.cast.lsp.AnalysisError;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
import com.ibm.wala.cast.python.ml.types.TensorType;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.CallSiteReference;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IMethod;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
Expand All @@ -26,13 +35,22 @@
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.graph.Graph;
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeAnalysis> {

/** A "fake" function name in the summaries that indicates that an API produces a new tensor. */
private static final String TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME = "read_data";

/**
* A "fake" function name in the summaries that indicates that an API produces a tensor iterable.
*/
private static final String TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME = "read_dataset";

private static final Logger logger = Logger.getLogger(PythonTensorAnalysisEngine.class.getName());

private static final MethodReference conv2d =
Expand Down Expand Up @@ -69,7 +87,10 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeA

private final Map<PointerKey, AnalysisError> errorLog = HashMapFactory.make();

private static Set<PointsToSetVariable> getDataflowSources(Graph<PointsToSetVariable> dataflow) {
private static Set<PointsToSetVariable> getDataflowSources(
Graph<PointsToSetVariable> dataflow,
CallGraph callGraph,
PointerAnalysis<InstanceKey> pointerAnalysis) {
Set<PointsToSetVariable> sources = HashSetFactory.make();
for (PointsToSetVariable src : dataflow) {
PointerKey k = src.getPointerKey();
Expand All @@ -81,12 +102,63 @@ private static Set<PointsToSetVariable> getDataflowSources(Graph<PointsToSetVari
SSAInstruction inst = du.getDef(vn);

if (inst instanceof SSAAbstractInvokeInstruction) {
// We potentially have a function call that generates a tensor.
SSAAbstractInvokeInstruction ni = (SSAAbstractInvokeInstruction) inst;

if (ni.getCallSite().getDeclaredTarget().getName().toString().equals("read_data")
if (ni.getCallSite()
.getDeclaredTarget()
.getName()
.toString()
.equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)
&& ni.getException() != vn) {
sources.add(src);
logger.info("Added dataflow source " + src + ".");
logger.info("Added dataflow source from tensor generator: " + src + ".");
}
} else if (inst instanceof EachElementGetInstruction) {
// We are potentially pulling a tensor out of a tensor iterable.
EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst;

// Find the potential tensor iterable creation site.
SSAInstruction iterableDef = du.getDef(eachElementGetInstruction.getUse(0));

if (iterableDef instanceof SSAAbstractInvokeInstruction) {
SSAAbstractInvokeInstruction iterableGenInvocationInstruction =
(SSAAbstractInvokeInstruction) iterableDef;

// What function are we calling?
int use = iterableGenInvocationInstruction.getUse(0);
PointerKey pointerKeyForLocal =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(kk.getNode(), use);
OrdinalSet<InstanceKey> pointsToSet =
pointerAnalysis.getPointsToSet(pointerKeyForLocal);

for (InstanceKey ik : pointsToSet) {
if (ik instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) ik;
IClass concreteType = asin.getConcreteType();
TypeReference reference = concreteType.getReference();
MethodReference methodReference = fnReference(reference);

// Get the nodes this method calls.
Set<CGNode> iterableNodes = callGraph.getNodes(methodReference);

for (CGNode itNode : iterableNodes)
for (Iterator<CGNode> succNodes = callGraph.getSuccNodes(itNode);
succNodes.hasNext(); ) {
CGNode callee = succNodes.next();
IMethod calledMethod = callee.getMethod();

// Does this method call the sythetic "marker?"
if (calledMethod
.getName()
.toString()
.equals(TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME)) {
sources.add(src);
logger.info("Added dataflow source from tensor iterable: " + src + ".");
}
}
}
}
}
}
}
Expand Down Expand Up @@ -165,7 +237,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
SlowSparseNumberedGraph.duplicate(
builder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints());

Set<PointsToSetVariable> sources = getDataflowSources(dataflow);
Set<PointsToSetVariable> sources =
getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis());

TensorType mnistData = TensorType.mnistInput();
Map<PointsToSetVariable, TensorType> init = HashMapFactory.make();
Expand Down

0 comments on commit 872dbc2

Please sign in to comment.