diff --git a/sentence_transformers/losses/MultipleNegativesRankingLoss.py b/sentence_transformers/losses/MultipleNegativesRankingLoss.py index 1aea7acfa..c9a57cef0 100644 --- a/sentence_transformers/losses/MultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/MultipleNegativesRankingLoss.py @@ -99,13 +99,21 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f self.cross_entropy_loss = nn.CrossEntropyLoss() def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: - reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] - embeddings_a = reps[0] - embeddings_b = torch.cat(reps[1:]) - - scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale - # Example a[i] should match with b[i] + # Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives) + embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] + anchors = embeddings[0] # (batch_size, embedding_dim) + candidates = torch.cat(embeddings[1:]) # (batch_size * (1 + num_negatives), embedding_dim) + + # For every anchor, we compute the similarity to all other candidates (positives and negatives), + # also from other anchors. This gives us a lot of in-batch negatives. + scores = ( + self.similarity_fct(anchors, candidates) * self.scale + ) # (batch_size, batch_size * (1 + num_negatives)) + + # anchor[i] should be most similar to candidates[i], as that is the paired positive, + # so the label for anchor[i] is i range_labels = torch.arange(0, scores.size(0), device=scores.device) + return self.cross_entropy_loss(scores, range_labels) def get_config_dict(self) -> dict[str, Any]: