-
Notifications
You must be signed in to change notification settings - Fork 612
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix tf-addons for upcoming keras 3 default.
Keras 3.0 will become default in TF 2.16 (and is currently default in tf-nightly). This breaks this tf-addons package. Here we make minimal changes to return functionality in a backward-compatible way.
- Loading branch information
Showing
14 changed files
with
757 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Base class for RNN cells. | ||
Adapted from legacy github.com/keras-team/tf-keras. | ||
""" | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): | ||
if inputs is not None: | ||
batch_size = tf.shape(inputs)[0] | ||
dtype = inputs.dtype | ||
return _generate_zero_filled_state(batch_size, cell.state_size, dtype) | ||
|
||
|
||
def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): | ||
"""Generate a zero filled tensor with shape [batch_size, state_size].""" | ||
if batch_size_tensor is None or dtype is None: | ||
raise ValueError( | ||
"batch_size and dtype cannot be None while constructing initial state: " | ||
"batch_size={}, dtype={}".format(batch_size_tensor, dtype) | ||
) | ||
|
||
def create_zeros(unnested_state_size): | ||
flat_dims = tf.TensorShape(unnested_state_size).as_list() | ||
init_state_size = [batch_size_tensor] + flat_dims | ||
return tf.zeros(init_state_size, dtype=dtype) | ||
|
||
if tf.nest.is_nested(state_size): | ||
return tf.nest.map_structure(create_zeros, state_size) | ||
else: | ||
return create_zeros(state_size) | ||
|
||
|
||
class AbstractRNNCell(tf.keras.layers.Layer): | ||
"""Abstract object representing an RNN cell. | ||
This is a base class for implementing RNN cells with custom behavior. | ||
Every `RNNCell` must have the properties below and implement `call` with | ||
the signature `(output, next_state) = call(input, state)`. | ||
Examples: | ||
```python | ||
class MinimalRNNCell(AbstractRNNCell): | ||
def __init__(self, units, **kwargs): | ||
self.units = units | ||
super(MinimalRNNCell, self).__init__(**kwargs) | ||
@property | ||
def state_size(self): | ||
return self.units | ||
def build(self, input_shape): | ||
self.kernel = self.add_weight(shape=(input_shape[-1], self.units), | ||
initializer='uniform', | ||
name='kernel') | ||
self.recurrent_kernel = self.add_weight( | ||
shape=(self.units, self.units), | ||
initializer='uniform', | ||
name='recurrent_kernel') | ||
self.built = True | ||
def call(self, inputs, states): | ||
prev_output = states[0] | ||
h = backend.dot(inputs, self.kernel) | ||
output = h + backend.dot(prev_output, self.recurrent_kernel) | ||
return output, output | ||
``` | ||
This definition of cell differs from the definition used in the literature. | ||
In the literature, 'cell' refers to an object with a single scalar output. | ||
This definition refers to a horizontal array of such units. | ||
An RNN cell, in the most abstract setting, is anything that has | ||
a state and performs some operation that takes a matrix of inputs. | ||
This operation results in an output matrix with `self.output_size` columns. | ||
If `self.state_size` is an integer, this operation also results in a new | ||
state matrix with `self.state_size` columns. If `self.state_size` is a | ||
(possibly nested tuple of) TensorShape object(s), then it should return a | ||
matching structure of Tensors having shape `[batch_size].concatenate(s)` | ||
for each `s` in `self.batch_size`. | ||
""" | ||
|
||
def call(self, inputs, states): | ||
"""The function that contains the logic for one RNN step calculation. | ||
Args: | ||
inputs: the input tensor, which is a slide from the overall RNN input by | ||
the time dimension (usually the second dimension). | ||
states: the state tensor from previous step, which has the same shape | ||
as `(batch, state_size)`. In the case of timestep 0, it will be the | ||
initial state user specified, or zero filled tensor otherwise. | ||
Returns: | ||
A tuple of two tensors: | ||
1. output tensor for the current timestep, with size `output_size`. | ||
2. state tensor for next step, which has the shape of `state_size`. | ||
""" | ||
raise NotImplementedError("Abstract method") | ||
|
||
@property | ||
def state_size(self): | ||
"""size(s) of state(s) used by this cell. | ||
It can be represented by an Integer, a TensorShape or a tuple of Integers | ||
or TensorShapes. | ||
""" | ||
raise NotImplementedError("Abstract method") | ||
|
||
@property | ||
def output_size(self): | ||
"""Integer or TensorShape: size of outputs produced by this cell.""" | ||
raise NotImplementedError("Abstract method") | ||
|
||
def get_initial_state(self, inputs=None, batch_size=None, dtype=None): | ||
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.