Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with dataset iterator processing #201

Merged
merged 3 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ public void testDataset18()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset18.py", "add", 2, 2, 2, 3);
test("tf2_test_dataset18.py", "f", 1, 1, 2);
test("tf2_test_dataset18.py", "g", 0, 2);
test("tf2_test_dataset18.py", "g", 0, 1);
}

/** Test a dataset that uses an iterator. */
Expand Down Expand Up @@ -1107,6 +1107,35 @@ public void testDataset33()
test("tf2_test_dataset33.py", "f", 1, 1, 2);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset34()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset34.py", "add", 2, 2, 2, 3);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset35()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset35.py", "add", 2, 2, 2, 3);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset36()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset36.py", "id1", 1, 1, 2);
// test("tf2_test_dataset36.py", "id2", 1, 1, 2);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset37()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset37.py", "add", 2, 2, 2, 3);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,47 +131,7 @@ private static Set<PointsToSetVariable> getDataflowSources(
if (inst instanceof SSAAbstractInvokeInstruction) {
// We potentially have a function call that generates a tensor.
SSAAbstractInvokeInstruction ni = (SSAAbstractInvokeInstruction) inst;

if (ni.getCallSite()
.getDeclaredTarget()
.getName()
.toString()
.equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)
&& ni.getException() != vn) {
sources.add(src);
logger.info("Added dataflow source from tensor generator: " + src + ".");
} else if (ni.getNumberOfUses() > 1) {
// Get the invoked function from the PA.
int target = ni.getUse(0);
PointerKey targetKey =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, target);

for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) {
if (ik instanceof ConcreteTypeKey) {
ConcreteTypeKey ctk = (ConcreteTypeKey) ik;
IClass type = ctk.getType();
TypeReference reference = type.getReference();

if (reference.equals(NEXT.getDeclaringClass())) {
// it's a call to `next()`. Look up the call to `iter()`.
int iterator = ni.getUse(1);
SSAInstruction iteratorDef = du.getDef(iterator);

// Let's see if the iterator is over a tensor dataset.
if (iteratorDef != null && iteratorDef.getNumberOfUses() > 1) {
// Get the argument.
int iterArg = iteratorDef.getUse(1);
processInstructionInterprocedurally(
iteratorDef, iterArg, localPointerKeyNode, src, sources, pointerAnalysis);
} else
// Use the original instruction. NOTE: We can only do this because `iter()` is
// currently just passing-through its argument.
processInstructionInterprocedurally(
ni, iterator, localPointerKeyNode, src, sources, pointerAnalysis);
}
}
}
}
processInstruction(ni, du, localPointerKeyNode, src, vn, sources, pointerAnalysis);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe processInvokeInstruction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. Yes, I also initially thought of that. But, then I decided to use method overloading:

  private static boolean processInstruction(
      SSAAbstractInvokeInstruction instruction,
      ...

In other words, the first parameter's type implicitly informs that an invoke instruction is being "processed." I can change it back if desired.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference is to change the method name to processInvokeInstruction for a bit more clarity (in my opinion). But it's up to you; you read this code a lot more than I do 🙂 Not a huge deal

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this is a good suggestion, but I wonder if we change this one, how many other ones are in a similar boat :).

} else if (inst instanceof EachElementGetInstruction) {
// We are potentially pulling a tensor out of a tensor iterable.
EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst;
Expand Down Expand Up @@ -215,15 +175,112 @@ private static Set<PointsToSetVariable> getDataflowSources(
} else if (def instanceof EachElementGetInstruction
|| def instanceof PythonPropertyRead
|| def instanceof PythonInvokeInstruction) {
processInstruction(
def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis);
boolean added = false;
// we may be invoking `next()` on a dataset.
if (def instanceof SSAAbstractInvokeInstruction && def.getNumberOfUses() > 1) {
SSAAbstractInvokeInstruction invokeInstruction = (SSAAbstractInvokeInstruction) def;
added =
processInstruction(
invokeInstruction,
du,
localPointerKeyNode,
src,
vn,
sources,
pointerAnalysis);
}

if (!added)
processInstruction(
def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis);
}
}
}
}
return sources;
}

/**
* Processes the given {@link SSAAbstractInvokeInstruction}, adding the given {@link PointsToSetVariable} to the given {@link Set} of {@link PointsToSetVariable}s as a dataflow source if the given {@link SSAAbstractInvokeInstruction} results in a tensor value.
*
* @param instruction The {@link SSAAbstractInvokeInstruction} to consider.
* @param du The {@link DefUse} for the given {@link SSAAbstractInvokeInstruction}.
* @param node The {@link CGNode} containing the given {@link SSAAbstractInvokeInstruction}.
* @param src The {@link PointsToSetVariable} to add to the given {@link Set} of {@link PointsToSetVariable}s if there a tensor flows from the given {@link SSAAbstractInvokeInstruction.
* @param vn The value number in the given {@link CGNode} corresponding to the given {@link PointsToSetVariable}.
* @param sources The {@link Set} of {@link PointsToSetVariable}s representing tensor dataflow sources.
* @param pointerAnalysis The {@link PointerAnalysis} for the given {@link CGNode}.
* @return True iff given the source was added to the set.
*/
private static boolean processInstruction(
SSAAbstractInvokeInstruction instruction,
DefUse du,
CGNode node,
PointsToSetVariable src,
int vn,
Set<PointsToSetVariable> sources,
PointerAnalysis<InstanceKey> pointerAnalysis) {
boolean ret = false;

// don't consider exceptions as a data source.
if (instruction.getException() != vn) {
if (instruction
.getCallSite()
.getDeclaredTarget()
.getName()
.toString()
.equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)) {
sources.add(src);
logger.info("Added dataflow source from tensor generator: " + src + ".");
ret = true;
} else if (instruction.getNumberOfUses() > 1) {
// Get the invoked function from the PA.
int target = instruction.getUse(0);
PointerKey targetKey = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, target);

for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) {
if (ik instanceof ConcreteTypeKey) {
ConcreteTypeKey ctk = (ConcreteTypeKey) ik;
IClass type = ctk.getType();
TypeReference reference = type.getReference();

if (reference.equals(NEXT.getDeclaringClass())) {
// it's a call to `next()`. Look up the iterator definition.
int iterator = instruction.getUse(1);
SSAInstruction iteratorDef = du.getDef(iterator);

// Let's see if the iterator is over a tensor dataset. First, check the iterator
// for a dataset source. NOTE: We can only do this because `iter()` is currently
// just passing-through its argument.
if (iteratorDef != null && iteratorDef.getNumberOfUses() > 1) {
boolean added =
processInstructionInterprocedurally(
iteratorDef, iteratorDef.getDef(), node, src, sources, pointerAnalysis);

ret |= added;

if (!added && iteratorDef instanceof SSAAbstractInvokeInstruction) {
// It may be a call to `iter()`. Get the argument.
int iterArg = iteratorDef.getUse(1);
ret |=
processInstructionInterprocedurally(
iteratorDef, iterArg, node, src, sources, pointerAnalysis);
}
} else
// Use the original instruction. NOTE: We can only do this because `iter()` is
// currently just passing-through its argument.
ret |=
processInstructionInterprocedurally(
instruction, iterator, node, src, sources, pointerAnalysis);
}
}
}
}
}

return ret;
}

/**
* Processes the given {@link SSAInstruction} to decide if the given {@link PointsToSetVariable}
* is added to the given {@link Set} of {@link PointsToSetVariable}s as tensor dataflow sources.
Expand Down
24 changes: 24 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset34.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tensorflow as tf


class C:

def __init__(self, some_iter):
self.some_iter = some_iter

def __str__(self):
return str(self.some_iter)


def add(a, b):
return a + b


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
my_iter = iter(dataset)
c = C(my_iter)
length = len(dataset)

for _ in range(length):
element = next(c.some_iter)
add(element, element)
19 changes: 19 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset35.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import tensorflow as tf


def add(a, b):
return a + b


def gen_iter(ds):
return iter(ds)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])

my_iter = gen_iter(dataset)
length = len(dataset)

for _ in range(length):
element = next(my_iter)
add(element, element)
34 changes: 34 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset36.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import tensorflow as tf


class C:

def __init__(self, some_iter):
self.some_iter = some_iter

def __str__(self):
return str(self.some_iter)


def id1(a):
return a


def id2(a):
return a


def gen():
yield "42", tf.constant("43")


dataset = tf.data.Dataset.from_generator(gen, output_types=(tf.string, tf.string))

my_iter = iter(dataset)
c = C(my_iter)
length = 1

for _ in range(length):
x, y = next(c.some_iter)
id1(x)
id2(y)
28 changes: 28 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset37.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tensorflow as tf


class C:

def __init__(self, some_iter):
self.some_iter = some_iter

def __str__(self):
return str(self.some_iter)


def add(a, b):
return a + b


def gen_iter(dataset):
my_iter = iter(dataset)
return C(my_iter)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
c = gen_iter(dataset)
length = len(dataset)

for _ in range(length):
element = next(c.some_iter)
add(element, element)
Loading