Skip to content

Commit

Permalink
Add precondition checking for eager case for primitives.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Dec 13, 2023
1 parent 32f682d commit f9441ef
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package edu.cuny.hunter.hybridize.core.analysis;

import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PRIMITIVE_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P1;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P2;
import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID;
Expand Down Expand Up @@ -1171,14 +1172,26 @@ public void check() {
this.setRefactoring(CONVERT_EAGER_FUNCTION_TO_HYBRID);

if (this.getLikelyHasTensorParameter() != null && this.getLikelyHasTensorParameter()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
if (this.getIsRecursive() != null && !this.getIsRecursive()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
this.setPassingPrecondition(P1);
} else if (this.getIsRecursive() != null) // it's recursive.
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
} else if (this.getHasPythonSideEffects() != null) { // it has side-effects.
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects.");
if (this.getLikelyHasPrimitiveParameters() != null && !this.getLikelyHasPrimitiveParameters()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
if (this.getIsRecursive() != null && !this.getIsRecursive()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
this.setPassingPrecondition(P1);
} else if (this.getIsRecursive() != null) // it's recursive.
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
} else if (this.getHasPythonSideEffects() != null) { // it has side-effects.
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS,
"Can't hybridize a function with Python side-effects.");

if (this.getIsRecursive() != null && this.getIsRecursive())
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
}
} else if (this.getLikelyHasPrimitiveParameters() != null) { // it has primitive parameters.
this.addFailure(HAS_PRIMITIVE_PARAMETERS, "Hybridizing a function with primitive parameters may induce retracing.");

if (this.getHasPythonSideEffects() != null && this.getHasPythonSideEffects())
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS,
"Can't hybridize a function with Python side-effects.");

if (this.getIsRecursive() != null && this.getIsRecursive())
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
Expand All @@ -1187,6 +1200,9 @@ public void check() {
this.addFailure(PreconditionFailure.HAS_NO_TENSOR_PARAMETERS,
"This function has no tensor parameters and may not benefit from hybridization.");

if (this.getLikelyHasPrimitiveParameters() != null && this.getLikelyHasPrimitiveParameters())
this.addFailure(HAS_PRIMITIVE_PARAMETERS, "Hybridizing a function with primitive parameters may induce retracing.");

if (this.getHasPythonSideEffects() != null && this.getHasPythonSideEffects())
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ public enum PreconditionFailure {
*/
UNDETERMINABLE_TENSOR_PARAMETER(10),

UNDETERMINABLE_PRIMITIVE_PARAMETER(11);
/**
* We need a call graph node.
*/
UNDETERMINABLE_PRIMITIVE_PARAMETER(11),

HAS_PRIMITIVE_PARAMETERS(12);

static {
// check that the codes are unique.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import tensorflow as tf


@tf.function
# @tf.function
def f(x):
return x # tf.abs(x)

# print(f(1))
# print(f(2)) # Slow - compiles new graph

print(f(1))
print(f(2)) # Slow - compiles new graph

print(f(tf.constant(1)))
print(f(tf.constant(2))) # Fast - reuses f1
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.CANT_APPROXIMATE_RECURSION;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_NO_TENSOR_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PRIMITIVE_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_TENSOR_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.IS_RECURSIVE;
Expand Down Expand Up @@ -5955,7 +5956,12 @@ public void testLikelyHasNonTensorParameter16() throws Exception {
public void testRetracing() throws Exception {
Function f = getFunction("f");
assertTrue(f.getLikelyHasTensorParameter());
// TODO.
assertTrue(f.getLikelyHasPrimitiveParameters());
assertFalse(f.getIsHybrid());
assertEquals(CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring());
assertNull(f.getPassingPrecondition());
assertNotNull(f.getEntryMatchingFailure(HAS_PRIMITIVE_PARAMETERS));
assertTrue(f.getTransformations().isEmpty());
}

/**
Expand All @@ -5965,6 +5971,13 @@ public void testRetracing() throws Exception {
public void testRetracing2() throws Exception {
Function f = getFunction("f");
assertTrue(f.getLikelyHasTensorParameter());
// TODO.
assertFalse(f.getLikelyHasPrimitiveParameters());
assertFalse(f.getIsHybrid());
assertEquals(CONVERT_EAGER_FUNCTION_TO_HYBRID, f.getRefactoring());
assertNotNull(f.getPassingPrecondition());
assertEquals(P1, f.getPassingPrecondition());
assertNull(f.getEntryMatchingFailure(HAS_PRIMITIVE_PARAMETERS));
assertFalse(f.getTransformations().isEmpty());
assertEquals(Collections.singleton(CONVERT_TO_HYBRID), f.getTransformations());
}
}

0 comments on commit f9441ef

Please sign in to comment.