From 6b524f8c233ff013c2cfc9b49145da25a2d7b79d Mon Sep 17 00:00:00 2001 From: Sam Sharpe Date: Wed, 13 Dec 2023 15:18:28 -0500 Subject: [PATCH] Fix indexing of lasttoken pooling for longest sequence (#2111) * Update Pooling.py fixing last token indexing for sequences that span the entire length * Update Pooling.py change to seq_len * Typo: lenth -> length --------- Co-authored-by: Tom Aarsen --- sentence_transformers/models/Pooling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/models/Pooling.py b/sentence_transformers/models/Pooling.py index aab02b8ea..64083652d 100644 --- a/sentence_transformers/models/Pooling.py +++ b/sentence_transformers/models/Pooling.py @@ -139,7 +139,9 @@ def forward(self, features: Dict[str, Tensor]): # attention_mask shape: (bs, seq_len) # Get shape [bs] indices of the last token (i.e. the last token for each batch item) # argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1 - gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1 # Shape [bs] + # Any sequence where min == 1, we use the entire sequence length since argmin = 0 + values, indices = torch.min(attention_mask, 1, keepdim = False) + gather_indices = torch.where(values==0, indices, seq_len) - 1 # Shape [bs] # There are empty sequences, where the index would become -1 which will crash gather_indices = torch.clamp(gather_indices, min=0)