From 0292b9b48db27cce603517b55ec2e1fa36629bab Mon Sep 17 00:00:00 2001 From: Marcel Brunnbauer Date: Tue, 10 Dec 2024 11:24:20 +0100 Subject: [PATCH] don't perform backward pass in evaluation mode --- sentence_transformers/losses/MatryoshkaLoss.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index 33c66d5f9..4dfe0acb4 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from typing import Any +import torch import torch.nn.functional as F from torch import Tensor, nn @@ -92,15 +93,20 @@ def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor: weight = self.matryoshka_weights[idx] truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps] + compute_gradients = torch.is_grad_enabled() # we need to detach the truncated embeddings, # otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation - detached = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated] - loss += weight * self.fn(detached, *args, **kwargs) + if compute_gradients: + matryoshka_reps = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated] + else: + matryoshka_reps = truncated + loss += weight * self.fn(matryoshka_reps, *args, **kwargs) # After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation # the gradients must be multipied with the weights because otherwise the matryoshka weights are not considered in the backward pass - for t_minibatch, d_minibatch in zip(truncated, detached): - for t, d in zip(t_minibatch, d_minibatch): - t.backward(weight * d.grad) + if compute_gradients: + for t_minibatch, d_minibatch in zip(truncated, matryoshka_reps): + for t, d in zip(t_minibatch, d_minibatch): + t.backward(weight * d.grad) return loss