From c50b5675d648d7c4bbe395d763cd468c3c4b56b7 Mon Sep 17 00:00:00 2001 From: Meliksah Turker Date: Mon, 25 Nov 2024 15:51:26 +0300 Subject: [PATCH] prepare_fa2_from_position_ids function bugfix (#33269) contiguous() is called before view() for key and value within prepare_fa2_from_position_ids function --- src/transformers/modeling_flash_attention_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 045d2f6d646010..1b9274e21f5205 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -163,8 +163,8 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)