Skip to content

Commit

Permalink
Fix 29807 sinusoidal positional encodings in Flaubert, Informer and X…
Browse files Browse the repository at this point in the history
…LM (#29904)

* Fix sinusoidal_embeddings in FlaubertModel

* Fix for Informer

* Fix for XLM

* Move sinusoidal emb for XLM

* Move sinusoidal emb for Flaubert

* Small cleanup

* Add comments on tests code copied from

* Add with Distilbert->
  • Loading branch information
hovnatan authored Apr 2, 2024
1 parent 83b26dd commit 416711c
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 8 deletions.
8 changes: 5 additions & 3 deletions src/transformers/models/flaubert/modeling_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
# Copied from transformers.models.xlm.modeling_xlm.create_sinusoidal_embeddings
def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False


# Copied from transformers.models.xlm.modeling_xlm.get_masks
Expand Down Expand Up @@ -370,6 +370,10 @@ def _init_weights(self, module):
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings:
create_sinusoidal_embeddings(
self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight
)


class FlaubertModel(FlaubertPreTrainedModel):
Expand Down Expand Up @@ -407,8 +411,6 @@ def __init__(self, config): # , dico, is_encoder, with_output):

# embeddings
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/informer/modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
elif isinstance(module, nn.Embedding) and not isinstance(module, InformerSinusoidalPositionalEmbedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/xlm/modeling_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@

def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False


def get_masks(slen, lengths, causal, padding_mask=None):
Expand Down Expand Up @@ -245,6 +245,10 @@ def _init_weights(self, module):
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, XLMModel) and self.config.sinusoidal_embeddings:
create_sinusoidal_embeddings(
self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight
)


@dataclass
Expand Down Expand Up @@ -414,8 +418,6 @@ def __init__(self, config):

# embeddings
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
Expand Down
9 changes: 9 additions & 0 deletions tests/models/flaubert/test_modeling_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
FlaubertModel,
FlaubertWithLMHeadModel,
)
from transformers.models.flaubert.modeling_flaubert import create_sinusoidal_embeddings


class FlaubertModelTester(object):
Expand Down Expand Up @@ -431,6 +432,14 @@ def test_flaubert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_model(*config_and_inputs)

# Copied from tests/models/distilbert/test_modeling_distilbert.py with Distilbert->Flaubert
def test_flaubert_model_with_sinusoidal_encodings(self):
config = FlaubertConfig(sinusoidal_embeddings=True)
model = FlaubertModel(config=config)
sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.emb_dim), dtype=torch.float32)
create_sinusoidal_embeddings(config.max_position_embeddings, config.emb_dim, sinusoidal_pos_embds)
self.model_tester.parent.assertTrue(torch.equal(model.position_embeddings.weight, sinusoidal_pos_embds))

def test_flaubert_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_lm_head(*config_and_inputs)
Expand Down
12 changes: 11 additions & 1 deletion tests/models/informer/test_modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
import torch

from transformers import InformerConfig, InformerForPrediction, InformerModel
from transformers.models.informer.modeling_informer import InformerDecoder, InformerEncoder
from transformers.models.informer.modeling_informer import (
InformerDecoder,
InformerEncoder,
InformerSinusoidalPositionalEmbedding,
)


@require_torch
Expand Down Expand Up @@ -164,6 +168,12 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict):

self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)

embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))

with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
Expand Down
9 changes: 9 additions & 0 deletions tests/models/xlm/test_modeling_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
XLMModel,
XLMWithLMHeadModel,
)
from transformers.models.xlm.modeling_xlm import create_sinusoidal_embeddings


class XLMModelTester:
Expand Down Expand Up @@ -432,6 +433,14 @@ def test_xlm_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_model(*config_and_inputs)

# Copied from tests/models/distilbert/test_modeling_distilbert.py with Distilbert->XLM
def test_xlm_model_with_sinusoidal_encodings(self):
config = XLMConfig(sinusoidal_embeddings=True)
model = XLMModel(config=config)
sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.emb_dim), dtype=torch.float32)
create_sinusoidal_embeddings(config.max_position_embeddings, config.emb_dim, sinusoidal_pos_embds)
self.model_tester.parent.assertTrue(torch.equal(model.position_embeddings.weight, sinusoidal_pos_embds))

def test_xlm_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
Expand Down

0 comments on commit 416711c

Please sign in to comment.