Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Nov 21, 2023
2 parents a88b9ea + 22bdd33 commit 6b7717c
Show file tree
Hide file tree
Showing 31 changed files with 600 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package edu.cuny.hunter.hybridize.core.analysis;

public class CantComputeRecursionException extends Exception {

private static final long serialVersionUID = 6647237480647044100L;

public CantComputeRecursionException() {
}

public CantComputeRecursionException(String message) {
super(message);
}

public CantComputeRecursionException(Throwable cause) {
super(cause);
}

public CantComputeRecursionException(String message, Throwable cause) {
super(message, cause);
}

public CantComputeRecursionException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.eclipse.core.resources.IProject;
Expand Down Expand Up @@ -356,12 +357,34 @@ public boolean getReduceRetracingParamExists() {

private RefactoringStatus status = new RefactoringStatus();

private Boolean isRecursive;

private static Map<MethodReference, Map<InstanceKey, Map<CallGraph, Boolean>>> creationsCache = Maps.newHashMap();

public Function(FunctionDefinition fd) {
this.functionDefinition = fd;
}

public void computeRecursion(CallGraph callGraph) throws CantComputeRecursionException {
// Get the nodes representing this function.
Set<CGNode> nodes = this.getNodes(callGraph);

if (nodes.isEmpty())
throw new CantComputeRecursionException("Can't compute if " + this + " is recusive without a call graph node.");

CGNode cgNode = nodes.iterator().next();

if (Util.calls(cgNode, this.getMethodReference(), callGraph)) {
// it's recursive.
LOG.info(this + " is recursive.");
this.setIsRecursive(true);
} else {
// not recursive.
LOG.info(this + " is not recursive.");
this.setIsRecursive(false);
}
}

/**
* Infer Python side-effects potentially produced by executing this {@link Function}.
*
Expand Down Expand Up @@ -513,12 +536,8 @@ private static boolean allCreationsWithinClosureInteral2(MethodReference methodR
CGNode next = succNodes.next();
MethodReference reference = next.getMethod().getReference();

if (!seen.contains(reference)) {
seen.add(reference);

if (allCreationsWithinClosureInteral(reference, instanceKey, callGraph, seen))
return true;
}
if (!seen.contains(reference) && allCreationsWithinClosureInteral(reference, instanceKey, callGraph, seen))
return true;
}

return false;
Expand Down Expand Up @@ -883,16 +902,26 @@ public void check() {

if (this.getLikelyHasTensorParameter()) {
if (this.getHasPythonSideEffects() != null && !this.getHasPythonSideEffects()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
this.setPassingPrecondition(P1);
} else if (this.getHasPythonSideEffects() != null) // it has side-effects.
if (this.getIsRecursive() != null && !this.getIsRecursive()) {
this.addTransformation(Transformation.CONVERT_TO_HYBRID);
this.setPassingPrecondition(P1);
} else if (this.getIsRecursive() != null) // it's recursive.
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
} else if (this.getHasPythonSideEffects() != null) { // it has side-effects.
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects.");

if (this.getIsRecursive() != null && this.getIsRecursive())
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
}
} else { // no tensor parameters.
this.addFailure(PreconditionFailure.HAS_NO_TENSOR_PARAMETERS,
"This function has no tensor parameters and may not benefit from hybridization.");

if (this.getHasPythonSideEffects() != null && this.getHasPythonSideEffects())
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS, "Can't hybridize a function with Python side-effects.");

if (this.getIsRecursive() != null && this.getIsRecursive())
this.addFailure(PreconditionFailure.IS_RECURSIVE, "Can't hybridize a recursive function.");
}
} else { // Hybrid. Use table 2.
this.setRefactoring(OPTIMIZE_HYBRID_FUNCTION);
Expand All @@ -912,6 +941,11 @@ public void check() {
this.addFailure(PreconditionFailure.HAS_PYTHON_SIDE_EFFECTS,
"De-hybridizing a function with Python side-effects may alter semantics.");
}

// Here, we have a hybrid function with a tensor parameter.
if (this.getIsRecursive() != null && this.getIsRecursive()) // if it's recursive.
// issue a warning.
this.addWarning("Recursive tf.functions are not supported by TensorFlow.");
}

// Warn if the function has side-effects.
Expand Down Expand Up @@ -1126,4 +1160,23 @@ public static void clearCaches() {
creationsCache.clear();
}

public Boolean getIsRecursive() {
return this.isRecursive;
}

protected void setIsRecursive(Boolean isRecursive) {
this.isRecursive = isRecursive;
}

private Set<RefactoringStatusEntry> getRefactoringStatusEntries(Predicate<? super RefactoringStatusEntry> predicate) {
return Arrays.stream(this.getStatus().getEntries()).filter(predicate).collect(Collectors.toSet());
}

public Set<RefactoringStatusEntry> getWarnings() {
return this.getRefactoringStatusEntries(RefactoringStatusEntry::isWarning);
}

public Set<RefactoringStatusEntry> getErrors() {
return this.getRefactoringStatusEntries(RefactoringStatusEntry::isError);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@ public enum PreconditionFailure {

HAS_NO_TENSOR_PARAMETERS(6),

HAS_TENSOR_PARAMETERS(7);
HAS_TENSOR_PARAMETERS(7),

/**
* Functions that are recursive can't be hybridized. Also, de-hybridizing hybrid recursive functions may alter semantics.
*/
IS_RECURSIVE(8),

/**
* Can't find the CG node corresponding to the function.
*/
CANT_APPROXIMATE_RECURSION(9);

static {
// check that the codes are unique.
if (Arrays.stream(PreconditionFailure.values()).map(PreconditionFailure::getCode).distinct()
.count() != PreconditionFailure.values().length)
throw new IllegalStateException("Codes aren't unique.");
assert Arrays.stream(PreconditionFailure.values()).map(PreconditionFailure::getCode).distinct()
.count() == PreconditionFailure.values().length : "Codes must be unique.";
}

public static void main(String[] args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import java.io.File;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

import org.eclipse.core.runtime.ILog;
Expand All @@ -30,6 +31,11 @@
import org.python.pydev.parser.visitors.NodeUtils;
import org.python.pydev.shared_core.string.CoreTextSelection;

import com.google.common.collect.Sets;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.types.MethodReference;

public class Util {

private static final ILog LOG = getLog(Util.class);
Expand Down Expand Up @@ -248,4 +254,28 @@ public static boolean isBuiltIn(decoratorsType decorator) {
String decoratorRepresentation = NodeUtils.getRepresentationString(decorator.func);
return decoratorRepresentation.equals("property");
}

public static boolean calls(CGNode node, MethodReference methodReference, CallGraph callGraph) {
Set<MethodReference> seen = Sets.newHashSet();
return callsInternal(node, methodReference, callGraph, seen);
}

private static boolean callsInternal(CGNode node, MethodReference methodReference, CallGraph callGraph, Set<MethodReference> seen) {
seen.add(node.getMethod().getReference());

// check the callees.
for (Iterator<CGNode> succNodes = callGraph.getSuccNodes(node); succNodes.hasNext();) {
CGNode next = succNodes.next();
MethodReference reference = next.getMethod().getReference();

if (methodReference.equals(reference))
return true;

// otherwise, check its callees.
if (!seen.contains(reference) && calls(next, methodReference, callGraph))
return true;
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import edu.cuny.citytech.refactoring.common.core.RefactoringProcessor;
import edu.cuny.citytech.refactoring.common.core.TimeCollector;
import edu.cuny.hunter.hybridize.core.analysis.CantComputeRecursionException;
import edu.cuny.hunter.hybridize.core.analysis.Function;
import edu.cuny.hunter.hybridize.core.analysis.FunctionDefinition;
import edu.cuny.hunter.hybridize.core.analysis.PreconditionFailure;
Expand Down Expand Up @@ -101,6 +102,8 @@ private static RefactoringStatus checkParameters(Function func) {

private boolean alwaysCheckPythonSideEffects;

private boolean alwaysCheckRecursion;

private boolean processFunctionsInParallel = true;

public HybridizeFunctionRefactoringProcessor() {
Expand All @@ -121,6 +124,12 @@ public HybridizeFunctionRefactoringProcessor(boolean alwaysCheckPythonSideEffect
this.processFunctionsInParallel = processFunctionsInParallel;
}

public HybridizeFunctionRefactoringProcessor(boolean alwaysCheckPythonSideEffects, boolean processFunctionsInParallel,
boolean alwaysCheckRecusion) {
this(alwaysCheckPythonSideEffects, processFunctionsInParallel);
this.alwaysCheckRecursion = alwaysCheckRecusion;
}

public HybridizeFunctionRefactoringProcessor(Set<FunctionDefinition> functionDefinitionSet)
throws TooManyMatchesException /* FIXME: This exception sounds too low-level. */ {
this();
Expand Down Expand Up @@ -150,6 +159,13 @@ public HybridizeFunctionRefactoringProcessor(Set<FunctionDefinition> functionDef
this.processFunctionsInParallel = processFunctionsInParallel;
}

public HybridizeFunctionRefactoringProcessor(Set<FunctionDefinition> functionDefinitionSet, boolean alwaysCheckPythonSideEffects,
boolean processFunctionsInParallel, boolean alwaysCheckRecursion)
throws TooManyMatchesException /* FIXME: This exception sounds too low-level. */ {
this(functionDefinitionSet, alwaysCheckPythonSideEffects, processFunctionsInParallel);
this.alwaysCheckRecursion = alwaysCheckRecursion;
}

@Override
public RefactoringStatus checkFinalConditions(IProgressMonitor pm, CheckConditionsContext context)
throws CoreException, OperationCanceledException {
Expand Down Expand Up @@ -249,6 +265,18 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat
"Can't infer side-effects, most likely due to a call graph issue caused by a decorator or a missing function call.");
}

// Check recursion.
try {
// NOTE: Whether a hybrid function is recursive is irrelevant; if the function has no tensor parameter, de-hybridizing
// it does not violate semantics preservation as potential retracing happens regardless. We do, however, issue a
// refactoring warning when a hybrid function with a tensor parameter is recursive.
if (this.getAlwaysCheckRecursion() || func.getLikelyHasTensorParameter())
func.computeRecursion(callGraph);
} catch (CantComputeRecursionException e) {
LOG.warn("Unable to compute whether " + this + " is recursive.", e);
func.addFailure(PreconditionFailure.CANT_APPROXIMATE_RECURSION, "Can't compute whether this function is recursive.");
}

// check the function preconditions.
func.check();

Expand Down Expand Up @@ -367,6 +395,10 @@ private boolean getAlwaysCheckPythonSideEffects() {
return this.alwaysCheckPythonSideEffects;
}

public boolean getAlwaysCheckRecursion() {
return alwaysCheckRecursion;
}

@Override
public boolean isApplicable() throws CoreException {
// TODO Auto-generated method stub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ public static Refactoring createRefactoring(Set<FunctionDefinition> functionDefi
}

public static HybridizeFunctionRefactoringProcessor createHybridizeFunctionRefactoring(IProject[] projects,
boolean alwaysCheckPythonSideEffects, boolean processFunctionsInParallel)
boolean alwaysCheckPythonSideEffects, boolean processFunctionsInParallel, boolean alwaysCheckRecusion)
throws ExecutionException, CoreException, IOException {
Set<FunctionDefinition> functionDefinitions = getFunctionDefinitions(Arrays.asList(projects));
return new HybridizeFunctionRefactoringProcessor(functionDefinitions, alwaysCheckPythonSideEffects, processFunctionsInParallel);
return new HybridizeFunctionRefactoringProcessor(functionDefinitions, alwaysCheckPythonSideEffects, processFunctionsInParallel,
alwaysCheckRecusion);
}

public static Set<FunctionDefinition> getFunctionDefinitions(Iterable<?> iterable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class EvaluateHybridizeFunctionRefactoringHandler extends EvaluateRefacto

private static final String ALWAYS_CHECK_PYTHON_SIDE_EFFECTS_PROPERTY_KEY = "edu.cuny.hunter.hybridize.eval.alwaysCheckPythonSideEffects";

private static final String ALWAYS_CHECK_RECURSION_PROPERTY_KEY = "edu.cuny.hunter.hybridize.eval.alwaysCheckRecursion";

private static final String PROCESS_FUNCTIONS_IN_PARALLEL_PROPERTY_KEY = "edu.cuny.hunter.hybridize.eval.processFunctionsInParallel";

private static String[] buildAttributeColumnNames(String... additionalColumnNames) {
Expand All @@ -95,6 +97,8 @@ private static Object[] buildAttributeColumnValues(Function function, Object...

private boolean alwaysCheckPythonSideEffects = Boolean.getBoolean(ALWAYS_CHECK_PYTHON_SIDE_EFFECTS_PROPERTY_KEY);

private boolean alwaysCheckRecursion = Boolean.getBoolean(ALWAYS_CHECK_RECURSION_PROPERTY_KEY);

private boolean processFunctionsInParallel = Boolean.getBoolean(PROCESS_FUNCTIONS_IN_PARALLEL_PROPERTY_KEY);

@Override
Expand All @@ -116,7 +120,7 @@ public Object execute(ExecutionEvent event) throws ExecutionException {

try (CSVPrinter resultsPrinter = createCSVPrinter(RESULTS_CSV_FILENAME, resultsHeader.toArray(String[]::new));
CSVPrinter functionsPrinter = createCSVPrinter(FUNCTIONS_CSV_FILENAME, buildFunctionAttributeColumnNames());
CSVPrinter candidatesPrinter = createCSVPrinter(CANDIDATES_CSV_FILENAME, buildFunctionAttributeColumnNames());
CSVPrinter candidatesPrinter = createCSVPrinter(CANDIDATES_CSV_FILENAME, buildAttributeColumnNames());
CSVPrinter transformationsPrinter = createCSVPrinter(TRANSFORMATIONS_CSV_FILENAME,
buildAttributeColumnNames("transformation"));
CSVPrinter optimizableFunctionPrinter = createCSVPrinter(OPTMIZABLE_CSV_FILENAME, buildAttributeColumnNames());
Expand All @@ -138,7 +142,7 @@ public Object execute(ExecutionEvent event) throws ExecutionException {

resultsTimeCollector.start();
HybridizeFunctionRefactoringProcessor processor = createHybridizeFunctionRefactoring(new IProject[] { project },
this.getAlwaysCheckPythonSideEffects(), this.getProcessFunctionsInParallel());
this.getAlwaysCheckPythonSideEffects(), this.getProcessFunctionsInParallel(), this.getAlwaysCheckRecusion());
resultsTimeCollector.stop();

// run the precondition checking.
Expand Down Expand Up @@ -166,7 +170,7 @@ public Object execute(ExecutionEvent event) throws ExecutionException {

// candidate functions.
for (Function function : candidates) {
printFunction(candidatesPrinter, function);
candidatesPrinter.printRecord(buildAttributeColumnValues(function));

// transformations.
for (Transformation transformation : function.getTransformations())
Expand Down Expand Up @@ -265,14 +269,14 @@ private static void printStatuses(CSVPrinter printer, Collection<RefactoringStat

private static String[] buildFunctionAttributeColumnNames() {
return buildAttributeColumnNames("method reference", "type reference", "parameters", "tensor parameter", "hybrid", "side-effects",
"autograph", "experimental_autograph_options", "experimental_follow_type_hints", "experimental_implements", "func",
"input_signature", "jit_compile", "reduce_retracing", "refactoring", "passing precondition", "status");
"recursive", "autograph", "experimental_autograph_options", "experimental_follow_type_hints", "experimental_implements",
"func", "input_signature", "jit_compile", "reduce_retracing", "refactoring", "passing precondition", "status");
}

private static void printFunction(CSVPrinter printer, Function function) throws IOException {
Object[] initialColumnValues = buildAttributeColumnValues(function, function.getMethodReference(), function.getDeclaringClass(),
function.getNumberOfParameters(), function.getLikelyHasTensorParameter(), function.getIsHybrid(),
function.getHasPythonSideEffects());
function.getHasPythonSideEffects(), function.getIsRecursive());

for (Object columnValue : initialColumnValues)
printer.print(columnValue);
Expand Down Expand Up @@ -353,6 +357,10 @@ public boolean getAlwaysCheckPythonSideEffects() {
return alwaysCheckPythonSideEffects;
}

private boolean getAlwaysCheckRecusion() {
return alwaysCheckRecursion;
}

private boolean getProcessFunctionsInParallel() {
return this.processFunctionsInParallel;
}
Expand Down
Loading

0 comments on commit 6b7717c

Please sign in to comment.