From bf1cb6c76062bd66331ed72994f6921d414cb40a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Dec 2024 00:42:37 -0800 Subject: [PATCH] .. --- llmfoundry/models/layers/flex_attn_utils.py | 4 ++-- llmfoundry/models/mpt/modeling_mpt.py | 9 +++++++-- llmfoundry/models/utils/config_defaults.py | 1 + tests/models/test_model.py | 20 ++++++++++++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/flex_attn_utils.py b/llmfoundry/models/layers/flex_attn_utils.py index 135650dc77..0c72a68579 100644 --- a/llmfoundry/models/layers/flex_attn_utils.py +++ b/llmfoundry/models/layers/flex_attn_utils.py @@ -246,7 +246,7 @@ def generate_block_mask( ) else: block_mask_fn = and_masks( - block_mask_fn, + block_mask_fn, # type: ignore partial( block_mask.mod_fn, query_offset=query_offset, @@ -282,7 +282,7 @@ def generate_score_mod( ) else: wrapped_score_mod = _wrap_score_mod_fns( - wrapped_score_mod, + wrapped_score_mod, # type: ignore partial( score_mod.mod_fn, query_offset=query_offset, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index df5e548d7f..825480ff99 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -23,6 +23,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from composer.docs.source import conf from composer.models import HuggingFaceModel from composer.utils import dist from tabulate import tabulate @@ -424,8 +425,12 @@ def __init__(self, config: MPTConfig): self.shift_labels = True if self.attn_impl == 'flex': - self.compiled_flex_attention = torch.compile(flex_attention) - self.compiled_create_block_mask = torch.compile(create_block_mask) + self.compiled_flex_attention = torch.compile( + flex_attention, + ) if config.attn_config['flex_attn_compile'] else flex_attention + self.compiled_create_block_mask = torch.compile( + create_block_mask, + ) if config.attn_config['flex_attn_compile'] else create_block_mask self.blocks = self.construct_blocks(config=config,) diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 159fd3ad3d..a101bbaac8 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -35,6 +35,7 @@ }, 'kv_dim': None, 'flex_attn_mod_list': [], + 'flex_attn_compile': True, } init_config_defaults: dict = { diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 280c5daaf7..84dda99341 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -530,6 +530,7 @@ def test_determinism( test_cfg.model.attn_config = { 'attn_impl': attn_impl, + 'flex_attn_compile': False, } if hasattr(test_cfg.model, 'ffn_config'): test_cfg.model.ffn_config['ffn_type'] = ffn_type @@ -1803,6 +1804,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': False, **pos_emb_config, }, use_cache=True, @@ -1979,6 +1981,7 @@ def test_forward_with_cache( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': False, **pos_emb_config, }, use_cache=True, @@ -2154,6 +2157,7 @@ def test_generate_with_past_kv( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': False, **pos_emb_config, }, use_cache=True, @@ -2296,6 +2300,7 @@ def test_generation_kwargs_dont_crash( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': False, **pos_emb_config, }, use_cache=True, @@ -2518,6 +2523,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': False, **pos_emb_config, }, use_cache=True, @@ -2657,6 +2663,12 @@ def test_hf_init( @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex']) def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.6.0'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + ) test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') test_cfg.device = torch.cuda.current_device() @@ -2673,6 +2685,7 @@ def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': False, 'attn_type': 'multiquery_attention', }, ) @@ -2722,6 +2735,7 @@ def test_construct_blocks(attn_impl: str): max_seq_len=64, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': False, 'attn_type': 'grouped_query_attention', 'kv_n_heads': 4, }, @@ -2807,6 +2821,12 @@ def test_reuse_prev_layer_kv_cache( request: pytest.FixtureRequest, batch_size: int = 2, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.6.0'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + ) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' model_config_overrides = { 'block_overrides': {