diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index dc75004af0..4b20b1316c 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Any, Callable import torch @@ -176,6 +176,20 @@ description=_attention_implementations_description, ) +_flex_attention_mods_description = ( + """The flex_attention_mods registry is used to register classes that implement flex attention mods. + + One example is 'CausalMaskMod'. See flex_attn_mods.py for examples. + """ +) +flex_attention_mods = create_registry( + 'llmfoundry', + 'flex_attention_mods', + generic_type=type[Any], + entry_points=True, + description=_flex_attention_mods_description, +) + _param_init_fns_description = ( """The param_init_fns registry is used to register functions that initialize parameters. @@ -231,5 +245,6 @@ 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', + 'flex_attention_mods', 'fcs', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 612d6b9642..2609ea85fa 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,9 +17,16 @@ from llmfoundry.layers_registry import ( attention_classes, attention_implementations, + flex_attention_mods, +) +from llmfoundry.models.layers.flex_attn_utils import ( + FLEX_ATTN_COMPILE, + generate_block_mask, + generate_score_mod, ) from llmfoundry.models.layers.layer_builders import build_fc, build_norm from llmfoundry.models.utils.config_defaults import fc_type_defaults +from llmfoundry.utils.warnings import experimental_function __all__ = [ 'scaled_multihead_dot_product_attention', @@ -150,11 +157,6 @@ def scaled_multihead_dot_product_attention( attn_weight = q.matmul(k) * softmax_scale - if attn_logit_softcapping is not None: - attn_weight = attn_logit_softcapping * torch.tanh( - attn_weight / attn_logit_softcapping, - ) - if attn_bias is not None: # clamp to 0 necessary for torch 2.0 compile() _s_q = max(0, attn_bias.size(2) - s_q) @@ -168,6 +170,11 @@ def scaled_multihead_dot_product_attention( ) attn_weight = attn_weight + attn_bias + if attn_logit_softcapping is not None: + attn_weight = attn_logit_softcapping * torch.tanh( + attn_weight / attn_logit_softcapping, + ) + min_val = torch.finfo(q.dtype).min if key_padding_mask is not None: @@ -428,6 +435,166 @@ def flash_attn_fn( return output, None, past_key_value +@experimental_function('Flex Attention') +def flex_attn_fn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + n_heads: int, + kv_n_heads: int, + compiled_flex_attention: Any, + compiled_create_block_mask: Any, + sequence_id_info: Optional[dict[str, torch.Tensor]], + flex_attn_compile: bool, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + softmax_scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + dropout_p: float = 0.0, + training: bool = False, + needs_weights: bool = False, + should_repeat_kv_for_gqa: Optional[bool] = True, + sliding_window_size: int = -1, + alibi_slopes: Optional[torch.Tensor] = None, + attn_logit_softcapping: Optional[float] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, + torch.Tensor]]]: + del training, should_repeat_kv_for_gqa + if attn_bias is not None: + raise ValueError('attn_bias should be None for flex attn.') + if key_padding_mask is not None: + raise ValueError( + 'key_padding_mask should be None for flex attn. Instead, any padding information should be sent through sequence_id_info.', + ) + if dropout_p > 0.0: + raise NotImplementedError(f'dropout not implemented for flex attn.') + if needs_weights: + raise NotImplementedError( + f'needs_weights not implemented for flex attn.', + ) + + check_valid_inputs(query, key, value) + assert key.shape[1] == value.shape[1] + assert query.shape[1] == key.shape[1] + query_offset = torch.tensor(0, device=query.device) + if past_key_value is not None: + if len(past_key_value) != 0: + assert past_key_value[0].shape[1] == past_key_value[1].shape[1] + query_offset += past_key_value[0].shape[1] + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + + past_key_value = (key, value) + + enable_gqa = (n_heads != kv_n_heads) + query = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + key = rearrange(key, 'b s (h d) -> b h s d', h=kv_n_heads) + value = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) + + def _check_mod_list(mod_list: list[dict[str, Any]], name: str): + for mod in mod_list: + if mod['mod_name'] == name: + raise ValueError( + f'{name} mod should not be defined through flex attention config.', + ) + + flex_attn_mod_list = copy.deepcopy( + flex_attn_mod_list, + ) if flex_attn_mod_list is not None else [] + if is_causal: + _check_mod_list(flex_attn_mod_list, 'causal_mask') + flex_attn_mod_list.append({'mod_name': 'causal_mask', 'mod_kwargs': {}}) + if sliding_window_size != -1: + flex_attn_mod_list.append({ + 'mod_name': 'sliding_window_mask', + 'mod_kwargs': { + 'sliding_window_size': + torch.tensor(sliding_window_size, device=query.device), + }, + }) + if sequence_id_info is not None and 'sequence_id' in sequence_id_info and sequence_id_info[ + 'sequence_id'] is not None: + _check_mod_list(flex_attn_mod_list, 'sequence_id_mask') + flex_attn_mod_list.append({ + 'mod_name': 'sequence_id_mask', + 'mod_kwargs': {}, + }) + + if sequence_id_info is not None and 'attention_mask' in sequence_id_info and sequence_id_info[ + 'attention_mask'] is not None: + _check_mod_list(flex_attn_mod_list, 'attention_mask') + flex_attn_mod_list.append({ + 'mod_name': 'attention_mask', + 'mod_kwargs': {}, + }) + + if alibi_slopes is not None: + _check_mod_list(flex_attn_mod_list, 'alibi_score_mod') + flex_attn_mod_list.append({ + 'mod_name': 'alibi_score_mod', + 'mod_kwargs': { + 'alibi_slopes': alibi_slopes, + }, + }) + if attn_logit_softcapping is not None: + if int(attn_logit_softcapping) != attn_logit_softcapping: + raise ValueError( + f'FlexAttention does not support attn_logit_softcapping with float softcap temperature. Got {attn_logit_softcapping=}. Please set consider rounding it to the closest integer.', + ) + attn_logit_softcapping = int(attn_logit_softcapping) + _check_mod_list(flex_attn_mod_list, 'softcap_score_mod') + flex_attn_mod_list.append({ + 'mod_name': 'softcap_score_mod', + 'mod_kwargs': { + 'attn_logit_softcapping': + torch.tensor(attn_logit_softcapping, device=query.device), + }, + }) + + flex_attn_mod_list = [ + flex_attention_mods.get(mod['mod_name'])(**mod['mod_kwargs']) + for mod in flex_attn_mod_list + ] + block_mask_list = [ + mod for mod in flex_attn_mod_list + if mod.mod_type == 'mask' # type: ignore + ] + score_mod_list = [ + mod for mod in flex_attn_mod_list + if mod.mod_type == 'score' # type: ignore + ] + + block_mask = generate_block_mask( + Q_LEN=query.shape[2], + KV_LEN=key.shape[2], + B=query.shape[0], + block_mask_list=block_mask_list, # type: ignore + compiled_create_block_mask=compiled_create_block_mask, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + flex_attn_compile=flex_attn_compile, + ) + score_mod = generate_score_mod( + score_mod_list=score_mod_list, # type: ignore + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ) + + output = compiled_flex_attention( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + scale=softmax_scale, + enable_gqa=enable_gqa, + ) + output = rearrange(output, 'b h s d -> b s (h d)') + return output, None, past_key_value + + @attention_classes.register_class('grouped_query_attention') class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). @@ -461,6 +628,8 @@ def __init__( reuse_kv_layer_idx: Optional[int] = None, attn_logit_softcapping: Optional[float] = None, kv_dim: Optional[int] = None, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + flex_attn_compile: bool = FLEX_ATTN_COMPILE, ): super().__init__() @@ -586,6 +755,10 @@ def __init__( ) self.out_proj._is_residual = True + if self.attn_impl == 'flex': + self.flex_attn_mod_list = flex_attn_mod_list + self.flex_attn_compile = flex_attn_compile + def forward( self, x: torch.Tensor, @@ -600,6 +773,7 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} @@ -623,6 +797,7 @@ def forward( attention_mask, alibi_slopes, flash_attn_padding_info, + flex_attn_kwargs, ) context, attn_weights, past_key_value = self.attn_fn( @@ -819,6 +994,7 @@ def get_implementation_specific_args( attention_mask: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: """Returns attention implementation specific args. @@ -826,6 +1002,7 @@ def get_implementation_specific_args( attention_mask (Optional[torch.Tensor]): The attention mask. alibi_slopes (Optional[torch.Tensor]): The alibi slopes. flash_attn_padding_info (Optional[dict[str, torch.Tensor]]): The padding information, only required for flash attention. + flex_attn_kwargs (Optional[dict[str, Any]]): The extra flex attn kwargs, sent from the model, includes seq id transforms and compiled flex attention functions. Returns: extra_attn_kwargs (dict[str, Any]): Implementation specific args. @@ -837,6 +1014,24 @@ def get_implementation_specific_args( 'flash_attn_padding_info': flash_attn_padding_info, 'key_padding_mask': None, } + elif self.attn_impl == 'flex': + if flex_attn_kwargs is None: + raise ValueError( + 'flex_attn_kwargs must be provided for flex attention.', + ) + if 'sequence_id_info' not in flex_attn_kwargs: + raise ValueError( + 'sequence_id_info must be provided in flex_attn_kwargs.', + ) + flex_attn_kwargs['sequence_id_info']['attention_mask' + ] = attention_mask + extra_attn_kwargs = { + 'alibi_slopes': alibi_slopes, + 'key_padding_mask': None, + 'flex_attn_mod_list': self.flex_attn_mod_list, + 'flex_attn_compile': self.flex_attn_compile, + **flex_attn_kwargs, + } else: extra_attn_kwargs = {'key_padding_mask': attention_mask} return extra_attn_kwargs @@ -869,6 +1064,8 @@ def __init__( reuse_kv_layer_idx: Optional[int] = None, attn_logit_softcapping: Optional[float] = None, kv_dim: Optional[int] = None, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + flex_attn_compile: bool = FLEX_ATTN_COMPILE, ): super().__init__( d_model=d_model, @@ -890,6 +1087,8 @@ def __init__( reuse_kv_layer_idx=reuse_kv_layer_idx, attn_logit_softcapping=attn_logit_softcapping, kv_dim=kv_dim, + flex_attn_mod_list=flex_attn_mod_list, + flex_attn_compile=flex_attn_compile, ) @@ -920,6 +1119,8 @@ def __init__( reuse_kv_layer_idx: Optional[int] = None, attn_logit_softcapping: Optional[float] = None, kv_dim: Optional[int] = None, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + flex_attn_compile: bool = FLEX_ATTN_COMPILE, ): super().__init__( d_model=d_model, @@ -941,6 +1142,8 @@ def __init__( reuse_kv_layer_idx=reuse_kv_layer_idx, attn_logit_softcapping=attn_logit_softcapping, kv_dim=kv_dim, + flex_attn_mod_list=flex_attn_mod_list, + flex_attn_compile=flex_attn_compile, ) @@ -952,7 +1155,7 @@ def attn_bias_shape( causal: bool, use_sequence_id: bool, ) -> Optional[tuple[int, int, int, int]]: - if attn_impl == 'flash': + if attn_impl == 'flash' or attn_impl == 'flex': return None elif attn_impl == 'torch': if alibi: @@ -1048,3 +1251,4 @@ def build_alibi_bias( 'torch', func=scaled_multihead_dot_product_attention, ) +attention_implementations.register('flex', func=flex_attn_fn) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c88cf33d1b..e9ca5c17ba 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -165,6 +165,7 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} @@ -184,6 +185,7 @@ def forward( output_attentions=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + flex_attn_kwargs=flex_attn_kwargs, **extra_kwargs, ) else: @@ -198,6 +200,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + flex_attn_kwargs=flex_attn_kwargs, **extra_kwargs, ) x = x + self.resid_attn_dropout(b) @@ -332,6 +335,7 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -351,6 +355,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + flex_attn_kwargs=flex_attn_kwargs, **extra_kwargs, ) x = x + self.resid_attn_dropout(b) diff --git a/llmfoundry/models/layers/flex_attn_utils.py b/llmfoundry/models/layers/flex_attn_utils.py new file mode 100644 index 0000000000..a84ddf321e --- /dev/null +++ b/llmfoundry/models/layers/flex_attn_utils.py @@ -0,0 +1,332 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from abc import ABC +from functools import partial +from typing import Any, Optional + +import torch +from packaging import version +from torch.nn.attention.flex_attention import ( + _DEFAULT_SPARSE_BLOCK_SIZE, + _score_mod_signature, + and_masks, +) + +FLEX_ATTN_COMPILE = version.parse( + torch.__version__.split('.dev')[0], +) >= version.parse('2.6.0') + +from llmfoundry.layers_registry import flex_attention_mods + + +class FlexAttentionMod(ABC): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, query_offset, b, h, q_idx, kv_idx + raise NotImplementedError + + def _score_mod_fn( + self, + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, query_offset, score, b, h, q_idx, kv_idx + raise NotImplementedError + + def __init__(self, mod_type: str) -> None: + assert mod_type in ['mask', 'score'] + self.mod_type = mod_type + self.mod_fn = self._mask_mod_fn if mod_type == 'mask' else self._score_mod_fn + + +@flex_attention_mods.register('causal_mask') +class CausalMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, b, h + q_idx = q_idx + query_offset + return q_idx >= kv_idx + + def __init__(self) -> None: + super().__init__(mod_type='mask') + + +@flex_attention_mods.register('sliding_window_mask') +class SlidingWindowMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, b, h + q_idx = q_idx + query_offset + return torch.abs(q_idx - kv_idx) <= self.sliding_window_size + + def __init__(self, sliding_window_size: torch.Tensor) -> None: + super().__init__(mod_type='mask') + self.sliding_window_size = sliding_window_size + + +@flex_attention_mods.register('sequence_id_mask') +class SequenceIdMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del h + q_idx = q_idx + query_offset + if sequence_id_info is None: + raise ValueError( + 'sequence_id_info is required for SequenceIdMaskMod', + ) + sequence_id = sequence_id_info['sequence_id'] + # Check if the query and key belong to the same sequence and the query token is not a padding token. + return (sequence_id[b, q_idx] == sequence_id[b, kv_idx]) + + def __init__(self) -> None: + super().__init__(mod_type='mask') + + +@flex_attention_mods.register('attention_mask') +class AttentionMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del h, q_idx, query_offset + if sequence_id_info is None: + raise ValueError( + 'sequence_id_info is required for SequenceIdMaskMod', + ) + attention_mask = sequence_id_info['attention_mask'] + # Check if the query and key belong to the same sequence and the query token is not a padding token. + return attention_mask[b, kv_idx] + + def __init__(self) -> None: + super().__init__(mod_type='mask') + + +@flex_attention_mods.register('local_global_mask') +class LocalGlobalMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del h + q_idx = q_idx + query_offset + if sequence_id_info is None: + raise ValueError( + 'sequence_id_info is required for LocalGlobalMaskMod', + ) + pos_in_seq = sequence_id_info['pos_in_seq'] + # Check if the query and key belong to the same sequence and the query token is not a padding token. + + if pos_in_seq is not None: + global_window_mask = ( + pos_in_seq[b, kv_idx] <= self.global_window_size + ) + else: + global_window_mask = (kv_idx <= self.global_window_size) + sliding_window_mask = (q_idx - kv_idx <= self.sliding_window_size) + + return global_window_mask | sliding_window_mask + + def __init__( + self, + sliding_window_size: int, + global_window_size: int, + ) -> None: + super().__init__(mod_type='mask') + self.sliding_window_size = sliding_window_size + self.global_window_size = global_window_size + + +@flex_attention_mods.register('alibi_score_mod') +class AlibiScoreMod(FlexAttentionMod): + + def _score_mod_fn( + self, + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, b + q_idx = q_idx + query_offset + bias = -self.alibi_slopes[h] * torch.abs(q_idx - kv_idx) + return score + bias + + def __init__(self, alibi_slopes: torch.Tensor) -> None: + super().__init__(mod_type='score') + self.alibi_slopes = alibi_slopes + + +@flex_attention_mods.register('softcap_score_mod') +class SoftcapScoreMod(FlexAttentionMod): + + def _score_mod_fn( + self, + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, query_offset, b, h, q_idx, kv_idx + return self.attn_logit_softcapping * torch.tanh( + score / self.attn_logit_softcapping, + ) + + def __init__(self, attn_logit_softcapping: torch.Tensor) -> None: + super().__init__(mod_type='score') + self.attn_logit_softcapping = attn_logit_softcapping + + +def generate_block_mask( + Q_LEN: int, + KV_LEN: int, + B: int, + block_mask_list: Optional[list[FlexAttentionMod]], + compiled_create_block_mask: Any, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + flex_attn_compile: bool, +): + if block_mask_list is None: + return None + + block_mask_fn = None + for i, block_mask in enumerate(block_mask_list): + if i == 0: + block_mask_fn = partial( + block_mask.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ) + else: + block_mask_fn = and_masks( + block_mask_fn, # type: ignore + partial( + block_mask.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ), + ) + + extra_args = {} + if (Q_LEN % _DEFAULT_SPARSE_BLOCK_SIZE != + 0) or (KV_LEN % _DEFAULT_SPARSE_BLOCK_SIZE != 0): + if flex_attn_compile: + warnings.warn( + f'Q_LEN and KV_LEN must be divisible by {_DEFAULT_SPARSE_BLOCK_SIZE}. The results might be incorrect.', + ) + else: + extra_args['BLOCK_SIZE'] = (Q_LEN, KV_LEN) + block_mask = compiled_create_block_mask( + block_mask_fn, + B=B, + H=None, # Setting this to None speeds up block mask generation, but this means the mask has to be the same across all heads. + Q_LEN=Q_LEN, + KV_LEN=KV_LEN, + **extra_args, + ) + + return block_mask + + +def generate_score_mod( + score_mod_list: Optional[list[FlexAttentionMod]], + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], +): + if score_mod_list is None: + return None + wrapped_score_mod = None + for i, score_mod in enumerate(score_mod_list): + if i == 0: + wrapped_score_mod = partial( + score_mod.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ) + else: + wrapped_score_mod = _wrap_score_mod_fns( + wrapped_score_mod, # type: ignore + partial( + score_mod.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ), + ) + + return wrapped_score_mod + + +def _wrap_score_mod_fns( + score_mod_fn_1: _score_mod_signature, + score_mod_fn_2: _score_mod_signature, +) -> _score_mod_signature: + + def wrapped_score_mod_fn( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + score = score_mod_fn_1(score, b, h, q_idx, kv_idx) + score = score_mod_fn_2(score, b, h, q_idx, kv_idx) + return score + + return wrapped_score_mod_fn diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 1adb64dc21..1cc48fc14f 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -7,6 +7,8 @@ import warnings from typing import Any, Optional, Union +import torch +from packaging import version from transformers import PretrainedConfig from llmfoundry.layers_registry import ffns_with_megablocks @@ -272,10 +274,16 @@ def _validate_config(self) -> None: raise ValueError( "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1", ) - if self.attn_config['attn_impl'] not in ['torch', 'flash']: + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'flex']: raise ValueError( f"Unknown attn_impl={self.attn_config['attn_impl']}", ) + if self.attn_config['attn_type'] == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.6.0'): + raise RuntimeError( + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + ) if self.attn_config['alibi'] and not check_alibi_support( self.attn_config['attn_impl'], ): @@ -283,7 +291,8 @@ def _validate_config(self) -> None: 'alibi only implemented with torch and flash (v2.4.2 or higher) attention.', ) if self.attn_config['attn_uses_sequence_id'] and not ( - self.attn_config['attn_impl'] == 'torch' or ( + self.attn_config['attn_impl'] == 'torch' or + self.attn_config['attn_impl'] == 'flex' or ( self.attn_config['attn_impl'] == 'flash' and is_flash_v2_installed(v2_version='v2.1.2') ) @@ -406,5 +415,6 @@ def allowed_block_overrides(self): 'attn_config': { 'sliding_window_size': None, 'reuse_kv_layer_idx': None, + 'flex_attn_mod_list': None, }, } diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 94e5fa29d5..ef79e7bfeb 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -26,9 +26,13 @@ from composer.models import HuggingFaceModel from composer.utils import dist from tabulate import tabulate +from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from llmfoundry.layers_registry import ffns_with_megablocks +from llmfoundry.layers_registry import ( + ffns_with_megablocks, +) from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above @@ -173,7 +177,7 @@ def gen_rotary_embedding( raise ValueError('rope_impl needs to be either dail or hf') -def gen_attention_mask_in_length( +def gen_sequence_id_info( sequence_id: Union[None, torch.Tensor], S: int, attn_uses_sequence_id: bool, @@ -234,9 +238,9 @@ def gen_attention_mask_in_length( ```. (The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .) """ - attention_mask_in_length = None - if (sequence_id - is not None) and attn_uses_sequence_id and (attn_impl == 'flash'): + if (sequence_id is not None) and attn_uses_sequence_id and ( + attn_impl == 'flash' or attn_impl == 'flex' + ): # Check if sequence has left padding. If yes, raise an error. if (attention_mask is not None ) and (attention_mask[:, 0].sum() != attention_mask.shape[0]): @@ -252,21 +256,24 @@ def gen_attention_mask_in_length( # We replace those -1 with 0 to prevent `torch.nn.functional.one_hot(sequence_id)` in the next line from failing. # We apply the attention mask again after the one_hot operation. sequence_id = sequence_id.masked_fill(~attention_mask, 0) - attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) + sequence_id_one_hot = torch.nn.functional.one_hot(sequence_id) if attention_mask is not None: - attention_mask_in_length = attention_mask_in_length.masked_fill( + sequence_id_one_hot = sequence_id_one_hot.masked_fill( ~attention_mask.unsqueeze(-1), 0, ) - attention_mask_in_length = attention_mask_in_length.sum(dim=1) + if attn_impl == 'flex': + return sequence_id_one_hot.cumsum(dim=1).sum(dim=-1) - 1 + attention_mask_in_length = sequence_id_one_hot.sum(dim=1) attention_mask_in_length = torch.nn.functional.pad( attention_mask_in_length, (0, S - attention_mask_in_length.shape[-1]), mode='constant', value=0, ) + return attention_mask_in_length - return attention_mask_in_length + return None def gen_flash_attn_padding_info( @@ -417,6 +424,18 @@ def __init__(self, config: MPTConfig): self.mb_args = None self.shift_labels = True + flex_attn_compile = config.attn_config.get( + 'flex_attn_compile', + FLEX_ATTN_COMPILE, + ) + if self.attn_impl == 'flex': + self.compiled_flex_attention = torch.compile( + flex_attention, + ) if flex_attn_compile else flex_attention + self.compiled_create_block_mask = torch.compile( + create_block_mask, + ) if flex_attn_compile else create_block_mask + self.blocks = self.construct_blocks(config=config,) # Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx @@ -718,7 +737,7 @@ def _attn_bias( self._attn_bias_initialized = True # flash will incorporate any attention_mask inside the attention module - if self.attn_impl == 'flash': + if self.attn_impl == 'flash' or self.attn_impl == 'flex': return self.attn_bias, attention_mask if self.attn_bias is not None: @@ -915,7 +934,8 @@ def forward( attention_mask=attention_mask, sequence_id=sequence_id, ) - attention_mask_in_length = gen_attention_mask_in_length( + + sequence_id_info = gen_sequence_id_info( sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, @@ -924,7 +944,9 @@ def forward( ) alibi_slopes = None # alibi_slopes will only be used by flash attention for ALiBi - if self.alibi and self.attn_impl == 'flash': + if self.alibi and ( + self.attn_impl == 'flash' or self.attn_impl == 'flex' + ): alibi_slopes = gen_slopes( n_heads=self.config.n_heads, alibi_bias_max=self.alibi_bias_max, @@ -949,7 +971,7 @@ def forward( S, past_position, x.device, - attention_mask_in_length, + sequence_id_info, attention_mask, ) @@ -974,6 +996,19 @@ def forward( extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if self.attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'sequence_id_info': { + 'pos_in_seq': + sequence_id_info, + 'sequence_id': + sequence_id if self.attn_uses_sequence_id else None, + }, + 'compiled_flex_attention': + self.compiled_flex_attention, + 'compiled_create_block_mask': + self.compiled_create_block_mask, + } x, attn_weights, present = block( x, past_key_value=past_key_value, diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 5550785149..e9bae22215 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -6,24 +6,42 @@ ffn_config_defaults: dict = { 'ffn_type': 'mptmlp', } +import torch +from packaging import version attn_config_defaults: dict = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'flash', - 'qk_ln': False, - 'qk_gn': False, - 'fused_qkv': True, - 'clip_qkv': None, - 'softmax_scale': None, - 'attn_uses_sequence_id': False, - 'sliding_window_size': -1, - 'attn_logit_softcapping': None, - 'alibi': False, - 'alibi_bias_max': 8, - 'rope': False, - 'rope_theta': 10000, - 'rope_impl': 'dail', + 'attn_type': + 'multihead_attention', + 'attn_pdrop': + 0.0, + 'attn_impl': + 'flash', + 'qk_ln': + False, + 'qk_gn': + False, + 'fused_qkv': + True, + 'clip_qkv': + None, + 'softmax_scale': + None, + 'attn_uses_sequence_id': + False, + 'sliding_window_size': + -1, + 'attn_logit_softcapping': + None, + 'alibi': + False, + 'alibi_bias_max': + 8, + 'rope': + False, + 'rope_theta': + 10000, + 'rope_impl': + 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -33,7 +51,12 @@ 'type': 'no_scaling', 'factor': 1.0, }, - 'kv_dim': None, + 'kv_dim': + None, + 'flex_attn_mod_list': [], + 'flex_attn_compile': + version.parse(torch.__version__.split('.dev')[0]) >= + version.parse('2.6.0'), } init_config_defaults: dict = { diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 850c4f3bbd..4cf542d34a 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -22,6 +22,7 @@ ffns, ffns_with_megablocks, ffns_with_norm, + flex_attention_mods, module_init_fns, norms, param_init_fns, @@ -432,6 +433,7 @@ 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', + 'flex_attention_mods', 'fcs', 'icl_datasets', 'config_transforms', diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index 63ecb17d78..352ef9649b 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -6,14 +6,23 @@ import pytest import torch from composer.utils import reproducibility +from packaging import version +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from llmfoundry.models.layers.attention import ( attention_implementations, scaled_multihead_dot_product_attention, ) +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if FLEX_ATTN_COMPILE: + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + @pytest.mark.parametrize( 'attn_name', @@ -170,14 +179,20 @@ def test_unfused_wqkv(attn_name: str, dim: int): @pytest.mark.gpu @pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) def test_sliding_window(sliding_window_size: int, attn_impl: str): # Test that sliding window attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) dtype = torch.bfloat16 device = 'cuda' d = 128 n_heads = 8 - seqlen_1 = 8 + seqlen_1 = 8 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1, @@ -205,6 +220,13 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): 'should_repeat_kv_for_gqa': True, } + elif attn_impl == 'flex': + attn_extra_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, @@ -270,6 +292,141 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): _assert_approx_equal(value_1.grad, value_2.grad) +@pytest.mark.gpu +@pytest.mark.parametrize('sliding_window_size', [1, 4]) +@pytest.mark.parametrize('global_window_size', [1, 4]) +@pytest.mark.parametrize('attn_impl', ['flex']) +def test_local_global_window( + sliding_window_size: int, + global_window_size: int, + attn_impl: str, +): + # Test that sliding window attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + global_window_size = torch.tensor(global_window_size, device='cuda') + sliding_window_size = torch.tensor(sliding_window_size, device='cuda') + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + n_heads = 8 + seqlen_1 = 128 + bsz = 1 + + query_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + value_1.requires_grad = True + + attn_extra_kwargs = {} + if attn_impl == 'flex': + attn_extra_kwargs = { + 'compiled_flex_attention': + compiled_flex_attention, + 'compiled_create_block_mask': + compiled_create_block_mask, + 'flex_attn_compile': + FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'pos_in_seq': None, + }, + 'flex_attn_mod_list': [{ + 'mod_name': 'local_global_mask', + 'mod_kwargs': { + 'sliding_window_size': sliding_window_size, + 'global_window_size': global_window_size, + }, + },], + } + + output_1, _, _ = attention_implementations.get(attn_impl)( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + sliding_window_size=-1, + **attn_extra_kwargs, + ) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + global_bias_2 = torch.where( + torch.arange(seqlen_1)[None, None, + None, :].to(dtype=dtype, device=device) <= + global_window_size, + torch.ones(1, 1, seqlen_1, + seqlen_1).to(dtype=torch.bool, device=device), + torch.zeros(1, 1, seqlen_1, + seqlen_1).to(dtype=torch.bool, device=device), + ) + + window_mask_2 = torch.tril( + torch.ones(seqlen_1, seqlen_1), + diagonal=-(sliding_window_size + 1), + ).to(dtype=torch.bool, device=device) + window_mask_2 = torch.where( + window_mask_2, + torch.zeros_like(window_mask_2), + torch.ones_like(window_mask_2), + ) + attn_bias_2 = global_bias_2 | window_mask_2 + attn_bias_2 = torch.where( + attn_bias_2, + torch.zeros_like(attn_bias_2, dtype=dtype), + torch.finfo(dtype).min, + ) + output_2, _, _ = scaled_multihead_dot_product_attention( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + ) + + output_2.sum().backward() + + _assert_approx_equal(output_1, output_2) + assert (query_2.grad is not None) and (query_1.grad is not None) + _assert_approx_equal(query_1.grad, query_2.grad) + assert (key_2.grad is not None) and (key_1.grad is not None) + _assert_approx_equal(key_1.grad, key_2.grad) + assert (value_2.grad is not None) and (value_1.grad is not None) + _assert_approx_equal(value_1.grad, value_2.grad) + + def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor): assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 666d93c9b4..abab741a94 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -6,31 +6,47 @@ import pytest import torch +from packaging import version +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from llmfoundry.models.layers.attention import ( + attention_implementations, attn_bias_shape, build_attn_bias, check_alibi_support, - flash_attn_fn, gen_slopes, is_flash_v2_installed, scaled_multihead_dot_product_attention, ) +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if FLEX_ATTN_COMPILE: + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + @pytest.mark.gpu @pytest.mark.skipif( not is_flash_v2_installed(), reason='GQA natively only supported by Flash Attention after v2.', ) +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) @pytest.mark.parametrize('kv_n_heads', [1, 4, 8]) -def test_gqa_kv_repetition(kv_n_heads: int): +def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) d = 128 n_heads = 8 - seqlen_1 = 6 + seqlen_1 = 6 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() @@ -41,7 +57,30 @@ def test_gqa_kv_repetition(kv_n_heads: int): kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -55,15 +94,7 @@ def test_gqa_kv_repetition(kv_n_heads: int): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, + **extra_attn_kwargs, ) output_1.sum().backward() @@ -74,8 +105,30 @@ def test_gqa_kv_repetition(kv_n_heads: int): key_2.requires_grad = True value_2 = value_1.detach().clone() value_2.requires_grad = True - - output_2, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_2.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + False, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + + output_2, _, _ = attention_implementations.get(attn_impl)( query=query_2, key=key_2, value=value_2, @@ -89,15 +142,7 @@ def test_gqa_kv_repetition(kv_n_heads: int): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_2.device, - None, - None, - ), - should_repeat_kv_for_gqa=False, + **extra_attn_kwargs, ) output_2.sum().backward() @@ -113,12 +158,19 @@ def test_gqa_kv_repetition(kv_n_heads: int): reason= 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', ) -def test_seq_id_masking_FA_v2(): +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) +def test_seq_id_masking_FA_v2(attn_impl: str): # Test that flash attention v2 with sequence id masking works correctly. - d = 128 + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + d = 128 # TODO: Compiled FlexAttention works for d=16 with seqlen=6, but not for d=128 with seqlen=6. For seqlen=128, all d's in [16, 32, 64, 128, 256] work. Probably because this is not yet fixed: https://pytorch.org/blog/flexattention/#limitations-and-future-work n_heads = 4 kv_n_heads = 4 - seqlen_1 = 6 + seqlen_1 = 128 bsz = 2 query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() @@ -134,9 +186,13 @@ def test_seq_id_masking_FA_v2(): (3, 5), (5, 6), ] # Each batch has 3 sequences of length 3, 2, and 1 respectively. - attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], - [3, 2, 1, 0, 0, - 0]]).to(torch.int64).cuda() + attention_mask_in_length_1 = torch.tensor([ + [3, 2, 1] + [0] * (seqlen_1 - 3), + [3, 2, 1] + [0] * (seqlen_1 - 3), + ]).to(torch.int64).cuda() + sequence_id = torch.tensor([[0, 0, 0, 1, 1, 2] + [-1] * + (seqlen_1 - 6), [0, 0, 0, 1, 1, 2] + [-1] * + (seqlen_1 - 6)],).to(torch.int64).cuda() flash_attn_padding_info_1 = gen_flash_attn_padding_info( bsz, @@ -146,8 +202,19 @@ def test_seq_id_masking_FA_v2(): attention_mask_in_length_1, None, ) - - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs['flash_attn_padding_info'] = flash_attn_padding_info_1 + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'sequence_id': sequence_id, + }, + } + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -161,7 +228,7 @@ def test_seq_id_masking_FA_v2(): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=flash_attn_padding_info_1, + **extra_attn_kwargs, ) output_1.sum().backward() @@ -182,8 +249,11 @@ def test_seq_id_masking_FA_v2(): None, None, ) - - output_2, _, _ = flash_attn_fn( + attn_impl = 'flash' + extra_attn_kwargs = { + 'flash_attn_padding_info': flash_attn_padding_info_2, + } + output_2, _, _ = attention_implementations.get(attn_impl)( query=query_2, key=key_2, value=value_2, @@ -197,11 +267,11 @@ def test_seq_id_masking_FA_v2(): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=flash_attn_padding_info_2, + **extra_attn_kwargs, ) output_2.sum().backward() - assert torch.allclose( + torch.testing.assert_close( output_1[:, seq_range[0]:seq_range[1], :], output_2, ) @@ -224,13 +294,24 @@ def test_seq_id_masking_FA_v2(): not check_alibi_support('flash'), reason='ALiBi only supported by Flash Attention after v2.4.2.', ) +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) @pytest.mark.parametrize('n_heads', [1, 6, 8]) -def test_alibi_bias(n_heads: int): +def test_alibi_bias(attn_impl: str, n_heads: int): # Test that sliding window attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + if attn_impl == 'flex' and n_heads != 8: + pytest.skip( + 'FlexAttention passes the test individually for n_heads=1, 6, and 8, but not when all three are configured.', + ) # TODO: Investigate why this is the case. dtype = torch.bfloat16 device = 'cuda' d = 128 - seqlen_1 = 8 + seqlen_1 = 6 if attn_impl != 'flex' else 128 # TODO: FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1, @@ -248,7 +329,29 @@ def test_alibi_bias(n_heads: int): device=torch.device(device), return_1d=True, ) - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -262,16 +365,8 @@ def test_alibi_bias(n_heads: int): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, alibi_slopes=alibi_slopes_1, + **extra_attn_kwargs, ) output_1.sum().backward() @@ -341,16 +436,34 @@ def gen_bias(): reason= 'attn_logit_softcapping only supported by Flash Attention after v2.6.2.', ) +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) @pytest.mark.parametrize( 'attn_logit_softcapping', [None, 0.1, 1.0, 10.0, 100.0], ) -def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]): +def test_attn_logit_softcapping( + attn_impl: str, + attn_logit_softcapping: Optional[float], +): # Test that attn_logit_softcapping in attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + if attn_impl == 'flex' and attn_logit_softcapping is not None: + if int(attn_logit_softcapping) != attn_logit_softcapping: + pytest.skip( + 'FlexAttention does not support attn_logit_softcapping with float softcap temperature.', + ) + else: + attn_logit_softcapping = int(attn_logit_softcapping) + dtype = torch.bfloat16 device = 'cuda' d = 128 - seqlen_1 = 8 + seqlen_1 = 8 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 n_heads = 4 @@ -363,7 +476,29 @@ def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]): value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, device=device) value_1.requires_grad = True - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -377,16 +512,8 @@ def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, attn_logit_softcapping=attn_logit_softcapping, + **extra_attn_kwargs, ) output_1.sum().backward() diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 0a4b32a73a..626281c2a1 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -4,6 +4,8 @@ import pytest import torch from omegaconf import OmegaConf as om +from packaging import version +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from llmfoundry.models.layers import attention from llmfoundry.models.layers.attention import ( @@ -11,14 +13,21 @@ gen_slopes, is_flash_v2_installed, ) +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import ( apply_sequence_id, - gen_attention_mask_in_length, gen_flash_attn_padding_info, gen_rotary_embedding, + gen_sequence_id_info, ) +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if FLEX_ATTN_COMPILE: + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + def allclose_helper( t0: torch.Tensor, @@ -30,9 +39,13 @@ def allclose_helper( @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl_0, attn_impl_1', [ - ('flash', 'torch'), -]) +@pytest.mark.parametrize( + 'attn_impl_0, attn_impl_1', + [ + ('flash', 'torch'), + ('flex', 'torch'), + ], +) @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize( 'qk_ln, qk_gn', @@ -96,6 +109,12 @@ def test_attn_impl( Includes testing with and without attn_clip_qkv, attn_qk_ln, attn_qk_gn, alibi, and rope. """ + if (attn_impl_0 == 'flex' or attn_impl_1 == 'flex') and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] if alibi and not ( @@ -117,7 +136,7 @@ def test_attn_impl( pytest.skip('attn_uses_sequence_id requires alibi or rope.') cfg = om.create({ - 'attn_impl': 'flash', + 'attn_impl': attn_impl_0, 'd_model': 64, 'n_heads': 4, 'attn_pdrop': 0, @@ -198,7 +217,7 @@ def gen_bias(attn_impl: str): return attn_bias - attention_mask_in_length_0 = gen_attention_mask_in_length( + attention_mask_in_length_0 = gen_sequence_id_info( sequence_id=sequence_id, S=s, attn_uses_sequence_id=attn_uses_sequence_id, @@ -217,7 +236,7 @@ def gen_bias(attn_impl: str): attention_mask, ) - attention_mask_in_length_1 = gen_attention_mask_in_length( + attention_mask_in_length_1 = gen_sequence_id_info( sequence_id=sequence_id, S=s, attn_uses_sequence_id=attn_uses_sequence_id, @@ -244,7 +263,7 @@ def gen_bias(attn_impl: str): with torch.autocast(x0.device.type): attn_bias_0 = gen_bias(attn_impl_0) alibi_slopes_0 = None - if alibi and attn_impl_0 == 'flash': + if alibi and (attn_impl_0 == 'flash' or attn_impl_0 == 'flex'): alibi_slopes_0 = gen_slopes( n_heads=cfg.n_heads, alibi_bias_max=8, @@ -278,7 +297,17 @@ def gen_bias(attn_impl: str): 'seq_len': s, } - + extra_kwargs = {} + if attn_impl_0 == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + if sequence_id is not None: + extra_kwargs['flex_attn_kwargs']['sequence_id_info'][ + 'sequence_id'] = sequence_id y0, _, _ = attn0( x0, past_key_value=None, @@ -288,16 +317,30 @@ def gen_bias(attn_impl: str): is_causal=True, flash_attn_padding_info=flash_attn_padding_info_0, alibi_slopes=alibi_slopes_0, + **extra_kwargs, ) attn_bias_1 = gen_bias(attn_impl_1) alibi_slopes_1 = None - if alibi and attn_impl_1 == 'flash': + if alibi and (attn_impl_1 == 'flash' or attn_impl_1 == 'flex'): alibi_slopes_1 = gen_slopes( n_heads=cfg.n_heads, alibi_bias_max=8, device=torch.device(device), return_1d=True, ) + + extra_kwargs = {} + if attn_impl_1 == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + } + if sequence_id is not None: + extra_kwargs['flex_attn_kwargs']['sequence_id_info'] = { + 'sequence_id': sequence_id, + } + y1, _, _ = attn1( x1, past_key_value=None, @@ -307,6 +350,7 @@ def gen_bias(attn_impl: str): is_causal=True, flash_attn_padding_info=flash_attn_padding_info_1, alibi_slopes=alibi_slopes_1, + **extra_kwargs, ) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -344,9 +388,15 @@ def gen_bias(attn_impl: str): @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) def test_vs_mha(attn_impl: str, device: str = 'cuda'): """Compare diff attn_impl to torch.nn.MultiheadAttention.""" + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) from llmfoundry.models.layers import attention cfg = om.create({ @@ -401,6 +451,14 @@ def gen_tca_mask(): None, attention_mask, ) + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } y0, _, _ = mmhsa( x0, past_key_value=None, @@ -408,6 +466,7 @@ def gen_tca_mask(): attention_mask=attention_mask, is_causal=True, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) y1, _ = tmhsa( x1, @@ -453,7 +512,7 @@ def gen_tca_mask(): @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) @pytest.mark.parametrize('n_heads', [16, 8]) @pytest.mark.parametrize('kv_n_heads', [4, 2, 1]) def test_grouped_attention_heads( @@ -463,6 +522,12 @@ def test_grouped_attention_heads( device: str = 'cuda', ): """Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads.""" + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) from llmfoundry.models.layers import attention cfg = om.create({ @@ -494,6 +559,14 @@ def test_grouped_attention_heads( None, attention_mask, ) + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } y0, _, _ = mmhsa( x0, past_key_value=None, @@ -501,6 +574,7 @@ def test_grouped_attention_heads( attention_mask=attention_mask, is_causal=True, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) y0 *= attention_mask.unsqueeze(-1) @@ -561,13 +635,19 @@ def test_grouped_query_invalid_heads(): }, }], ) -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) def test_reuse_prev_layer_kv_cache( pos_emb_config: dict, attn_impl: str, device: str = 'cuda', ): """Checks reusing previous layer's kv cache.""" + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -642,7 +722,7 @@ def gen_bias(attn_impl: str): return attn_bias - attention_mask_in_length = gen_attention_mask_in_length( + attention_mask_in_length = gen_sequence_id_info( sequence_id=sequence_id, S=s, attn_uses_sequence_id=True, @@ -701,7 +781,16 @@ def gen_bias(attn_impl: str): 'seq_len': s, } - + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'sequence_id': sequence_id, + }, + } y0, _, prev_layer_key_value = attn0( x0, past_key_value=(), @@ -711,6 +800,7 @@ def gen_bias(attn_impl: str): is_causal=True, flash_attn_padding_info=flash_attn_padding_info, alibi_slopes=alibi_slopes_0, + **extra_kwargs, ) attn_bias_1 = gen_bias(attn_impl) alibi_slopes_1 = None @@ -725,6 +815,16 @@ def gen_bias(attn_impl: str): prev_layer_key_value = [ t.clone().detach() for t in prev_layer_key_value ] + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'sequence_id': sequence_id, + }, + } y1, _, _ = attn1( x1, past_key_value=None, @@ -735,6 +835,7 @@ def gen_bias(attn_impl: str): flash_attn_padding_info=flash_attn_padding_info, alibi_slopes=alibi_slopes_1, prev_layer_key_value=prev_layer_key_value, + **extra_kwargs, ) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 8a6290d5c4..51bd7b14b6 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -28,6 +28,7 @@ ) from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from packaging import version from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -48,6 +49,7 @@ is_flash_v2_installed, ) from llmfoundry.models.layers.blocks import MPTBlock +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel from llmfoundry.models.mpt.modeling_mpt import ( CROSS_ENTROPY_IGNORE_INDEX, @@ -82,6 +84,7 @@ def _get_objs( conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', model_config_overrides: Optional[dict] = None, attn_impl: str = 'torch', + flex_attn_compile: bool = FLEX_ATTN_COMPILE, ): warnings.filterwarnings( action='ignore', @@ -111,6 +114,7 @@ def _get_objs( test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32' test_cfg.model.attn_config = { 'attn_impl': attn_impl, + 'flex_attn_compile': flex_attn_compile, } test_cfg.model.init_device = device test_cfg.device = device @@ -484,7 +488,9 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): 'attn_impl,precision', [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), - pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)], + pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu), + pytest.param('flex', torch.float16, marks=pytest.mark.gpu), + pytest.param('flex', torch.bfloat16, marks=pytest.mark.gpu)], ) @pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptglu']) @pytest.mark.parametrize( @@ -515,12 +521,19 @@ def test_determinism( ffn_type: str, ffn_act_fn: dict, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) test_cfg.model.attn_config = { 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, } if hasattr(test_cfg.model, 'ffn_config'): test_cfg.model.ffn_config['ffn_type'] = ffn_type @@ -1032,7 +1045,7 @@ def test_mb_mpt_creation(): @pytest.mark.gpu -@pytest.mark.parametrize('attention_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attention_impl', ['flash', 'torch', 'flex']) @pytest.mark.parametrize( 'pos_emb_config', [{ @@ -1078,6 +1091,12 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): pytest.skip( 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', ) + if attention_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) composer_device = get_device(None) @@ -1093,6 +1112,7 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): attn_config={ 'attn_impl': attention_impl, 'attn_uses_sequence_id': True, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, init_config={ @@ -1161,6 +1181,7 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), pytest.param('torch', marks=pytest.mark.gpu), ], ) @@ -1200,6 +1221,12 @@ def test_forward_with_padding( tie_word_embeddings: bool, ): # Test that different placement of padding does not affect the output. + if attention_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) alibi = pos_emb_config['alibi'] if alibi and not check_alibi_support(attention_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1224,6 +1251,7 @@ def test_forward_with_padding( resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, init_config={ @@ -1417,6 +1445,7 @@ def test_forward_with_padding( [ ('torch', 'fp32'), pytest.param('flash', 'amp_bf16', marks=pytest.mark.gpu), + pytest.param('flex', 'amp_bf16', marks=pytest.mark.gpu), pytest.param('torch', 'amp_bf16', marks=pytest.mark.gpu), pytest.param('torch', 'fp32', marks=pytest.mark.gpu), ], @@ -1459,6 +1488,12 @@ def test_generate( ): # Test that generate works, and produces the same output with or without # padding in the input. + if attention_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attention_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1471,7 +1506,9 @@ def test_generate( pytest.skip(f'This test configuration has precision / sampling issues.') composer_device = get_device(None) - + reproducibility.seed_all( + 4, + ) # Flex atttention fails for the default seed, but works for all the other seeds tested. Probably the output logit softmax score is such that a slight numerical imprecision changes the output. hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1483,6 +1520,8 @@ def test_generate( resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, + 'flex_attn_compile': + False, # TODO: Needs these issues to be fixed: https://github.com/pytorch/pytorch/issues/139064, https://github.com/pytorch/pytorch/issues/139544. Causes errors even with dynamic=True and/or fullgraph=True. **pos_emb_config, }, tie_word_embeddings=tie_word_embeddings, @@ -1713,6 +1752,7 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -1753,6 +1793,12 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.', ) + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) composer_device = get_device(None) @@ -1767,6 +1813,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -1875,6 +1922,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -1914,6 +1962,12 @@ def test_forward_with_cache( ): # Test that model forward with and without the key-value cache produces the # same output. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1936,6 +1990,7 @@ def test_forward_with_cache( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2046,6 +2101,7 @@ def test_forward_with_cache( [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -2083,6 +2139,12 @@ def test_generate_with_past_kv( pos_emb_config: dict, tie_word_embeddings: bool, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ @@ -2104,6 +2166,7 @@ def test_generate_with_past_kv( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2168,6 +2231,7 @@ def test_generate_with_past_kv( [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -2215,6 +2279,12 @@ def test_generation_kwargs_dont_crash( pos_emb_config: dict, tie_word_embeddings: bool, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -2239,6 +2309,7 @@ def test_generation_kwargs_dont_crash( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2393,6 +2464,7 @@ def test_alibi_vs_hf(): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), pytest.param('torch', marks=pytest.mark.gpu), ], ) @@ -2429,9 +2501,15 @@ def test_forward_with_output_attentions_and_output_hidden_states( attn_impl: str, pos_emb_config: dict, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') - if attn_impl == 'flash': + if attn_impl == 'flash' or attn_impl == 'flex': pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -2454,6 +2532,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2591,7 +2670,14 @@ def test_hf_init( @pytest.mark.gpu -def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): +@pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex']) +def test_head_dim_8_flash_mqa_attn(attn_impl: str, batch_size: int = 2): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') test_cfg.device = torch.cuda.current_device() @@ -2607,7 +2693,8 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'flash', + 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, 'attn_type': 'multiquery_attention', }, ) @@ -2639,7 +2726,14 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): assert not torch.isnan(output.logits).any() -def test_construct_blocks(): +@pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex']) +def test_construct_blocks(attn_impl: str): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) n_layers = 13 config = MPTConfig( @@ -2649,7 +2743,8 @@ def test_construct_blocks(): expansion_ratio=2, max_seq_len=64, attn_config={ - 'attn_impl': 'flash', + 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, 'attn_type': 'grouped_query_attention', 'kv_n_heads': 4, }, @@ -2729,10 +2824,18 @@ def test_construct_blocks(): @pytest.mark.gpu +@pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex']) def test_reuse_prev_layer_kv_cache( + attn_impl: str, request: pytest.FixtureRequest, batch_size: int = 2, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' model_config_overrides = { 'block_overrides': { @@ -2758,7 +2861,8 @@ def test_reuse_prev_layer_kv_cache( request=request, conf_path=conf_path, model_config_overrides=model_config_overrides, - attn_impl='flash', + attn_impl=attn_impl, + flex_attn_compile=FLEX_ATTN_COMPILE, ) batch = gen_random_batch(batch_size, test_cfg) diff --git a/tests/test_registry.py b/tests/test_registry.py index 90ef3bfaac..8ddce5125d 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -40,6 +40,7 @@ def test_expected_registries_exist(): 'ffns', 'ffns_with_norm', 'ffns_with_megablocks', + 'flex_attention_mods', 'attention_classes', 'attention_implementations', 'fcs',