Skip to content

Commit

Permalink
Allow MPT models to return attention weights
Browse files Browse the repository at this point in the history
  • Loading branch information
lorabit110 committed Sep 15, 2023
1 parent 7ec2fe0 commit a4f5679
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def forward(
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -100,6 +101,7 @@ def forward(
attn_bias=attn_bias,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def forward(
attn_bias=attn_bias,
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=output_attentions == True,
)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value
Expand Down

0 comments on commit a4f5679

Please sign in to comment.