From 13453c22ac9f6f30bcd303c8871fc878fdf477a8 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 6 Dec 2023 20:40:08 +0000 Subject: [PATCH] Fix the encoder-decoder / dual-encoder classes --- .../models/convbert/modeling_tf_convbert.py | 1 - .../encoder_decoder/modeling_tf_encoder_decoder.py | 6 ++++++ .../modeling_tf_vision_encoder_decoder.py | 6 ++++++ .../modeling_tf_vision_text_dual_encoder.py | 10 +++++++--- src/transformers/models/xlm/modeling_tf_xlm.py | 1 - .../test_modeling_tf_encoder_decoder.py | 4 ++-- .../test_modeling_tf_vision_encoder_decoder.py | 4 ++-- 7 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py index f9f8a508a3ba26..4fe8b75d346dac 100644 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ b/src/transformers/models/convbert/modeling_tf_convbert.py @@ -458,7 +458,6 @@ def build(self, input_shape=None): self.dense.build(self.config.hidden_size) - class TFConvBertOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index ef3dc769a7b7dc..e9915da5cc9b53 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -674,3 +674,9 @@ def build(self, input_shape=None): if getattr(self, "enc_to_dec_proj", None) is not None: with tf.name_scope(self.enc_to_dec_proj.name): self.enc_to_dec_proj.build(self.encoder.config.hidden_size) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 5d727956c79b37..01c45ce1663dd0 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -725,3 +725,9 @@ def build(self, input_shape=None): if getattr(self, "enc_to_dec_proj", None) is not None: with tf.name_scope(self.enc_to_dec_proj.name): self.enc_to_dec_proj.build(self.encoder.config.hidden_size) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py index aef12fff1449a5..29ba688e928489 100644 --- a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py @@ -223,19 +223,23 @@ def __init__( self.config = config def build(self, input_shape=None): + if self.built: + return + self.built = True # Build in the build() method to make sure the names are right initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value) self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale") - if self.built: - return - self.built = True if getattr(self, "visual_projection", None) is not None: with tf.name_scope(self.visual_projection.name): self.visual_projection.build(self.vision_embed_dim) if getattr(self, "text_projection", None) is not None: with tf.name_scope(self.text_projection.name): self.text_projection.build(self.text_embed_dim) + with tf.name_scope(self.vision_model.name): + self.vision_model.build(None) + with tf.name_scope(self.text_model.name): + self.text_model.build(None) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index f5a262d45629f1..1d3a5e5d29b522 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -384,7 +384,6 @@ def build(self, input_shape=None): with tf.name_scope(layer.name): layer.build([None, None, self.dim]) - def get_input_embeddings(self): return self.embeddings diff --git a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py index 48d9a03e578926..c056e16c507a4c 100644 --- a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py @@ -1071,9 +1071,9 @@ def test_encoder_decoder_save_load_from_encoder_decoder(self): # create two random BERT models for bert2bert & initialize weights (+cross_attention weights) encoder = TFBertModel(config.encoder) - encoder.build() + encoder.build_in_name_scope() decoder = TFBertLMHeadModel(config.decoder) - decoder.build() + decoder.build_in_name_scope() encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index db38e4a9899298..9d81a476531e0c 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -729,9 +729,9 @@ def test_encoder_decoder_save_load_from_encoder_decoder(self): # create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights) encoder = TFViTModel(config.encoder) - encoder.build() + encoder.build_in_name_scope() decoder = TFGPT2LMHeadModel(config.decoder) - decoder.build() + decoder.build_in_name_scope() encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)