Skip to content

Commit

Permalink
Fix tf-addons for upcoming keras 3 default.
Browse files Browse the repository at this point in the history
Keras 3.0 will become default in TF 2.16 (and is currently default in
tf-nightly).  This breaks this tf-addons package.  Here we make
minimal changes to return functionality in a backward-compatible way.
  • Loading branch information
cantonios committed Dec 4, 2023
1 parent 9688ebd commit f482838
Show file tree
Hide file tree
Showing 14 changed files with 757 additions and 38 deletions.
6 changes: 3 additions & 3 deletions tensorflow_addons/image/tests/distort_image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_adjust_random_hue_in_yiq(shape, style, dtype):
y_np = _adjust_hue_in_yiq_np(x_np, delta_h)
y_tf = _adjust_hue_in_yiq_tf(x_np, delta_h)
test_utils.assert_allclose_according_to_type(
y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=0.8
y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=1.1
)


Expand All @@ -121,11 +121,11 @@ def test_invalid_channels_hsv():

def test_adjust_hsv_in_yiq_unknown_shape():
fn = tf.function(distort_image_ops.adjust_hsv_in_yiq).get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float64)
tf.TensorSpec(shape=None, dtype=tf.float32)
)
for shape in (2, 3, 3), (4, 2, 3, 3):
image_np = np.random.rand(*shape) * 255.0
image_tf = tf.constant(image_np)
image_tf = tf.constant(image_np, dtype=tf.float32)
np.testing.assert_allclose(
_adjust_hue_in_yiq_np(image_np, 0), fn(image_tf), rtol=2e-4, atol=1e-4
)
Expand Down
17 changes: 14 additions & 3 deletions tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,20 @@
from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typeguard import typechecked

if Version(tf.__version__).release >= Version("2.13").release:
# New versions of Keras require importing from `keras.src` when
# importing internal symbols.
if Version(tf.__version__).release >= Version("2.16").release:
# Determine if loading keras 2 or 3.
if (
hasattr(tf.keras, "version")
and Version(tf.keras.version()).release >= Version("3.0").release
):
# New versions of Keras require importing from `keras.src` when
# importing internal symbols.
from keras.src import backend
from keras.src.utils import tf_utils
else:
from tf_keras.src import backend
from tf_keras.src.utils import tf_utils
elif Version(tf.__version__).release >= Version("2.13").release:
from keras.src import backend
from keras.src.utils import tf_utils
else:
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_addons/optimizers/lazy_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,6 @@ def _resource_scatter_operate(self, resource, indices, update, resource_scatter_
}

return resource_scatter_op(**resource_update_kwargs)

def get_config(self):
return super().get_config()
133 changes: 133 additions & 0 deletions tensorflow_addons/rnn/abstract_rnn_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for RNN cells.
Adapted from legacy github.com/keras-team/tf-keras.
"""

import tensorflow as tf


def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
if inputs is not None:
batch_size = tf.shape(inputs)[0]
dtype = inputs.dtype
return _generate_zero_filled_state(batch_size, cell.state_size, dtype)


def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
"""Generate a zero filled tensor with shape [batch_size, state_size]."""
if batch_size_tensor is None or dtype is None:
raise ValueError(
"batch_size and dtype cannot be None while constructing initial state: "
"batch_size={}, dtype={}".format(batch_size_tensor, dtype)
)

def create_zeros(unnested_state_size):
flat_dims = tf.TensorShape(unnested_state_size).as_list()
init_state_size = [batch_size_tensor] + flat_dims
return tf.zeros(init_state_size, dtype=dtype)

if tf.nest.is_nested(state_size):
return tf.nest.map_structure(create_zeros, state_size)
else:
return create_zeros(state_size)


class AbstractRNNCell(tf.keras.layers.Layer):
"""Abstract object representing an RNN cell.
This is a base class for implementing RNN cells with custom behavior.
Every `RNNCell` must have the properties below and implement `call` with
the signature `(output, next_state) = call(input, state)`.
Examples:
```python
class MinimalRNNCell(AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)
@property
def state_size(self):
return self.units
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, output
```
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.
An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns. If `self.state_size` is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape `[batch_size].concatenate(s)`
for each `s` in `self.batch_size`.
"""

def call(self, inputs, states):
"""The function that contains the logic for one RNN step calculation.
Args:
inputs: the input tensor, which is a slide from the overall RNN input by
the time dimension (usually the second dimension).
states: the state tensor from previous step, which has the same shape
as `(batch, state_size)`. In the case of timestep 0, it will be the
initial state user specified, or zero filled tensor otherwise.
Returns:
A tuple of two tensors:
1. output tensor for the current timestep, with size `output_size`.
2. state tensor for next step, which has the shape of `state_size`.
"""
raise NotImplementedError("Abstract method")

@property
def state_size(self):
"""size(s) of state(s) used by this cell.
It can be represented by an Integer, a TensorShape or a tuple of Integers
or TensorShapes.
"""
raise NotImplementedError("Abstract method")

@property
def output_size(self):
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError("Abstract method")

def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
4 changes: 2 additions & 2 deletions tensorflow_addons/rnn/esn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
"""Implements ESN Cell."""

import tensorflow as tf
import tensorflow.keras as keras
from typeguard import typechecked

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils.types import (
Activation,
Initializer,
)


@tf.keras.utils.register_keras_serializable(package="Addons")
class ESNCell(keras.layers.AbstractRNNCell):
class ESNCell(AbstractRNNCell):
"""Echo State recurrent Network (ESN) cell.
This implements the recurrent cell from the paper:
H. Jaeger
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_addons/rnn/nas_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"""Implements NAS Cell."""

import tensorflow as tf
import tensorflow.keras as keras
from typeguard import typechecked

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils.types import (
FloatTensorLike,
TensorLike,
Expand All @@ -27,7 +27,7 @@


@tf.keras.utils.register_keras_serializable(package="Addons")
class NASCell(keras.layers.AbstractRNNCell):
class NASCell(AbstractRNNCell):
"""Neural Architecture Search (NAS) recurrent network cell.
This implements the recurrent cell from the paper:
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/seq2seq/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ py_library(
"//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so",
],
deps = [
"//tensorflow_addons/rnn",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
],
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/seq2seq/attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tensorflow as tf

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.types import (
AcceptableDTypes,
Expand Down Expand Up @@ -1577,7 +1578,7 @@ def _compute_attention(
return attention, alignments, next_attention_state


class AttentionWrapper(tf.keras.layers.AbstractRNNCell):
class AttentionWrapper(AbstractRNNCell):
"""Wraps another RNN cell with attention.
Example:
Expand Down
16 changes: 7 additions & 9 deletions tensorflow_addons/text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ package(default_visibility = ["//visibility:public"])
py_library(
name = "text",
srcs = glob(["*.py"]),
data = select({
"//tensorflow_addons:windows": [
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
],
data = [
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
"//tensorflow_addons/rnn",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
] + select({
"//tensorflow_addons:windows": [],
"//conditions:default": [
"//tensorflow_addons/custom_ops/text:_parse_time_op.so",
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
],
}),
)
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/text/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import tensorflow as tf

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils.types import TensorLike
from typeguard import typechecked
from typing import Optional, Tuple
Expand Down Expand Up @@ -403,7 +404,7 @@ def viterbi_decode(score: TensorLike, transition_params: TensorLike) -> tf.Tenso
return viterbi, viterbi_score


class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
class CrfDecodeForwardRnnCell(AbstractRNNCell):
"""Computes the forward decoding in a linear-chain CRF."""

@typechecked
Expand Down
10 changes: 1 addition & 9 deletions tensorflow_addons/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,10 @@
import pytest
import tensorflow as tf

from packaging.version import Version
from tensorflow_addons import options
from tensorflow_addons.utils import resource_loader

if Version(tf.__version__).release >= Version("2.13").release:
# New versions of Keras require importing from `keras.src` when
# importing internal symbols.
from keras.src.testing_infra.test_utils import layer_test # noqa: F401
elif Version(tf.__version__) >= Version("2.9"):
from keras.testing_infra.test_utils import layer_test # noqa: F401
else:
from keras.testing_utils import layer_test # noqa: F401
from tensorflow_addons.utils.tf_test_utils import layer_test # noqa

NUMBER_OF_WORKERS = int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1"))
WORKER_ID = int(os.environ.get("PYTEST_XDIST_WORKER", "gw0")[2])
Expand Down
Loading

0 comments on commit f482838

Please sign in to comment.