Skip to content

Commit

Permalink
configuring with torch 2.5.1 and 2.6.0.dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 5, 2024
1 parent 18afcc5 commit 434aa83
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 70 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def flex_attn_fn(
check_valid_inputs(query, key, value)
assert key.shape[1] == value.shape[1]
assert query.shape[1] == key.shape[1]
query_offset = 0
query_offset = torch.tensor(0, device=query.device)
if past_key_value is not None:
if len(past_key_value) != 0:
assert past_key_value[0].shape[1] == past_key_value[1].shape[1]
Expand Down Expand Up @@ -573,7 +573,7 @@ def _check_mod_list(mod_list: list[dict[str, Any]], name: str):
)
score_mod = generate_score_mod(
score_mod_list=score_mod_list, # type: ignore
query_offset=torch.tensor(query_offset, device=query.device),
query_offset=query_offset,
sequence_id_info=sequence_id_info,
)

Expand Down
33 changes: 22 additions & 11 deletions llmfoundry/models/layers/flex_attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from typing import Any, Optional

import torch
from torch.nn.attention.flex_attention import _score_mod_signature, and_masks
from packaging import version
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
_score_mod_signature,
and_masks,
)

from llmfoundry.layers_registry import flex_attention_mods

Expand All @@ -19,7 +24,7 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del sequence_id_info, query_offset, b, h, q_idx, kv_idx
Expand All @@ -32,7 +37,7 @@ def _score_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del sequence_id_info, query_offset, score, b, h, q_idx, kv_idx
Expand All @@ -53,7 +58,7 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del sequence_id_info, b, h
Expand All @@ -73,7 +78,7 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del sequence_id_info, b, h
Expand All @@ -94,7 +99,7 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del h
Expand All @@ -121,7 +126,7 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del h, q_idx, query_offset
Expand All @@ -146,7 +151,7 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del h
Expand Down Expand Up @@ -188,7 +193,7 @@ def _score_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del sequence_id_info, b
Expand All @@ -211,7 +216,7 @@ def _score_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
) -> torch.Tensor:
del sequence_id_info, query_offset, b, h, q_idx, kv_idx
Expand All @@ -230,7 +235,7 @@ def generate_block_mask(
B: int,
block_mask_list: Optional[list[FlexAttentionMod]],
compiled_create_block_mask: Any,
query_offset: int,
query_offset: torch.Tensor,
sequence_id_info: Optional[dict[str, torch.Tensor]],
):
if block_mask_list is None:
Expand All @@ -254,12 +259,18 @@ def generate_block_mask(
),
)

extra_args = {}
if version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0') and Q_LEN % _DEFAULT_SPARSE_BLOCK_SIZE != 0:
extra_args['BLOCK_SIZE'] = Q_LEN
block_mask = compiled_create_block_mask(
block_mask_fn,
B=B,
H=None, # Setting this to None speeds up block mask generation, but this means the mask has to be the same across all heads.
Q_LEN=Q_LEN,
KV_LEN=KV_LEN,
**extra_args,
)

return block_mask
Expand Down
7 changes: 6 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch.nn.functional as F
from composer.models import HuggingFaceModel
from composer.utils import dist
from packaging import version
from tabulate import tabulate
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

Expand Down Expand Up @@ -423,7 +424,11 @@ def __init__(self, config: MPTConfig):
self.mb_args = None
self.shift_labels = True

flex_attn_compile = config.attn_config.pop('flex_attn_compile', False)
flex_attn_compile = config.attn_config.pop(
'flex_attn_compile',
version.parse(torch.__version__.split('.dev')[0]) >=
version.parse('2.6.0'),
)
if self.attn_impl == 'flex':
self.compiled_flex_attention = torch.compile(
flex_attention,
Expand Down
57 changes: 39 additions & 18 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,42 @@
ffn_config_defaults: dict = {
'ffn_type': 'mptmlp',
}
import torch
from packaging import version

attn_config_defaults: dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'fused_qkv': True,
'clip_qkv': None,
'softmax_scale': None,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'attn_logit_softcapping': None,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_impl': 'dail',
'attn_type':
'multihead_attention',
'attn_pdrop':
0.0,
'attn_impl':
'flash',
'qk_ln':
False,
'qk_gn':
False,
'fused_qkv':
True,
'clip_qkv':
None,
'softmax_scale':
None,
'attn_uses_sequence_id':
False,
'sliding_window_size':
-1,
'attn_logit_softcapping':
None,
'alibi':
False,
'alibi_bias_max':
8,
'rope':
False,
'rope_theta':
10000,
'rope_impl':
'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
Expand All @@ -33,9 +51,12 @@
'type': 'no_scaling',
'factor': 1.0,
},
'kv_dim': None,
'kv_dim':
None,
'flex_attn_mod_list': [],
'flex_attn_compile': True,
'flex_attn_compile':
version.parse(torch.__version__.split('.dev')[0]) >=
version.parse('2.6.0'),
}

init_config_defaults: dict = {
Expand Down
15 changes: 10 additions & 5 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info

compiled_flex_attention = flex_attention
compiled_create_block_mask = create_block_mask
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.6.0'):
compiled_flex_attention = torch.compile(flex_attention)
compiled_create_block_mask = torch.compile(create_block_mask)


@pytest.mark.parametrize(
'attn_name',
Expand Down Expand Up @@ -177,9 +183,9 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):
# Test that sliding window attention works as expected.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
dtype = torch.bfloat16
device = 'cuda'
Expand Down Expand Up @@ -215,9 +221,8 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):
}
elif attn_impl == 'flex':
attn_extra_kwargs = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}

Expand Down
Loading

0 comments on commit 434aa83

Please sign in to comment.