diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0b616fdd28..b609d13f2f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -476,7 +476,7 @@ def flex_attn_fn( check_valid_inputs(query, key, value) assert key.shape[1] == value.shape[1] assert query.shape[1] == key.shape[1] - query_offset = 0 + query_offset = torch.tensor(0, device=query.device) if past_key_value is not None: if len(past_key_value) != 0: assert past_key_value[0].shape[1] == past_key_value[1].shape[1] @@ -573,7 +573,7 @@ def _check_mod_list(mod_list: list[dict[str, Any]], name: str): ) score_mod = generate_score_mod( score_mod_list=score_mod_list, # type: ignore - query_offset=torch.tensor(query_offset, device=query.device), + query_offset=query_offset, sequence_id_info=sequence_id_info, ) diff --git a/llmfoundry/models/layers/flex_attn_utils.py b/llmfoundry/models/layers/flex_attn_utils.py index 0c72a68579..5ab5b1d616 100644 --- a/llmfoundry/models/layers/flex_attn_utils.py +++ b/llmfoundry/models/layers/flex_attn_utils.py @@ -6,7 +6,12 @@ from typing import Any, Optional import torch -from torch.nn.attention.flex_attention import _score_mod_signature, and_masks +from packaging import version +from torch.nn.attention.flex_attention import ( + _DEFAULT_SPARSE_BLOCK_SIZE, + _score_mod_signature, + and_masks, +) from llmfoundry.layers_registry import flex_attention_mods @@ -19,7 +24,7 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del sequence_id_info, query_offset, b, h, q_idx, kv_idx @@ -32,7 +37,7 @@ def _score_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del sequence_id_info, query_offset, score, b, h, q_idx, kv_idx @@ -53,7 +58,7 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del sequence_id_info, b, h @@ -73,7 +78,7 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del sequence_id_info, b, h @@ -94,7 +99,7 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del h @@ -121,7 +126,7 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del h, q_idx, query_offset @@ -146,7 +151,7 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del h @@ -188,7 +193,7 @@ def _score_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del sequence_id_info, b @@ -211,7 +216,7 @@ def _score_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ) -> torch.Tensor: del sequence_id_info, query_offset, b, h, q_idx, kv_idx @@ -230,7 +235,7 @@ def generate_block_mask( B: int, block_mask_list: Optional[list[FlexAttentionMod]], compiled_create_block_mask: Any, - query_offset: int, + query_offset: torch.Tensor, sequence_id_info: Optional[dict[str, torch.Tensor]], ): if block_mask_list is None: @@ -254,12 +259,18 @@ def generate_block_mask( ), ) + extra_args = {} + if version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.6.0') and Q_LEN % _DEFAULT_SPARSE_BLOCK_SIZE != 0: + extra_args['BLOCK_SIZE'] = Q_LEN block_mask = compiled_create_block_mask( block_mask_fn, B=B, H=None, # Setting this to None speeds up block mask generation, but this means the mask has to be the same across all heads. Q_LEN=Q_LEN, KV_LEN=KV_LEN, + **extra_args, ) return block_mask diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d3eb834365..cee165d25d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from composer.models import HuggingFaceModel from composer.utils import dist +from packaging import version from tabulate import tabulate from torch.nn.attention.flex_attention import create_block_mask, flex_attention @@ -423,7 +424,11 @@ def __init__(self, config: MPTConfig): self.mb_args = None self.shift_labels = True - flex_attn_compile = config.attn_config.pop('flex_attn_compile', False) + flex_attn_compile = config.attn_config.pop( + 'flex_attn_compile', + version.parse(torch.__version__.split('.dev')[0]) >= + version.parse('2.6.0'), + ) if self.attn_impl == 'flex': self.compiled_flex_attention = torch.compile( flex_attention, diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index a101bbaac8..e9bae22215 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -6,24 +6,42 @@ ffn_config_defaults: dict = { 'ffn_type': 'mptmlp', } +import torch +from packaging import version attn_config_defaults: dict = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'flash', - 'qk_ln': False, - 'qk_gn': False, - 'fused_qkv': True, - 'clip_qkv': None, - 'softmax_scale': None, - 'attn_uses_sequence_id': False, - 'sliding_window_size': -1, - 'attn_logit_softcapping': None, - 'alibi': False, - 'alibi_bias_max': 8, - 'rope': False, - 'rope_theta': 10000, - 'rope_impl': 'dail', + 'attn_type': + 'multihead_attention', + 'attn_pdrop': + 0.0, + 'attn_impl': + 'flash', + 'qk_ln': + False, + 'qk_gn': + False, + 'fused_qkv': + True, + 'clip_qkv': + None, + 'softmax_scale': + None, + 'attn_uses_sequence_id': + False, + 'sliding_window_size': + -1, + 'attn_logit_softcapping': + None, + 'alibi': + False, + 'alibi_bias_max': + 8, + 'rope': + False, + 'rope_theta': + 10000, + 'rope_impl': + 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -33,9 +51,12 @@ 'type': 'no_scaling', 'factor': 1.0, }, - 'kv_dim': None, + 'kv_dim': + None, 'flex_attn_mod_list': [], - 'flex_attn_compile': True, + 'flex_attn_compile': + version.parse(torch.__version__.split('.dev')[0]) >= + version.parse('2.6.0'), } init_config_defaults: dict = { diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index 9533fd5db1..c83e0725b8 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -16,6 +16,12 @@ from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.6.0'): + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + @pytest.mark.parametrize( 'attn_name', @@ -177,9 +183,9 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): # Test that sliding window attention works as expected. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) dtype = torch.bfloat16 device = 'cuda' @@ -215,9 +221,8 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): } elif attn_impl == 'flex': attn_extra_kwargs = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 84dda99341..0bed566cc5 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -58,6 +58,10 @@ from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container +FLEX_ATTN_COMPILE = version.parse( + torch.__version__.split('.dev')[0], +) >= version.parse('2.6.0') + def get_config( conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', @@ -83,6 +87,7 @@ def _get_objs( conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', model_config_overrides: Optional[dict] = None, attn_impl: str = 'torch', + flex_attn_compile: bool = FLEX_ATTN_COMPILE, ): warnings.filterwarnings( action='ignore', @@ -112,6 +117,7 @@ def _get_objs( test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32' test_cfg.model.attn_config = { 'attn_impl': attn_impl, + 'flex_attn_compile': flex_attn_compile, } test_cfg.model.init_device = device test_cfg.device = device @@ -520,9 +526,9 @@ def test_determinism( ): if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: @@ -530,7 +536,7 @@ def test_determinism( test_cfg.model.attn_config = { 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, } if hasattr(test_cfg.model, 'ffn_config'): test_cfg.model.ffn_config['ffn_type'] = ffn_type @@ -1090,9 +1096,9 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): ) if attention_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) composer_device = get_device(None) @@ -1109,6 +1115,7 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): attn_config={ 'attn_impl': attention_impl, 'attn_uses_sequence_id': True, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, init_config={ @@ -1219,9 +1226,9 @@ def test_forward_with_padding( # Test that different placement of padding does not affect the output. if attention_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) alibi = pos_emb_config['alibi'] if alibi and not check_alibi_support(attention_impl): @@ -1247,6 +1254,7 @@ def test_forward_with_padding( resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, init_config={ @@ -1474,7 +1482,7 @@ def test_forward_with_padding( }, }], ) -@pytest.mark.parametrize('tie_word_embeddings', [False]) # [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate( attention_impl: str, precision: str, @@ -1485,9 +1493,9 @@ def test_generate( # padding in the input. if attention_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) if pos_emb_config['alibi'] and not check_alibi_support(attention_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1513,6 +1521,7 @@ def test_generate( resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, tie_word_embeddings=tie_word_embeddings, @@ -1786,9 +1795,9 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): ) if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) composer_device = get_device(None) @@ -1804,7 +1813,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -1955,9 +1964,9 @@ def test_forward_with_cache( # same output. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1981,7 +1990,7 @@ def test_forward_with_cache( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2132,9 +2141,9 @@ def test_generate_with_past_kv( ): if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -2157,7 +2166,7 @@ def test_generate_with_past_kv( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2272,9 +2281,9 @@ def test_generation_kwargs_dont_crash( ): if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -2300,7 +2309,7 @@ def test_generation_kwargs_dont_crash( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2494,9 +2503,9 @@ def test_forward_with_output_attentions_and_output_hidden_states( ): if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -2523,7 +2532,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2665,9 +2674,9 @@ def test_hf_init( def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2): if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') test_cfg.device = torch.cuda.current_device() @@ -2685,7 +2694,7 @@ def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, 'attn_type': 'multiquery_attention', }, ) @@ -2721,9 +2730,9 @@ def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2): def test_construct_blocks(attn_impl: str): if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) n_layers = 13 @@ -2735,7 +2744,7 @@ def test_construct_blocks(attn_impl: str): max_seq_len=64, attn_config={ 'attn_impl': attn_impl, - 'flex_attn_compile': False, + 'flex_attn_compile': FLEX_ATTN_COMPILE, 'attn_type': 'grouped_query_attention', 'kv_n_heads': 4, }, @@ -2823,9 +2832,9 @@ def test_reuse_prev_layer_kv_cache( ): if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' model_config_overrides = { @@ -2853,6 +2862,7 @@ def test_reuse_prev_layer_kv_cache( conf_path=conf_path, model_config_overrides=model_config_overrides, attn_impl=attn_impl, + flex_attn_compile=FLEX_ATTN_COMPILE, ) batch = gen_random_batch(batch_size, test_cfg)