Skip to content

Commit

Permalink
Make MultipleNegativesRankingLoss easier to understand
Browse files Browse the repository at this point in the history
Because this is one of the most common loss functions, I think it's useful to comment-spam it a bit.
  • Loading branch information
tomaarsen committed Nov 28, 2024
1 parent a542b0a commit d5359a4
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions sentence_transformers/losses/MultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,21 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f
self.cross_entropy_loss = nn.CrossEntropyLoss()

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
embeddings_a = reps[0]
embeddings_b = torch.cat(reps[1:])

scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
# Example a[i] should match with b[i]
# Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives)
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
anchors = embeddings[0] # (batch_size, embedding_dim)
candidates = torch.cat(embeddings[1:]) # (batch_size * (1 + num_negatives), embedding_dim)

# For every anchor, we compute the similarity to all other candidates (positives and negatives),
# also from other anchors. This gives us a lot of in-batch negatives.
scores = (
self.similarity_fct(anchors, candidates) * self.scale
) # (batch_size, batch_size * (1 + num_negatives))

# anchor[i] should be most similar to candidates[i], as that is the paired positive,
# so the label for anchor[i] is i
range_labels = torch.arange(0, scores.size(0), device=scores.device)

return self.cross_entropy_loss(scores, range_labels)

def get_config_dict(self) -> dict[str, Any]:
Expand Down

0 comments on commit d5359a4

Please sign in to comment.