From 1bb592a99fd48705fec0b15bc1086be04b05126a Mon Sep 17 00:00:00 2001 From: Marcel Brunnbauer Date: Wed, 27 Nov 2024 09:17:32 +0100 Subject: [PATCH] fix backward pass for cached losses --- .../losses/CachedGISTEmbedLoss.py | 8 ++++++-- .../losses/CachedMultipleNegativesRankingLoss.py | 8 +++++--- ...CachedMultipleNegativesSymmetricRankingLoss.py | 8 +++++--- sentence_transformers/losses/MatryoshkaLoss.py | 15 +++++++++++---- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 747b4c540..53e39a531 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -230,14 +230,15 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: """Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" loss = self.calculate_loss(reps, reps_guided) - loss.backward() loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] return loss - def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: + def calculate_loss( + self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]], with_backward: bool = False + ) -> Tensor: """Generalized function to calculate the cross-entropy loss without caching gradients.""" if len(reps) != len(reps_guided): raise ValueError("reps and reps_guided must have the same length") @@ -291,6 +292,9 @@ def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor # Normalize the scores and calculate the cross-entropy loss scores = scores / self.temperature loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size + if with_backward: + loss_mbatch.backward() + loss_mbatch = loss_mbatch.detach() losses.append(loss_mbatch) loss = sum(losses) diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index dfdff2c82..74fe8ae64 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -213,15 +213,14 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor: """Calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - loss = self.calculate_loss(reps) - loss.backward() + loss = self.calculate_loss(reps, with_backward=True) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) return loss - def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: + def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor: """Calculate the cross-entropy loss. No need to cache the gradients.""" embeddings_a = torch.cat(reps[0]) # (bsz, hdim) embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim) @@ -241,6 +240,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: e = b + self.mini_batch_size scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size + if with_backward: + loss_mbatch.backward() + loss_mbatch = loss_mbatch.detach() losses.append(loss_mbatch) loss = sum(losses) diff --git a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py index 20a3054d5..4b0995c96 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py @@ -182,15 +182,14 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor: """Calculate the symmetric loss and cache gradients.""" - loss = self.calculate_loss(reps) - loss.backward() + loss = self.calculate_loss(reps, with_backward=True) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) return loss - def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: + def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor: """Calculate the symmetric loss without caching gradients (for evaluation).""" embeddings_a = torch.cat(reps[0]) # (bsz, hdim) embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim) @@ -214,6 +213,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor: backward_loss: torch.Tensor = self.cross_entropy_loss(positive_scores.t(), labels[: len(positive_scores)]) loss_mbatch = (forward_loss + backward_loss) / 2 + if with_backward: + loss_mbatch.backward() + loss_mbatch = loss_mbatch.detach() losses.append(loss_mbatch) loss = sum(losses) / len(losses) diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index fb71cc15d..33c66d5f9 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -81,7 +81,7 @@ def __init__( self.matryoshka_weights = matryoshka_weights self.n_dims_per_step = n_dims_per_step - def __call__(self, reps: list[list[Tensor]], *args) -> Tensor: + def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor: dim_indices = range(len(self.matryoshka_dims)) if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices): dim_indices = random.sample(dim_indices, self.n_dims_per_step) @@ -91,9 +91,16 @@ def __call__(self, reps: list[list[Tensor]], *args) -> Tensor: dim = self.matryoshka_dims[idx] weight = self.matryoshka_weights[idx] - truncated = [[shrink(r, dim) for r in rs] for rs in reps] - loss += weight * self.fn(truncated, *args) - + truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps] + # we need to detach the truncated embeddings, + # otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation + detached = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated] + loss += weight * self.fn(detached, *args, **kwargs) + # After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation + # the gradients must be multipied with the weights because otherwise the matryoshka weights are not considered in the backward pass + for t_minibatch, d_minibatch in zip(truncated, detached): + for t, d in zip(t_minibatch, d_minibatch): + t.backward(weight * d.grad) return loss