From e5ae9dc25b824cf4c067220f96b2a4e46b181240 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Mon, 11 Dec 2023 22:47:24 -0500 Subject: [PATCH] Add stuff back. --- .../hybridize/core/analysis/Function.java | 15 ++----- .../HybridizeFunctionRefactoringTest.java | 45 +++++++++++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java index 1fb15b5e..d4e977c1 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java @@ -696,20 +696,13 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c } // Check for containers of tensors. - if ((this.likelyHasTensorParameter == null || this.likelyHasTensorParameter == FALSE) - && tensorAnalysisIncludesParameterContainer(analysis, paramInx, callGraph, - monitor.slice(IProgressMonitor.UNKNOWN))) { + if (this.likelyHasTensorParameter == null && tensorAnalysisIncludesParameterContainer(analysis, paramInx, callGraph, + monitor.slice(IProgressMonitor.UNKNOWN))) { this.likelyHasTensorParameter = Boolean.TRUE; LOG.info(this + " likely has a tensor-like parameter: " + paramName + " due to tensor analysis."); - monitor.worked(1); - continue; // next parameter. - } - - // if there is at least one parameter and we haven't found a tensor parameter. - if (actualParams.length > 0) { - // then we must have encountered a "non-tensor" parameter. + } else if (this.likelyHasNonTensorParameters == null) { this.likelyHasNonTensorParameters = TRUE; - LOG.info(this + " likely has a non-tensor parameter: " + paramName); + LOG.info(this + " likely has a non-tensor-like parameter: " + paramName + " due to tensor analysis."); } monitor.worked(1); diff --git a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java index dd6a184b..880c2359 100644 --- a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java +++ b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java @@ -5861,4 +5861,49 @@ public void testLikelyHasNonTensorParameter4() throws Exception { assertTrue("This function has one tensor parameter and one non-tensor parameter.", f.getLikelyHasTensorParameter()); assertTrue("This function has one tensor parameter and one non-tensor parameter.", f.getLikelyHasNonTensorParameters()); } + + @Test + public void testLikelyHasNonTensorParameter5() throws Exception { + Function f = getFunction("f"); + assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.", + f.getLikelyHasTensorParameter()); + assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.", + f.getLikelyHasNonTensorParameters()); + } + + @Test + public void testLikelyHasNonTensorParameter6() throws Exception { + Function f = getFunction("f"); + assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.", + f.getLikelyHasTensorParameter()); + assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.", + f.getLikelyHasNonTensorParameters()); + } + + @Test + public void testLikelyHasNonTensorParameter7() throws Exception { + Function f = getFunction("f"); + assertTrue(f.getLikelyHasTensorParameter()); + assertTrue(f.getLikelyHasNonTensorParameters()); + } + + /** + * Test https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing, + */ + @Test + public void testRetracing() throws Exception { + Function f = getFunction("f"); + assertTrue(f.getLikelyHasTensorParameter()); + // TODO. + } + + /** + * Test https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing, + */ + @Test + public void testRetracing2() throws Exception { + Function f = getFunction("f"); + assertTrue(f.getLikelyHasTensorParameter()); + // TODO. + } }