Skip to content

Commit

Permalink
Handle containers in type hints.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Dec 8, 2023
1 parent cfcf75b commit 45687e7
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
import org.osgi.framework.FrameworkUtil;
import org.python.pydev.core.IPythonNature;
import org.python.pydev.core.docutils.PySelection;
import org.python.pydev.parser.jython.SimpleNode;
import org.python.pydev.parser.jython.ast.Attribute;
import org.python.pydev.parser.jython.ast.Call;
import org.python.pydev.parser.jython.ast.FunctionDef;
import org.python.pydev.parser.jython.ast.NameTok;
import org.python.pydev.parser.jython.ast.VisitorBase;
import org.python.pydev.parser.jython.ast.argumentsType;
import org.python.pydev.parser.jython.ast.decoratorsType;
import org.python.pydev.parser.jython.ast.exprType;
Expand Down Expand Up @@ -627,8 +629,7 @@ public TypeReference getDeclaringClass() {
return TypeReference.findOrCreate(PythonTypes.pythonLoader, typeName);
}

public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph callGraph, IProgressMonitor monitor)
throws BadLocationException, CantInferTensorParametersException {
public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph callGraph, IProgressMonitor monitor) throws Exception {
monitor.beginTask("Analyzing whether function has a tensor parameter.", IProgressMonitor.UNKNOWN);
// TODO: Use cast/assert statements?
FunctionDef functionDef = this.getFunctionDefinition().getFunctionDef();
Expand Down Expand Up @@ -661,34 +662,36 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c
LOG.info("Found type for parameter " + paramName + " in " + this + ": " + argTypeInfo.getActTok() + ".");

exprType node = argTypeInfo.getNode();
Attribute typeHintExpr = (Attribute) node;

// Look up the definition.
IDocument document = this.getContainingDocument();
PySelection selection = Util.getSelection(typeHintExpr.attr, document);

String fqn;
try {
fqn = Util.getFullyQualifiedName(typeHintExpr, containingModuleName, containingFile, selection,
this.getNature(), monitor);
} catch (AmbiguousDeclaringModuleException e) {
LOG.warn(String.format(
"Can't determine FQN of type hint expression: %s in selection: %s, module: %s, file: %s, and project: %s.",
typeHintExpr, selection.getSelectedText(), containingModuleName, containingFile.getName(),
this.getProject()), e);

monitor.worked(1);
continue; // next parameter.
}
Set<Attribute> allAttributes = getAllAttributes(node);

for (Attribute typeHintExpr : allAttributes) {
// Look up the definition.
IDocument document = this.getContainingDocument();
PySelection selection = Util.getSelection(typeHintExpr.attr, document);

String fqn;
try {
fqn = Util.getFullyQualifiedName(typeHintExpr, containingModuleName, containingFile, selection,
this.getNature(), monitor);
} catch (AmbiguousDeclaringModuleException e) {
LOG.warn(String.format(
"Can't determine FQN of type hint expression: %s in selection: %s, module: %s, file: %s, and project: %s.",
typeHintExpr, selection.getSelectedText(), containingModuleName, containingFile.getName(),
this.getProject()), e);

monitor.worked(1);
continue; // next parameter.
}

LOG.info("Found FQN: " + fqn + ".");
LOG.info("Found FQN: " + fqn + ".");

if (fqn.equals(TF_TENSOR_FQN)) { // TODO: Also check for subtypes.
// TODO: Also check for tensor-like stuff.
this.likelyHasTensorParameter = Boolean.TRUE;
LOG.info(this + " likely has a tensor parameter due to a type hint.");
monitor.done();
return;
if (fqn.equals(TF_TENSOR_FQN)) { // TODO: Also check for subtypes.
// TODO: Also check for tensor-like stuff.
this.likelyHasTensorParameter = Boolean.TRUE;
LOG.info(this + " likely has a tensor parameter due to a type hint.");
monitor.done();
return;
}
}
}
}
Expand Down Expand Up @@ -794,6 +797,34 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c
monitor.done();
}

private static Set<Attribute> getAllAttributes(exprType node) throws Exception {
Set<Attribute> ret = Sets.newHashSet();

if (node instanceof Attribute)
ret.add((Attribute) node);

node.traverse(new VisitorBase() {

@Override
public Object visitAttribute(Attribute node) throws Exception {
ret.add(node);
return super.visitAttribute(node);
}

@Override
protected Object unhandled_node(SimpleNode node) throws Exception {
return null;
}

@Override
public void traverse(SimpleNode node) throws Exception {
node.traverse(this);
}
});

return ret;
}

private static Set<NewSiteReference> getAllNewSiteReferences(int use, DefUse du) {
return getAllNewSiteReferences(use, du, new HashSet<>());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,12 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat

try {
func.inferTensorTensorParameters(analysis, callGraph, subMonitor.split(IProgressMonitor.UNKNOWN));
} catch (BadLocationException e) {
throw new RuntimeException("Could not infer tensor parameters for: " + func + ".", e);
} catch (CantInferTensorParametersException e) {
LOG.warn("Unable to compute whether " + func + " has tensor parameters.", e);
func.addFailure(PreconditionFailure.UNDETERMINABLE_TENSOR_PARAMETER,
"Can't infer tensor parameters for this function.");
} catch (Exception e) {
throw new RuntimeException("Could not infer tensor parameters for: " + func + ".", e);
}

// Check Python side-effects.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# From https://www.tensorflow.org/guide/function#usage.

import tensorflow as tf


@tf.function(experimental_follow_type_hints=True)
def add(t: tuple[tf.Tensor, tf.Tensor]):
return t[0] + t[1]


arg = (2, 2)
assert type(arg) == tuple
result = add(arg)
print(result)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
Expand Up @@ -4540,6 +4540,22 @@ public void testHasLikelyTensorParameter155() throws Exception {
assertNotNull(function.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS));
}

/**
* Test for https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/283.
*/
@Test
public void testHasLikelyTensorParameter156() throws Exception {
Function function = getFunction("add");

assertTrue(function.getIsHybrid());
assertTrue(function.getLikelyHasTensorParameter());
assertEquals(OPTIMIZE_HYBRID_FUNCTION, function.getRefactoring());
assertNull(function.getPassingPrecondition());
assertTrue(function.getTransformations().isEmpty());
assertTrue(function.getStatus().hasError());
assertNotNull(function.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS));
}

// TODO: Test arbitrary expression.
// TODO: Test cast/assert statements?
// TODO: https://www.tensorflow.org/guide/function#pass_tensors_instead_of_python_literals. How do we deal with union types? Do we want
Expand Down

0 comments on commit 45687e7

Please sign in to comment.