diff --git a/llmfoundry/models/layers/flex_attn_utils.py b/llmfoundry/models/layers/flex_attn_utils.py index df412768ed..1489df2de2 100644 --- a/llmfoundry/models/layers/flex_attn_utils.py +++ b/llmfoundry/models/layers/flex_attn_utils.py @@ -110,8 +110,7 @@ def _mask_mod_fn( ) sequence_id = sequence_id_info['sequence_id'] # Check if the query and key belong to the same sequence and the query token is not a padding token. - return (sequence_id[b, q_idx] - == sequence_id[b, kv_idx]) & (sequence_id[b, kv_idx] != -1) + return (sequence_id[b, q_idx] == sequence_id[b, kv_idx]) def __init__(self) -> None: super().__init__(mod_type='mask')