Skip to content

Commit

Permalink
Support Cached losses in combination with Matryoshka loss (#3068)
Browse files Browse the repository at this point in the history
* support cached losses in combination with matryoshka loss

* Add some docstrings to the decorators in the MatryoshkaLoss

---------

Co-authored-by: Marcel Brunnbauer <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent 7ffbb4b commit 348190d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 126 deletions.
61 changes: 4 additions & 57 deletions sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,64 +229,11 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")

# Concatenate embeddings along the batch dimension
concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]

labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
batch_size = concatenated_reps[0].shape[0]

losses: list[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size

# Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])

# Define the anchor threshold for each similarity matrix
guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)

# Compute similarity scores for the current mini-batch
ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1]) # anchor-positive similarity
aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity
pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity

# Apply thresholds based on guided model similarities
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf

# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

# If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1)

# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())

loss = sum(losses).requires_grad_()
loss = self.calculate_loss(reps, reps_guided)
loss.backward()
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # Cache the gradients
self.cache = [[r.grad for r in rs] for rs in reps]

return loss

Expand Down
25 changes: 3 additions & 22 deletions sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,28 +213,9 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)

batch_size = len(embeddings_a)
labels = torch.tensor(
range(batch_size), dtype=torch.long, device=embeddings_a.device
) # (bsz, (1 + nneg) * bsz) Example a[i] should match with b[i]
losses: list[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())

loss = sum(losses).requires_grad_()
loss = self.calculate_loss(reps)
loss.backward()
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,35 +182,11 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the symmetric loss and cache gradients."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)

batch_size = len(embeddings_a)
labels = torch.arange(batch_size, device=embeddings_a.device)

losses: list[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = min(b + self.mini_batch_size, batch_size)
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
forward_loss: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e])

positive_scores = scores[:, b:e]
backward_loss: torch.Tensor = self.cross_entropy_loss(positive_scores.t(), labels[: len(positive_scores)])

loss_mbatch = (forward_loss + backward_loss) / 2
loss_mbatch.backward()
losses.append(loss_mbatch.detach())

loss = sum(losses) / len(losses)
loss = loss.requires_grad_()
loss = self.calculate_loss(reps)
loss.backward()
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps]
self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

Expand Down
99 changes: 80 additions & 19 deletions sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,40 @@
from __future__ import annotations

import random
import warnings
from collections.abc import Iterable
from typing import Any

import torch.nn.functional as F
from torch import Tensor, nn

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses.CachedGISTEmbedLoss import CachedGISTEmbedLoss
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss
from sentence_transformers.losses import (
CachedGISTEmbedLoss,
CachedMultipleNegativesRankingLoss,
CachedMultipleNegativesSymmetricRankingLoss,
)


def shrink(tensor: Tensor, dim: int) -> Tensor:
tensor_dim = tensor.shape[-1]
if dim > tensor_dim:
raise ValueError(
f"Dimension {dim} in matryoshka_dims cannot be greater than the model's embedding dimension: {tensor_dim}"
)
tensor = tensor[..., :dim]
tensor = F.normalize(tensor, p=2, dim=-1)
return tensor


class ForwardDecorator:
"""
This decorator is used to cache the output of the Sentence Transformer's forward pass,
so that it can be shrank and reused for multiple loss calculations. This prevents the
model from recalculating the embeddings for each desired Matryoshka dimensionality.
This decorator is applied to `SentenceTransformer.forward`.
"""

def __init__(self, fn) -> None:
self.fn = fn

Expand All @@ -26,16 +47,6 @@ def set_dim(self, dim) -> None:
self.dim = dim
self.idx = 0

def shrink(self, tensor: Tensor) -> Tensor:
tensor_dim = tensor.shape[-1]
if self.dim > tensor_dim:
raise ValueError(
f"Dimension {self.dim} in matryoshka_dims cannot be greater than the model's embedding dimension: {tensor_dim}"
)
tensor = tensor[..., : self.dim]
tensor = F.normalize(tensor, p=2, dim=-1)
return tensor

def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
# Growing cache:
if self.cache_dim is None or self.cache_dim == self.dim:
Expand All @@ -46,12 +57,46 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
else:
output = self.cache[self.idx]
if "token_embeddings" in output:
output["token_embeddings"] = self.shrink(output["token_embeddings"])
output["sentence_embedding"] = self.shrink(output["sentence_embedding"])
output["token_embeddings"] = shrink(output["token_embeddings"], self.dim)
output["sentence_embedding"] = shrink(output["sentence_embedding"], self.dim)
self.idx += 1
return output


class CachedLossDecorator:
"""
This decorator is used with the Cached... losses to compute the underlying loss function
for each Matryoshka dimensionality. This is done by shrinking the pre-computed embeddings
to the desired dimensionality and then passing them to the underlying loss function once
for each desired dimensionality.
This decorator is applied to the `calculate_loss` method of the Cached... losses.
"""

def __init__(
self, fn, matryoshka_dims: list[int], matryoshka_weights: list[float | int], n_dims_per_step: int = -1
) -> None:
self.fn = fn
self.matryoshka_dims = matryoshka_dims
self.matryoshka_weights = matryoshka_weights
self.n_dims_per_step = n_dims_per_step

def __call__(self, reps: list[list[Tensor]], *args) -> Tensor:
dim_indices = range(len(self.matryoshka_dims))
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
dim_indices = random.sample(dim_indices, self.n_dims_per_step)

loss = 0.0
for idx in dim_indices:
dim = self.matryoshka_dims[idx]
weight = self.matryoshka_weights[idx]

truncated = [[shrink(r, dim) for r in rs] for rs in reps]
loss += weight * self.fn(truncated, *args)

return loss


class MatryoshkaLoss(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -123,10 +168,6 @@ def __init__(
super().__init__()
self.model = model
self.loss = loss
if isinstance(loss, CachedMultipleNegativesRankingLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2)
if isinstance(loss, CachedGISTEmbedLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedGISTEmbedLoss.", stacklevel=2)

if matryoshka_weights is None:
matryoshka_weights = [1] * len(matryoshka_dims)
Expand All @@ -135,7 +176,27 @@ def __init__(
self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True))
self.n_dims_per_step = n_dims_per_step

# The Cached... losses require a special treatment as their backward pass is incompatible with the
# ForwardDecorator approach. Instead, we use a CachedLossDecorator to compute the loss for each
# Matryoshka dimensionality given pre-computed embeddings passed to `calculate_loss`.
self.cached_losses = (
CachedMultipleNegativesRankingLoss,
CachedGISTEmbedLoss,
CachedMultipleNegativesSymmetricRankingLoss,
)
if isinstance(loss, self.cached_losses):
loss.calculate_loss = CachedLossDecorator(
loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights
)

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
# For the Cached... losses, the CachedLossDecorator has been applied to the `calculate_loss` method,
# so we can directly call the loss function.
if isinstance(self.loss, self.cached_losses):
return self.loss(sentence_features, labels)

# Otherwise, we apply the ForwardDecorator to the model's forward pass, which will cache the output
# embeddings for each Matryoshka dimensionality, allowing it to be reused for the smaller dimensions.
original_forward = self.model.forward
try:
decorated_forward = ForwardDecorator(original_forward)
Expand Down

0 comments on commit 348190d

Please sign in to comment.