diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 5af5481f0a..82e8e94f74 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -164,11 +164,13 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + extra_kwargs['key_value_states'] = key_value_states if self.fuse_norm_attn_norm: x, m, attn_weights, past_key_value = self.norm_attn_norm( x, @@ -327,12 +329,14 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + extra_kwargs['key_value_states'] = key_value_states b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value,