Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows interweaving of arbitrary kinds of 'attention' layers, like sliding window, reuse prev layer kv cache etc. #1299

Merged
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
8804472
[WIP] Allows interweaving of arbitrary kinds of 'attention' layers, l…
ShashankMosaicML Jun 21, 2024
a50755e
lint
ShashankMosaicML Jun 21, 2024
fcc28a1
applying overrides to blocks rather than just attentions
ShashankMosaicML Jun 21, 2024
877d80e
add docstring
ShashankMosaicML Jun 21, 2024
1f11415
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 21, 2024
dd1c64b
minor
ShashankMosaicML Jun 21, 2024
d3b32e6
Merge branch 'mixed_attention_modules' of github.com:ShashankMosaicML…
ShashankMosaicML Jun 21, 2024
fc1bf0b
changing yaml specification style
ShashankMosaicML Jun 22, 2024
81e2930
..
ShashankMosaicML Jun 22, 2024
9b3d813
fixes
ShashankMosaicML Jun 22, 2024
aafcebb
fix
ShashankMosaicML Jun 22, 2024
b46756f
fix
ShashankMosaicML Jun 22, 2024
ad6ba32
fix
ShashankMosaicML Jun 22, 2024
3ea79fd
refactoring
ShashankMosaicML Jun 22, 2024
13802cb
add warning
ShashankMosaicML Jun 22, 2024
9b6ae9c
compute only query vector when reusing kv
ShashankMosaicML Jun 22, 2024
c774a4b
refactor
ShashankMosaicML Jun 22, 2024
8dee35e
fixing
ShashankMosaicML Jun 22, 2024
8ff15b4
adding test for reusing previous layer kv cache
ShashankMosaicML Jun 23, 2024
b1ee62a
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 23, 2024
04e9888
adding error messages
ShashankMosaicML Jun 23, 2024
5eee910
..
ShashankMosaicML Jun 23, 2024
2a6c986
adding test
ShashankMosaicML Jun 23, 2024
7bf89f2
add logging
ShashankMosaicML Jun 23, 2024
dcc5cc0
adding logging
ShashankMosaicML Jun 23, 2024
06d03c1
minor
ShashankMosaicML Jun 23, 2024
ec42e72
bug fix, adding test
ShashankMosaicML Jun 23, 2024
cc1f2f3
minor
ShashankMosaicML Jun 23, 2024
9f7b346
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 24, 2024
214456f
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 24, 2024
62a3030
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 25, 2024
c955b98
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 25, 2024
63f4196
addressing some comments
ShashankMosaicML Jun 25, 2024
8e89db9
Merge branch 'mixed_attention_modules' of github.com:ShashankMosaicML…
ShashankMosaicML Jun 25, 2024
89bc22e
addressing some comments
ShashankMosaicML Jun 25, 2024
0a3f1b4
setting absolute absolute value for reuse_kv_layer_idx
ShashankMosaicML Jun 25, 2024
b74330e
lint
ShashankMosaicML Jun 25, 2024
5079d2e
adding tests for override_block_args
ShashankMosaicML Jun 25, 2024
2a7e948
adding error if reusing kv cache from a mismatch layer
ShashankMosaicML Jun 26, 2024
1daf534
fixing test
ShashankMosaicML Jun 26, 2024
a57af04
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 26, 2024
afa4c09
fixing code, test
ShashankMosaicML Jun 26, 2024
a8a8a8b
fix
ShashankMosaicML Jun 26, 2024
6f468ef
..
ShashankMosaicML Jun 26, 2024
0f96a87
refactoring
ShashankMosaicML Jun 26, 2024
085de4c
fix
ShashankMosaicML Jun 26, 2024
65b816e
..
ShashankMosaicML Jun 26, 2024
e919422
..
ShashankMosaicML Jun 26, 2024
b2c9ee9
..
ShashankMosaicML Jun 26, 2024
7e77979
refactoring
ShashankMosaicML Jun 26, 2024
0ff081a
..
ShashankMosaicML Jun 26, 2024
d15d34b
..
ShashankMosaicML Jun 26, 2024
114b42f
..
ShashankMosaicML Jun 26, 2024
380e954
adding test for _get_modules_order_expanded
ShashankMosaicML Jun 26, 2024
3e3c974
fixing test
ShashankMosaicML Jun 26, 2024
7f9ef5c
fixing test
ShashankMosaicML Jun 26, 2024
18b77af
lint
ShashankMosaicML Jun 26, 2024
f417e95
lint
ShashankMosaicML Jun 26, 2024
72eb7e1
adding test
ShashankMosaicML Jun 26, 2024
89c3985
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 26, 2024
391978c
addressing comment
ShashankMosaicML Jun 26, 2024
0659e32
..
ShashankMosaicML Jun 27, 2024
180c004
fixing test
ShashankMosaicML Jun 27, 2024
2f073f7
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 27, 2024
4f9893a
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 27, 2024
023070f
changing yaml format
ShashankMosaicML Jun 27, 2024
ca9a2e9
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 27, 2024
dc9890d
fix configuation
ShashankMosaicML Jun 27, 2024
4be19c6
fixing test
ShashankMosaicML Jun 27, 2024
d4a417a
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 27, 2024
a5298c3
allowing repeat at top level
ShashankMosaicML Jun 27, 2024
ecd560c
Merge branch 'mixed_attention_modules' of github.com:ShashankMosaicML…
ShashankMosaicML Jun 27, 2024
700ede1
allowing overriding error
ShashankMosaicML Jun 28, 2024
ff055a4
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 29, 2024
ca047fa
addressing comments
ShashankMosaicML Jun 29, 2024
3ca3d55
lint
ShashankMosaicML Jun 29, 2024
3fbec2e
addressing comments
ShashankMosaicML Jun 29, 2024
e7d85bb
fix
ShashankMosaicML Jun 29, 2024
8d97354
..
ShashankMosaicML Jun 29, 2024
54abbd2
..
ShashankMosaicML Jun 29, 2024
c6ac78a
Merge branch 'main' into mixed_attention_modules
ShashankMosaicML Jun 29, 2024
52536cd
..
ShashankMosaicML Jun 29, 2024
6d62c2c
..
ShashankMosaicML Jun 29, 2024
2b53237
..
ShashankMosaicML Jun 29, 2024
d53a1e6
addressing comment
ShashankMosaicML Jun 30, 2024
6554a22
fixing test
ShashankMosaicML Jun 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 69 additions & 20 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__()

Expand All @@ -428,6 +429,7 @@ def __init__(
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads
self.sliding_window_size = sliding_window_size
self.reuse_kv_layer_idx = reuse_kv_layer_idx

self.head_dim = d_model // n_heads

Expand Down Expand Up @@ -458,18 +460,29 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
if self.reuse_kv_layer_idx is None:
self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
else:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_size = self.head_dim if qk_gn else d_model
Expand All @@ -478,13 +491,14 @@ def __init__(
normalized_shape=norm_size,
device=device,
)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)
if self.reuse_kv_layer_idx is None:
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)

self.attn_fn = attention_implementations.get(self.attn_impl)

Expand All @@ -507,9 +521,11 @@ def forward(
needs_weights: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
query, key, value = self.get_qkv(x)
query, key, value = self.get_qkv(x, prev_layer_key_value)

if rotary_emb_w_meta_info is not None:
query, key, value = self._apply_rotary_embeddings(
Expand Down Expand Up @@ -546,6 +562,7 @@ def forward(
def get_qkv(
self,
x: torch.Tensor,
prev_layer_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes and returns the query, key, and value tensors.

Expand All @@ -557,6 +574,27 @@ def get_qkv(
key (torch.Tensor): The key tensor.
value (torch.Tensor): The value tensor.
"""
if self.reuse_kv_layer_idx is not None:
if prev_layer_key_value is None:
raise ValueError(
'prev_layer_key_value is None, cannot reuse_prev_layer_kv.',
)
key, value = prev_layer_key_value

query = self.Wq(x)
if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
q_shape = query.shape
if self.qk_gn:
b, s = query.shape[:2]
query = query.view(b, s, self.n_heads, -1)
dtype = query.dtype
query = self.q_ln(query).to(dtype).view(q_shape)
return query, key, value

qkv = self.Wqkv(x)

if self.clip_qkv:
Expand Down Expand Up @@ -591,6 +629,10 @@ def _apply_rotary_embeddings(
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.reuse_kv_layer_idx is not None:
orig_key, orig_value = key, value
key, value = torch.empty_like(key), torch.empty_like(value)

rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
Expand All @@ -602,6 +644,7 @@ def _apply_rotary_embeddings(
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
# Note: Rotates in place (https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/flash_attn/layers/rotary.py#L429)
query, kv = rotary_emb(
query,
kv,
Expand Down Expand Up @@ -652,6 +695,8 @@ def _apply_rotary_embeddings(

query = query.view(bsz, seqlen, -1)
key = key.view(bsz, seqlen, -1)
if self.reuse_kv_layer_idx is not None:
return query, orig_key, orig_value # type: ignore
return query, key, value

def get_implementation_specific_args(
Expand Down Expand Up @@ -705,6 +750,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -721,6 +767,7 @@ def __init__(
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
)


Expand All @@ -746,6 +793,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -762,6 +810,7 @@ def __init__(
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
)


Expand Down
7 changes: 7 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def forward(
output_attentions: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
if self.fuse_norm_attn_norm:
Expand All @@ -171,6 +173,7 @@ def forward(
output_attentions=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
prev_layer_key_value=prev_layer_key_value,
)
else:
a = self.norm_1(x)
Expand All @@ -184,6 +187,7 @@ def forward(
needs_weights=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
prev_layer_key_value=prev_layer_key_value,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down Expand Up @@ -308,6 +312,8 @@ def forward(
output_attentions: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -321,6 +327,7 @@ def forward(
needs_weights=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
prev_layer_key_value=prev_layer_key_value,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
61 changes: 61 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ffn_config_defaults,
init_config_defaults,
)
from llmfoundry.utils.warnings import ExperimentalWarning


class MPTConfig(PretrainedConfig):
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
fc_type: Union[str, Dict] = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffn: bool = True,
block_overrides: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
"""The MPT configuration class.
Expand Down Expand Up @@ -117,6 +119,30 @@ def __init__(
also be a dictionary that specifies the fc layer name and any kwargs for the fc layer.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and `order`. `order` is a nested list which describes the order of the layers. For each kind of layer, specify the `overrides` in the overrides config (default refers to a layer that does not apply any overrides).
To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed:
block_overrides:
order:
- name: default
- repeat: 2
order:
- name: sliding_window_layer
- name: sliding_window_layer_reuse
- name: sliding_window_layer
- repeat: 2
name: sliding_window_layer_reuse
- name: reuse_kv_layer
overrides:
sliding_window_layer:
attn_config:
sliding_window_size: 1024
sliding_window_layer_reuse:
attn_config:
sliding_window_size: 1024
reuse_kv_layer_idx: -1 # Relative index of the layer whose kv cache to reuse
reuse_kv_layer:
attn_config:
reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -145,6 +171,15 @@ def __init__(
init_config_defaults,
)

if 'reuse_kv_layer_idx' in self.attn_config and self.attn_config[
'attn_impl'] == 'torch':
raise NotImplementedError(
'reusing kv cache from a previous layer is not implemented for torch attention.',
)
if block_overrides is not None:
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
self._validate_block_overrides(block_overrides)
self.block_overrides = block_overrides

if isinstance(fc_type, str):
fc_type = {'name': fc_type}
self.fc_type = fc_type
Expand All @@ -169,6 +204,23 @@ def __init__(

self._validate_config()

def _validate_block_overrides(self, block_overrides: Dict[str, Any]):
warnings.warn(ExperimentalWarning('block_overrides'))
if 'order' not in block_overrides:
raise ValueError('`order` should be defined in block_overrides',)
if 'overrides' not in block_overrides:
raise ValueError(
'`overrides` should be defined in block_overrides',
)
for name, override in block_overrides['overrides'].items():
if name == 'default':
raise ValueError('block overrides cannot be named "default".',)
if 'attn_config' in override and 'reuse_kv_layer_idx' in override[
'attn_config'] and self.attn_config['attn_impl'] == 'torch':
raise NotImplementedError(
'reusing kv cache from a previous layer is not implemented for torch attention.',
)

def _set_config_defaults(
self,
config: Dict[str, Any],
Expand Down Expand Up @@ -335,3 +387,12 @@ def _validate_config(self) -> None:
)

self.validate_attention_config()

@property
def allowed_block_overrides(self):
return {
'attn_config': {
'sliding_window_size': None,
'reuse_kv_layer_idx': None,
},
}
Loading
Loading