Skip to content

Commit

Permalink
fix llava index errors
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Dec 14, 2023
1 parent 2788f8d commit 726f744
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def forward(
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1

Expand All @@ -433,8 +434,14 @@ def forward(
device=attention_mask.device,
)

# Ensuring indices are within bounds - and avoid CUDA index errors
# See https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 for more details
valid_indices = non_attended_tokens < extended_attention_mask.shape[1]
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[batch_index, non_attended_tokens] = 0
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,14 @@ def forward(
device=attention_mask.device,
)

# Ensuring indices are within bounds - and avoid CUDA index errors
# See https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 for more details
valid_indices = non_attended_tokens < extended_attention_mask.shape[1]
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[batch_index, non_attended_tokens] = 0
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
Expand Down

0 comments on commit 726f744

Please sign in to comment.