From 6866aab7af8e7898f20fa5fde5e3a39bac449e0b Mon Sep 17 00:00:00 2001 From: Tatiana Castro Velez Date: Mon, 4 Dec 2023 18:50:24 -0500 Subject: [PATCH] Adding callback test --- .../testHasLikelyTensorParameter148/in/A.py | 16 +++++++++++ .../in/requirements.txt | 1 + .../HybridizeFunctionRefactoringTest.java | 27 +++++++++++++++++++ 3 files changed, 44 insertions(+) create mode 100644 edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/A.py create mode 100644 edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/requirements.txt diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/A.py b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/A.py new file mode 100644 index 000000000..d5a43b03f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/A.py @@ -0,0 +1,16 @@ +# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/distribute/Strategy#example_usage_2. + +import tensorflow as tf + +strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) +tensor_input = tf.constant(3.0) + + +@tf.function +def replica_fn(input): + return input * 2.0 + + +# Indirect call to replica_fun(). +result = strategy.run(replica_fn, (tensor_input,)) +print(result) diff --git a/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/requirements.txt b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/requirements.txt new file mode 100644 index 000000000..b154f958f --- /dev/null +++ b/edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testHasLikelyTensorParameter148/in/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.9.3 diff --git a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java index 0fdb37cd3..c51a3abfe 100644 --- a/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java +++ b/edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java @@ -4395,6 +4395,33 @@ public void testHasLikelyTensorParameter147() throws Exception { testHasLikelyTensorParameterHelper(false, false); } + /** + * Test for https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/280. + */ + @Test + public void testHasLikelyTensorParameter148() throws Exception { + Set functions = this.getFunctions(); + assertNotNull(functions); + assertEquals(1, functions.size()); + Function function = functions.iterator().next(); + assertNotNull(function); + assertEquals(true, function.getIsHybrid()); + + argumentsType params = function.getParameters(); + + // two params. + exprType[] actualParams = params.args; + assertEquals(1, actualParams.length); + + exprType actualParameter = actualParams[0]; + assertNotNull(actualParameter); + + String paramName = NodeUtils.getRepresentationString(actualParameter); + assertEquals("input", paramName); + + assertTrue("Expecting function with likely tensor parameter.", function.getLikelyHasTensorParameter()); + } + // TODO: Test arbitrary expression. // TODO: Test cast/assert statements? // TODO: https://www.tensorflow.org/guide/function#pass_tensors_instead_of_python_literals. How do we deal with union types? Do we want