Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix indexing of lasttoken pooling for longest sequence #2111

Merged
merged 4 commits into from
Dec 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sentence_transformers/models/Pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def forward(self, features: Dict[str, Tensor]):
# attention_mask shape: (bs, seq_len)
# Get shape [bs] indices of the last token (i.e. the last token for each batch item)
# argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1
gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1 # Shape [bs]
# Any sequence where min == 1, we use the entire sequence length since argmin = 0
values, indices = torch.min(attention_mask, 1, keepdim = False)
gather_indices = torch.where(values==0, indices, seq_len) - 1 # Shape [bs]

# There are empty sequences, where the index would become -1 which will crash
gather_indices = torch.clamp(gather_indices, min=0)
Expand Down