Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Nov 15, 2023
2 parents 9906e0f + 2767043 commit 34623b4
Show file tree
Hide file tree
Showing 28 changed files with 532 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package edu.cuny.hunter.hybridize.core.analysis;

public class CantComputeRecursionException extends Exception {

private static final long serialVersionUID = 6647237480647044100L;

public CantComputeRecursionException() {
}

public CantComputeRecursionException(String message) {
super(message);
}

public CantComputeRecursionException(Throwable cause) {
super(cause);
}

public CantComputeRecursionException(String message, Throwable cause) {
super(message, cause);
}

public CantComputeRecursionException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.eclipse.core.resources.IProject;
Expand Down Expand Up @@ -356,12 +357,34 @@ public boolean getReduceRetracingParamExists() {

private RefactoringStatus status = new RefactoringStatus();

private Boolean isRecursive;

private static Map<MethodReference, Map<InstanceKey, Map<CallGraph, Boolean>>> creationsCache = Maps.newHashMap();

public Function(FunctionDefinition fd) {
this.functionDefinition = fd;
}

public void computeRecursion(CallGraph callGraph) throws CantComputeRecursionException {
// Get the nodes representing this function.
Set<CGNode> nodes = this.getNodes(callGraph);

if (nodes.isEmpty())
throw new CantComputeRecursionException("Can't compute if " + this + " is recusive without a call graph node.");

CGNode cgNode = nodes.iterator().next();

if (Util.calls(cgNode, this.getMethodReference(), callGraph)) {
// it's recursive.
LOG.info(this + " is recursive.");
this.setIsRecursive(true);
} else {
// not recursive.
LOG.info(this + " is not recursive.");
this.setIsRecursive(false);
}
}

/**
* Infer Python side-effects potentially produced by executing this {@link Function}.
*
Expand Down Expand Up @@ -879,16 +902,26 @@ public void check() {

if (this.getLikelyHasTensorParameter()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
this.setPassingPrecondition(P1);
} else if (this.getHasPythonSideEffects() != null) // it has side-effects.
if (this.getIsRecursive() != null && !this.getIsRecursive()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
this.setPassingPrecondition(P1);
} else if (this.getIsRecursive() != null) // it's recursive.
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
} else if (this.getHasPythonSideEffects() != null) { // it has side-effects.
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects.");

if (this.getIsRecursive() != null && this.getIsRecursive())
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
}
} else { // no tensor parameters.
this.addFailure(PreconditionFailure.HAS_NO_TENSOR_PARAMETERS,
"This function has no tensor parameters and may not benefit from hybridization.");

if (this.getHasPythonSideEffects() != null && this.getHasPythonSideEffects())
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects.");

if (this.getIsRecursive() != null && this.getIsRecursive())
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
}
} else { // Hybrid. Use table 2.
this.setRefactoring(OPTIMIZE_HYBRID_FUNCTION);
Expand All @@ -908,6 +941,11 @@ public void check() {
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS,
"De-hybridizing a function with Python side-effects may alter semantics.");
}

// Here, we have a hybrid function with a tensor parameter.
if (this.getIsRecursive() != null && this.getIsRecursive()) // if it's recursive.
// issue a warning.
this.addWarning("Recursive tf.functions are not supported by TensorFlow.");
}

// Warn if the function has side-effects.
Expand Down Expand Up @@ -1122,4 +1160,23 @@ public static void clearCaches() {
creationsCache.clear();
}

public Boolean getIsRecursive() {
return this.isRecursive;
}

protected void setIsRecursive(Boolean isRecursive) {
this.isRecursive = isRecursive;
}

private Set<RefactoringStatusEntry> getRefactoringStatusEntries(Predicate<? super RefactoringStatusEntry> predicate) {
return Arrays.stream(this.getStatus().getEntries()).filter(predicate).collect(Collectors.toSet());
}

public Set<RefactoringStatusEntry> getWarnings() {
return this.getRefactoringStatusEntries(RefactoringStatusEntry::isWarning);
}

public Set<RefactoringStatusEntry> getErrors() {
return this.getRefactoringStatusEntries(RefactoringStatusEntry::isError);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@ public enum PreconditionFailure {

HAS_NO_TENSOR_PARAMETERS(6),

HAS_TENSOR_PARAMETERS(7);
HAS_TENSOR_PARAMETERS(7),

/**
* Functions that are recursive can't be hybridized. Also, de-hybridizing hybrid recursive functions may alter semantics.
*/
IS_RECURSIVE(8),

/**
* Can't find the CG node corresponding to the function.
*/
CANT_APPROXIMATE_RECURSION(9);

static {
// check that the codes are unique.
if (Arrays.stream(PreconditionFailure.values()).map(PreconditionFailure::getCode).distinct()
.count() != PreconditionFailure.values().length)
throw new IllegalStateException("Codes aren't unique.");
assert Arrays.stream(PreconditionFailure.values()).map(PreconditionFailure::getCode).distinct()
.count() == PreconditionFailure.values().length : "Codes must be unique.";
}

public static void main(String[] args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import java.io.File;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

import org.eclipse.core.runtime.ILog;
Expand All @@ -30,6 +31,11 @@
import org.python.pydev.parser.visitors.NodeUtils;
import org.python.pydev.shared_core.string.CoreTextSelection;

import com.google.common.collect.Sets;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.types.MethodReference;

public class Util {

private static final ILog LOG = getLog(Util.class);
Expand Down Expand Up @@ -248,4 +254,28 @@ public static boolean isBuiltIn(decoratorsType decorator) {
String decoratorRepresentation = NodeUtils.getRepresentationString(decorator.func);
return decoratorRepresentation.equals("property");
}

public static boolean calls(CGNode node, MethodReference methodReference, CallGraph callGraph) {
Set<MethodReference> seen = Sets.newHashSet();
return callsInternal(node, methodReference, callGraph, seen);
}

private static boolean callsInternal(CGNode node, MethodReference methodReference, CallGraph callGraph, Set<MethodReference> seen) {
seen.add(node.getMethod().getReference());

// check the callees.
for (Iterator<CGNode> succNodes = callGraph.getSuccNodes(node); succNodes.hasNext();) {
CGNode next = succNodes.next();
MethodReference reference = next.getMethod().getReference();

if (methodReference.equals(reference))
return true;

// otherwise, check its callees.
if (!seen.contains(reference) && calls(next, methodReference, callGraph))
return true;
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import edu.cuny.citytech.refactoring.common.core.RefactoringProcessor;
import edu.cuny.citytech.refactoring.common.core.TimeCollector;
import edu.cuny.hunter.hybridize.core.analysis.CantComputeRecursionException;
import edu.cuny.hunter.hybridize.core.analysis.Function;
import edu.cuny.hunter.hybridize.core.analysis.FunctionDefinition;
import edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure;
Expand Down Expand Up @@ -249,6 +250,13 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat
"Can't infer side-effects, most likely due to a call graph issue caused by a decorator or a missing function call.");
}

try {
func.computeRecursion(callGraph);
} catch (CantComputeRecursionException e) {
LOG.warn("Unable to compute whether " + this + " is recursive.", e);
func.addFailure(PreconditionFailure.CANT_APPROXIMATE_RECURSION, "Can't compute whether this function is recursive.");
}

// check the function preconditions.
func.check();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported.

import tensorflow as tf


# @tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1


recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported.

import tensorflow as tf


@tf.function
def recursive_fn(n):
if n > 0:
tf.print('tracing')
return recursive_fn(n - 1)
else:
return 1


recursive_fn(5) # Warning - multiple tracings
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported.

import tensorflow as tf


def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1


def not_recursive_fn(n):
if n > 0:
return abs(n - 1)
elif n <= 0:
return 1
else:
return recursive_fn(n)


not_recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported.

import tensorflow as tf


# @tf.function
def not_recursive_fn(n):
if n > 0:
return abs(n - 1)
else:
return 1


not_recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported

import tensorflow as tf


def recursive_fn2(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1


# @tf.function
def recursive_fn(n):
return recursive_fn2(n)


recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported.

import tensorflow as tf
from nose.tools import assert_raises


@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1


with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.9.3
nose==1.3.7
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported.

import tensorflow as tf


@tf.function
def not_recursive_fn(n):
if n > 0:
return abs(n - 1)
else:
return 1


not_recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported

import tensorflow as tf
from nose.tools import assert_raises


def recursive_fn2(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1


@tf.function
def recursive_fn(n):
return recursive_fn2(n)


with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.9.3
nose==1.3.7
Loading

0 comments on commit 34623b4

Please sign in to comment.