Skip to content

Commit

Permalink
Handle containers of containers.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Dec 8, 2023
1 parent c1b2195 commit cfcf75b
Showing 1 changed file with 49 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
import com.ibm.wala.cast.python.ml.analysis.TensorVariable;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
import com.ibm.wala.cast.python.ssa.PythonPropertyWrite;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.tree.CAstSourcePositionMap.Position;
import com.ibm.wala.cast.types.AstMethodReference;
Expand Down Expand Up @@ -751,33 +752,29 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c
if (useNum < invokeInstruction.getNumberOfUses()) {
int paramUse = invokeInstruction.getUse(useNum);
DefUse du = callerOfThisFunction.getDU();
SSAInstruction paramDef = du.getDef(paramUse);

if (paramDef != null && paramDef instanceof SSANewInstruction) {
SSANewInstruction paramNewInstruction = (SSANewInstruction) paramDef;
NewSiteReference paramNewSiteReference = paramNewInstruction.getNewSite();
Set<NewSiteReference> allNewSiteReferences = getAllNewSiteReferences(paramUse, du);

for (Pair<PointerKey, TensorVariable> pair : analysis) {
PointerKey pointerKey = pair.fst;
for (Pair<PointerKey, TensorVariable> pair : analysis) {
PointerKey pointerKey = pair.fst;

if (pointerKey instanceof InstanceFieldPointerKey) {
InstanceFieldPointerKey ifpk = (InstanceFieldPointerKey) pointerKey;
InstanceKey instanceKey = ifpk.getInstanceKey();
if (pointerKey instanceof InstanceFieldPointerKey) {
InstanceFieldPointerKey ifpk = (InstanceFieldPointerKey) pointerKey;
InstanceKey instanceKey = ifpk.getInstanceKey();

if (instanceKey instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) instanceKey;
if (instanceKey instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) instanceKey;

if (asin.getNode().equals(callerOfThisFunction)
&& asin.getSite().equals(paramNewSiteReference)) {
// We have a match.
// check the existence of the tensor variable.
assert pair.snd != null : "Tensor variable should be non-null if there is a PK.";
if (asin.getNode().equals(callerOfThisFunction)
&& allNewSiteReferences.contains(asin.getSite())) {
// We have a match.
// check the existence of the tensor variable.
assert pair.snd != null : "Tensor variable should be non-null if there is a PK.";

this.likelyHasTensorParameter = Boolean.TRUE;
LOG.info(this + " likely has a tensor-like parameter due to tensor analysis.");
monitor.done();
return;
}
this.likelyHasTensorParameter = Boolean.TRUE;
LOG.info(this + " likely has a tensor-like parameter due to tensor analysis.");
monitor.done();
return;
}
}
}
Expand All @@ -787,8 +784,8 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c
}
}
}
monitor.worked(1);
}
monitor.worked(1);
}
}

Expand All @@ -797,6 +794,36 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c
monitor.done();
}

private static Set<NewSiteReference> getAllNewSiteReferences(int use, DefUse du) {
return getAllNewSiteReferences(use, du, new HashSet<>());
}

private static Set<NewSiteReference> getAllNewSiteReferences(int use, DefUse du, Set<PythonPropertyWrite> seen) {
Set<NewSiteReference> ret = new HashSet<>();
SSAInstruction def = du.getDef(use);

if (def != null && def instanceof SSANewInstruction) {
SSANewInstruction newInstruction = (SSANewInstruction) def;
NewSiteReference newSite = newInstruction.getNewSite();
ret.add(newSite);

for (Iterator<SSAInstruction> uses = du.getUses(def.getDef()); uses.hasNext();) {
SSAInstruction useInstruction = uses.next();

if (useInstruction instanceof PythonPropertyWrite) {
PythonPropertyWrite write = (PythonPropertyWrite) useInstruction;

if (!seen.contains(write)) {
seen.add(write);
int value = write.getValue();
ret.addAll(getAllNewSiteReferences(value, du, seen));
}
}
}
}
return ret;
}

/**
* Returns true iff lhsParamExpr corresponds to rhsPointerKey.
*
Expand Down

0 comments on commit cfcf75b

Please sign in to comment.