diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestMNISTExamples.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestMNISTExamples.java index 54967bcfa..24d71352b 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestMNISTExamples.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestMNISTExamples.java @@ -27,6 +27,7 @@ import com.ibm.wala.util.collections.HashSetFactory; import java.io.IOException; import java.util.Set; +import org.junit.Ignore; import org.junit.Test; public class TestMNISTExamples extends TestPythonMLCallGraphShape { @@ -61,10 +62,14 @@ public void testEx1Tensors() throws IllegalArgumentException, CancelException, I CG); String in = "[{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]"; + + @SuppressWarnings("unused") String out = "[{[D:Symbolic,?, D:Constant,28, D:Constant,28, D:Constant,1] of pixel}]"; - checkTensorOp(cgBuilder, CG, result, "reshape", in, out); - in = "[{[D:Symbolic,?, D:Constant,28, D:Constant,28, D:Constant,1] of pixel}]"; + // No change due to the workaround of https://github.com/wala/ML/issues/195. + checkTensorOp(cgBuilder, CG, result, "reshape", in, in); + + // No change due to the workaround of https://github.com/wala/ML/issues/195. checkTensorOp(cgBuilder, CG, result, "conv2d", in, null); }); } @@ -79,6 +84,7 @@ public void testEx2CG() } @Test + @Ignore("Workaround https://github.com/wala/ML/issues/195") public void testEx2Tensors() throws IllegalArgumentException, CancelException, IOException { checkTensorOps( Ex2URL, @@ -202,6 +208,7 @@ public void testEx4CG() "https://raw.githubusercontent.com/tensorflow/tensorflow/r1.12/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py"; @Test + @Ignore("Workaround https://github.com/wala/ML/issues/195") public void testEx5CG() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { checkTensorOps( diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index 57f46b185..86d23d74f 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -401,12 +401,9 @@ - - - - - - + + + diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index f09513ea1..4ad164e98 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -632,7 +632,9 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) } Map shapeOps = HashMapFactory.make(); - shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2)); + + // Don't handle shape source operations for now to workaround + // https://github.com/wala/ML/issues/195. Set conv2ds = getKeysDefinedByCall(conv2d, builder);