diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 3a7dbc198e3732..38ee66d391e797 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -424,6 +424,7 @@ def forward( # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0] batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0) + # Get the target length target_seqlen = first_layer_past_key_value.shape[-1] + 1 @@ -433,8 +434,14 @@ def forward( device=attention_mask.device, ) + # Ensuring indices are within bounds - and avoid CUDA index errors + # See https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 for more details + valid_indices = non_attended_tokens < extended_attention_mask.shape[1] + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend - extended_attention_mask[batch_index, non_attended_tokens] = 0 + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 0b1dc3fa86b383..9f0920d2dfcdb5 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -429,8 +429,14 @@ def forward( device=attention_mask.device, ) + # Ensuring indices are within bounds - and avoid CUDA index errors + # See https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 for more details + valid_indices = non_attended_tokens < extended_attention_mask.shape[1] + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend - extended_attention_mask[batch_index, non_attended_tokens] = 0 + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1