Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't compute tensor parameters if the node is not in the call graph #295

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package edu.cuny.hunter.hybridize.core.analysis;

public class CantInferTensorParametersException extends Exception {

public CantInferTensorParametersException(String message) {
super(message);
}

private static final long serialVersionUID = 4856240199804486169L;

}
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,9 @@ public TypeReference getDeclaringClass() {
return TypeReference.findOrCreate(PythonTypes.pythonLoader, typeName);
}

public void inferTensorTensorParameters(TensorTypeAnalysis analysis, IProgressMonitor monitor) throws BadLocationException {
public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph callGraph, IProgressMonitor monitor)
throws BadLocationException, CantInferTensorParametersException {
monitor.beginTask("Analyzing whether function has a tensor parameter.", IProgressMonitor.UNKNOWN);
// TODO: What if there are no current calls to the function? How will we determine its type?
// TODO: Use cast/assert statements?
FunctionDef functionDef = this.getFunctionDefinition().getFunctionDef();
argumentsType params = functionDef.args;
Expand Down Expand Up @@ -686,6 +686,14 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, IProgressMo
}
}

// Is this function in the call graph?
Set<CGNode> nodes = this.getNodes(callGraph);

if (nodes.isEmpty())
// if there are no nodes representing this function, then it most likely isn't called.
throw new CantInferTensorParametersException(
"Can't infer tensor parameters for " + this + " without a call graph node.");

// Check the tensor type analysis. Check that the methods are the same, the parameters, and so on. If we match the
// pointer key, then we know it's a tensor if the TensorType is not null.
for (Pair<PointerKey, TensorVariable> pair : analysis) {
Expand Down Expand Up @@ -900,7 +908,7 @@ public void check() {
if (!this.getIsHybrid()) { // Eager. Table 1.
this.setRefactoring(CONVERT_EAGER_FUNCTION_TO_HYBRID);

if (this.getLikelyHasTensorParameter()) {
if (this.getLikelyHasTensorParameter() != null && this.getLikelyHasTensorParameter()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
if (this.getIsRecursive() != null && !this.getIsRecursive()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
Expand All @@ -926,7 +934,7 @@ public void check() {
} else { // Hybrid. Use table 2.
this.setRefactoring(OPTIMIZE_HYBRID_FUNCTION);

if (!this.getLikelyHasTensorParameter()) {
if (this.getLikelyHasTensorParameter() != null && !this.getLikelyHasTensorParameter()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
this.addTransformation(CONVERT_TO_EAGER);
this.setPassingPrecondition(P2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ public enum PreconditionFailure {
/**
* Can't find the CG node corresponding to the function.
*/
CANT_APPROXIMATE_RECURSION(9);
CANT_APPROXIMATE_RECURSION(9),

/**
* Either there is no call to the function, there is a call but don't handle it, or something about decorators?.
*/
UNDETERMINABLE_TENSOR_PARAMETER(10);

static {
// check that the codes are unique.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import edu.cuny.citytech.refactoring.common.core.RefactoringProcessor;
import edu.cuny.citytech.refactoring.common.core.TimeCollector;
import edu.cuny.hunter.hybridize.core.analysis.CantComputeRecursionException;
import edu.cuny.hunter.hybridize.core.analysis.CantInferTensorParametersException;
import edu.cuny.hunter.hybridize.core.analysis.Function;
import edu.cuny.hunter.hybridize.core.analysis.FunctionDefinition;
import edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure;
Expand Down Expand Up @@ -250,14 +251,19 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat
}

try {
func.inferTensorTensorParameters(analysis, subMonitor.split(IProgressMonitor.UNKNOWN));
func.inferTensorTensorParameters(analysis, callGraph, subMonitor.split(IProgressMonitor.UNKNOWN));
} catch (BadLocationException e) {
throw new RuntimeException("Could not infer tensor parameters for: " + func + ".", e);
} catch (CantInferTensorParametersException e) {
LOG.warn("Unable to compute whether " + func + " has tensor parameters.", e);
func.addFailure(PreconditionFailure.UNDETERMINABLE_TENSOR_PARAMETER,
"Can't infer tensor parameters for this function.");
}

// Check Python side-effects.
try {
if (this.getAlwaysCheckPythonSideEffects() || func.getIsHybrid() || func.getLikelyHasTensorParameter())
if (this.getAlwaysCheckPythonSideEffects() || func.getIsHybrid()
|| func.getLikelyHasTensorParameter() != null && func.getLikelyHasTensorParameter())
func.inferPythonSideEffects(callGraph, builder.getPointerAnalysis());
} catch (UndeterminablePythonSideEffectsException e) {
LOG.warn("Unable to infer side-effects of: " + func + ".", e);
Expand All @@ -270,7 +276,7 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat
// NOTE: Whether a hybrid function is recursive is irrelevant; if the function has no tensor parameter, de-hybridizing
// it does not violate semantics preservation as potential retracing happens regardless. We do, however, issue a
// refactoring warning when a hybrid function with a tensor parameter is recursive.
if (this.getAlwaysCheckRecursion() || func.getLikelyHasTensorParameter())
if (this.getAlwaysCheckRecursion() || func.getLikelyHasTensorParameter() != null && func.getLikelyHasTensorParameter())
func.computeRecursion(callGraph);
} catch (CantComputeRecursionException e) {
LOG.warn("Unable to compute whether " + this + " is recursive.", e);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# From https://www.tensorflow.org/guide/function#usage.

import tensorflow as tf


@tf.function # The decorator converts `add` into a `PolymorphicFunction`.
def add(a, b):
return a + b


add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# From https://www.tensorflow.org/guide/function#usage.

import tensorflow as tf


@tf.function # The decorator converts `add` into a `PolymorphicFunction`.
def add(a, b):
return a + b

# add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package edu.cuny.hunter.hybridize.tests;

import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.CANT_APPROXIMATE_RECURSION;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_NO_TENSOR_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.HAS_TENSOR_PARAMETERS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.IS_RECURSIVE;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.UNDETERMINABLE_SIDE_EFFECTS;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure.UNDETERMINABLE_TENSOR_PARAMETER;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P1;
import static edu.cuny.hunter.hybridize.core.analysis.PreconditionSuccess.P2;
import static edu.cuny.hunter.hybridize.core.analysis.Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID;
Expand Down Expand Up @@ -640,7 +643,18 @@ public void testAmbiguousDefinition() throws Exception {
for (Function function : functions) {
assertNotNull(function);
assertFalse(function.getIsHybrid());
assertFalse(function.getLikelyHasTensorParameter());

switch (function.getIdentifier()) {
case "Test.value":
case "Test.name":
assertNull(function.getLikelyHasTensorParameter());
break;
case "Test.__init__":
assertFalse(function.getLikelyHasTensorParameter());
break;
default:
throw new IllegalStateException("Unknown function: " + function + ".");
}

switch (function.getIdentifier()) {
case "Test.value":
Expand Down Expand Up @@ -4395,6 +4409,40 @@ public void testHasLikelyTensorParameter147() throws Exception {
testHasLikelyTensorParameterHelper(false, false);
}

/**
* Test for https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/294. Control case.
*/
@Test
public void testHasLikelyTensorParameter148() throws Exception {
Function function = getFunction("add");

assertTrue(function.getIsHybrid());
assertTrue(function.getLikelyHasTensorParameter());
assertEquals(OPTIMIZE_HYBRID_FUNCTION, function.getRefactoring());
assertNull(function.getPassingPrecondition());
assertTrue(function.getTransformations().isEmpty());
assertTrue(function.getStatus().hasError());
assertNotNull(function.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS));
}

/**
* Test for https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/294. No call.
*/
@Test
public void testHasLikelyTensorParameter149() throws Exception {
Function function = getFunction("add");

assertTrue(function.getIsHybrid());
assertNull(function.getLikelyHasTensorParameter());
assertEquals(OPTIMIZE_HYBRID_FUNCTION, function.getRefactoring());
assertNull(function.getPassingPrecondition());
assertTrue(function.getTransformations().isEmpty());
assertTrue(function.getStatus().hasError());
assertNotNull(function.getEntryMatchingFailure(UNDETERMINABLE_SIDE_EFFECTS));
assertNotNull(function.getEntryMatchingFailure(CANT_APPROXIMATE_RECURSION));
assertNotNull(function.getEntryMatchingFailure(UNDETERMINABLE_TENSOR_PARAMETER));
}

// TODO: Test arbitrary expression.
// TODO: Test cast/assert statements?
// TODO: https://www.tensorflow.org/guide/function#pass_tensors_instead_of_python_literals. How do we deal with union types? Do we want
Expand Down Expand Up @@ -4570,7 +4618,7 @@ public void testModel5() throws Exception {
break;
case "call":
// NOTE: Change to assertTrue once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/229 is fixed.
assertFalse("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
assertNull("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
// Can't infer side-effects here because there's no invocation of this method.
checkSideEffectStatus(f);
break;
Expand Down Expand Up @@ -4606,7 +4654,7 @@ public void testModel6() throws Exception {
break;
case "__call__":
// NOTE: Change to assertTrue once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/229 is fixed.
assertFalse("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
assertNull("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
// No invocation, so we won't be able to infer side-effects.
checkSideEffectStatus(f);
break;
Expand Down