From 256154b35d10b284660568fd96e55c331651f5cd Mon Sep 17 00:00:00 2001 From: Sam Sharpe Date: Wed, 24 May 2023 08:34:30 -0400 Subject: [PATCH 1/3] Update Pooling.py fixing last token indexing for sequences that span the entire length --- sentence_transformers/models/Pooling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/models/Pooling.py b/sentence_transformers/models/Pooling.py index 7ab2dd81d..ad560e5b5 100644 --- a/sentence_transformers/models/Pooling.py +++ b/sentence_transformers/models/Pooling.py @@ -139,7 +139,9 @@ def forward(self, features: Dict[str, Tensor]): # attention_mask shape: (bs, seq_len) # Get shape [bs] indices of the last token (i.e. the last token for each batch item) # argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1 - gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1 # Shape [bs] + # Any sequence where min == 1, we use the entire sequence lenth since argmin = 0 + values, indices = torch.min(attention_mask, 1, keepdim = False) + gather_indices = torch.where(values==0, indices, attention_mask.shape[1]) - 1 # Shape [bs] # There are empty sequences, where the index would become -1 which will crash gather_indices = torch.clamp(gather_indices, min=0) From 581fe58939377595df0cfb5698f207bc1e777b19 Mon Sep 17 00:00:00 2001 From: Sam Sharpe Date: Thu, 25 May 2023 08:58:52 -0400 Subject: [PATCH 2/3] Update Pooling.py change to seq_len --- sentence_transformers/models/Pooling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/models/Pooling.py b/sentence_transformers/models/Pooling.py index ad560e5b5..0e0f93697 100644 --- a/sentence_transformers/models/Pooling.py +++ b/sentence_transformers/models/Pooling.py @@ -141,7 +141,7 @@ def forward(self, features: Dict[str, Tensor]): # argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1 # Any sequence where min == 1, we use the entire sequence lenth since argmin = 0 values, indices = torch.min(attention_mask, 1, keepdim = False) - gather_indices = torch.where(values==0, indices, attention_mask.shape[1]) - 1 # Shape [bs] + gather_indices = torch.where(values==0, indices, seq_len) - 1 # Shape [bs] # There are empty sequences, where the index would become -1 which will crash gather_indices = torch.clamp(gather_indices, min=0) From b875550548b116c12c68484e65b6146fc3ad0e7c Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 13 Dec 2023 19:44:17 +0100 Subject: [PATCH 3/3] Typo: lenth -> length --- sentence_transformers/models/Pooling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/models/Pooling.py b/sentence_transformers/models/Pooling.py index ca60d0608..64083652d 100644 --- a/sentence_transformers/models/Pooling.py +++ b/sentence_transformers/models/Pooling.py @@ -139,7 +139,7 @@ def forward(self, features: Dict[str, Tensor]): # attention_mask shape: (bs, seq_len) # Get shape [bs] indices of the last token (i.e. the last token for each batch item) # argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1 - # Any sequence where min == 1, we use the entire sequence lenth since argmin = 0 + # Any sequence where min == 1, we use the entire sequence length since argmin = 0 values, indices = torch.min(attention_mask, 1, keepdim = False) gather_indices = torch.where(values==0, indices, seq_len) - 1 # Shape [bs]