Skip to content

Commit

Permalink
[AttentionMaskConverter] fix sdpa unmask unattended (#28369)
Browse files Browse the repository at this point in the history
fix tensor device
  • Loading branch information
zspo authored Jan 8, 2024
1 parent 98dba52 commit 87a6cf4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def _unmask_unattended(

# Get the index of the first non-zero value for every sample in the batch.
# In the above example, indices = [[2], [0], [1]]]
tmp = torch.arange(attention_mask.shape[1], 0, -1)
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
tmp = torch.arange(attention_mask.shape[1], 0, -1, device=attention_mask.device)
indices = torch.argmax(attention_mask * tmp, 1, keepdim=True)

# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
# expanded mask will be completely unattended.
Expand Down

2 comments on commit 87a6cf4

@kingb12
Copy link

@kingb12 kingb12 commented on 87a6cf4 Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fyi, I think I ran into an issue possibly caused by this, where on this line, indices is on my model/data's device and range_tensor is on CPU:

        # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
        range_tensor[range_tensor >= indices] = 0

Downgrading to 4.36.2 resolved. I can submit an issue if that helps but didn't have a great way to produce a minimum working example. Only popped up with batch size > 1 but guessing that has more to do with when this would be called?

@ArthurZucker
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be reverted to being all on CPU in #28400

Please sign in to comment.