Skip to content

Commit

Permalink
Make maybe_resize_bias() protected
Browse files Browse the repository at this point in the history
  • Loading branch information
hackyon committed Feb 15, 2024
1 parent 12b9e55 commit b7216df
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def __init__(self, config, lxmert_model_embedding_weights):
self.decoder.weight = lxmert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0)))

def maybe_resize_bias(self, new_size: int):
def _maybe_resize_bias(self, new_size: int):
if new_size != self.bias.shape[0]:
self.bias.data = nn.functional.pad(self.bias.data, (0, new_size - self.bias.shape[0]), "constant", 0)

Expand Down Expand Up @@ -1082,7 +1082,7 @@ def __init__(self, config):

def _tie_weights(self):
self.cls.predictions.decoder.weight = self.lxmert.embeddings.word_embeddings.weight
self.cls.predictions.maybe_resize_bias(self.lxmert.embeddings.word_embeddings.weight.shape[0])
self.cls.predictions._maybe_resize_bias(self.lxmert.embeddings.word_embeddings.weight.shape[0])

def resize_num_qa_labels(self, num_labels):
"""
Expand Down

0 comments on commit b7216df

Please sign in to comment.