Skip to content

Commit

Permalink
Fix indexing of lasttoken pooling for longest sequence (#2111)
Browse files Browse the repository at this point in the history
* Update Pooling.py

fixing last token indexing for sequences that span the entire length

* Update Pooling.py

change to seq_len

* Typo: lenth -> length

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
ssharpe42 and tomaarsen authored Dec 13, 2023
1 parent 0417a0b commit 6b524f8
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sentence_transformers/models/Pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def forward(self, features: Dict[str, Tensor]):
# attention_mask shape: (bs, seq_len)
# Get shape [bs] indices of the last token (i.e. the last token for each batch item)
# argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1
gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1 # Shape [bs]
# Any sequence where min == 1, we use the entire sequence length since argmin = 0
values, indices = torch.min(attention_mask, 1, keepdim = False)
gather_indices = torch.where(values==0, indices, seq_len) - 1 # Shape [bs]

# There are empty sequences, where the index would become -1 which will crash
gather_indices = torch.clamp(gather_indices, min=0)
Expand Down

0 comments on commit 6b524f8

Please sign in to comment.