Skip to content

Commit

Permalink
Esm checkpointing (#26454)
Browse files Browse the repository at this point in the history
* Fixed in-place operation error in EsmEmbeddings

* Fixed in-place operation error in EsmEmbeddings again

---------

Co-authored-by: Schreiber-Finance <[email protected]>
  • Loading branch information
Amelie-Schreiber and Schreiber-Finance authored Sep 28, 2023
1 parent 5e11d72 commit 4e931a8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(
# This is analogous to the way that dropout layers scale down outputs during evaluation when not
# actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
if self.token_dropout:
embeddings.masked_fill_((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
src_lengths = attention_mask.sum(-1)
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
Expand All @@ -224,7 +224,7 @@ def forward(

if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = embeddings + position_embeddings

if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
Expand Down Expand Up @@ -399,7 +399,7 @@ def __init__(self, config):
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states += input_tensor
hidden_states = hidden_states + input_tensor
return hidden_states


Expand Down Expand Up @@ -474,7 +474,7 @@ def __init__(self, config):
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states += input_tensor
hidden_states = hidden_states + input_tensor
return hidden_states


Expand Down Expand Up @@ -633,7 +633,7 @@ def custom_forward(*inputs):

hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
Expand Down

0 comments on commit 4e931a8

Please sign in to comment.