From 45687e7b3291e2cbbb69c74e147dec86a4a14623 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Fri, 8 Dec 2023 17:35:00 -0500 Subject: [PATCH] Handle containers in type hints. --- .../hybridize/core/analysis/Function.java | 87 +++++++++++++------ ...HybridizeFunctionRefactoringProcessor.java | 4 +- .../testHasLikelyTensorParameter156/in/A.py | 14 +++ .../in/requirements.txt | 1 + .../HybridizeFunctionRefactoringTest.java | 16 ++++ 5 files changed, 92 insertions(+), 30 deletions(-) create mode 100644 edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/A.py create mode 100644 edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/requirements.txt diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java index 7b34e7353..26004a445 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java @@ -32,10 +32,12 @@ import org.osgi.framework.FrameworkUtil; import org.python.pydev.core.IPythonNature; import org.python.pydev.core.docutils.PySelection; +import org.python.pydev.parser.jython.SimpleNode; import org.python.pydev.parser.jython.ast.Attribute; import org.python.pydev.parser.jython.ast.Call; import org.python.pydev.parser.jython.ast.FunctionDef; import org.python.pydev.parser.jython.ast.NameTok; +import org.python.pydev.parser.jython.ast.VisitorBase; import org.python.pydev.parser.jython.ast.argumentsType; import org.python.pydev.parser.jython.ast.decoratorsType; import org.python.pydev.parser.jython.ast.exprType; @@ -627,8 +629,7 @@ public TypeReference getDeclaringClass() { return TypeReference.findOrCreate(PythonTypes.pythonLoader, typeName); } - public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph callGraph, IProgressMonitor monitor) - throws BadLocationException, CantInferTensorParametersException { + public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph callGraph, IProgressMonitor monitor) throws Exception { monitor.beginTask("Analyzing whether function has a tensor parameter.", IProgressMonitor.UNKNOWN); // TODO: Use cast/assert statements? FunctionDef functionDef = this.getFunctionDefinition().getFunctionDef(); @@ -661,34 +662,36 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c LOG.info("Found type for parameter " + paramName + " in " + this + ": " + argTypeInfo.getActTok() + "."); exprType node = argTypeInfo.getNode(); - Attribute typeHintExpr = (Attribute) node; - - // Look up the definition. - IDocument document = this.getContainingDocument(); - PySelection selection = Util.getSelection(typeHintExpr.attr, document); - - String fqn; - try { - fqn = Util.getFullyQualifiedName(typeHintExpr, containingModuleName, containingFile, selection, - this.getNature(), monitor); - } catch (AmbiguousDeclaringModuleException e) { - LOG.warn(String.format( - "Can't determine FQN of type hint expression: %s in selection: %s, module: %s, file: %s, and project: %s.", - typeHintExpr, selection.getSelectedText(), containingModuleName, containingFile.getName(), - this.getProject()), e); - - monitor.worked(1); - continue; // next parameter. - } + Set allAttributes = getAllAttributes(node); + + for (Attribute typeHintExpr : allAttributes) { + // Look up the definition. + IDocument document = this.getContainingDocument(); + PySelection selection = Util.getSelection(typeHintExpr.attr, document); + + String fqn; + try { + fqn = Util.getFullyQualifiedName(typeHintExpr, containingModuleName, containingFile, selection, + this.getNature(), monitor); + } catch (AmbiguousDeclaringModuleException e) { + LOG.warn(String.format( + "Can't determine FQN of type hint expression: %s in selection: %s, module: %s, file: %s, and project: %s.", + typeHintExpr, selection.getSelectedText(), containingModuleName, containingFile.getName(), + this.getProject()), e); + + monitor.worked(1); + continue; // next parameter. + } - LOG.info("Found FQN: " + fqn + "."); + LOG.info("Found FQN: " + fqn + "."); - if (fqn.equals(TF_TENSOR_FQN)) { // TODO: Also check for subtypes. - // TODO: Also check for tensor-like stuff. - this.likelyHasTensorParameter = Boolean.TRUE; - LOG.info(this + " likely has a tensor parameter due to a type hint."); - monitor.done(); - return; + if (fqn.equals(TF_TENSOR_FQN)) { // TODO: Also check for subtypes. + // TODO: Also check for tensor-like stuff. + this.likelyHasTensorParameter = Boolean.TRUE; + LOG.info(this + " likely has a tensor parameter due to a type hint."); + monitor.done(); + return; + } } } } @@ -794,6 +797,34 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c monitor.done(); } + private static Set getAllAttributes(exprType node) throws Exception { + Set ret = Sets.newHashSet(); + + if (node instanceof Attribute) + ret.add((Attribute) node); + + node.traverse(new VisitorBase() { + + @Override + public Object visitAttribute(Attribute node) throws Exception { + ret.add(node); + return super.visitAttribute(node); + } + + @Override + protected Object unhandled_node(SimpleNode node) throws Exception { + return null; + } + + @Override + public void traverse(SimpleNode node) throws Exception { + node.traverse(this); + } + }); + + return ret; + } + private static Set getAllNewSiteReferences(int use, DefUse du) { return getAllNewSiteReferences(use, du, new HashSet<>()); } diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java index 8d8936370..477ee7d77 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/refactorings/HybridizeFunctionRefactoringProcessor.java @@ -252,12 +252,12 @@ private RefactoringStatus checkFunctions(IProgressMonitor monitor) throws Operat try { func.inferTensorTensorParameters(analysis, callGraph, subMonitor.split(IProgressMonitor.UNKNOWN)); - } catch (BadLocationException e) { - throw new RuntimeException("Could not infer tensor parameters for: " + func + ".", e); } catch (CantInferTensorParametersException e) { LOG.warn("Unable to compute whether " + func + " has tensor parameters.", e); func.addFailure(PreconditionFailure.UNDETERMINABLE_TENSOR_PARAMETER, "Can't infer tensor parameters for this function."); + } catch (Exception e) { + throw new RuntimeException("Could not infer tensor parameters for: " + func + ".", e); } // Check Python side-effects. diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/A.py new file mode 100644 index 000000000..f628931e9 --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/A.py @@ -0,0 +1,14 @@ +# From https://www.tensorflow.org/guide/function#usage. + +import tensorflow as tf + + +@tf.function(experimental_follow_type_hints=True) +def add(t: tuple[tf.Tensor, tf.Tensor]): + return t[0] + t[1] + + +arg = (2, 2) +assert type(arg) == tuple +result = add(arg) +print(result) diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter156/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java index 3bd247fd9..2e737f765 100644 --- a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java +++ b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java @@ -4540,6 +4540,22 @@ public void testHasLikelyTensorParameter155() throws Exception { assertNotNull(function.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS)); } + /** + * Test for https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/283. + */ + @Test + public void testHasLikelyTensorParameter156() throws Exception { + Function function = getFunction("add"); + + assertTrue(function.getIsHybrid()); + assertTrue(function.getLikelyHasTensorParameter()); + assertEquals(OPTIMIZE_HYBRID_FUNCTION, function.getRefactoring()); + assertNull(function.getPassingPrecondition()); + assertTrue(function.getTransformations().isEmpty()); + assertTrue(function.getStatus().hasError()); + assertNotNull(function.getEntryMatchingFailure(HAS_TENSOR_PARAMETERS)); + } + // TODO: Test arbitrary expression. // TODO: Test cast/assert statements? // TODO: https://www.tensorflow.org/guide/function#pass_tensors_instead_of_python_literals. How do we deal with union types? Do we want