diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index ea2acf3814..2da8ff3eb0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -474,9 +474,11 @@ def flex_attn_fn( ) check_valid_inputs(query, key, value) - + query_offset = 0 if past_key_value is not None: if len(past_key_value) != 0: + assert past_key_value[0].shape[1] == past_key_value[1].shape[1] + query_offset = past_key_value[0].shape[1] key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) @@ -564,10 +566,12 @@ def _check_mod_list(mod_list: list[dict[str, Any]], name: str): B=query.shape[0], block_mask_list=block_mask_list, # type: ignore compiled_create_block_mask=compiled_create_block_mask, + query_offset=query_offset, sequence_id_info=sequence_id_info, ) score_mod = generate_score_mod( score_mod_list=score_mod_list, # type: ignore + query_offset=query_offset, sequence_id_info=sequence_id_info, ) diff --git a/llmfoundry/models/layers/flex_attn_utils.py b/llmfoundry/models/layers/flex_attn_utils.py index f7b66fd8f5..41b6813741 100644 --- a/llmfoundry/models/layers/flex_attn_utils.py +++ b/llmfoundry/models/layers/flex_attn_utils.py @@ -19,9 +19,10 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: - del sequence_id_info, b, h, q_idx, kv_idx + del sequence_id_info, query_offset, b, h, q_idx, kv_idx raise NotImplementedError def _score_mod_fn( @@ -31,9 +32,10 @@ def _score_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: - del sequence_id_info, score, b, h, q_idx, kv_idx + del sequence_id_info, query_offset, score, b, h, q_idx, kv_idx raise NotImplementedError def __init__(self, mod_type: str) -> None: @@ -51,9 +53,11 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: del sequence_id_info, b, h + q_idx = q_idx + query_offset return q_idx >= kv_idx def __init__(self) -> None: @@ -69,9 +73,11 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: del sequence_id_info, b, h + q_idx = q_idx + query_offset return q_idx - kv_idx <= self.sliding_window_size def __init__(self, sliding_window_size: int) -> None: @@ -88,9 +94,11 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: del h + q_idx = q_idx + query_offset if sequence_id_info is None: raise ValueError( 'sequence_id_info is required for SequenceIdMaskMod', @@ -113,9 +121,10 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: - del h, q_idx + del h, q_idx, query_offset if sequence_id_info is None: raise ValueError( 'sequence_id_info is required for SequenceIdMaskMod', @@ -137,9 +146,11 @@ def _mask_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: del h + q_idx = q_idx + query_offset if sequence_id_info is None: raise ValueError( 'sequence_id_info is required for LocalGlobalMaskMod', @@ -177,9 +188,11 @@ def _score_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: del sequence_id_info, b + q_idx = q_idx + query_offset bias = -self.alibi_slopes[h] * torch.abs(q_idx - kv_idx) return score + bias @@ -198,9 +211,10 @@ def _score_mod_fn( h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ) -> torch.Tensor: - del sequence_id_info, b, h, q_idx, kv_idx + del sequence_id_info, query_offset, b, h, q_idx, kv_idx return self.attn_logit_softcapping * torch.tanh( score / self.attn_logit_softcapping, ) @@ -216,6 +230,7 @@ def generate_block_mask( B: int, block_mask_list: Optional[list[FlexAttentionMod]], compiled_create_block_mask: Any, + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ): if block_mask_list is None: @@ -226,12 +241,17 @@ def generate_block_mask( if i == 0: block_mask_fn = partial( block_mask.mod_fn, + query_offset=query_offset, sequence_id_info=sequence_id_info, ) else: block_mask_fn = and_masks( block_mask_fn, - partial(block_mask.mod_fn, sequence_id_info=sequence_id_info), + partial( + block_mask.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ), ) block_mask = compiled_create_block_mask( @@ -247,6 +267,7 @@ def generate_block_mask( def generate_score_mod( score_mod_list: Optional[list[FlexAttentionMod]], + query_offset: int, sequence_id_info: Optional[dict[str, Any]], ): if score_mod_list is None: @@ -256,12 +277,17 @@ def generate_score_mod( if i == 0: wrapped_score_mod = partial( score_mod.mod_fn, + query_offset=query_offset, sequence_id_info=sequence_id_info, ) else: wrapped_score_mod = _wrap_score_mod_fns( wrapped_score_mod, - partial(score_mod.mod_fn, sequence_id_info=sequence_id_info), + partial( + score_mod.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ), ) return wrapped_score_mod