Skip to content

Commit

Permalink
Workaround missing reshape operations.
Browse files Browse the repository at this point in the history
Workaround for wala#195.
  • Loading branch information
khatchad committed May 9, 2024
1 parent 26a47e0 commit ce003e0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.ibm.wala.util.collections.HashSetFactory;
import java.io.IOException;
import java.util.Set;
import org.junit.Ignore;
import org.junit.Test;

public class TestMNISTExamples extends TestPythonMLCallGraphShape {
Expand Down Expand Up @@ -61,10 +62,14 @@ public void testEx1Tensors() throws IllegalArgumentException, CancelException, I
CG);

String in = "[{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]";

@SuppressWarnings("unused")
String out = "[{[D:Symbolic,?, D:Constant,28, D:Constant,28, D:Constant,1] of pixel}]";
checkTensorOp(cgBuilder, CG, result, "reshape", in, out);

in = "[{[D:Symbolic,?, D:Constant,28, D:Constant,28, D:Constant,1] of pixel}]";
// No change due to the workaround of https://github.com/wala/ML/issues/195.
checkTensorOp(cgBuilder, CG, result, "reshape", in, in);

// No change due to the workaround of https://github.com/wala/ML/issues/195.
checkTensorOp(cgBuilder, CG, result, "conv2d", in, null);
});
}
Expand All @@ -79,6 +84,7 @@ public void testEx2CG()
}

@Test
@Ignore("Workaround https://github.com/wala/ML/issues/195")
public void testEx2Tensors() throws IllegalArgumentException, CancelException, IOException {
checkTensorOps(
Ex2URL,
Expand Down Expand Up @@ -202,6 +208,7 @@ public void testEx4CG()
"https://raw.githubusercontent.com/tensorflow/tensorflow/r1.12/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py";

@Test
@Ignore("Workaround https://github.com/wala/ML/issues/195")
public void testEx5CG()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
checkTensorOps(
Expand Down
9 changes: 3 additions & 6 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,9 @@
</class>
<class name="reshape" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/reshape -->
<method name="copy_data" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/examples/tutorials/mnist/dataset" />
<return value="x" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="3" paramNames="tensor shape name">
<call class="LRoot" name="copy_data" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="input name">
<new def="z" class="Ltensorflow/functions/convert_to_tensor" />
<call class="Ltensorflow/functions/convert_to_tensor" name="do" descriptor="()LRoot;" type="virtual" arg0="z" arg1="input" def="x" />
<return value="x" />
</method>
</class>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,9 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
}

Map<PointsToSetVariable, TensorType> shapeOps = HashMapFactory.make();
shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2));

// Don't handle shape source operations for now to workaround
// https://github.com/wala/ML/issues/195.

Set<PointsToSetVariable> conv2ds = getKeysDefinedByCall(conv2d, builder);

Expand Down

0 comments on commit ce003e0

Please sign in to comment.