Skip to content

Commit

Permalink
Fix the encoder-decoder / dual-encoder classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 6, 2023
1 parent 4128761 commit 907e329
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 9 deletions.
1 change: 0 additions & 1 deletion src/transformers/models/convbert/modeling_tf_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ def build(self, input_shape=None):
self.dense.build(self.config.hidden_size)



class TFConvBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,3 +674,9 @@ def build(self, input_shape=None):
if getattr(self, "enc_to_dec_proj", None) is not None:
with tf.name_scope(self.enc_to_dec_proj.name):
self.enc_to_dec_proj.build(self.encoder.config.hidden_size)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,9 @@ def build(self, input_shape=None):
if getattr(self, "enc_to_dec_proj", None) is not None:
with tf.name_scope(self.enc_to_dec_proj.name):
self.enc_to_dec_proj.build(self.encoder.config.hidden_size)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,23 @@ def __init__(
self.config = config

def build(self, input_shape=None):
if self.built:
return
self.built = True
# Build in the build() method to make sure the names are right
initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value)
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")

if self.built:
return
self.built = True
if getattr(self, "visual_projection", None) is not None:
with tf.name_scope(self.visual_projection.name):
self.visual_projection.build(self.vision_embed_dim)
if getattr(self, "text_projection", None) is not None:
with tf.name_scope(self.text_projection.name):
self.text_projection.build(self.text_embed_dim)
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
with tf.name_scope(self.text_model.name):
self.text_model.build(None)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/xlm/modeling_tf_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ def build(self, input_shape=None):
with tf.name_scope(layer.name):
layer.build([None, None, self.dim])


def get_input_embeddings(self):
return self.embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1071,9 +1071,9 @@ def test_encoder_decoder_save_load_from_encoder_decoder(self):

# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
encoder = TFBertModel(config.encoder)
encoder.build()
encoder.build_in_name_scope()
decoder = TFBertLMHeadModel(config.decoder)
decoder.build()
decoder.build_in_name_scope()

encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,9 @@ def test_encoder_decoder_save_load_from_encoder_decoder(self):

# create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights)
encoder = TFViTModel(config.encoder)
encoder.build()
encoder.build_in_name_scope()
decoder = TFGPT2LMHeadModel(config.decoder)
decoder.build()
decoder.build_in_name_scope()

encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

Expand Down

0 comments on commit 907e329

Please sign in to comment.