Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 5, 2024
1 parent 18c4bb9 commit bf1cb6c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/flex_attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def generate_block_mask(
)
else:
block_mask_fn = and_masks(
block_mask_fn,
block_mask_fn, # type: ignore
partial(
block_mask.mod_fn,
query_offset=query_offset,
Expand Down Expand Up @@ -282,7 +282,7 @@ def generate_score_mod(
)
else:
wrapped_score_mod = _wrap_score_mod_fns(
wrapped_score_mod,
wrapped_score_mod, # type: ignore
partial(
score_mod.mod_fn,
query_offset=query_offset,
Expand Down
9 changes: 7 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.docs.source import conf
from composer.models import HuggingFaceModel
from composer.utils import dist
from tabulate import tabulate
Expand Down Expand Up @@ -424,8 +425,12 @@ def __init__(self, config: MPTConfig):
self.shift_labels = True

if self.attn_impl == 'flex':
self.compiled_flex_attention = torch.compile(flex_attention)
self.compiled_create_block_mask = torch.compile(create_block_mask)
self.compiled_flex_attention = torch.compile(
flex_attention,
) if config.attn_config['flex_attn_compile'] else flex_attention
self.compiled_create_block_mask = torch.compile(
create_block_mask,
) if config.attn_config['flex_attn_compile'] else create_block_mask

self.blocks = self.construct_blocks(config=config,)

Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
},
'kv_dim': None,
'flex_attn_mod_list': [],
'flex_attn_compile': True,
}

init_config_defaults: dict = {
Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def test_determinism(

test_cfg.model.attn_config = {
'attn_impl': attn_impl,
'flex_attn_compile': False,
}
if hasattr(test_cfg.model, 'ffn_config'):
test_cfg.model.ffn_config['ffn_type'] = ffn_type
Expand Down Expand Up @@ -1803,6 +1804,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict):
resid_pdrop=0.2,
attn_config={
'attn_impl': attn_impl,
'flex_attn_compile': False,
**pos_emb_config,
},
use_cache=True,
Expand Down Expand Up @@ -1979,6 +1981,7 @@ def test_forward_with_cache(
resid_pdrop=0.2,
attn_config={
'attn_impl': attn_impl,
'flex_attn_compile': False,
**pos_emb_config,
},
use_cache=True,
Expand Down Expand Up @@ -2154,6 +2157,7 @@ def test_generate_with_past_kv(
resid_pdrop=0.2,
attn_config={
'attn_impl': attn_impl,
'flex_attn_compile': False,
**pos_emb_config,
},
use_cache=True,
Expand Down Expand Up @@ -2296,6 +2300,7 @@ def test_generation_kwargs_dont_crash(
resid_pdrop=0.2,
attn_config={
'attn_impl': attn_impl,
'flex_attn_compile': False,
**pos_emb_config,
},
use_cache=True,
Expand Down Expand Up @@ -2518,6 +2523,7 @@ def test_forward_with_output_attentions_and_output_hidden_states(
resid_pdrop=0.2,
attn_config={
'attn_impl': attn_impl,
'flex_attn_compile': False,
**pos_emb_config,
},
use_cache=True,
Expand Down Expand Up @@ -2657,6 +2663,12 @@ def test_hf_init(
@pytest.mark.gpu
@pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex'])
def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2):
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
test_cfg.device = torch.cuda.current_device()

Expand All @@ -2673,6 +2685,7 @@ def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2):
resid_pdrop=0.2,
attn_config={
'attn_impl': attn_impl,
'flex_attn_compile': False,
'attn_type': 'multiquery_attention',
},
)
Expand Down Expand Up @@ -2722,6 +2735,7 @@ def test_construct_blocks(attn_impl: str):
max_seq_len=64,
attn_config={
'attn_impl': attn_impl,
'flex_attn_compile': False,
'attn_type': 'grouped_query_attention',
'kv_n_heads': 4,
},
Expand Down Expand Up @@ -2807,6 +2821,12 @@ def test_reuse_prev_layer_kv_cache(
request: pytest.FixtureRequest,
batch_size: int = 2,
):
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
model_config_overrides = {
'block_overrides': {
Expand Down

0 comments on commit bf1cb6c

Please sign in to comment.