From c2d397f0912b16d5a2aff92b67782fdfc29e7f75 Mon Sep 17 00:00:00 2001 From: Marcel Brunnbauer Date: Wed, 4 Dec 2024 15:21:34 +0100 Subject: [PATCH] fix missing backward flag --- sentence_transformers/losses/CachedGISTEmbedLoss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 53e39a531..d76d3338e 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -229,7 +229,7 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: """Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - loss = self.calculate_loss(reps, reps_guided) + loss = self.calculate_loss(reps, reps_guided, with_backward=True) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps]