Skip to content

Commit

Permalink
[WIP] support cached losses in combination with matryoshka loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Brunnbauer committed Nov 18, 2024
1 parent e156f38 commit 132b6e7
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 114 deletions.
61 changes: 4 additions & 57 deletions sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,64 +229,11 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")

# Concatenate embeddings along the batch dimension
concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]

labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
batch_size = concatenated_reps[0].shape[0]

losses: list[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size

# Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])

# Define the anchor threshold for each similarity matrix
guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)

# Compute similarity scores for the current mini-batch
ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1]) # anchor-positive similarity
aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity
pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity

# Apply thresholds based on guided model similarities
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf

# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

# If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1)

# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())

loss = sum(losses).requires_grad_()
loss = self.calculate_loss(reps, reps_guided)
loss.backward()
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # Cache the gradients
self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

Expand Down
25 changes: 3 additions & 22 deletions sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,28 +213,9 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)

batch_size = len(embeddings_a)
labels = torch.tensor(
range(batch_size), dtype=torch.long, device=embeddings_a.device
) # (bsz, (1 + nneg) * bsz) Example a[i] should match with b[i]
losses: list[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())

loss = sum(losses).requires_grad_()
loss = self.calculate_loss(reps)
loss.backward()
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,35 +182,11 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the symmetric loss and cache gradients."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)

batch_size = len(embeddings_a)
labels = torch.arange(batch_size, device=embeddings_a.device)

losses: list[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = min(b + self.mini_batch_size, batch_size)
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
forward_loss: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e])

positive_scores = scores[:, b:e]
backward_loss: torch.Tensor = self.cross_entropy_loss(positive_scores.t(), labels[: len(positive_scores)])

loss_mbatch = (forward_loss + backward_loss) / 2
loss_mbatch.backward()
losses.append(loss_mbatch.detach())

loss = sum(losses) / len(losses)
loss = loss.requires_grad_()
loss = self.calculate_loss(reps)
loss.backward()
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps]
self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

Expand Down
51 changes: 44 additions & 7 deletions sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from torch import Tensor, nn

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses.CachedGISTEmbedLoss import CachedGISTEmbedLoss
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss

from sentence_transformers.losses import CachedMultipleNegativesSymmetricRankingLoss, CachedMultipleNegativesRankingLoss, CachedGISTEmbedLoss

class ForwardDecorator:
def __init__(self, fn) -> None:
Expand Down Expand Up @@ -52,6 +50,43 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
return output



class CachedLossDecorator:
def __init__(self, fn, matryoshka_dims: list[int],
matryoshka_weights: list[float | int],
n_dims_per_step: int = -1) -> None:
self.fn = fn
self.matryoshka_dims = matryoshka_dims
self.matryoshka_weights = matryoshka_weights
self.n_dims_per_step = n_dims_per_step


def shrink(self, tensor: Tensor, dim: int) -> Tensor:
tensor_dim = tensor.shape[-1]
if dim > tensor_dim:
raise ValueError(
f"Dimension {dim} in matryoshka_dims cannot be greater than the model's embedding dimension: {tensor_dim}"
)
tensor = tensor[..., : dim]
tensor = F.normalize(tensor, p=2, dim=-1)
return tensor

def __call__(self, reps: list[list[Tensor]], *args) -> Tensor:
dim_indices = range(len(self.matryoshka_dims))
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
dim_indices = random.sample(dim_indices, self.n_dims_per_step)

loss = 0.0
for idx in dim_indices:
dim = self.matryoshka_dims[idx]
weight = self.matryoshka_weights[idx]

truncated = [[self.shrink(r, dim) for r in rs] for rs in reps]
loss += weight * self.fn(truncated, *args)

return loss


class MatryoshkaLoss(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -123,10 +158,6 @@ def __init__(
super().__init__()
self.model = model
self.loss = loss
if isinstance(loss, CachedMultipleNegativesRankingLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2)
if isinstance(loss, CachedGISTEmbedLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedGISTEmbedLoss.", stacklevel=2)

if matryoshka_weights is None:
matryoshka_weights = [1] * len(matryoshka_dims)
Expand All @@ -135,7 +166,13 @@ def __init__(
self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True))
self.n_dims_per_step = n_dims_per_step

if isinstance(loss, CachedMultipleNegativesRankingLoss) or isinstance(loss, CachedGISTEmbedLoss) or isinstance(loss, CachedMultipleNegativesSymmetricRankingLoss):
loss.calculate_loss = CachedLossDecorator(loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights)

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
if isinstance(self.loss, CachedMultipleNegativesRankingLoss) or isinstance(self.loss, CachedMultipleNegativesSymmetricRankingLoss) or isinstance(self.loss, CachedGISTEmbedLoss):
return self.loss(sentence_features, labels)

original_forward = self.model.forward
try:
decorated_forward = ForwardDecorator(original_forward)
Expand Down

0 comments on commit 132b6e7

Please sign in to comment.