diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1fd545691de60e..1f61fe62a65afa 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1067,8 +1067,8 @@ def __init__(self, config: BlenderbotConfig): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) self.encoder = BlenderbotEncoder(config, self.shared) self.decoder = BlenderbotDecoder(config, self.shared)