Skip to content

Commit

Permalink
Add tf.keras.layers.Dense.
Browse files Browse the repository at this point in the history
Doesn't always work due to wala#127.
  • Loading branch information
khatchad committed Jan 19, 2024
1 parent 44d4434 commit 243ecb9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,23 @@ public void testTf2()
testTf2("tensorflow_gan_tutorial.py", "train_step", 1, 2, 2);
testTf2("tensorflow_gan_tutorial2.py", "train_step", 1, 2, 2);
testTf2("tensorflow_eager_execution.py", "MyModel.call", 1, 1, 3);
testTf2("neural_network.py", "NeuralNet.call", 1, 5, 3);
testTf2("neural_network.py", "cross_entropy_loss", 2, 6, 2, 3);
testTf2("neural_network.py", "run_optimization", 2, 3, 2, 3);
// FIXME: This test is disabled because, currently, the number of expected tensor parameters
// differs between calls to accuracy(). They should be consistent.
// testTf2("neural_network.py", "accuracy", 2, 5, 2, 3);
testTf2("neural_network.py", "NeuralNet.call", 1, 1, 3);
testTf2(
"neural_network.py",
"cross_entropy_loss",
1,
4,
3); // NOTE: Change to 2 tensor parameters once https://github.com/wala/ML/issues/127 is
// fixed. Values 2 and 3 will correspond to the tensor parameters.
testTf2("neural_network.py", "run_optimization", 2, 2, 2, 3);
testTf2(
"neural_network.py",
"accuracy",
1,
3,
3); // NOTE: Change to 2 tensor parameters and 5 tensor variables once
// https://github.com/wala/ML/issues/127 is fixed. Values 2 and 3 will correspond to the
// tensor parameters.
}

private void testTf2(
Expand Down
30 changes: 30 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@
<putfield class="LRoot" field="Input" fieldType="LRoot" ref="keras" value="Input" />
<putfield class="LRoot" field="Input" fieldType="LRoot" ref="layers" value="Input" />

<new def="Dense" class="Ltensorflow/keras/layers/Dense" />
<putfield class="LRoot" field="Dense" fieldType="LRoot" ref="layers" value="Dense" />

<new def="Variable" class="Ltensorflow/functions/Variable" />
<putfield class="LRoot" field="Variable" fieldType="LRoot" ref="x" value="Variable" />
<putfield class="LRoot" field="Variable" fieldType="LRoot" ref="variables" value="Variable" />
Expand Down Expand Up @@ -800,6 +803,33 @@
</class>
</package>

<package name="tensorflow/keras/layers">
<class name="Dense" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/keras/layers/Dense -->
<method name="do" descriptor="()LRoot;" numArgs="11" paramNames="self units activation use_bias kernel_initializer bias_initializer kernel_regularizer bias_regularizer activity_regularizer kernel_constraint bias_constraint">
<new def="__call__" class="Ltensorflow/keras/layers/__call__" />
<putfield class="LRoot" field="__call__" fieldType="LRoot" ref="arg0" value="__call__" />
<new def="call" class="Ltensorflow/keras/layers/call" />
<putfield class="LRoot" field="call" fieldType="LRoot" ref="arg0" value="call" />
<return value="arg0" />
</method>
</class>
<!-- FIXME: These methods must be called explicitly. The implicit cases blocked on https://github.com/wala/ML/issues/127. -->
<class name="__call__" allocatable="true">
<!-- https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/layers/core/dense.py#L166-L240 -->
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="self inputs">
<return value="inputs" />
</method>
</class>
<!-- FIXME: Workaround for https://github.com/wala/ML/issues/106. -->
<class name="call" allocatable="true">
<!-- https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/layers/core/dense.py#L166-L240 -->
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="self inputs">
<return value="inputs" />
</method>
</class>
</package>

<package name="tensorflow/data">
<class name="Dataset" allocatable="true">
<!-- "read_dataset" means that this function reads a tensor iterable. -->
Expand Down

0 comments on commit 243ecb9

Please sign in to comment.