Skip to content

Commit

Permalink
Refactoring attention (#1182)
Browse files Browse the repository at this point in the history
* refactoring

* adding back a function that got deleted by mistake

* adding co-authors
Co-Authored-By: Vitaliy Chiley <[email protected]>
Co-Authored-By: Cheng Li <[email protected]>

* adding co-authors
Co-Authored-By: Vitaliy Chiley <[email protected]>

* adding co-authors
Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Vitaliy Chiley <[email protected]>

* Update config_utils.py

adding co-authors
Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Cheng Li <@cli99>

* lint

Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Vitaliy Chiley <[email protected]>

* Adding_co_authors

Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Cheng Li <[email protected]>

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Daniel King <[email protected]>

* addressing comments

* adding_co_authors

Co-authored-by: Cheng Li <[email protected]>

* Update llmfoundry/utils/config_utils.py

---------

Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
  • Loading branch information
6 people authored May 8, 2024
1 parent 46b8bee commit ac563e6
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 139 deletions.
228 changes: 140 additions & 88 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import math
import warnings
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn as nn
import transformers
from einops import rearrange
from packaging import version
Expand Down Expand Up @@ -233,7 +232,6 @@ def flash_attn_fn(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -506,6 +504,54 @@ def forward(
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
query, key, value = self.get_qkv(x)

if rotary_emb_w_meta_info is not None:
query, key, value = self._apply_rotary_embeddings(
rotary_emb_w_meta_info,
query,
key,
value,
)

extra_attn_kwargs = self.get_implementation_specific_args(
attention_mask,
alibi_slopes,
flash_attn_padding_info,
)

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
n_heads=self.n_heads,
kv_n_heads=self.kv_n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
**extra_attn_kwargs,
)

return self.out_proj(context), attn_weights, past_key_value

def get_qkv(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes and returns the query, key, and value tensors.
Args:
x (torch.Tensor): The input tensor.
Returns:
query (torch.Tensor): The query tensor.
key (torch.Tensor): The key tensor.
value (torch.Tensor): The value tensor.
"""
qkv = self.Wqkv(x)

if self.clip_qkv:
Expand All @@ -520,8 +566,6 @@ def forward(
dim=2,
)

key_padding_mask = attention_mask

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
q_shape, k_shape = query.shape, key.shape
Expand All @@ -533,97 +577,105 @@ def forward(
query = self.q_ln(query).to(dtype).view(q_shape)
key = self.k_ln(key).to(dtype).view(k_shape)

if rotary_emb_w_meta_info is not None:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)

if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(
query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len,
return query, key, value

def _apply_rotary_embeddings(
self,
rotary_emb_w_meta_info: Dict[str, Any],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)

if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(
query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len,
)
[key, value] = torch.unbind(kv, dim=2)

value = value.view(bsz, seqlen, -1)
elif rotary_emb_w_meta_info['impl'] == 'hf':
if is_transformers_version_gte('4.38'):
(cos, sin) = rotary_emb(
x=value,
position_ids=offset_info,
)
else:
(cos, sin) = rotary_emb(x=value, seq_len=seq_len)
if is_transformers_version_gte('4.38'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=None,
unsqueeze_dim=2,
)
elif is_transformers_version_gte('4.36'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
unsqueeze_dim=2,
)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
)
[key, value] = torch.unbind(kv, dim=2)

value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
if is_transformers_version_gte('4.38'):
(cos, sin) = rotary_emb(
x=value,
position_ids=offset_info,
)
else:
(cos, sin) = rotary_emb(x=value, seq_len=seq_len)
if is_transformers_version_gte('4.38'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=None,
unsqueeze_dim=2,
)
elif is_transformers_version_gte('4.36'):
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
unsqueeze_dim=2,
)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(
q=query,
k=key,
cos=cos,
sin=sin,
position_ids=offset_info,
)
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

extra_attn_kwargs = {}
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, -1)
key = key.view(bsz, seqlen, -1)
return query, key, value

def get_implementation_specific_args(
self,
attention_mask: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> dict[str, Any]:
"""Returns attention implementation specific args.
Args:
attention_mask (Optional[torch.Tensor]): The attention mask.
alibi_slopes (Optional[torch.Tensor]): The alibi slopes.
flash_attn_padding_info (Optional[dict[str, torch.Tensor]]): The padding information, only required for flash attention.
Returns:
extra_attn_kwargs (dict[str, Any]): Implementation specific args.
"""
if self.attn_impl == 'flash':
key_padding_mask = None
extra_attn_kwargs = {
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'flash_attn_padding_info': flash_attn_padding_info,
'key_padding_mask': None,
}

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
self.kv_n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
**extra_attn_kwargs,
)

return self.out_proj(context), attn_weights, past_key_value
else:
extra_attn_kwargs = {'key_padding_mask': attention_mask}
return extra_attn_kwargs


@attention_classes.register_class('multihead_attention')
Expand Down
66 changes: 39 additions & 27 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""GPT Blocks used for the GPT Model."""

from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Set, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -88,6 +88,8 @@ def __init__(
self.norm_attn_norm = FusedNormAttentionNorm(
d_model=d_model,
n_heads=n_heads,
args_to_exclude_in_attn_class=self.
args_to_exclude_in_attn_class,
attn_config=attn_config,
ffn_has_norm=ffn_has_norm,
fc_type=fc_type,
Expand All @@ -99,21 +101,10 @@ def __init__(
else:
assert isinstance(attn_config['attn_type'], str)
# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type',
'alibi',
'attn_uses_sequence_id',
'alibi_bias_max',
'rope',
'rope_theta',
'rope_impl',
'rope_dail_config',
'rope_hf_config',
}
attn_config_subset_for_attn_class = {
k: v
for k, v in attn_config.items()
if k not in args_to_exclude_in_attn_class
if k not in self.args_to_exclude_in_attn_class
}

self.norm_1 = build_norm(
Expand Down Expand Up @@ -153,6 +144,20 @@ def __init__(
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn

@property
def args_to_exclude_in_attn_class(self):
return {
'attn_type',
'alibi',
'attn_uses_sequence_id',
'alibi_bias_max',
'rope',
'rope_theta',
'rope_impl',
'rope_dail_config',
'rope_hf_config',
}

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -196,6 +201,24 @@ def forward(
if self.norm_2 is not None:
m = self.norm_2(x)

n = self.apply_ffn(attention_mask, m)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value

def apply_ffn(
self,
attention_mask: Optional[torch.ByteTensor],
m: torch.Tensor,
) -> torch.Tensor:
"""Apply feed forward layers to the input.
Args:
attention_mask (Optional[torch.ByteTensor]): The attention mask.
m (torch.Tensor): The input.
Returns:
n (torch.Tensor): The output.
"""
batch_size, seq_len = m.size()[:2]
indices = None
if not self.use_pad_tok_in_ffn:
Expand All @@ -205,8 +228,7 @@ def forward(
if not self.use_pad_tok_in_ffn:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value
return n


class FusedNormAttentionNorm(nn.Module):
Expand All @@ -215,6 +237,7 @@ def __init__(
self,
d_model: int,
n_heads: int,
args_to_exclude_in_attn_class: Set[str],
attn_config: Optional[Dict] = None,
ffn_has_norm: bool = False,
fc_type: str = 'torch',
Expand All @@ -228,18 +251,7 @@ def __init__(
assert attn_config is not None
assert isinstance(attn_config['attn_type'], str)

# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type',
'alibi',
'attn_uses_sequence_id',
'alibi_bias_max',
'rope',
'rope_theta',
'rope_impl',
'rope_dail_config',
'rope_hf_config',
}
# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
attn_config_subset_for_attn_class = {
k: v
for k, v in attn_config.items()
Expand Down
Loading

0 comments on commit ac563e6

Please sign in to comment.