From 3eb53b2803d283aa37c1d5daacfc3fbba460735f Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Wed, 8 Nov 2023 16:15:01 -0800 Subject: [PATCH] Fix tf-addons for upcoming keras 3 default. Keras 3.0 will become default in TF 2.16 (and is currently default in tf-nightly). This breaks this tf-addons package. Here we make minimal changes to return functionality in a backward-compatible way. --- .../image/tests/distort_image_ops_test.py | 6 +- .../discriminative_layer_training.py | 17 +- tensorflow_addons/optimizers/lazy_adam.py | 3 + tensorflow_addons/rnn/abstract_rnn_cell.py | 133 ++++++++ tensorflow_addons/rnn/esn_cell.py | 4 +- tensorflow_addons/rnn/nas_cell.py | 4 +- tensorflow_addons/seq2seq/BUILD | 1 + .../seq2seq/attention_wrapper.py | 3 +- tensorflow_addons/text/BUILD | 16 +- tensorflow_addons/text/crf.py | 3 +- tensorflow_addons/utils/test_utils.py | 10 +- tensorflow_addons/utils/tf_inspect.py | 282 +++++++++++++++++ tensorflow_addons/utils/tf_test_utils.py | 294 ++++++++++++++++++ tensorflow_addons/utils/types.py | 23 +- 14 files changed, 761 insertions(+), 38 deletions(-) create mode 100644 tensorflow_addons/rnn/abstract_rnn_cell.py create mode 100644 tensorflow_addons/utils/tf_inspect.py create mode 100644 tensorflow_addons/utils/tf_test_utils.py diff --git a/tensorflow_addons/image/tests/distort_image_ops_test.py b/tensorflow_addons/image/tests/distort_image_ops_test.py index 661c1454ab..ee7f801174 100644 --- a/tensorflow_addons/image/tests/distort_image_ops_test.py +++ b/tensorflow_addons/image/tests/distort_image_ops_test.py @@ -94,7 +94,7 @@ def test_adjust_random_hue_in_yiq(shape, style, dtype): y_np = _adjust_hue_in_yiq_np(x_np, delta_h) y_tf = _adjust_hue_in_yiq_tf(x_np, delta_h) test_utils.assert_allclose_according_to_type( - y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=0.8 + y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=1.1 ) @@ -121,11 +121,11 @@ def test_invalid_channels_hsv(): def test_adjust_hsv_in_yiq_unknown_shape(): fn = tf.function(distort_image_ops.adjust_hsv_in_yiq).get_concrete_function( - tf.TensorSpec(shape=None, dtype=tf.float64) + tf.TensorSpec(shape=None, dtype=tf.float32) ) for shape in (2, 3, 3), (4, 2, 3, 3): image_np = np.random.rand(*shape) * 255.0 - image_tf = tf.constant(image_np) + image_tf = tf.constant(image_np, dtype=tf.float32) np.testing.assert_allclose( _adjust_hue_in_yiq_np(image_np, 0), fn(image_tf), rtol=2e-4, atol=1e-4 ) diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index a82f1b2d3e..d41c5ee997 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -22,9 +22,20 @@ from tensorflow_addons.optimizers import KerasLegacyOptimizer from typeguard import typechecked -if Version(tf.__version__).release >= Version("2.13").release: - # New versions of Keras require importing from `keras.src` when - # importing internal symbols. +if Version(tf.__version__).release >= Version("2.16").release: + # Determine if loading keras 2 or 3. + if ( + hasattr(tf.keras, "version") + and Version(tf.keras.version()).release >= Version("3.0").release + ): + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src import backend + from keras.src.utils import tf_utils + else: + from tf_keras.src import backend + from tf_keras.src.utils import tf_utils +elif Version(tf.__version__).release >= Version("2.13").release: from keras.src import backend from keras.src.utils import tf_utils else: diff --git a/tensorflow_addons/optimizers/lazy_adam.py b/tensorflow_addons/optimizers/lazy_adam.py index ad8570bc3c..b09e4e96ad 100644 --- a/tensorflow_addons/optimizers/lazy_adam.py +++ b/tensorflow_addons/optimizers/lazy_adam.py @@ -149,3 +149,6 @@ def _resource_scatter_operate(self, resource, indices, update, resource_scatter_ } return resource_scatter_op(**resource_update_kwargs) + + def get_config(self): + return super().get_config() diff --git a/tensorflow_addons/rnn/abstract_rnn_cell.py b/tensorflow_addons/rnn/abstract_rnn_cell.py new file mode 100644 index 0000000000..de5225bf76 --- /dev/null +++ b/tensorflow_addons/rnn/abstract_rnn_cell.py @@ -0,0 +1,133 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for RNN cells. + +Adapted from legacy github.com/keras-team/tf-keras. +""" + +import tensorflow as tf + + +def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): + if inputs is not None: + batch_size = tf.shape(inputs)[0] + dtype = inputs.dtype + return _generate_zero_filled_state(batch_size, cell.state_size, dtype) + + +def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): + """Generate a zero filled tensor with shape [batch_size, state_size].""" + if batch_size_tensor is None or dtype is None: + raise ValueError( + "batch_size and dtype cannot be None while constructing initial state: " + "batch_size={}, dtype={}".format(batch_size_tensor, dtype) + ) + + def create_zeros(unnested_state_size): + flat_dims = tf.TensorShape(unnested_state_size).as_list() + init_state_size = [batch_size_tensor] + flat_dims + return tf.zeros(init_state_size, dtype=dtype) + + if tf.nest.is_nested(state_size): + return tf.nest.map_structure(create_zeros, state_size) + else: + return create_zeros(state_size) + + +class AbstractRNNCell(tf.keras.layers.Layer): + """Abstract object representing an RNN cell. + + This is a base class for implementing RNN cells with custom behavior. + + Every `RNNCell` must have the properties below and implement `call` with + the signature `(output, next_state) = call(input, state)`. + + Examples: + + ```python + class MinimalRNNCell(AbstractRNNCell): + + def __init__(self, units, **kwargs): + self.units = units + super(MinimalRNNCell, self).__init__(**kwargs) + + @property + def state_size(self): + return self.units + + def build(self, input_shape): + self.kernel = self.add_weight(shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.built = True + + def call(self, inputs, states): + prev_output = states[0] + h = backend.dot(inputs, self.kernel) + output = h + backend.dot(prev_output, self.recurrent_kernel) + return output, output + ``` + + This definition of cell differs from the definition used in the literature. + In the literature, 'cell' refers to an object with a single scalar output. + This definition refers to a horizontal array of such units. + + An RNN cell, in the most abstract setting, is anything that has + a state and performs some operation that takes a matrix of inputs. + This operation results in an output matrix with `self.output_size` columns. + If `self.state_size` is an integer, this operation also results in a new + state matrix with `self.state_size` columns. If `self.state_size` is a + (possibly nested tuple of) TensorShape object(s), then it should return a + matching structure of Tensors having shape `[batch_size].concatenate(s)` + for each `s` in `self.batch_size`. + """ + + def call(self, inputs, states): + """The function that contains the logic for one RNN step calculation. + + Args: + inputs: the input tensor, which is a slide from the overall RNN input by + the time dimension (usually the second dimension). + states: the state tensor from previous step, which has the same shape + as `(batch, state_size)`. In the case of timestep 0, it will be the + initial state user specified, or zero filled tensor otherwise. + + Returns: + A tuple of two tensors: + 1. output tensor for the current timestep, with size `output_size`. + 2. state tensor for next step, which has the shape of `state_size`. + """ + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + + It can be represented by an Integer, a TensorShape or a tuple of Integers + or TensorShapes. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) diff --git a/tensorflow_addons/rnn/esn_cell.py b/tensorflow_addons/rnn/esn_cell.py index 835da96e98..2147c07de8 100644 --- a/tensorflow_addons/rnn/esn_cell.py +++ b/tensorflow_addons/rnn/esn_cell.py @@ -15,9 +15,9 @@ """Implements ESN Cell.""" import tensorflow as tf -import tensorflow.keras as keras from typeguard import typechecked +from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell from tensorflow_addons.utils.types import ( Activation, Initializer, @@ -25,7 +25,7 @@ @tf.keras.utils.register_keras_serializable(package="Addons") -class ESNCell(keras.layers.AbstractRNNCell): +class ESNCell(AbstractRNNCell): """Echo State recurrent Network (ESN) cell. This implements the recurrent cell from the paper: H. Jaeger diff --git a/tensorflow_addons/rnn/nas_cell.py b/tensorflow_addons/rnn/nas_cell.py index ce6ca766ce..f5304d1c12 100644 --- a/tensorflow_addons/rnn/nas_cell.py +++ b/tensorflow_addons/rnn/nas_cell.py @@ -15,9 +15,9 @@ """Implements NAS Cell.""" import tensorflow as tf -import tensorflow.keras as keras from typeguard import typechecked +from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell from tensorflow_addons.utils.types import ( FloatTensorLike, TensorLike, @@ -27,7 +27,7 @@ @tf.keras.utils.register_keras_serializable(package="Addons") -class NASCell(keras.layers.AbstractRNNCell): +class NASCell(AbstractRNNCell): """Neural Architecture Search (NAS) recurrent network cell. This implements the recurrent cell from the paper: diff --git a/tensorflow_addons/seq2seq/BUILD b/tensorflow_addons/seq2seq/BUILD index 0674740e58..8f7b8470b3 100644 --- a/tensorflow_addons/seq2seq/BUILD +++ b/tensorflow_addons/seq2seq/BUILD @@ -10,6 +10,7 @@ py_library( "//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so", ], deps = [ + "//tensorflow_addons/rnn", "//tensorflow_addons/testing", "//tensorflow_addons/utils", ], diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index b1b6f93f2d..d44cbe6a47 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -23,6 +23,7 @@ import tensorflow as tf +from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell from tensorflow_addons.utils import keras_utils from tensorflow_addons.utils.types import ( AcceptableDTypes, @@ -1577,7 +1578,7 @@ def _compute_attention( return attention, alignments, next_attention_state -class AttentionWrapper(tf.keras.layers.AbstractRNNCell): +class AttentionWrapper(AbstractRNNCell): """Wraps another RNN cell with attention. Example: diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD index ae4005d391..79afb5637f 100644 --- a/tensorflow_addons/text/BUILD +++ b/tensorflow_addons/text/BUILD @@ -7,17 +7,15 @@ package(default_visibility = ["//visibility:public"]) py_library( name = "text", srcs = glob(["*.py"]), - data = select({ - "//tensorflow_addons:windows": [ - "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so", - "//tensorflow_addons/testing", - "//tensorflow_addons/utils", - ], + data = [ + "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so", + "//tensorflow_addons/rnn", + "//tensorflow_addons/testing", + "//tensorflow_addons/utils", + ] + select({ + "//tensorflow_addons:windows": [], "//conditions:default": [ "//tensorflow_addons/custom_ops/text:_parse_time_op.so", - "//tensorflow_addons/custom_ops/text:_skip_gram_ops.so", - "//tensorflow_addons/testing", - "//tensorflow_addons/utils", ], }), ) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 3820b08a94..287481e546 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -17,6 +17,7 @@ import numpy as np import tensorflow as tf +from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell from tensorflow_addons.utils.types import TensorLike from typeguard import typechecked from typing import Optional, Tuple @@ -403,7 +404,7 @@ def viterbi_decode(score: TensorLike, transition_params: TensorLike) -> tf.Tenso return viterbi, viterbi_score -class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell): +class CrfDecodeForwardRnnCell(AbstractRNNCell): """Computes the forward decoding in a linear-chain CRF.""" @typechecked diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py index f998fb4a45..31a43a5536 100644 --- a/tensorflow_addons/utils/test_utils.py +++ b/tensorflow_addons/utils/test_utils.py @@ -22,18 +22,10 @@ import pytest import tensorflow as tf -from packaging.version import Version from tensorflow_addons import options from tensorflow_addons.utils import resource_loader -if Version(tf.__version__).release >= Version("2.13").release: - # New versions of Keras require importing from `keras.src` when - # importing internal symbols. - from keras.src.testing_infra.test_utils import layer_test # noqa: F401 -elif Version(tf.__version__) >= Version("2.9"): - from keras.testing_infra.test_utils import layer_test # noqa: F401 -else: - from keras.testing_utils import layer_test # noqa: F401 +from tensorflow_addons.utils.tf_test_utils import layer_test # noqa NUMBER_OF_WORKERS = int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1")) WORKER_ID = int(os.environ.get("PYTEST_XDIST_WORKER", "gw0")[2]) diff --git a/tensorflow_addons/utils/tf_inspect.py b/tensorflow_addons/utils/tf_inspect.py new file mode 100644 index 0000000000..8ca132091a --- /dev/null +++ b/tensorflow_addons/utils/tf_inspect.py @@ -0,0 +1,282 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFDecorator-aware replacements for the inspect module.""" +import collections +import functools +import inspect as _inspect + +import tensorflow as tf + +if hasattr(_inspect, "ArgSpec"): + ArgSpec = _inspect.ArgSpec +else: + ArgSpec = collections.namedtuple( + "ArgSpec", + [ + "args", + "varargs", + "keywords", + "defaults", + ], + ) + +if hasattr(_inspect, "FullArgSpec"): + FullArgSpec = _inspect.FullArgSpec +else: + FullArgSpec = collections.namedtuple( + "FullArgSpec", + [ + "args", + "varargs", + "varkw", + "defaults", + "kwonlyargs", + "kwonlydefaults", + "annotations", + ], + ) + + +def _convert_maybe_argspec_to_fullargspec(argspec): + if isinstance(argspec, FullArgSpec): + return argspec + return FullArgSpec( + args=argspec.args, + varargs=argspec.varargs, + varkw=argspec.keywords, + defaults=argspec.defaults, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ) + + +if hasattr(_inspect, "getfullargspec"): + _getfullargspec = _inspect.getfullargspec + + def _getargspec(target): + """A python3 version of getargspec. + + Calls `getfullargspec` and assigns args, varargs, + varkw, and defaults to a python 2/3 compatible `ArgSpec`. + + The parameter name 'varkw' is changed to 'keywords' to fit the + `ArgSpec` struct. + + Args: + target: the target object to inspect. + + Returns: + An ArgSpec with args, varargs, keywords, and defaults parameters + from FullArgSpec. + """ + fullargspecs = getfullargspec(target) + argspecs = ArgSpec( + args=fullargspecs.args, + varargs=fullargspecs.varargs, + keywords=fullargspecs.varkw, + defaults=fullargspecs.defaults, + ) + return argspecs + +else: + _getargspec = _inspect.getargspec + + def _getfullargspec(target): + """A python2 version of getfullargspec. + + Args: + target: the target object to inspect. + + Returns: + A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations. + """ + return _convert_maybe_argspec_to_fullargspec(getargspec(target)) + + +def currentframe(): + """TFDecorator-aware replacement for inspect.currentframe.""" + return _inspect.stack()[1][0] + + +def getargspec(obj): + """TFDecorator-aware replacement for `inspect.getargspec`. + + Note: `getfullargspec` is recommended as the python 2/3 compatible + replacement for this function. + + Args: + obj: A function, partial function, or callable object, possibly decorated. + + Returns: + The `ArgSpec` that describes the signature of the outermost decorator that + changes the callable's signature, or the `ArgSpec` that describes + the object if not decorated. + + Raises: + ValueError: When callable's signature can not be expressed with + ArgSpec. + TypeError: For objects of unsupported types. + """ + if isinstance(obj, functools.partial): + return _get_argspec_for_partial(obj) + + decorators, target = tf.__internal__.decorator.unwrap(obj) + + spec = next( + (d.decorator_argspec for d in decorators if d.decorator_argspec is not None), + None, + ) + if spec: + return spec + + try: + # Python3 will handle most callables here (not partial). + return _getargspec(target) + except TypeError: + pass + + if isinstance(target, type): + try: + return _getargspec(target.__init__) + except TypeError: + pass + + try: + return _getargspec(target.__new__) + except TypeError: + pass + + # The `type(target)` ensures that if a class is received we don't return + # the signature of its __call__ method. + return _getargspec(type(target).__call__) + + +def _get_argspec_for_partial(obj): + """Implements `getargspec` for `functools.partial` objects. + + Args: + obj: The `functools.partial` object + Returns: + An `inspect.ArgSpec` + Raises: + ValueError: When callable's signature can not be expressed with + ArgSpec. + """ + # When callable is a functools.partial object, we construct its ArgSpec with + # following strategy: + # - If callable partial contains default value for positional arguments (ie. + # object.args), then final ArgSpec doesn't contain those positional + # arguments. + # - If callable partial contains default value for keyword arguments (ie. + # object.keywords), then we merge them with wrapped target. Default values + # from callable partial takes precedence over those from wrapped target. + # + # However, there is a case where it is impossible to construct a valid + # ArgSpec. Python requires arguments that have no default values must be + # defined before those with default values. ArgSpec structure is only valid + # when this presumption holds true because default values are expressed as a + # tuple of values without keywords and they are always assumed to belong to + # last K arguments where K is number of default values present. + # + # Since functools.partial can give default value to any argument, this + # presumption may no longer hold in some cases. For example: + # + # def func(m, n): + # return 2 * m + n + # partialed = functools.partial(func, m=1) + # + # This example will result in m having a default value but n doesn't. This + # is usually not allowed in Python and can not be expressed in ArgSpec + # correctly. + # + # Thus, we must detect cases like this by finding first argument with + # default value and ensures all following arguments also have default + # values. When this is not true, a ValueError is raised. + + n_prune_args = len(obj.args) + partial_keywords = obj.keywords or {} + + args, varargs, keywords, defaults = getargspec(obj.func) + + # Pruning first n_prune_args arguments. + args = args[n_prune_args:] + + # Partial function may give default value to any argument, therefore length + # of default value list must be len(args) to allow each argument to + # potentially be given a default value. + no_default = object() + all_defaults = [no_default] * len(args) + + if defaults: + all_defaults[-len(defaults) :] = defaults + + # Fill in default values provided by partial function in all_defaults. + for kw, default in partial_keywords.items(): + if kw in args: + idx = args.index(kw) + all_defaults[idx] = default + elif not keywords: + raise ValueError( + "Function does not have **kwargs parameter, but " + "contains an unknown partial keyword." + ) + + # Find first argument with default value set. + first_default = next( + (idx for idx, x in enumerate(all_defaults) if x is not no_default), None + ) + + # If no default values are found, return ArgSpec with defaults=None. + if first_default is None: + return ArgSpec(args, varargs, keywords, None) + + # Checks if all arguments have default value set after first one. + invalid_default_values = [ + args[i] + for i, j in enumerate(all_defaults) + if j is no_default and i > first_default + ] + + if invalid_default_values: + raise ValueError( + f"Some arguments {invalid_default_values} do not have " + "default value, but they are positioned after those with " + "default values. This can not be expressed with ArgSpec." + ) + + return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:])) + + +def getfullargspec(obj): + """TFDecorator-aware replacement for `inspect.getfullargspec`. + + This wrapper emulates `inspect.getfullargspec` in[^)]* Python2. + + Args: + obj: A callable, possibly decorated. + + Returns: + The `FullArgSpec` that describes the signature of + the outermost decorator that changes the callable's signature. If the + callable is not decorated, `inspect.getfullargspec()` will be called + directly on the callable. + """ + decorators, target = tf.__internal__.decorator.unwrap(obj) + + for d in decorators: + if d.decorator_argspec is not None: + return _convert_maybe_argspec_to_fullargspec(d.decorator_argspec) + return _getfullargspec(target) diff --git a/tensorflow_addons/utils/tf_test_utils.py b/tensorflow_addons/utils/tf_test_utils.py new file mode 100644 index 0000000000..01b88e4807 --- /dev/null +++ b/tensorflow_addons/utils/tf_test_utils.py @@ -0,0 +1,294 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for unit-testing TF-Keras.""" + + +import threading + +import numpy as np +import tensorflow as tf + +from tensorflow.keras import backend +from tensorflow.keras import layers +from tensorflow.keras import models +from tensorflow_addons.utils import tf_inspect + + +def string_test(actual, expected): + np.testing.assert_array_equal(actual, expected) + + +def numeric_test(actual, expected): + np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6) + + +def layer_test( + layer_cls, + kwargs=None, + input_shape=None, + input_dtype=None, + input_data=None, + expected_output=None, + expected_output_dtype=None, + expected_output_shape=None, + validate_training=True, + adapt_data=None, + custom_objects=None, + test_harness=None, + supports_masking=None, +): + """Test routine for a layer with a single input and single output. + + Args: + layer_cls: Layer class object. + kwargs: Optional dictionary of keyword arguments for instantiating the + layer. + input_shape: Input shape tuple. + input_dtype: Data type of the input data. + input_data: Numpy array of input data. + expected_output: Numpy array of the expected output. + expected_output_dtype: Data type expected for the output. + expected_output_shape: Shape tuple for the expected shape of the output. + validate_training: Whether to attempt to validate training on this layer. + This might be set to False for non-differentiable layers that output + string or integer values. + adapt_data: Optional data for an 'adapt' call. If None, adapt() will not + be tested for this layer. This is only relevant for PreprocessingLayers. + custom_objects: Optional dictionary mapping name strings to custom objects + in the layer class. This is helpful for testing custom layers. + test_harness: The Tensorflow test, if any, that this function is being + called in. + supports_masking: Optional boolean to check the `supports_masking` + property of the layer. If None, the check will not be performed. + + Returns: + The output data (Numpy array) returned by the layer, for additional + checks to be done by the calling code. + + Raises: + ValueError: if `input_shape is None`. + """ + if input_data is None: + if input_shape is None: + raise ValueError("input_shape is None") + if not input_dtype: + input_dtype = "float32" + input_data_shape = list(input_shape) + for i, e in enumerate(input_data_shape): + if e is None: + input_data_shape[i] = np.random.randint(1, 4) + input_data = 10 * np.random.random(input_data_shape) + if input_dtype[:5] == "float": + input_data -= 0.5 + input_data = input_data.astype(input_dtype) + elif input_shape is None: + input_shape = input_data.shape + if input_dtype is None: + input_dtype = input_data.dtype + if expected_output_dtype is None: + expected_output_dtype = input_dtype + + if tf.as_dtype(expected_output_dtype) == tf.string: + if test_harness: + assert_equal = test_harness.assertAllEqual + else: + assert_equal = string_test + else: + if test_harness: + assert_equal = test_harness.assertAllClose + else: + assert_equal = numeric_test + + # instantiation + kwargs = kwargs or {} + layer = layer_cls(**kwargs) + + if supports_masking is not None and layer.supports_masking != supports_masking: + raise AssertionError( + "When testing layer %s, the `supports_masking` property is %r" + "but expected to be %r.\nFull kwargs: %s" + % ( + layer_cls.__name__, + layer.supports_masking, + supports_masking, + kwargs, + ) + ) + + # Test adapt, if data was passed. + if adapt_data is not None: + layer.adapt(adapt_data) + + # test get_weights , set_weights at layer level + weights = layer.get_weights() + layer.set_weights(weights) + + # test and instantiation from weights + if "weights" in tf_inspect.getargspec(layer_cls.__init__): + kwargs["weights"] = weights + layer = layer_cls(**kwargs) + + # test in functional API + x = layers.Input(shape=input_shape[1:], dtype=input_dtype) + y = layer(x) + if backend.dtype(y) != expected_output_dtype: + raise AssertionError( + "When testing layer %s, for input %s, found output " + "dtype=%s but expected to find %s.\nFull kwargs: %s" + % ( + layer_cls.__name__, + x, + backend.dtype(y), + expected_output_dtype, + kwargs, + ) + ) + + def assert_shapes_equal(expected, actual): + """Asserts that the output shape from the layer matches the actual + shape.""" + if len(expected) != len(actual): + raise AssertionError( + "When testing layer %s, for input %s, found output_shape=" + "%s but expected to find %s.\nFull kwargs: %s" + % (layer_cls.__name__, x, actual, expected, kwargs) + ) + + for expected_dim, actual_dim in zip(expected, actual): + if isinstance(expected_dim, tf.compat.v1.Dimension): + expected_dim = expected_dim.value + if isinstance(actual_dim, tf.compat.v1.Dimension): + actual_dim = actual_dim.value + if expected_dim is not None and expected_dim != actual_dim: + raise AssertionError( + "When testing layer %s, for input %s, found output_shape=" + "%s but expected to find %s.\nFull kwargs: %s" + % (layer_cls.__name__, x, actual, expected, kwargs) + ) + + if expected_output_shape is not None: + assert_shapes_equal(tf.TensorShape(expected_output_shape), y.shape) + + # check shape inference + model = models.Model(x, y) + computed_output_shape = tuple( + layer.compute_output_shape(tf.TensorShape(input_shape)).as_list() + ) + computed_output_signature = layer.compute_output_signature( + tf.TensorSpec(shape=input_shape, dtype=input_dtype) + ) + actual_output = model.predict(input_data) + actual_output_shape = actual_output.shape + assert_shapes_equal(computed_output_shape, actual_output_shape) + assert_shapes_equal(computed_output_signature.shape, actual_output_shape) + if computed_output_signature.dtype != actual_output.dtype: + raise AssertionError( + "When testing layer %s, for input %s, found output_dtype=" + "%s but expected to find %s.\nFull kwargs: %s" + % ( + layer_cls.__name__, + x, + actual_output.dtype, + computed_output_signature.dtype, + kwargs, + ) + ) + if expected_output is not None: + assert_equal(actual_output, expected_output) + + # test serialization, weight setting at model level + model_config = model.get_config() + recovered_model = models.Model.from_config(model_config, custom_objects) + if model.weights: + weights = model.get_weights() + recovered_model.set_weights(weights) + output = recovered_model.predict(input_data) + assert_equal(output, actual_output) + + # test training mode (e.g. useful for dropout tests) + # Rebuild the model to avoid the graph being reused between predict() and + # See b/120160788 for more details. This should be mitigated after 2.0. + layer_weights = layer.get_weights() # Get the layer weights BEFORE training. + if validate_training: + model = models.Model(x, layer(x)) + if _thread_local_data.run_eagerly is not None: + model.compile( + "rmsprop", + "mse", + weighted_metrics=["acc"], + run_eagerly=should_run_eagerly(), + ) + else: + model.compile("rmsprop", "mse", weighted_metrics=["acc"]) + model.train_on_batch(input_data, actual_output) + + # test as first layer in Sequential API + layer_config = layer.get_config() + layer_config["batch_input_shape"] = input_shape + layer = layer.__class__.from_config(layer_config) + + # Test adapt, if data was passed. + if adapt_data is not None: + layer.adapt(adapt_data) + + model = models.Sequential() + model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype)) + model.add(layer) + + layer.set_weights(layer_weights) + actual_output = model.predict(input_data) + actual_output_shape = actual_output.shape + for expected_dim, actual_dim in zip(computed_output_shape, actual_output_shape): + if expected_dim is not None: + if expected_dim != actual_dim: + raise AssertionError( + "When testing layer %s **after deserialization**, " + "for input %s, found output_shape=" + "%s but expected to find inferred shape %s.\n" + "Full kwargs: %s" + % ( + layer_cls.__name__, + x, + actual_output_shape, + computed_output_shape, + kwargs, + ) + ) + if expected_output is not None: + assert_equal(actual_output, expected_output) + + # test serialization, weight setting at model level + model_config = model.get_config() + recovered_model = models.Sequential.from_config(model_config, custom_objects) + if model.weights: + weights = model.get_weights() + recovered_model.set_weights(weights) + output = recovered_model.predict(input_data) + assert_equal(output, actual_output) + + # for further checks in the caller function + return actual_output + + +_thread_local_data = threading.local() +_thread_local_data.model_type = None +_thread_local_data.run_eagerly = None +_thread_local_data.saved_model_format = None +_thread_local_data.save_kwargs = None + + +def should_run_eagerly(): + """Returns whether the models we are testing should be run eagerly.""" + return _thread_local_data.run_eagerly and tf.executing_eagerly() diff --git a/tensorflow_addons/utils/types.py b/tensorflow_addons/utils/types.py index de8da2a5dd..6b8c00e5ea 100644 --- a/tensorflow_addons/utils/types.py +++ b/tensorflow_addons/utils/types.py @@ -22,15 +22,22 @@ from packaging.version import Version -# TODO: Remove once https://github.com/tensorflow/tensorflow/issues/44613 is resolved -if Version(tf.__version__).release >= Version("2.13").release: - # New versions of Keras require importing from `keras.src` when - # importing internal symbols. - from keras.src.engine import keras_tensor +# Find KerasTensor. +if Version(tf.__version__).release >= Version("2.16").release: + # Determine if loading keras 2 or 3. + if ( + hasattr(tf.keras, "version") + and Version(tf.keras.version()).release >= Version("3.0").release + ): + from keras import KerasTensor + else: + from tf_keras.src.engine.keras_tensor import KerasTensor +elif Version(tf.__version__).release >= Version("2.13").release: + from keras.src.engine.keras_tensor import KerasTensor elif Version(tf.__version__).release >= Version("2.5").release: - from keras.engine import keras_tensor + from keras.engine.keras_tensor import KerasTensor else: - from tensorflow.python.keras.engine import keras_tensor + from tensorflow.python.keras.engine.keras_tensor import KerasTensor Number = Union[ @@ -68,7 +75,7 @@ tf.Tensor, tf.SparseTensor, tf.Variable, - keras_tensor.KerasTensor, + KerasTensor, ] FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64] AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None]