Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track tensors returned by tf.reshape() #105

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ public void testEx1Tensors() throws IllegalArgumentException, CancelException, I
checkTensorOps(
Ex1URL,
(PropagationCallGraphBuilder cgBuilder, CallGraph CG, TensorTypeAnalysis result) -> {
CAstCallGraphUtil.AVOID_DUMP = false;
CAstCallGraphUtil.dumpCG(
(SSAContextInterpreter) cgBuilder.getContextInterpreter(),
cgBuilder.getPointerAnalysis(),
CG);

String in = "[{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]";
String out = "[{[D:Symbolic,?, D:Constant,28, D:Constant,28, D:Constant,1] of pixel}]";
checkTensorOp(cgBuilder, CG, result, "reshape", in, out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2759,6 +2759,21 @@ public void testDecoratedFunctions()
test("tf2_test_decorated_functions.py", "test_function4", 1, 1, 2);
}

@Test
public void testReshape() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_reshape.py", "f", 1, 1, 2);
}

@Test
public void testReshape2() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_reshape2.py", "f", 1, 1, 2);
}

@Test
public void testReshape3() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_reshape3.py", "f", 1, 1, 2);
}

private void test(
String filename,
String functionName,
Expand Down
4 changes: 3 additions & 1 deletion com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
<putfield class="LRoot" field="random" fieldType="LRoot" ref="Dataset" value="dsrandom" />
<new def="from_tensors" class="Ltensorflow/data/Dataset/from_tensors" />
<putfield class="LRoot" field="from_tensors" fieldType="LRoot" ref="Dataset" value="from_tensors" />
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/reshape -->
<new def="reshape" class="Ltensorflow/functions/reshape" />
<putfield class="LRoot" field="reshape" fieldType="LRoot" ref="x" value="reshape" />
<new def="conv2d" class="Ltensorflow/functions/conv2d" />
Expand Down Expand Up @@ -399,11 +400,12 @@
</method>
</class>
<class name="reshape" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/reshape -->
<method name="copy_data" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/examples/tutorials/mnist/dataset" />
<return value="x" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="3">
<method name="do" descriptor="()LRoot;" numArgs="3" paramNames="tensor shape name">
<call class="LRoot" name="copy_data" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<return value="x" />
</method>
Expand Down
12 changes: 12 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/reshape

import tensorflow as tf


def f(a):
pass


t1 = tf.ones([2, 3])
t2 = tf.reshape(t1, [6])
f(t2)
145 changes: 145 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_reshape2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# https://raw.githubusercontent.com/aymericdamien/TensorFlow-Examples/dd2e6dcd9603d5de008d8c766453162d0204affa/examples/3_NeuralNetworks/convolutional_network.py
""" Convolutional Neural Network.

Build and train a convolutional neural network with TensorFlow.
This example is using the MNIST database of handwritten digits
(http://yann.lecun.com/exdb/mnist/)

This example is using TensorFlow layers API, see 'convolutional_network_raw'
example for a raw implementation with variables.

Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
"""
from __future__ import division, print_function, absolute_import

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)

import tensorflow as tf

# Training Parameters
learning_rate = 0.001
num_steps = 2000
batch_size = 128

# Network Parameters
num_input = 784 # MNIST data input (img shape: 28*28)
num_classes = 10 # MNIST total classes (0-9 digits)
dropout = 0.75 # Dropout, probability to keep units


def f(a):
pass


# Create the neural network
def conv_net(x_dict, n_classes, dropout, reuse, is_training):
# Define a scope for reusing the variables
with tf.variable_scope("ConvNet", reuse=reuse):
# TF Estimator input is a dict, in case of multiple inputs
x = x_dict["images"]

# MNIST data input is a 1-D vector of 784 features (28*28 pixels)
# Reshape to match picture format [Height x Width x Channel]
# Tensor input become 4-D: [Batch Size, Height, Width, Channel]
x = tf.reshape(x, shape=[-1, 28, 28, 1])
f(x)

# Convolution Layer with 32 filters and a kernel size of 5
conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)
# Max Pooling (down-sampling) with strides of 2 and kernel size of 2
conv1 = tf.layers.max_pooling2d(conv1, 2, 2)

# Convolution Layer with 64 filters and a kernel size of 3
conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu)
# Max Pooling (down-sampling) with strides of 2 and kernel size of 2
conv2 = tf.layers.max_pooling2d(conv2, 2, 2)

# Flatten the data to a 1-D vector for the fully connected layer
fc1 = tf.contrib.layers.flatten(conv2)

# Fully connected layer (in tf contrib folder for now)
fc1 = tf.layers.dense(fc1, 1024)
# Apply Dropout (if is_training is False, dropout is not applied)
fc1 = tf.layers.dropout(fc1, rate=dropout, training=is_training)

# Output layer, class prediction
out = tf.layers.dense(fc1, n_classes)

return out


# Define the model function (following TF Estimator Template)
def model_fn(features, labels, mode):
# Build the neural network
# Because Dropout have different behavior at training and prediction time, we
# need to create 2 distinct computation graphs that still share the same weights.
logits_train = conv_net(
features, num_classes, dropout, reuse=False, is_training=True
)
logits_test = conv_net(
features, num_classes, dropout, reuse=True, is_training=False
)

# Predictions
pred_classes = tf.argmax(logits_test, axis=1)
pred_probas = tf.nn.softmax(logits_test)

# If prediction mode, early return
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode, predictions=pred_classes)

# Define loss and optimizer
loss_op = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits_train, labels=tf.cast(labels, dtype=tf.int32)
)
)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op, global_step=tf.train.get_global_step())

# Evaluate the accuracy of the model
acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes)

# TF Estimators requires to return a EstimatorSpec, that specify
# the different ops for training, evaluating, ...
estim_specs = tf.estimator.EstimatorSpec(
mode=mode,
predictions=pred_classes,
loss=loss_op,
train_op=train_op,
eval_metric_ops={"accuracy": acc_op},
)

return estim_specs


# Build the Estimator
model = tf.estimator.Estimator(model_fn)

# Define the input function for training
input_fn = tf.estimator.inputs.numpy_input_fn(
x={"images": mnist.train.images},
y=mnist.train.labels,
batch_size=batch_size,
num_epochs=None,
shuffle=True,
)
# Train the Model
model.train(input_fn, steps=num_steps)

# Evaluate the Model
# Define the input function for evaluating
input_fn = tf.estimator.inputs.numpy_input_fn(
x={"images": mnist.test.images},
y=mnist.test.labels,
batch_size=batch_size,
shuffle=False,
)
# Use the Estimator 'evaluate' method
e = model.evaluate(input_fn)

print("Testing Accuracy:", e["accuracy"])
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_reshape3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/reshape

import tensorflow as tf


def f(a):
pass


t1 = [[1, 2, 3], [4, 5, 6]]

t2 = tf.reshape(t1, [6])
f(t2)
Loading