Skip to content

Commit

Permalink
Take a crack at the hyb prim case.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Dec 13, 2023
1 parent c225696 commit 8bb5c8c
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package edu.cuny.hunter.hybridize.core.analysis;

import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PRIMITIVE_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P1;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P2;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P3;
import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID;
import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.OPTIMIZE_HYBRID_FUNCTION;
import static edu.cuny.hunter.hybridize.core.analysis.Transformation.CONVERT_TO_EAGER;
Expand Down Expand Up @@ -1220,12 +1222,22 @@ public void check() {
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS,
"De-hybridizing a function with Python side-effects may alter semantics.");
} else if (this.getLikelyHasTensorParameter() != null) { // it has a tensor parameter.
this.addFailure(PreconditionFailure.HAS_TENSOR_PARAMETERS,
"Functions with tensor parameters may benefit from hybreidization.");

if (this.getHasPythonSideEffects() != null && this.getHasPythonSideEffects()) {
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS,
"De-hybridizing a function with Python side-effects may alter semantics.");
// if it has primitive parameters.
if (this.getLikelyHasPrimitiveParameters() != null && this.getLikelyHasPrimitiveParameters()) {
// if it does not have side-effects.
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
this.addTransformation(CONVERT_TO_EAGER);
this.setPassingPrecondition(P3);
} else if (this.getHasPythonSideEffects() != null) // it has side-effects.
this.addFailure(HAS_PYTHON_SIDE_EFFECTS, "De-hybridizing a function with Python side-effects may alter semantics.");
} else if (this.getLikelyHasPrimitiveParameters() != null) { // no primitive parameters.
this.addFailure(PreconditionFailure.HAS_NO_PRIMITIVE_PARAMETERS,
"Functions with no Python literal arguments may benefit from hybridization.");

if (this.getHasPythonSideEffects() != null && this.getHasPythonSideEffects()) {
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS,
"De-hybridizing a function with Python side-effects may alter semantics.");
}
}

// Here, we have a hybrid function with a tensor parameter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ public enum PreconditionFailure {
*/
UNDETERMINABLE_PRIMITIVE_PARAMETER(11),

HAS_PRIMITIVE_PARAMETERS(12);
HAS_PRIMITIVE_PARAMETERS(12),

/**
* P3 failure.
*/
HAS_NO_PRIMITIVE_PARAMETERS(13);

static {
// check that the codes are unique.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package edu.cuny.hunter.hybridize.core.analysis;

public enum PreconditionSuccess {
P1, P2,
// P3,
P1, P2, P3,
// P4,
// P5
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing.

import tensorflow as tf


@tf.function
def f(x):
print(x)
return tf.abs(x)


print(f(1))
print(f(2)) # Slow - compiles new graph

print(f(tf.constant(1)))
print(f(tf.constant(2))) # Fast - reuses f1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing.

import tensorflow as tf


@tf.function
def f(x):
return tf.abs(x)


# print(f(1))
# print(f(2)) # Slow - compiles new graph

print(f(tf.constant(1)))
print(f(tf.constant(2))) # Fast - reuses f1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package edu.cuny.hunter.hybridize.tests;

import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.CANT_APPROXIMATE_RECURSION;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_NO_PRIMITIVE_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_NO_TENSOR_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PRIMITIVE_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS;
Expand All @@ -10,10 +11,12 @@
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.UNDETERMINABLE_TENSOR_PARAMETER;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P1;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P2;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P3;
import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID;
import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.OPTIMIZE_HYBRID_FUNCTION;
import static edu.cuny.hunter.hybridize.core.analysis.Transformation.CONVERT_TO_EAGER;
import static edu.cuny.hunter.hybridize.core.analysis.Transformation.CONVERT_TO_HYBRID;
import static java.util.Collections.singleton;
import static org.eclipse.core.runtime.Platform.getLog;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
Expand Down Expand Up @@ -5987,12 +5990,58 @@ public void testRetracing2() throws Exception {
@Test
public void testRetracing3() throws Exception {
Function f = getFunction("f");
assertTrue(f.getLikelyHasTensorParameter());
assertTrue(f.getLikelyHasPrimitiveParameters());
assertTrue(f.getIsHybrid());

assertTrue(f.getIsHybrid()); // hyb
assertTrue(f.getLikelyHasTensorParameter()); // T
assertTrue(f.getLikelyHasPrimitiveParameters()); // T
assertFalse(f.getHasPythonSideEffects()); // F

assertEquals(OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring());
assertNotNull(f.getPassingPrecondition());
assertEquals(P3, f.getPassingPrecondition());
assertFalse(f.getStatus().hasError());
assertNull(f.getEntryMatchingFailure(HAS_PRIMITIVE_PARAMETERS));
assertFalse(f.getTransformations().isEmpty());
assertEquals(singleton(CONVERT_TO_EAGER), f.getTransformations());
}

/**
* Test https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing,
*/
@Test
public void testRetracing4() throws Exception {
Function f = getFunction("f");

assertTrue(f.getIsHybrid()); // hyb
assertTrue(f.getLikelyHasTensorParameter()); // T
assertTrue(f.getLikelyHasPrimitiveParameters()); // T
assertTrue(f.getHasPythonSideEffects()); // T

assertEquals(OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring());
assertNull(f.getPassingPrecondition());
assertTrue(f.getStatus().hasError());
assertNull(f.getEntryMatchingFailure(HAS_PRIMITIVE_PARAMETERS));
assertNotNull(f.getEntryMatchingFailure(HAS_PYTHON_SIDE_EFFECTS));
assertTrue(f.getTransformations().isEmpty());
}

/**
* Test https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing,
*/
@Test
public void testRetracing5() throws Exception {
Function f = getFunction("f");

assertTrue(f.getIsHybrid()); // hyb
assertTrue(f.getLikelyHasTensorParameter()); // T
assertFalse(f.getLikelyHasPrimitiveParameters()); // F
assertFalse(f.getHasPythonSideEffects()); // F

assertEquals(OPTIMIZE_HYBRID_FUNCTION, f.getRefactoring());
assertNull(f.getPassingPrecondition());
assertTrue(f.getStatus().hasError());
assertNotNull(f.getEntryMatchingFailure(HAS_NO_PRIMITIVE_PARAMETERS));
assertNull(f.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS));
assertTrue(f.getTransformations().isEmpty());
}
}

0 comments on commit 8bb5c8c

Please sign in to comment.