diff --git a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java index ad42285d..7b34e735 100644 --- a/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java +++ b/edu.cuny.hunter.hybridize.core/src/edu/cuny/hunter/hybridize/core/analysis/Function.java @@ -51,6 +51,7 @@ import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis; import com.ibm.wala.cast.python.ml.analysis.TensorVariable; import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction; +import com.ibm.wala.cast.python.ssa.PythonPropertyWrite; import com.ibm.wala.cast.python.types.PythonTypes; import com.ibm.wala.cast.tree.CAstSourcePositionMap.Position; import com.ibm.wala.cast.types.AstMethodReference; @@ -751,33 +752,29 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c if (useNum < invokeInstruction.getNumberOfUses()) { int paramUse = invokeInstruction.getUse(useNum); DefUse du = callerOfThisFunction.getDU(); - SSAInstruction paramDef = du.getDef(paramUse); - if (paramDef != null && paramDef instanceof SSANewInstruction) { - SSANewInstruction paramNewInstruction = (SSANewInstruction) paramDef; - NewSiteReference paramNewSiteReference = paramNewInstruction.getNewSite(); + Set allNewSiteReferences = getAllNewSiteReferences(paramUse, du); - for (Pair pair : analysis) { - PointerKey pointerKey = pair.fst; + for (Pair pair : analysis) { + PointerKey pointerKey = pair.fst; - if (pointerKey instanceof InstanceFieldPointerKey) { - InstanceFieldPointerKey ifpk = (InstanceFieldPointerKey) pointerKey; - InstanceKey instanceKey = ifpk.getInstanceKey(); + if (pointerKey instanceof InstanceFieldPointerKey) { + InstanceFieldPointerKey ifpk = (InstanceFieldPointerKey) pointerKey; + InstanceKey instanceKey = ifpk.getInstanceKey(); - if (instanceKey instanceof AllocationSiteInNode) { - AllocationSiteInNode asin = (AllocationSiteInNode) instanceKey; + if (instanceKey instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = (AllocationSiteInNode) instanceKey; - if (asin.getNode().equals(callerOfThisFunction) - && asin.getSite().equals(paramNewSiteReference)) { - // We have a match. - // check the existence of the tensor variable. - assert pair.snd != null : "Tensor variable should be non-null if there is a PK."; + if (asin.getNode().equals(callerOfThisFunction) + && allNewSiteReferences.contains(asin.getSite())) { + // We have a match. + // check the existence of the tensor variable. + assert pair.snd != null : "Tensor variable should be non-null if there is a PK."; - this.likelyHasTensorParameter = Boolean.TRUE; - LOG.info(this + " likely has a tensor-like parameter due to tensor analysis."); - monitor.done(); - return; - } + this.likelyHasTensorParameter = Boolean.TRUE; + LOG.info(this + " likely has a tensor-like parameter due to tensor analysis."); + monitor.done(); + return; } } } @@ -787,8 +784,8 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c } } } - monitor.worked(1); } + monitor.worked(1); } } @@ -797,6 +794,36 @@ public void inferTensorTensorParameters(TensorTypeAnalysis analysis, CallGraph c monitor.done(); } + private static Set getAllNewSiteReferences(int use, DefUse du) { + return getAllNewSiteReferences(use, du, new HashSet<>()); + } + + private static Set getAllNewSiteReferences(int use, DefUse du, Set seen) { + Set ret = new HashSet<>(); + SSAInstruction def = du.getDef(use); + + if (def != null && def instanceof SSANewInstruction) { + SSANewInstruction newInstruction = (SSANewInstruction) def; + NewSiteReference newSite = newInstruction.getNewSite(); + ret.add(newSite); + + for (Iterator uses = du.getUses(def.getDef()); uses.hasNext();) { + SSAInstruction useInstruction = uses.next(); + + if (useInstruction instanceof PythonPropertyWrite) { + PythonPropertyWrite write = (PythonPropertyWrite) useInstruction; + + if (!seen.contains(write)) { + seen.add(write); + int value = write.getValue(); + ret.addAll(getAllNewSiteReferences(value, du, seen)); + } + } + } + } + return ret; + } + /** * Returns true iff lhsParamExpr corresponds to rhsPointerKey. *