From a4f567906d56051cd6161a20d709da662c0874f6 Mon Sep 17 00:00:00 2001 From: Yanan Xie Date: Fri, 15 Sep 2023 10:30:44 -0700 Subject: [PATCH] Allow MPT models to return attention weights --- llmfoundry/models/layers/blocks.py | 2 ++ llmfoundry/models/mpt/modeling_mpt.py | 1 + 2 files changed, 3 insertions(+) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index dd208302b8..2c5b5d1c7c 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -91,6 +91,7 @@ def forward( attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, + output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -100,6 +101,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, + needs_weights=output_attentions, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3371c67a0d..6a184ee6dd 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -434,6 +434,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, + output_attentions=output_attentions == True, ) if past_key_values is not None: past_key_values[b_idx] = past_key_value