Skip to content

Commit

Permalink
fix gpt2 attention output
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 16, 2024
1 parent fe90ec0 commit a3f50d0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def forward(
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def forward(
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
Expand Down

0 comments on commit a3f50d0

Please sign in to comment.