Skip to content

Commit

Permalink
Allow passing key_value_statest for x-attn through MPT Block (#1511)
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay authored Sep 3, 2024
1 parent 02802c5 commit 4ab483f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ def forward(
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
extra_kwargs['key_value_states'] = key_value_states
if self.fuse_norm_attn_norm:
x, m, attn_weights, past_key_value = self.norm_attn_norm(
x,
Expand Down Expand Up @@ -327,12 +329,14 @@ def forward(
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
extra_kwargs['key_value_states'] = key_value_states
b, attn_weights, past_key_value = self.attn(
a,
past_key_value=past_key_value,
Expand Down

0 comments on commit 4ab483f

Please sign in to comment.