Skip to content

Commit

Permalink
Progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Dec 7, 2023
1 parent b4955aa commit e98c862
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package edu.cuny.hunter.hybridize.core.analysis;

public class CantInferTensorParametersException extends Exception {

public CantInferTensorParametersException(String message) {
super(message);
}

private static final long serialVersionUID = 4856240199804486169L;

}
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,9 @@ public TypeReference getDeclaringClass() {
return TypeReference.findOrCreate(PythonTypes.pythonLoader, typeName);
}

public void inferTensorTensorParameters(TensorTypeAnalysis analysis, IProgressMonitor monitor) throws BadLocationException {
public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph callGraph, IProgressMonitor monitor)
throws BadLocationException, CantInferTensorParametersException {
monitor.beginTask("Analyzing whether function has a tensor parameter.", IProgressMonitor.UNKNOWN);
// TODO: What if there are no current calls to the function? How will we determine its type?
// TODO: Use cast/assert statements?
FunctionDef functionDef = this.getFunctionDefinition().getFunctionDef();
argumentsType params = functionDef.args;
Expand Down Expand Up @@ -686,6 +686,14 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, IProgressMo
}
}

// Is this function in the call graph?
Set<CGNode> nodes = this.getNodes(callGraph);

if (nodes.isEmpty())
// if there are no nodes representing this function, then it most likely isn't called.
throw new CantInferTensorParametersException(
"Can't infer tensor parameters for " + this + " without a call graph node.");

// Check the tensor type analysis. Check that the methods are the same, the parameters, and so on. If we match the
// pointer key, then we know it's a tensor if the TensorType is not null.
for (Pair<PointerKey, TensorVariable> pair : analysis) {
Expand Down Expand Up @@ -900,7 +908,7 @@ public void check() {
if (!this.getIsHybrid()) { // Eager. Table 1.
this.setRefactoring(CONVERT_EAGER_FUNCTION_TO_HYBRID);

if (this.getLikelyHasTensorParameter()) {
if (this.getLikelyHasTensorParameter() != null && this.getLikelyHasTensorParameter()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
if (this.getIsRecursive() != null && !this.getIsRecursive()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
Expand All @@ -926,7 +934,7 @@ public void check() {
} else { // Hybrid. Use table 2.
this.setRefactoring(OPTIMIZE_HYBRID_FUNCTION);

if (!this.getLikelyHasTensorParameter()) {
if (this.getLikelyHasTensorParameter() != null && !this.getLikelyHasTensorParameter()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
this.addTransformation(CONVERT_TO_EAGER);
this.setPassingPrecondition(P2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public enum PreconditionFailure {
CANT_APPROXIMATE_RECURSION(9),

/**
* Either there is no call to the function, there is a call but don't handle it, or something about decorators?.
* Either there is no call to the function, there is a call but don't handle it, or something about decorators?
*/
UNDETERMINABLE_TENSOR_PARAMETER(10);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import edu.cuny.citytech.refactoring.common.core.RefactoringProcessor;
import edu.cuny.citytech.refactoring.common.core.TimeCollector;
import edu.cuny.hunter.hybridize.core.analysis.CantComputeRecursionException;
import edu.cuny.hunter.hybridize.core.analysis.CantInferTensorParametersException;
import edu.cuny.hunter.hybridize.core.analysis.Function;
import edu.cuny.hunter.hybridize.core.analysis.FunctionDefinition;
import edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure;
Expand Down Expand Up @@ -250,14 +251,19 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat
}

try {
func.inferTensorTensorParameters(analysis, subMonitor.split(IProgressMonitor.UNKNOWN));
func.inferTensorTensorParameters(analysis, callGraph, subMonitor.split(IProgressMonitor.UNKNOWN));
} catch (BadLocationException e) {
throw new RuntimeException("Could not infer tensor parameters for: " + func + ".", e);
} catch (CantInferTensorParametersException e) {
LOG.warn("Unable to compute whether " + func + " has tensor parameters.", e);
func.addFailure(PreconditionFailure.UNDETERMINABLE_TENSOR_PARAMETER,
"Can't infer tensor parameters for this function.");
}

// Check Python side-effects.
try {
if (this.getAlwaysCheckPythonSideEffects() || func.getIsHybrid() || func.getLikelyHasTensorParameter())
if (this.getAlwaysCheckPythonSideEffects() || func.getIsHybrid()
|| func.getLikelyHasTensorParameter() != null && func.getLikelyHasTensorParameter())
func.inferPythonSideEffects(callGraph, builder.getPointerAnalysis());
} catch (UndeterminablePythonSideEffectsException e) {
LOG.warn("Unable to infer side-effects of: " + func + ".", e);
Expand All @@ -270,7 +276,7 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat
// NOTE: Whether a hybrid function is recursive is irrelevant; if the function has no tensor parameter, de-hybridizing
// it does not violate semantics preservation as potential retracing happens regardless. We do, however, issue a
// refactoring warning when a hybrid function with a tensor parameter is recursive.
if (this.getAlwaysCheckRecursion() || func.getLikelyHasTensorParameter())
if (this.getAlwaysCheckRecursion() || func.getLikelyHasTensorParameter() != null && func.getLikelyHasTensorParameter())
func.computeRecursion(callGraph);
} catch (CantComputeRecursionException e) {
LOG.warn("Unable to compute whether " + this + " is recursive.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,18 @@ public void testAmbiguousDefinition() throws Exception {
for (Function function : functions) {
assertNotNull(function);
assertFalse(function.getIsHybrid());
assertFalse(function.getLikelyHasTensorParameter());

switch (function.getIdentifier()) {
case "Test.value":
case "Test.name":
assertNull(function.getLikelyHasTensorParameter());
break;
case "Test.__init__":
assertFalse(function.getLikelyHasTensorParameter());
break;
default:
throw new IllegalStateException("Unknown function: " + function + ".");
}

switch (function.getIdentifier()) {
case "Test.value":
Expand Down Expand Up @@ -4607,7 +4618,7 @@ public void testModel5() throws Exception {
break;
case "call":
// NOTE: Change to assertTrue once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/229 is fixed.
assertFalse("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
assertNull("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
// Can't infer side-effects here because there's no invocation of this method.
checkSideEffectStatus(f);
break;
Expand Down Expand Up @@ -4643,7 +4654,7 @@ public void testModel6() throws Exception {
break;
case "__call__":
// NOTE: Change to assertTrue once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/229 is fixed.
assertFalse("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
assertNull("Expecting " + simpleName + " not to have a tensor param.", f.getLikelyHasTensorParameter());
// No invocation, so we won't be able to infer side-effects.
checkSideEffectStatus(f);
break;
Expand Down

0 comments on commit e98c862

Please sign in to comment.