Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fix-fa-2-from-config
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Dec 14, 2023
2 parents a9be74d + 050e0b4 commit 8160f44
Show file tree
Hide file tree
Showing 75 changed files with 11,042 additions and 505 deletions.
3 changes: 2 additions & 1 deletion docs/source/en/model_doc/seamless_m4t.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ specific language governing permissions and limitations under the License.
## Overview

The SeamlessM4T model was proposed in [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team from Meta AI.
This is the version 1 release of the model. For the updated version 2 release, refer to the [Seamless M4T v2 docs](./seamless_m4t_v2.md).

This is the **version 1** release of the model. For the updated **version 2** release, refer to the [Seamless M4T v2 docs](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t_v2).

SeamlessM4T is a collection of models designed to provide high quality translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/seamless_m4t_v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License.

The SeamlessM4T-v2 model was proposed in [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team from Meta AI.

SeamlessM4T-v2 is a collection of models designed to provide high quality translation, allowing people from different linguistic communities to communicate effortlessly through speech and text. It is an improvement on the [previous version](./seamless_m4t.md). For more details on the differences between v1 and v2, refer to section [Difference with SeamlessM4T-v1](#difference-with-seamlessm4t-v1).
SeamlessM4T-v2 is a collection of models designed to provide high quality translation, allowing people from different linguistic communities to communicate effortlessly through speech and text. It is an improvement on the [previous version](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t). For more details on the differences between v1 and v2, refer to section [Difference with SeamlessM4T-v1](#difference-with-seamlessm4t-v1).

SeamlessM4T-v2 enables multiple tasks without relying on separate models:

Expand Down
43 changes: 25 additions & 18 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from huggingface_hub import Repository, list_repo_files
from keras import backend as K
from packaging.version import parse
from tensorflow.python.util.keras_deps import get_call_context_function

from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
Expand Down Expand Up @@ -1122,6 +1121,10 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
)
return dummies

def build_in_name_scope(self):
with tf.name_scope(self.name):
self.build(input_shape=None)

@property
def framework(self) -> str:
"""
Expand All @@ -1130,15 +1133,7 @@ def framework(self) -> str:
return "tf"

def build(self, input_shape=None):
call_context = get_call_context_function()
if self.built or call_context().in_call:
self.built = True
else:
self.built = True
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
self._set_save_spec(self.input_signature)
self(self.dummy_inputs, training=False)
pass # This is just here to make sure we don't call the superclass build()

def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down Expand Up @@ -1869,7 +1864,7 @@ def set_input_embeddings(self, value):
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
main_layer.set_input_embeddings(value)

def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
Expand All @@ -1886,7 +1881,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
return lm_head.get_output_embeddings()
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()

return lm_head().get_output_embeddings()

Expand All @@ -1906,7 +1901,7 @@ def set_output_embeddings(self, value):
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
lm_head.set_output_embeddings(value)

def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
Expand Down Expand Up @@ -1944,7 +1939,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
try:
return lm_head.get_bias()
except AttributeError:
self.build()
self.build_in_name_scope()

return lm_head.get_bias()
return None
Expand All @@ -1962,7 +1957,7 @@ def set_bias(self, value):
try:
lm_head.set_bias(value)
except AttributeError:
self.build()
self.build_in_name_scope()
lm_head.set_bias(value)

def get_lm_head(self) -> tf.keras.layers.Layer:
Expand Down Expand Up @@ -2049,7 +2044,7 @@ def _get_word_embedding_weight(model, embedding_layer):
# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model.build()
model.build_in_name_scope()

embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
Expand Down Expand Up @@ -2914,9 +2909,9 @@ def from_pretrained(
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs
else:
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs

if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
Expand Down Expand Up @@ -3215,6 +3210,9 @@ def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
self.initializer_range = initializer_range

def build(self, input_shape):
if self.built:
return
self.built = True
self.weight = self.add_weight(
"weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
)
Expand Down Expand Up @@ -3398,6 +3396,7 @@ def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **
self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
self.hidden_size = config.hidden_size

def call(self, inputs, cls_index=None, training=False):
if not isinstance(inputs, (dict, tuple, list)):
Expand Down Expand Up @@ -3450,6 +3449,14 @@ def call(self, inputs, cls_index=None, training=False):

return output

def build(self, input_shape):
if self.built:
return
self.built = True
if getattr(self, "summary", None) is not None:
with tf.name_scope("summary"):
self.summary.build(self.hidden_size)


def get_initializer(initializer_range: float = 0.02) -> tf.keras.initializers.TruncatedNormal:
"""
Expand Down
Loading

0 comments on commit 8160f44

Please sign in to comment.