Skip to content

Commit

Permalink
In open source, import tf_keras instead of setting TF_USE_LEGACY_KERA…
Browse files Browse the repository at this point in the history
…S=1.

With this change, other imported libraries are free to use Keras 3 while TFP uses Keras 2.

PiperOrigin-RevId: 578625610
  • Loading branch information
jburnim authored and tensorflower-gardener committed Nov 1, 2023
1 parent ef27f46 commit a446360
Show file tree
Hide file tree
Showing 110 changed files with 835 additions and 659 deletions.
2 changes: 1 addition & 1 deletion STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ they supersede all previous conventions.
1. Submodule names should be singular, except where they overlap to TF.
Justification: Having plural looks strange in user code, ie,
tf.optimizer.Foo reads nicer than tf.keras.optimizers.Foo since submodules
tf.optimizer.Foo reads nicer than tf_keras.optimizers.Foo since submodules
are only used to access a single, specific thing (at a time).
1. Use `tf.newaxis` rather than `None` to `tf.expand_dims`.
Expand Down
2 changes: 1 addition & 1 deletion SUBSTRATES.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ vmap, etc.), we will special-case using an `if JAX_MODE:` block.
tests, TFP impl, etc), with `tfp.math.value_and_gradient` or similar. Then,
we can special-case `JAX_MODE` inside the body of `value_and_gradient`.

* __`tf.Variable`, `tf.keras.optimizers.Optimizer`__
* __`tf.Variable`, `tf_keras.optimizers.Optimizer`__

TF provides a `Variable` abstraction so that graph functions may modify
state, including using the Keras `Optimizer` subclasses like `Adam`. JAX,
Expand Down
1 change: 1 addition & 0 deletions tensorflow_probability/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ py_library(
# six dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:tf_keras",
],
)

Expand Down
17 changes: 9 additions & 8 deletions tensorflow_probability/examples/bayesian_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import tf_keras

tf.enable_v2_behavior()

Expand Down Expand Up @@ -174,26 +175,26 @@ def create_model():
# and two fully connected dense layers. We use the Flipout
# Monte Carlo estimator for these layers, which enables lower variance
# stochastic gradients than naive reparameterization.
model = tf.keras.models.Sequential([
model = tf_keras.models.Sequential([
tfp.layers.Convolution2DFlipout(
6, kernel_size=5, padding='SAME',
kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(
tf_keras.layers.MaxPooling2D(
pool_size=[2, 2], strides=[2, 2],
padding='SAME'),
tfp.layers.Convolution2DFlipout(
16, kernel_size=5, padding='SAME',
kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(
tf_keras.layers.MaxPooling2D(
pool_size=[2, 2], strides=[2, 2],
padding='SAME'),
tfp.layers.Convolution2DFlipout(
120, kernel_size=5, padding='SAME',
kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
tf.keras.layers.Flatten(),
tf_keras.layers.Flatten(),
tfp.layers.DenseFlipout(
84, kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
Expand All @@ -203,7 +204,7 @@ def create_model():
])

# Model compilation.
optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate)
optimizer = tf_keras.optimizers.Adam(lr=FLAGS.learning_rate)
# We use the categorical_crossentropy loss since the MNIST dataset contains
# ten labels. The Keras API will then automatically add the
# Kullback-Leibler divergence (contained on the individual layers of
Expand All @@ -214,7 +215,7 @@ def create_model():
return model


class MNISTSequence(tf.keras.utils.Sequence):
class MNISTSequence(tf_keras.utils.Sequence):
"""Produces a sequence of MNIST digits with labels."""

def __init__(self, data=None, batch_size=128, fake_data_size=None):
Expand Down Expand Up @@ -272,7 +273,7 @@ def __preprocessing(images, labels):
images = 2 * (images / 255.) - 1.
images = images[..., tf.newaxis]

labels = tf.keras.utils.to_categorical(labels)
labels = tf_keras.utils.to_categorical(labels)
return images, labels

def __len__(self):
Expand All @@ -298,7 +299,7 @@ def main(argv):
heldout_seq = MNISTSequence(batch_size=FLAGS.batch_size,
fake_data_size=NUM_HELDOUT_EXAMPLES)
else:
train_set, heldout_set = tf.keras.datasets.mnist.load_data()
train_set, heldout_set = tf_keras.datasets.mnist.load_data()
train_seq = MNISTSequence(data=train_set, batch_size=FLAGS.batch_size)
heldout_seq = MNISTSequence(data=heldout_set, batch_size=FLAGS.batch_size)

Expand Down
4 changes: 3 additions & 1 deletion tensorflow_probability/examples/cifar10_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from tensorflow_probability.examples.models.bayesian_resnet import bayesian_resnet
from tensorflow_probability.examples.models.bayesian_vgg import bayesian_vgg

from tensorflow_probability.python.internal import tf_keras

matplotlib.use("Agg")
warnings.simplefilter(action="ignore")
tfd = tfp.distributions
Expand Down Expand Up @@ -169,7 +171,7 @@ def main(argv):
if FLAGS.fake_data:
(x_train, y_train), (x_test, y_test) = build_fake_data()
else:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
(x_train, y_train), (x_test, y_test) = tf_keras.datasets.cifar10.load_data()

(images, labels, handle,
training_iterator,
Expand Down
Loading

0 comments on commit a446360

Please sign in to comment.