Skip to content

Commit

Permalink
Add stuff back.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Dec 12, 2023
1 parent eaccb46 commit e5ae9dc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -696,20 +696,13 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c
}

// Check for containers of tensors.
if ((this.likelyHasTensorParameter == null || this.likelyHasTensorParameter == FALSE)
&& tensorAnalysisIncludesParameterContainer(analysis, paramInx, callGraph,
monitor.slice(IProgressMonitor.UNKNOWN))) {
if (this.likelyHasTensorParameter == null && tensorAnalysisIncludesParameterContainer(analysis, paramInx, callGraph,
monitor.slice(IProgressMonitor.UNKNOWN))) {
this.likelyHasTensorParameter = Boolean.TRUE;
LOG.info(this + " likely has a tensor-like parameter: " + paramName + " due to tensor analysis.");
monitor.worked(1);
continue; // next parameter.
}

// if there is at least one parameter and we haven't found a tensor parameter.
if (actualParams.length > 0) {
// then we must have encountered a "non-tensor" parameter.
} else if (this.likelyHasNonTensorParameters == null) {
this.likelyHasNonTensorParameters = TRUE;
LOG.info(this + " likely has a non-tensor parameter: " + paramName);
LOG.info(this + " likely has a non-tensor-like parameter: " + paramName + " due to tensor analysis.");
}

monitor.worked(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5861,4 +5861,49 @@ public void testLikelyHasNonTensorParameter4() throws Exception {
assertTrue("This function has one tensor parameter and one non-tensor parameter.", f.getLikelyHasTensorParameter());
assertTrue("This function has one tensor parameter and one non-tensor parameter.", f.getLikelyHasNonTensorParameters());
}

@Test
public void testLikelyHasNonTensorParameter5() throws Exception {
Function f = getFunction("f");
assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.",
f.getLikelyHasTensorParameter());
assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.",
f.getLikelyHasNonTensorParameters());
}

@Test
public void testLikelyHasNonTensorParameter6() throws Exception {
Function f = getFunction("f");
assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.",
f.getLikelyHasTensorParameter());
assertTrue("This function has one parameter with one tensor argument and one non-tensor argument.",
f.getLikelyHasNonTensorParameters());
}

@Test
public void testLikelyHasNonTensorParameter7() throws Exception {
Function f = getFunction("f");
assertTrue(f.getLikelyHasTensorParameter());
assertTrue(f.getLikelyHasNonTensorParameters());
}

/**
* Test https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing,
*/
@Test
public void testRetracing() throws Exception {
Function f = getFunction("f");
assertTrue(f.getLikelyHasTensorParameter());
// TODO.
}

/**
* Test https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/function#retracing,
*/
@Test
public void testRetracing2() throws Exception {
Function f = getFunction("f");
assertTrue(f.getLikelyHasTensorParameter());
// TODO.
}
}

0 comments on commit e5ae9dc

Please sign in to comment.