Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NN test. #318

Merged
merged 6 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,13 @@ public boolean getReduceRetracingParamExists() {

private Boolean isRecursive;

private boolean ignoreBooleans;

private static Map<MethodReference, Map<InstanceKey, Map<CallGraph, Boolean>>> creationsCache = Maps.newHashMap();

public Function(FunctionDefinition fd) {
public Function(FunctionDefinition fd, boolean ignoreBooleans) {
this.functionDefinition = fd;
this.ignoreBooleans = ignoreBooleans;
}

public void computeRecursion(CallGraph callGraph) throws CantComputeRecursionException {
Expand Down Expand Up @@ -440,7 +443,8 @@ public void inferPrimitiveParameters(CallGraph callGraph, PointerAnalysis<Instan
for (InstanceKey instanceKey : pointsToSet) {
LOG.info("Parameter of: " + this + " with index: " + paramInx + " points to: " + instanceKey + ".");

allInstancesArePrimitive &= containsPrimitive(instanceKey, pointerAnalysis, subMonitor.split(1));
allInstancesArePrimitive &= containsPrimitive(instanceKey, this.getIgnoreBooleans(), pointerAnalysis,
subMonitor.split(1));
subMonitor.worked(1);
}

Expand All @@ -462,17 +466,34 @@ public void inferPrimitiveParameters(CallGraph callGraph, PointerAnalysis<Instan
subMonitor.done();
}

private boolean containsPrimitive(InstanceKey instanceKey, PointerAnalysis<InstanceKey> pointerAnalysis, IProgressMonitor monitor) {
/**
* Returns true iff the given {@link InstanceKey} takes on primitive values.
*
* @param instanceKey The {@link InstanceKey} in question.
* @param ignoreBooleans True iff boolean values should not be considered.
* @param pointerAnalysis The {@link PointerAnalysis} corresponding to the given {@link InstanceKey}.
* @param monitor To monitor progress.
* @return True iff the given {@link InstanceKey} takes on primitive values according to the given {@link PointerAnalysis}.
*/
private static boolean containsPrimitive(InstanceKey instanceKey, boolean ignoreBooleans, PointerAnalysis<InstanceKey> pointerAnalysis,
IProgressMonitor monitor) {
SubMonitor subMonitor = SubMonitor.convert(monitor, "Examining instance...", 1);

if (instanceKey instanceof ConstantKey<?>) {
ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey;
Object constantValue = constantKey.getValue();

if (constantValue != null) {
LOG.info("Found constant value: " + constantValue + " for parameter of: " + this + ".");
subMonitor.done();
return true;
LOG.info("Found constant value: " + constantValue + ".");

boolean foundBooleanValue = constantValue.equals(TRUE) || constantValue.equals(FALSE);

// If it's not the case that we found a boolean value and we are ignoring booleans.
if (!(foundBooleanValue && ignoreBooleans)) {
// We have found a primitive.
subMonitor.done();
return true;
}
}
} else if (instanceKey instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) instanceKey;
Expand All @@ -489,7 +510,7 @@ private boolean containsPrimitive(InstanceKey instanceKey, PointerAnalysis<Insta
subMonitor.beginTask("Examining instance field instances...", instanceFieldPointsToSet.size());

for (InstanceKey key : instanceFieldPointsToSet)
if (containsPrimitive(key, pointerAnalysis, subMonitor.split(1))) {
if (containsPrimitive(key, ignoreBooleans, pointerAnalysis, subMonitor.split(1))) {
subMonitor.done();
return true;
}
Expand Down Expand Up @@ -1515,4 +1536,13 @@ public boolean isMethod() {

return false;
}

/**
* True iff booleans shouldn't be considered primitives.
*
* @return True iff boolean values shouldn't be considered primitives.
*/
protected boolean getIgnoreBooleans() {
return ignoreBooleans;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ private static RefactoringStatus checkParameters(Function func) {

private boolean alwaysCheckRecursion;

private boolean processFunctionsInParallel = true;
private boolean ignoreBooleansInLiteralCheck = true;

private boolean processFunctionsInParallel;

public HybridizeFunctionRefactoringProcessor() {
// Force the use of typeshed. It's an experimental feature of PyDev.
Expand All @@ -132,6 +134,12 @@ public HybridizeFunctionRefactoringProcessor(boolean alwaysCheckPythonSideEffect
this.alwaysCheckRecursion = alwaysCheckRecusion;
}

public HybridizeFunctionRefactoringProcessor(boolean alwaysCheckPythonSideEffects, boolean processFunctionsInParallel,
boolean alwaysCheckRecusion, boolean ignoreBooleansInLiteralCheck) {
this(alwaysCheckPythonSideEffects, processFunctionsInParallel, alwaysCheckRecusion);
this.ignoreBooleansInLiteralCheck = ignoreBooleansInLiteralCheck;
}

public HybridizeFunctionRefactoringProcessor(Set<FunctionDefinition> functionDefinitionSet)
throws TooManyMatchesException /* FIXME: This exception sounds too low-level. */ {
this();
Expand All @@ -141,7 +149,7 @@ public HybridizeFunctionRefactoringProcessor(Set<FunctionDefinition> functionDef
Set<Function> functionSet = this.getFunctions();

for (FunctionDefinition fd : functionDefinitionSet) {
Function function = new Function(fd);
Function function = new Function(fd, this.ignoreBooleansInLiteralCheck);

// Add the Function to the Function set.
functionSet.add(function);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def f(a):
pass


f(True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def f(a):
pass


f(False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def f(a, b):
pass


f(5, True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def f(a, b):
pass


f(6, False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def f(a):
pass


f(7)
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# From https://github.com/aymericdamien/TensorFlow-Examples/blob/6dcbe14649163814e72a22a999f20c5e247ce988/tensorflow_v2/notebooks/3_NeuralNetworks/neural_network.ipynb.

# %%
# # Neural Network Example

# Build a 2-hidden layers fully connected neural network (a.k.a multilayer perceptron) with TensorFlow v2.

# This example is using a low-level approach to better understand all mechanics behind building neural networks and the training process.

# - Author: Aymeric Damien
# - Project: https://github.com/aymericdamien/TensorFlow-Examples/
# """

# %%
# ## Neural Network Overview

# <img src="http://cs231n.github.io/assets/nn1/neural_net2.jpeg" alt="nn" style="width: 400px;"/>

# ## MNIST Dataset Overview

# This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 255.

# In this example, each image will be converted to float32, normalized to [0, 1] and flattened to a 1-D array of 784 features (28*28).

# ![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)

# More info: http://yann.lecun.com/exdb/mnist/

# %%
from __future__ import absolute_import, division, print_function

import tensorflow as tf
print("TensorFlow version:", tf.__version__)
assert(tf.__version__ == "2.9.3")
from tensorflow.keras import Model, layers
import numpy as np
import timeit

start_time = timeit.default_timer()
skipped_time = 0

# %%
# MNIST dataset parameters.
num_classes = 10 # total classes (0-9 digits).
num_features = 784 # data features (img shape: 28*28).

# Training parameters.
learning_rate = 0.1
training_steps = 1
batch_size = 256
display_step = 100

# Network parameters.
n_hidden_1 = 128 # 1st layer number of neurons.
n_hidden_2 = 256 # 2nd layer number of neurons.

# %%
# Prepare MNIST data.
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Convert to float32.
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
# Flatten images to 1-D vector of 784 features (28*28).
x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])
# Normalize images value from [0, 255] to [0, 1].
x_train, x_test = x_train / 255., x_test / 255.

# %%
# Use tf.data API to shuffle and batch data.
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)


# %%
# Create TF Model.
class NeuralNet(Model):

# Set layers.
def __init__(self):
super(NeuralNet, self).__init__()
# First fully-connected hidden layer.
self.fc1 = layers.Dense(n_hidden_1, activation=tf.nn.relu)
# First fully-connected hidden layer.
self.fc2 = layers.Dense(n_hidden_2, activation=tf.nn.relu)
# Second fully-connecter hidden layer.
self.out = layers.Dense(num_classes)

# Set forward pass.
def call(self, x, is_training=False):
x = self.fc1(x)
x = self.fc2(x)
x = self.out(x)
if not is_training:
# tf cross entropy expect logits without softmax, so only
# apply softmax when not training.
x = tf.nn.softmax(x)
return x


# Build neural network model.
neural_net = NeuralNet()


# %%
# Cross-Entropy Loss.
# Note that this will apply 'softmax' to the logits.
def cross_entropy_loss(x, y):
# Convert labels to int 64 for tf cross-entropy function.
y = tf.cast(y, tf.int64)
# Apply softmax to logits and compute cross-entropy.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
# Average loss across the batch.
return tf.reduce_mean(loss)


# Accuracy metric.
def accuracy(y_pred, y_true):
# Predicted class is the index of highest score in prediction vector (i.e. argmax).
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)


# Stochastic gradient descent optimizer.
optimizer = tf.optimizers.SGD(learning_rate)


# %%
# Optimization process.
def run_optimization(x, y):
# Wrap computation inside a GradientTape for automatic differentiation.
with tf.GradientTape() as g:
# Forward pass.
pred = neural_net(x, is_training=True)
# Compute loss.
loss = cross_entropy_loss(pred, y)

# Variables to update, i.e. trainable variables.
trainable_variables = neural_net.trainable_variables

# Compute gradients.
gradients = g.gradient(loss, trainable_variables)

# Update W and b following gradients.
optimizer.apply_gradients(zip(gradients, trainable_variables))


# %%
# Run training for the given number of steps.
for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
# Run the optimization to update W and b values.
run_optimization(batch_x, batch_y)

if step % display_step == 0:
pred = neural_net(batch_x, is_training=True)
loss = cross_entropy_loss(pred, batch_y)
acc = accuracy(pred, batch_y)
print_time = timeit.default_timer()
print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))
skipped_time += timeit.default_timer() - print_time

# %%
# Test model on validation set.
pred = neural_net(x_test, is_training=False)
print_time = timeit.default_timer()
print("Test Accuracy: %f" % accuracy(pred, y_test))
skipped_time += timeit.default_timer() - print_time

# %%
# Visualize predictions.
import matplotlib.pyplot as plt

# %%
# Predict 5 images from validation set.
n_images = 5
test_images = x_test[:n_images]
predictions = neural_net(test_images)

print("Elapsed time: ", timeit.default_timer() - start_time - skipped_time)

# Display image and model prediction.
for i in range(n_images):
# plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray')
# plt.show()
print("Model prediction: %i" % np.argmax(predictions.numpy()[i]))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Loading