Skip to content

Commit

Permalink
Revert "Revert "Add code and tests for non-tensor parameters.""
Browse files Browse the repository at this point in the history
This reverts commit b4c8139.
  • Loading branch information
khatchad committed Dec 12, 2023
1 parent 9474666 commit eaccb46
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ public boolean getReduceRetracingParamExists() {
*/
private Boolean likelyHasTensorParameter;

/**
* True iff this {@link Function} has at least one parameter that is not likely a tensor.
*/
private Boolean likelyHasNonTensorParameters;

/**
* True iff this {@link Function} has Python side-effects.
*/
Expand Down Expand Up @@ -638,6 +643,7 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c

if (params != null) {
exprType[] actualParams = params.args; // FIXME: Looks like we are only considering position parameters here.

if (actualParams != null) {
for (int paramInx = 0; paramInx < actualParams.length; paramInx++) {
exprType paramExpr = actualParams[paramInx];
Expand Down Expand Up @@ -699,6 +705,13 @@ && tensorAnalysisIncludesParameterContainer(analysis, paramInx, callGraph,
continue; // next parameter.
}

// if there is at least one parameter and we haven't found a tensor parameter.
if (actualParams.length > 0) {
// then we must have encountered a "non-tensor" parameter.
this.likelyHasNonTensorParameters = TRUE;
LOG.info(this + " likely has a non-tensor parameter: " + paramName);
}

monitor.worked(1);
}
}
Expand All @@ -709,6 +722,11 @@ && tensorAnalysisIncludesParameterContainer(analysis, paramInx, callGraph,
LOG.info(this + " does not likely have a tensor parameter.");
}

if (this.likelyHasNonTensorParameters == null) {
this.likelyHasNonTensorParameters = FALSE;
LOG.info(this + " does not likely have a non-tensor parameter.");
}

monitor.done();
}

Expand Down Expand Up @@ -1352,4 +1370,13 @@ public Set<RefactoringStatusEntry> getWarnings() {
public Set<RefactoringStatusEntry> getErrors() {
return this.getRefactoringStatusEntries(RefactoringStatusEntry::isError);
}

/**
* Returns true iff this {@link Function} has at least one parameter that is likely not a tensor.
*
* @return True iff this {@link Function} has at least one parameter that is likely not a tensor.
*/
public Boolean getLikelyHasNonTensorParameters() {
return likelyHasNonTensorParameters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def f():
pass


f()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def f(a):
pass


f(5)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import tensorflow as tf


def f(a):
pass


f(tf.constant(5))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import tensorflow as tf


def f(a, b):
pass


f(tf.constant(5), 5)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
Expand Up @@ -5833,4 +5833,32 @@ public void testCallback() throws Exception {
Function f = getFunction("replica_fn");
assertTrue(f.getLikelyHasTensorParameter());
}

@Test
public void testLikelyHasNonTensorParameter() throws Exception {
Function f = getFunction("f");
assertFalse("This function has no parameters.", f.getLikelyHasTensorParameter());
assertFalse("This function has no parameters.", f.getLikelyHasNonTensorParameters());
}

@Test
public void testLikelyHasNonTensorParameter2() throws Exception {
Function f = getFunction("f");
assertFalse("This function has one parameter.", f.getLikelyHasTensorParameter());
assertTrue("This function has one parameter.", f.getLikelyHasNonTensorParameters());
}

@Test
public void testLikelyHasNonTensorParameter3() throws Exception {
Function f = getFunction("f");
assertTrue("This function has one (tensor) parameter.", f.getLikelyHasTensorParameter());
assertFalse("This function has one (tensor) parameter.", f.getLikelyHasNonTensorParameters());
}

@Test
public void testLikelyHasNonTensorParameter4() throws Exception {
Function f = getFunction("f");
assertTrue("This function has one tensor parameter and one non-tensor parameter.", f.getLikelyHasTensorParameter());
assertTrue("This function has one tensor parameter and one non-tensor parameter.", f.getLikelyHasNonTensorParameters());
}
}

0 comments on commit eaccb46

Please sign in to comment.