diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 53e39a531..d76d3338e 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -229,7 +229,7 @@ def embed_minibatch_iter( def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: """Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - loss = self.calculate_loss(reps, reps_guided) + loss = self.calculate_loss(reps, reps_guided, with_backward=True) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps]