Skip to content

Commit

Permalink
Explicitly use keras 2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588123050
  • Loading branch information
zoyahav authored and tfx-copybara committed Dec 5, 2023
1 parent a6cca79 commit e9cd22c
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 27 deletions.
9 changes: 5 additions & 4 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

* Bumped the Ubuntu version on which `tensorflow_transform` is tested to 20.04
(previously was 16.04).
* Explicitly use Keras 2 or `tf_keras`` if Keras 3 is installed.

## Breaking Changes

Expand Down Expand Up @@ -219,7 +220,7 @@
`preprocessing_fn` inputs which are not used in any TFT analyzers, but end
up in a control dependency (automatic control dependencies are not present
in TF1, hence this change will only affect the native TF2 implementation).
* Assign different resource hint tags to both orginal and cloned PTransforms
* Assign different resource hint tags to both original and cloned PTransforms
in deep copy optimization. The reason of adding these tags is to prevent
root Reads that are generated from deep copy being merged due to common
subexpression elimination.
Expand Down Expand Up @@ -746,7 +747,7 @@
`tft.AnalyzeDataset`, `tft.AnalyzeDatasetWithCache`,
`tft.AnalyzeAndTransformDataset` and `tft.TransformDataset`. The default
behavior will continue to use Tensorflow's compat.v1 APIs. This can be
overriden by setting `tft.Context.force_tf_compat_v1=False`. The default
overridden by setting `tft.Context.force_tf_compat_v1=False`. The default
behavior for TF 2 users will be switched to the new native implementation in
a future release.

Expand Down Expand Up @@ -1197,7 +1198,7 @@
* 'tft.vocabulary' and 'tft.compute_and_apply_vocabulary' now accept an
optional `weights` argument. When `weights` is provided, weighted frequencies
are used instead of frequencies based on counts.
* 'tft.quantiles' and 'tft.bucketize' now accept an optoinal `weights` argument.
* 'tft.quantiles' and 'tft.bucketize' now accept an optional `weights` argument.
When `weights` is provided, weighted count is used for quantiles instead of
the counts themselves.
* Updated examples to construct the schema using
Expand Down Expand Up @@ -1276,7 +1277,7 @@
for `TFTransformOutput.TRANSFORMED_METADATA_DIR` and
`TFTransformOutput.TRANSFORM_FN_DIR` respectively.
* `partially_apply_saved_transform` is deprecated, users should use the
`transform_raw_features` method of `TFTransformOuptut` instead. These differ
`transform_raw_features` method of `TFTransformOutput` instead. These differ
in that `partially_apply_saved_transform` can also be used to return both the
input placeholders and the outputs. But users do not need this functionality
because they will typically create the input placeholders themselves based
Expand Down
19 changes: 10 additions & 9 deletions examples/census_example_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tensorflow as tf
import tensorflow_transform as tft
import census_example_common as common
from tensorflow_transform.keras_lib import tf_keras

# Functions for training

Expand Down Expand Up @@ -193,28 +194,28 @@ def train_and_evaluate(raw_train_eval_data_path_pattern,
if isinstance(spec, tf.io.FixedLenFeature):
# TODO(b/208879020): Move into schema such that spec.shape is [1] and not
# [] for scalars.
inputs[key] = tf.keras.layers.Input(
inputs[key] = tf_keras.layers.Input(
shape=spec.shape or [1], name=key, dtype=spec.dtype)
dense_inputs[key] = inputs[key]
elif isinstance(spec, tf.io.SparseFeature):
inputs[key] = tf.keras.layers.Input(
inputs[key] = tf_keras.layers.Input(
shape=spec.size, name=key, dtype=spec.dtype, sparse=True
)
sparse_inputs[key] = inputs[key]
else:
raise ValueError('Spec type is not supported: ', key, spec)

outputs = [
tf.keras.layers.Dense(10, activation='relu')(x)
tf_keras.layers.Dense(10, activation='relu')(x)
for x in tf.nest.flatten(sparse_inputs)
]
stacked_inputs = tf.concat(tf.nest.flatten(dense_inputs) + outputs, axis=1)
output = tf.keras.layers.Dense(100, activation='relu')(stacked_inputs)
output = tf.keras.layers.Dense(70, activation='relu')(output)
output = tf.keras.layers.Dense(50, activation='relu')(output)
output = tf.keras.layers.Dense(20, activation='relu')(output)
output = tf.keras.layers.Dense(2, activation='sigmoid')(output)
model = tf.keras.Model(inputs=inputs, outputs=output)
output = tf_keras.layers.Dense(100, activation='relu')(stacked_inputs)
output = tf_keras.layers.Dense(70, activation='relu')(output)
output = tf_keras.layers.Dense(50, activation='relu')(output)
output = tf_keras.layers.Dense(20, activation='relu')(output)
output = tf_keras.layers.Dense(2, activation='sigmoid')(output)
model = tf_keras.Model(inputs=inputs, outputs=output)

model.compile(optimizer='adam',
loss='binary_crossentropy',
Expand Down
5 changes: 3 additions & 2 deletions examples/census_example_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import census_example_v2
from tensorflow_transform import test_case as tft_test_case
import local_model_server
from tensorflow_transform.keras_lib import tf_keras

from google.protobuf import text_format
from tensorflow.python import tf2 # pylint: disable=g-direct-tensorflow-import
Expand Down Expand Up @@ -168,8 +169,8 @@ def testCensusExampleAccuracy(self, read_raw_data_for_training):
census_example_common.EXPORTED_MODEL_DIR)

actual_model_path = os.path.join(model_path, '1')
tf.keras.backend.clear_session()
model = tf.keras.models.load_model(actual_model_path)
tf_keras.backend.clear_session()
model = tf_keras.models.load_model(actual_model_path)
model.summary()

example = text_format.Parse(_PREDICT_TF_EXAMPLE_TEXT_PB, tf.train.Example())
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_transform/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tensorflow as tf
from tensorflow_transform.graph_context import TFGraphContext
from tensorflow_transform.keras_lib import tf_keras
from tensorflow.python.trackable import base # pylint: disable=g-direct-tensorflow-import

__all__ = ['annotate_asset', 'make_and_track_object']
Expand Down Expand Up @@ -168,7 +169,7 @@ def make_and_track_object(trackable_factory_callable: Callable[[],
if result is None:
with tf.init_scope():
result = trackable_factory_callable()
if name is None and isinstance(result, tf.keras.layers.Layer):
if name is None and isinstance(result, tf_keras.layers.Layer):
raise ValueError(
'Please pass a unique `name` to this API to ensure Keras objects '
'are tracked correctly.')
Expand Down
13 changes: 7 additions & 6 deletions tensorflow_transform/beam/impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
from tensorflow_transform.beam import tft_unit
from tensorflow_transform.keras_lib import tf_keras
from tfx_bsl.tfxio import tensor_adapter

from google.protobuf import text_format
Expand Down Expand Up @@ -4665,12 +4666,12 @@ def testEmptySchema(self):
def testLoadKerasModelInPreprocessingFn(self):
def _create_model(features, target):
inputs = [
tf.keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features
tf_keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features
]
x = tf.keras.layers.Concatenate()(inputs)
x = tf.keras.layers.Dense(64, activation='relu')(x)
outputs = tf.keras.layers.Dense(1, activation='sigmoid', name=target)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
x = tf_keras.layers.Concatenate()(inputs)
x = tf_keras.layers.Dense(64, activation='relu')(x)
outputs = tf_keras.layers.Dense(1, activation='sigmoid', name=target)(x)
model = tf_keras.Model(inputs=inputs, outputs=outputs)
model.compile(
loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

Expand All @@ -4694,7 +4695,7 @@ def _create_model(features, target):

def preprocessing_fn(inputs):
model = tft.make_and_track_object(
lambda: tf.keras.models.load_model(keras_model_dir), name='keras')
lambda: tf_keras.models.load_model(keras_model_dir), name='keras')
return {'prediction': model(inputs)}

input_data = [{'f1': 1.0, 'f2': 0.0}, {'f1': 2.0, 'f2': 3.0}]
Expand Down
28 changes: 28 additions & 0 deletions tensorflow_transform/keras_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Imports keras 2."""
import tensorflow as tf

version_fn = getattr(tf.keras, 'version', None)
if version_fn and version_fn().startswith('3.'):
# `tf.keras` points to `keras 3`, so use `tf_keras` package
try:
import tf_keras # pylint: disable=g-import-not-at-top,unused-import
except ImportError:
raise ImportError( # pylint: disable=raise-missing-from
'Keras 2 requires the `tf_keras` package.'
'Please install it with `pip install tf_keras`.'
) from None
else:
tf_keras = tf.keras # Keras 2
11 changes: 6 additions & 5 deletions tensorflow_transform/output_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow_transform import common_types
from tensorflow_transform import graph_tools
from tensorflow_transform.analyzers import sanitized_vocab_filename
from tensorflow_transform.keras_lib import tf_keras
from tensorflow_transform.saved import saved_transform_io
from tensorflow_transform.saved import saved_transform_io_v2
from tensorflow_transform.tf_metadata import dataset_metadata
Expand Down Expand Up @@ -259,7 +260,7 @@ def num_buckets_for_transformed_feature(self, name: str) -> int:
name, domain.min))
return domain.max + 1

def transform_features_layer(self) -> tf.keras.Model:
def transform_features_layer(self) -> tf_keras.Model:
"""Creates a `TransformFeaturesLayer` from this transform output.
If a `TransformFeaturesLayer` has already been created for self, the same
Expand Down Expand Up @@ -427,8 +428,8 @@ def post_transform_statistics_path(self) -> str:


# TODO(b/162055065): Possibly switch back to inherit from Layer when possible.
@tf.keras.utils.register_keras_serializable(package='TensorFlowTransform')
class TransformFeaturesLayer(tf.keras.Model):
@tf_keras.utils.register_keras_serializable(package='TensorFlowTransform')
class TransformFeaturesLayer(tf_keras.Model):
"""A Keras layer for applying a tf.Transform output to input layers."""

def __init__(self,
Expand Down Expand Up @@ -520,10 +521,10 @@ def method_override(*args, **kwargs):
# TODO(zoyahav): Get rid of property attributes docs as well.
def _override_parent_methods(keep_items):
"""Makes inheritted attributes of the TFT layer unusable and undocumented."""
for name in dir(tf.keras.Model):
for name in dir(tf_keras.Model):
if name.startswith('_') or name in keep_items:
continue
if callable(getattr(tf.keras.Model, name)):
if callable(getattr(tf_keras.Model, name)):
setattr(TransformFeaturesLayer, name, _make_method_override(name))
elif not isinstance(getattr(TransformFeaturesLayer, name), property):
doc_controls.do_not_generate_docs(getattr(TransformFeaturesLayer, name))
Expand Down

0 comments on commit e9cd22c

Please sign in to comment.