Skip to content

Commit

Permalink
In open source, import tf_keras instead of setting TF_USE_LEGACY_KERA…
Browse files Browse the repository at this point in the history
…S=1.

With this change, other imported libraries are free to use Keras 3 while TFP uses Keras 2.

PiperOrigin-RevId: 578368750
  • Loading branch information
Googler authored and tensorflower-gardener committed Nov 1, 2023
1 parent 1bf95b0 commit ef27f46
Show file tree
Hide file tree
Showing 110 changed files with 659 additions and 833 deletions.
2 changes: 1 addition & 1 deletion STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ they supersede all previous conventions.
1. Submodule names should be singular, except where they overlap to TF.
Justification: Having plural looks strange in user code, ie,
tf.optimizer.Foo reads nicer than tf_keras.optimizers.Foo since submodules
tf.optimizer.Foo reads nicer than tf.keras.optimizers.Foo since submodules
are only used to access a single, specific thing (at a time).
1. Use `tf.newaxis` rather than `None` to `tf.expand_dims`.
Expand Down
2 changes: 1 addition & 1 deletion SUBSTRATES.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ vmap, etc.), we will special-case using an `if JAX_MODE:` block.
tests, TFP impl, etc), with `tfp.math.value_and_gradient` or similar. Then,
we can special-case `JAX_MODE` inside the body of `value_and_gradient`.

* __`tf.Variable`, `tf_keras.optimizers.Optimizer`__
* __`tf.Variable`, `tf.keras.optimizers.Optimizer`__

TF provides a `Variable` abstraction so that graph functions may modify
state, including using the Keras `Optimizer` subclasses like `Adam`. JAX,
Expand Down
1 change: 0 additions & 1 deletion tensorflow_probability/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ py_library(
# six dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:tf_keras",
],
)

Expand Down
17 changes: 8 additions & 9 deletions tensorflow_probability/examples/bayesian_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import tf_keras

tf.enable_v2_behavior()

Expand Down Expand Up @@ -175,26 +174,26 @@ def create_model():
# and two fully connected dense layers. We use the Flipout
# Monte Carlo estimator for these layers, which enables lower variance
# stochastic gradients than naive reparameterization.
model = tf_keras.models.Sequential([
model = tf.keras.models.Sequential([
tfp.layers.Convolution2DFlipout(
6, kernel_size=5, padding='SAME',
kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
tf_keras.layers.MaxPooling2D(
tf.keras.layers.MaxPooling2D(
pool_size=[2, 2], strides=[2, 2],
padding='SAME'),
tfp.layers.Convolution2DFlipout(
16, kernel_size=5, padding='SAME',
kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
tf_keras.layers.MaxPooling2D(
tf.keras.layers.MaxPooling2D(
pool_size=[2, 2], strides=[2, 2],
padding='SAME'),
tfp.layers.Convolution2DFlipout(
120, kernel_size=5, padding='SAME',
kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
tf_keras.layers.Flatten(),
tf.keras.layers.Flatten(),
tfp.layers.DenseFlipout(
84, kernel_divergence_fn=kl_divergence_function,
activation=tf.nn.relu),
Expand All @@ -204,7 +203,7 @@ def create_model():
])

# Model compilation.
optimizer = tf_keras.optimizers.Adam(lr=FLAGS.learning_rate)
optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate)
# We use the categorical_crossentropy loss since the MNIST dataset contains
# ten labels. The Keras API will then automatically add the
# Kullback-Leibler divergence (contained on the individual layers of
Expand All @@ -215,7 +214,7 @@ def create_model():
return model


class MNISTSequence(tf_keras.utils.Sequence):
class MNISTSequence(tf.keras.utils.Sequence):
"""Produces a sequence of MNIST digits with labels."""

def __init__(self, data=None, batch_size=128, fake_data_size=None):
Expand Down Expand Up @@ -273,7 +272,7 @@ def __preprocessing(images, labels):
images = 2 * (images / 255.) - 1.
images = images[..., tf.newaxis]

labels = tf_keras.utils.to_categorical(labels)
labels = tf.keras.utils.to_categorical(labels)
return images, labels

def __len__(self):
Expand All @@ -299,7 +298,7 @@ def main(argv):
heldout_seq = MNISTSequence(batch_size=FLAGS.batch_size,
fake_data_size=NUM_HELDOUT_EXAMPLES)
else:
train_set, heldout_set = tf_keras.datasets.mnist.load_data()
train_set, heldout_set = tf.keras.datasets.mnist.load_data()
train_seq = MNISTSequence(data=train_set, batch_size=FLAGS.batch_size)
heldout_seq = MNISTSequence(data=heldout_set, batch_size=FLAGS.batch_size)

Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/examples/cifar10_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
from tensorflow_probability.examples.models.bayesian_resnet import bayesian_resnet
from tensorflow_probability.examples.models.bayesian_vgg import bayesian_vgg

from tensorflow_probability.python.internal import tf_keras

matplotlib.use("Agg")
warnings.simplefilter(action="ignore")
tfd = tfp.distributions
Expand Down Expand Up @@ -171,7 +169,7 @@ def main(argv):
if FLAGS.fake_data:
(x_train, y_train), (x_test, y_test) = build_fake_data()
else:
(x_train, y_train), (x_test, y_test) = tf_keras.datasets.cifar10.load_data()
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

(images, labels, handle,
training_iterator,
Expand Down
Loading

0 comments on commit ef27f46

Please sign in to comment.