From 132b6e786d3aeb2a43eb90d895575c9769e98c26 Mon Sep 17 00:00:00 2001 From: Marcel Brunnbauer Date: Mon, 18 Nov 2024 15:02:53 +0100 Subject: [PATCH] [WIP] support cached losses in combination with matryoshka loss --- .../losses/CachedGISTEmbedLoss.py | 61 ++----------------- .../CachedMultipleNegativesRankingLoss.py | 25 +------- ...edMultipleNegativesSymmetricRankingLoss.py | 32 ++-------- .../losses/MatryoshkaLoss.py | 51 +++++++++++++--- 4 files changed, 55 insertions(+), 114 deletions(-) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index e8e131582..89c7a1a24 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -229,64 +229,11 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: """Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - if len(reps) != len(reps_guided): - raise ValueError("reps and reps_guided must have the same length") - - # Concatenate embeddings along the batch dimension - concatenated_reps = [torch.cat(rep, dim=0) for rep in reps] - concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided] - - labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device) - batch_size = concatenated_reps[0].shape[0] - - losses: list[torch.Tensor] = [] - for b in tqdm.trange( - 0, - batch_size, - self.mini_batch_size, - desc="Preparing caches", - disable=not self.show_progress_bar, - ): - e = b + self.mini_batch_size - - # Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples - guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1]) - guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0]) - guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1]) - - # Define the anchor threshold for each similarity matrix - guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1) - - # Compute similarity scores for the current mini-batch - ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1]) # anchor-positive similarity - aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity - pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity - - # Apply thresholds based on guided model similarities - ap_sim[guided_ap_sim > guided_sim] = -torch.inf - aa_sim[guided_aa_sim > guided_sim] = -torch.inf - pp_sim[guided_pp_sim > guided_sim] = -torch.inf - - # Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive - scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1) - - # If there are negatives (len(reps) > 2), process them - if len(concatenated_reps) > 2: - for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive - guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i]) - neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i]) - neg_sim[guided_neg_sim > guided_sim] = -torch.inf - scores = torch.cat([scores, neg_sim], dim=1) - - # Normalize the scores and calculate the cross-entropy loss - scores = scores / self.temperature - loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size - loss_mbatch.backward() - losses.append(loss_mbatch.detach()) - - loss = sum(losses).requires_grad_() + loss = self.calculate_loss(reps, reps_guided) + loss.backward() + loss = loss.detach().requires_grad_() - self.cache = [[r.grad for r in rs] for rs in reps] # Cache the gradients + self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) return loss diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index 544c73891..61adfb70b 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -213,28 +213,9 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor: """Calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - embeddings_a = torch.cat(reps[0]) # (bsz, hdim) - embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim) - - batch_size = len(embeddings_a) - labels = torch.tensor( - range(batch_size), dtype=torch.long, device=embeddings_a.device - ) # (bsz, (1 + nneg) * bsz) Example a[i] should match with b[i] - losses: list[torch.Tensor] = [] - for b in tqdm.trange( - 0, - batch_size, - self.mini_batch_size, - desc="Preparing caches", - disable=not self.show_progress_bar, - ): - e = b + self.mini_batch_size - scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale - loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size - loss_mbatch.backward() - losses.append(loss_mbatch.detach()) - - loss = sum(losses).requires_grad_() + loss = self.calculate_loss(reps) + loss.backward() + loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) diff --git a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py index be77b1fbf..20a3054d5 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py @@ -182,35 +182,11 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor: """Calculate the symmetric loss and cache gradients.""" - embeddings_a = torch.cat(reps[0]) # (bsz, hdim) - embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim) - - batch_size = len(embeddings_a) - labels = torch.arange(batch_size, device=embeddings_a.device) - - losses: list[torch.Tensor] = [] - for b in tqdm.trange( - 0, - batch_size, - self.mini_batch_size, - desc="Preparing caches", - disable=not self.show_progress_bar, - ): - e = min(b + self.mini_batch_size, batch_size) - scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale - forward_loss: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) - - positive_scores = scores[:, b:e] - backward_loss: torch.Tensor = self.cross_entropy_loss(positive_scores.t(), labels[: len(positive_scores)]) - - loss_mbatch = (forward_loss + backward_loss) / 2 - loss_mbatch.backward() - losses.append(loss_mbatch.detach()) - - loss = sum(losses) / len(losses) - loss = loss.requires_grad_() + loss = self.calculate_loss(reps) + loss.backward() + loss = loss.detach().requires_grad_() - self.cache = [[r.grad for r in rs] for rs in reps] + self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) return loss diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index 557574262..326c734d4 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -9,9 +9,7 @@ from torch import Tensor, nn from sentence_transformers import SentenceTransformer -from sentence_transformers.losses.CachedGISTEmbedLoss import CachedGISTEmbedLoss -from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss - +from sentence_transformers.losses import CachedMultipleNegativesSymmetricRankingLoss, CachedMultipleNegativesRankingLoss, CachedGISTEmbedLoss class ForwardDecorator: def __init__(self, fn) -> None: @@ -52,6 +50,43 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]: return output + +class CachedLossDecorator: + def __init__(self, fn, matryoshka_dims: list[int], + matryoshka_weights: list[float | int], + n_dims_per_step: int = -1) -> None: + self.fn = fn + self.matryoshka_dims = matryoshka_dims + self.matryoshka_weights = matryoshka_weights + self.n_dims_per_step = n_dims_per_step + + + def shrink(self, tensor: Tensor, dim: int) -> Tensor: + tensor_dim = tensor.shape[-1] + if dim > tensor_dim: + raise ValueError( + f"Dimension {dim} in matryoshka_dims cannot be greater than the model's embedding dimension: {tensor_dim}" + ) + tensor = tensor[..., : dim] + tensor = F.normalize(tensor, p=2, dim=-1) + return tensor + + def __call__(self, reps: list[list[Tensor]], *args) -> Tensor: + dim_indices = range(len(self.matryoshka_dims)) + if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices): + dim_indices = random.sample(dim_indices, self.n_dims_per_step) + + loss = 0.0 + for idx in dim_indices: + dim = self.matryoshka_dims[idx] + weight = self.matryoshka_weights[idx] + + truncated = [[self.shrink(r, dim) for r in rs] for rs in reps] + loss += weight * self.fn(truncated, *args) + + return loss + + class MatryoshkaLoss(nn.Module): def __init__( self, @@ -123,10 +158,6 @@ def __init__( super().__init__() self.model = model self.loss = loss - if isinstance(loss, CachedMultipleNegativesRankingLoss): - warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2) - if isinstance(loss, CachedGISTEmbedLoss): - warnings.warn("MatryoshkaLoss is not compatible with CachedGISTEmbedLoss.", stacklevel=2) if matryoshka_weights is None: matryoshka_weights = [1] * len(matryoshka_dims) @@ -135,7 +166,13 @@ def __init__( self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True)) self.n_dims_per_step = n_dims_per_step + if isinstance(loss, CachedMultipleNegativesRankingLoss) or isinstance(loss, CachedGISTEmbedLoss) or isinstance(loss, CachedMultipleNegativesSymmetricRankingLoss): + loss.calculate_loss = CachedLossDecorator(loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights) + def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: + if isinstance(self.loss, CachedMultipleNegativesRankingLoss) or isinstance(self.loss, CachedMultipleNegativesSymmetricRankingLoss) or isinstance(self.loss, CachedGISTEmbedLoss): + return self.loss(sentence_features, labels) + original_forward = self.model.forward try: decorated_forward = ForwardDecorator(original_forward)