diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/CantComputeRecursionException.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/CantComputeRecursionException.java new file mode 100644 index 000000000..c66475f65 --- /dev/null +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/CantComputeRecursionException.java @@ -0,0 +1,26 @@ +package edu.cuny.hunter.hybridize.core.analysis; + +public class CantComputeRecursionException extends Exception { + + private static final long serialVersionUID = 6647237480647044100L; + + public CantComputeRecursionException() { + } + + public CantComputeRecursionException(String message) { + super(message); + } + + public CantComputeRecursionException(Throwable cause) { + super(cause); + } + + public CantComputeRecursionException(String message, Throwable cause) { + super(message, cause); + } + + public CantComputeRecursionException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } + +} diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java index 1a75f48ab..3e579bc57 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java @@ -18,6 +18,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.eclipse.core.resources.IProject; @@ -356,12 +357,34 @@ public boolean getReduceRetracingParamExists() { private RefactoringStatus status = new RefactoringStatus(); + private Boolean isRecursive; + private static Map>> creationsCache = Maps.newHashMap(); public Function(FunctionDefinition fd) { this.functionDefinition = fd; } + public void computeRecursion(CallGraph callGraph) throws CantComputeRecursionException { + // Get the nodes representing this function. + Set nodes = this.getNodes(callGraph); + + if (nodes.isEmpty()) + throw new CantComputeRecursionException("Can't compute if " + this + " is recusive without a call graph node."); + + CGNode cgNode = nodes.iterator().next(); + + if (Util.calls(cgNode, this.getMethodReference(), callGraph)) { + // it's recursive. + LOG.info(this + " is recursive."); + this.setIsRecursive(true); + } else { + // not recursive. + LOG.info(this + " is not recursive."); + this.setIsRecursive(false); + } + } + /** * Infer Python side-effects potentially produced by executing this {@link Function}. * @@ -879,16 +902,26 @@ public void check() { if (this.getLikelyHasTensorParameter()) { if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) { - this.addTransformation(Transformation.CONVERT_TO_HYBRID); - this.setPassingPrecondition(P1); - } else if (this.getHasPythonSideEffects() != null) // it has side-effects. + if (this.getIsRecursive() != null && !this.getIsRecursive()) { + this.addTransformation(Transformation.CONVERT_TO_HYBRID); + this.setPassingPrecondition(P1); + } else if (this.getIsRecursive() != null) // it's recursive. + this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function."); + } else if (this.getHasPythonSideEffects() != null) { // it has side-effects. this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects."); + + if (this.getIsRecursive() != null && this.getIsRecursive()) + this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function."); + } } else { // no tensor parameters. this.addFailure(PreconditionFailure.HAS_NO_TENSOR_PARAMETERS, "This function has no tensor parameters and may not benefit from hybridization."); if (this.getHasPythonSideEffects() != null && this.getHasPythonSideEffects()) this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects."); + + if (this.getIsRecursive() != null && this.getIsRecursive()) + this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function."); } } else { // Hybrid. Use table 2. this.setRefactoring(OPTIMIZE_HYBRID_FUNCTION); @@ -908,6 +941,11 @@ public void check() { this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "De-hybridizing a function with Python side-effects may alter semantics."); } + + // Here, we have a hybrid function with a tensor parameter. + if (this.getIsRecursive() != null && this.getIsRecursive()) // if it's recursive. + // issue a warning. + this.addWarning("Recursive tf.functions are not supported by TensorFlow."); } // Warn if the function has side-effects. @@ -1122,4 +1160,23 @@ public static void clearCaches() { creationsCache.clear(); } + public Boolean getIsRecursive() { + return this.isRecursive; + } + + protected void setIsRecursive(Boolean isRecursive) { + this.isRecursive = isRecursive; + } + + private Set getRefactoringStatusEntries(Predicate predicate) { + return Arrays.stream(this.getStatus().getEntries()).filter(predicate).collect(Collectors.toSet()); + } + + public Set getWarnings() { + return this.getRefactoringStatusEntries(RefactoringStatusEntry::isWarning); + } + + public Set getErrors() { + return this.getRefactoringStatusEntries(RefactoringStatusEntry::isError); + } } diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/PreconditionFailure.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/PreconditionFailure.java index cfe6eb72c..01032fee5 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/PreconditionFailure.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/PreconditionFailure.java @@ -14,13 +14,22 @@ public enum PreconditionFailure { HAS_NO_TENSOR_PARAMETERS(6), - HAS_TENSOR_PARAMETERS(7); + HAS_TENSOR_PARAMETERS(7), + + /** + * Functions that are recursive can't be hybridized. Also, de-hybridizing hybrid recursive functions may alter semantics. + */ + IS_RECURSIVE(8), + + /** + * Can't find the CG node corresponding to the function. + */ + CANT_APPROXIMATE_RECURSION(9); static { // check that the codes are unique. - if (Arrays.stream(PreconditionFailure.values()).map(PreconditionFailure::getCode).distinct() - .count() != PreconditionFailure.values().length) - throw new IllegalStateException("Codes aren't unique."); + assert Arrays.stream(PreconditionFailure.values()).map(PreconditionFailure::getCode).distinct() + .count() == PreconditionFailure.values().length : "Codes must be unique."; } public static void main(String[] args) { diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Util.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Util.java index bf7093abf..a1aff1cf1 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Util.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Util.java @@ -4,6 +4,7 @@ import java.io.File; import java.util.HashSet; +import java.util.Iterator; import java.util.Set; import org.eclipse.core.runtime.ILog; @@ -30,6 +31,11 @@ import org.python.pydev.parser.visitors.NodeUtils; import org.python.pydev.shared_core.string.CoreTextSelection; +import com.google.common.collect.Sets; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.CallGraph; +import com.ibm.wala.types.MethodReference; + public class Util { private static final ILog LOG = getLog(Util.class); @@ -248,4 +254,28 @@ public static boolean isBuiltIn(decoratorsType decorator) { String decoratorRepresentation = NodeUtils.getRepresentationString(decorator.func); return decoratorRepresentation.equals("property"); } + + public static boolean calls(CGNode node, MethodReference methodReference, CallGraph callGraph) { + Set seen = Sets.newHashSet(); + return callsInternal(node, methodReference, callGraph, seen); + } + + private static boolean callsInternal(CGNode node, MethodReference methodReference, CallGraph callGraph, Set seen) { + seen.add(node.getMethod().getReference()); + + // check the callees. + for (Iterator succNodes = callGraph.getSuccNodes(node); succNodes.hasNext();) { + CGNode next = succNodes.next(); + MethodReference reference = next.getMethod().getReference(); + + if (methodReference.equals(reference)) + return true; + + // otherwise, check its callees. + if (!seen.contains(reference) && calls(next, methodReference, callGraph)) + return true; + } + + return false; + } } diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java index 2be327b47..6f8addff9 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java @@ -40,6 +40,7 @@ import edu.cuny.citytech.refactoring.common.core.RefactoringProcessor; import edu.cuny.citytech.refactoring.common.core.TimeCollector; +import edu.cuny.hunter.hybridize.core.analysis.CantComputeRecursionException; import edu.cuny.hunter.hybridize.core.analysis.Function; import edu.cuny.hunter.hybridize.core.analysis.FunctionDefinition; import edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure; @@ -249,6 +250,13 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat "Can't infer side-effects, most likely due to a call graph issue caused by a decorator or a missing function call."); } + try { + func.computeRecursion(callGraph); + } catch (CantComputeRecursionException e) { + LOG.warn("Unable to compute whether " + this + " is recursive.", e); + func.addFailure(PreconditionFailure.CANT_APPROXIMATE_RECURSION, "Can't compute whether this function is recursive."); + } + // check the function preconditions. func.check(); diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion/in/A.py new file mode 100644 index 000000000..26b016fe8 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion/in/A.py @@ -0,0 +1,14 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf + + +# @tf.function +def recursive_fn(n): + if n > 0: + return recursive_fn(n - 1) + else: + return 1 + + +recursive_fn(tf.constant(5)) # Bad - maximum recursion error. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion10/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion10/in/A.py new file mode 100644 index 000000000..7a684b31f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion10/in/A.py @@ -0,0 +1,15 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf + + +@tf.function +def recursive_fn(n): + if n > 0: + tf.print('tracing') + return recursive_fn(n - 1) + else: + return 1 + + +recursive_fn(5) # Warning - multiple tracings diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion10/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion10/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion10/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion11/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion11/in/A.py new file mode 100644 index 000000000..8d01c2c34 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion11/in/A.py @@ -0,0 +1,22 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf + + +def recursive_fn(n): + if n > 0: + return recursive_fn(n - 1) + else: + return 1 + + +def not_recursive_fn(n): + if n > 0: + return abs(n - 1) + elif n <= 0: + return 1 + else: + return recursive_fn(n) + + +not_recursive_fn(tf.constant(5)) # Bad - maximum recursion error. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion11/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion11/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion11/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion2/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion2/in/A.py new file mode 100644 index 000000000..25a12a873 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion2/in/A.py @@ -0,0 +1,14 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf + + +# @tf.function +def not_recursive_fn(n): + if n > 0: + return abs(n - 1) + else: + return 1 + + +not_recursive_fn(tf.constant(5)) # Bad - maximum recursion error. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion2/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion2/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion2/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion3/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion3/in/A.py new file mode 100644 index 000000000..5996deb79 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion3/in/A.py @@ -0,0 +1,18 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported + +import tensorflow as tf + + +def recursive_fn2(n): + if n > 0: + return recursive_fn(n - 1) + else: + return 1 + + +# @tf.function +def recursive_fn(n): + return recursive_fn2(n) + + +recursive_fn(tf.constant(5)) # Bad - maximum recursion error. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion3/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion3/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion3/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion4/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion4/in/A.py new file mode 100644 index 000000000..9fc76f733 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion4/in/A.py @@ -0,0 +1,16 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf +from nose.tools import assert_raises + + +@tf.function +def recursive_fn(n): + if n > 0: + return recursive_fn(n - 1) + else: + return 1 + + +with assert_raises(Exception): + recursive_fn(tf.constant(5)) # Bad - maximum recursion error. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion4/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion4/in/requirements.txt new file mode 100644 index 000000000..56020dd04 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion4/in/requirements.txt @@ -0,0 +1,2 @@ +tensorflow==2.9.3 +nose==1.3.7 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion5/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion5/in/A.py new file mode 100644 index 000000000..b36ffb7cf --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion5/in/A.py @@ -0,0 +1,14 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf + + +@tf.function +def not_recursive_fn(n): + if n > 0: + return abs(n - 1) + else: + return 1 + + +not_recursive_fn(tf.constant(5)) # Bad - maximum recursion error. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion5/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion5/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion5/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion6/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion6/in/A.py new file mode 100644 index 000000000..8f2359e2e --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion6/in/A.py @@ -0,0 +1,20 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported + +import tensorflow as tf +from nose.tools import assert_raises + + +def recursive_fn2(n): + if n > 0: + return recursive_fn(n - 1) + else: + return 1 + + +@tf.function +def recursive_fn(n): + return recursive_fn2(n) + + +with assert_raises(Exception): + recursive_fn(tf.constant(5)) # Bad - maximum recursion error. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion6/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion6/in/requirements.txt new file mode 100644 index 000000000..56020dd04 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion6/in/requirements.txt @@ -0,0 +1,2 @@ +tensorflow==2.9.3 +nose==1.3.7 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion7/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion7/in/A.py new file mode 100644 index 000000000..56e828be5 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion7/in/A.py @@ -0,0 +1,12 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + + +def recursive_fn(n): + if n > 0: + print('tracing') + return recursive_fn(n - 1) + else: + return 1 + + +recursive_fn(5) # Warning - multiple tracings diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion7/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion7/in/requirements.txt new file mode 100644 index 000000000..e69de29bb diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion8/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion8/in/A.py new file mode 100644 index 000000000..3dee5da34 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion8/in/A.py @@ -0,0 +1,15 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf + + +@tf.function +def recursive_fn(n): + if n > 0: + print('tracing') + return recursive_fn(n - 1) + else: + return 1 + + +recursive_fn(5) # Warning - multiple tracings diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion8/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion8/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion8/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion9/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion9/in/A.py new file mode 100644 index 000000000..6b8daaa47 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion9/in/A.py @@ -0,0 +1,14 @@ +# From https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported. + +import tensorflow as tf + + +def recursive_fn(n): + if n > 0: + tf.print('tracing') + return recursive_fn(n - 1) + else: + return 1 + + +recursive_fn(5) # Warning - multiple tracings diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion9/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion9/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testRecursion9/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java index b55468297..08aa71308 100644 --- a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java +++ b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java @@ -1,9 +1,15 @@ package edu.cuny.hunter.hybridize.tests; +import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_NO_TENSOR_PARAMETERS; +import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS; +import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_TENSOR_PARAMETERS; +import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.IS_RECURSIVE; import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P1; import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P2; +import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID; import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.OPTIMIZE_HYBRID_FUNCTION; import static edu.cuny.hunter.hybridize.core.analysis.Transformation.CONVERT_TO_EAGER; +import static edu.cuny.hunter.hybridize.core.analysis.Transformation.CONVERT_TO_HYBRID; import static org.eclipse.core.runtime.Platform.getLog; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -1882,7 +1888,7 @@ public void testHasLikelyTensorParameter10() throws Exception { String attributeName = NodeUtils.getFullRepresentationString(typeHint); assertEquals("tf.Tensor", attributeName); - // TODO: Set to assertFalse() when #111 is fixed. + // NOTE: Set to assertFalse() when #111 is fixed. assertTrue(function.getLikelyHasTensorParameter()); } @@ -4359,7 +4365,7 @@ public void testHasLikelyTensorParameter143() throws Exception { */ @Test public void testHasLikelyTensorParameter144() throws Exception { - // TODO: Change to false, true once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/265 is fixed. + // NOTE: Change to false, true once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/265 is fixed. testHasLikelyTensorParameterHelper(false, false); } @@ -5260,6 +5266,7 @@ public void testPythonSideEffects47() throws Exception { assertTrue(capturesLeakedTensor.getStatus().hasError()); assertFalse(capturesLeakedTensor.getStatus().hasFatalError()); RefactoringStatusEntry error = capturesLeakedTensor.getStatus().getEntryMatchingSeverity(RefactoringStatus.ERROR); + assertTrue(error.isError()); assertEquals(PreconditionFailure.HAS_TENSOR_PARAMETERS.getCode(), error.getCode()); // NOTE: Change to assertEquals(..., 1, ...) once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed. @@ -5458,5 +5465,204 @@ public void testPythonSideEffects58() throws Exception { assertFalse("Decorated embedded functions aren't side-effects.", f.getHasPythonSideEffects()); } - // TODO: Left off at: https://www.tensorflow.org/guide/function#recursive_tffunctions_are_not_supported + @Test + public void testRecursion() throws Exception { + Function f = getFunction("recursive_fn"); + + assertTrue(f.getIsRecursive()); + + assertFalse(f.getIsHybrid()); + assertEquals(Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring()); + + assertTrue("No recursive functions.", f.getStatus().hasError()); + assertTrue(f.getEntryMatchingFailure(IS_RECURSIVE).isError()); + } + + @Test + public void testRecursion2() throws Exception { + Function f = getFunction("not_recursive_fn"); + + assertFalse(f.getIsHybrid()); // eag. + assertEquals(Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring()); + + assertTrue(f.getLikelyHasTensorParameter()); // T. + assertNull(f.getEntryMatchingFailure(PreconditionFailure.HAS_NO_TENSOR_PARAMETERS)); + + assertFalse(f.getHasPythonSideEffects()); // F. + assertNull(f.getEntryMatchingFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS)); + + assertFalse(f.getIsRecursive()); // F. + assertNull(f.getEntryMatchingFailure(PreconditionFailure.IS_RECURSIVE)); + + assertTrue(f.getStatus().isOK()); + assertEquals(P1, f.getPassingPrecondition()); + assertEquals(Collections.singleton(CONVERT_TO_HYBRID), f.getTransformations()); + } + + @Test + public void testRecursion3() throws Exception { + Function f = getFunction("recursive_fn"); + assertEquals(Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring()); + + assertTrue(f.getIsRecursive()); + assertTrue("No (transitively) recursive functions.", f.getStatus().hasError()); + assertTrue(f.getEntryMatchingFailure(IS_RECURSIVE).isError()); + } + + @Test + public void testRecursion4() throws Exception { + Function f = getFunction("recursive_fn"); + + assertTrue(f.getIsHybrid()); // hyb. + assertEquals(Refactoring.OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring()); + + assertTrue(f.getLikelyHasTensorParameter()); // T. + assertTrue(f.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS).isError()); + + assertFalse(f.getHasPythonSideEffects()); // F. + assertNull(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS)); + + assertTrue(f.getIsRecursive()); // T. + assertNull(f.getEntryMatchingFailure(IS_RECURSIVE)); + + assertEquals("We have a recursive hybrid function with a tensor parameter. Warn.", 1, f.getWarnings().size()); + + } + + @Test + public void testRecursion5() throws Exception { + Function f = getFunction("not_recursive_fn"); + + assertTrue(f.getIsHybrid()); + assertEquals(OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring()); + + assertTrue(f.getLikelyHasTensorParameter()); + assertFalse("Already optimal.", f.getStatus().isOK()); + assertTrue(f.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS).isError()); + + assertFalse(f.getHasPythonSideEffects()); // F. + assertNull(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS)); + + assertFalse(f.getIsRecursive()); // F. + assertNull(f.getEntryMatchingFailure(IS_RECURSIVE)); + + assertTrue("We have a non-recursive hybrid function with a tensor parameter. No warning.", f.getWarnings().isEmpty()); + + } + + @Test + public void testRecursion6() throws Exception { + Function f = getFunction("recursive_fn"); + + assertTrue(f.getIsHybrid()); // hyb. + assertEquals(OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring()); + + assertTrue(f.getLikelyHasTensorParameter()); // T. + assertFalse("Already optimal.", f.getStatus().isOK()); + assertTrue(f.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS).isError()); + + assertFalse(f.getHasPythonSideEffects()); // F. + assertNull(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS)); + + assertTrue(f.getIsRecursive()); // T. + assertNull(f.getEntryMatchingFailure(IS_RECURSIVE)); + + assertEquals("We have a recursive hybrid function with a tensor parameter. Warn.", 1, f.getWarnings().size()); + } + + @Test + public void testRecursion7() throws Exception { + Function f = getFunction("recursive_fn"); + + assertFalse(f.getIsHybrid()); + assertEquals(CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring()); + + assertFalse(f.getLikelyHasTensorParameter()); + assertTrue(f.getEntryMatchingFailure(HAS_NO_TENSOR_PARAMETERS).isError()); + + assertTrue(f.getHasPythonSideEffects()); // T. + assertTrue(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS).isError()); + + assertTrue(f.getIsRecursive()); + assertTrue("No recursive functions.", f.getStatus().hasError()); + assertTrue(f.getEntryMatchingFailure(IS_RECURSIVE).isError()); + + assertTrue(f.getWarnings().isEmpty()); + } + + @Test + public void testRecursion8() throws Exception { + Function f = getFunction("recursive_fn"); + + assertTrue(f.getIsHybrid()); // hyb. + assertEquals(OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring()); + + assertFalse(f.getLikelyHasTensorParameter()); // F. + assertNull(f.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS)); + assertNull("Not having tensor parameters is not a failure for: " + OPTIMIZE_HYBRID_FUNCTION + ".", + f.getEntryMatchingFailure(HAS_NO_TENSOR_PARAMETERS)); + + assertTrue(f.getHasPythonSideEffects()); // T. + assertTrue(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS).isError()); + + assertTrue(f.getIsRecursive()); // T. + assertNull("Because there is no tensor parameter, it doesn't matter if it's recursive or not.", + f.getEntryMatchingFailure(IS_RECURSIVE)); + + assertEquals("No tensor parameter. No warning. The warning currently is from side-effects", 1, f.getWarnings().size()); + } + + @Test + public void testRecursion9() throws Exception { + Function f = getFunction("recursive_fn"); + + assertFalse(f.getIsHybrid()); + assertEquals(CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring()); + + assertFalse(f.getLikelyHasTensorParameter()); + assertTrue(f.getEntryMatchingFailure(HAS_NO_TENSOR_PARAMETERS).isError()); + + assertFalse(f.getHasPythonSideEffects()); + assertNull(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS)); + + assertTrue(f.getIsRecursive()); + assertTrue("No recursive functions.", f.getStatus().hasError()); + assertTrue(f.getEntryMatchingFailure(IS_RECURSIVE).isError()); + + assertTrue(f.getWarnings().isEmpty()); + } + + @Test + public void testRecursion10() throws Exception { + Function f = getFunction("recursive_fn"); + + assertTrue(f.getIsHybrid()); // hyb + assertEquals(OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring()); + + assertFalse(f.getLikelyHasTensorParameter()); // F. + assertNull(f.getEntryMatchingFailure(HAS_NO_TENSOR_PARAMETERS)); + assertNull(f.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS)); + + assertFalse(f.getHasPythonSideEffects()); // F. + assertNull(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS)); + + assertTrue(f.getIsRecursive()); + assertNull("Because there is no tensor parameter, it doesn't matter if it's recursive or not.", + f.getEntryMatchingFailure(IS_RECURSIVE)); + + assertTrue(f.getWarnings().isEmpty()); + + assertEquals(P2, f.getPassingPrecondition()); + assertEquals(Collections.singleton(CONVERT_TO_EAGER), f.getTransformations()); + } + + @Test + public void testRecursion11() throws Exception { + Function f = getFunction("recursive_fn"); + assertEquals(Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring()); + + assertTrue(f.getIsRecursive()); + assertTrue("No (transitively) recursive functions.", f.getStatus().hasError()); + assertTrue(f.getEntryMatchingFailure(IS_RECURSIVE).isError()); + } }