Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix backward pass for cached losses #3114

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,16 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
loss = self.calculate_loss(reps, reps_guided)
loss.backward()
loss = self.calculate_loss(reps, reps_guided, with_backward=True)
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps]

return loss

def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
def calculate_loss(
self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]], with_backward: bool = False
) -> Tensor:
"""Generalized function to calculate the cross-entropy loss without caching gradients."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")
Expand Down Expand Up @@ -291,6 +292,9 @@ def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor
# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
if with_backward:
loss_mbatch.backward()
loss_mbatch = loss_mbatch.detach()
losses.append(loss_mbatch)

loss = sum(losses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,14 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
loss = self.calculate_loss(reps)
loss.backward()
loss = self.calculate_loss(reps, with_backward=True)
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor:
"""Calculate the cross-entropy loss. No need to cache the gradients."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)
Expand All @@ -241,6 +240,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
e = b + self.mini_batch_size
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
if with_backward:
loss_mbatch.backward()
loss_mbatch = loss_mbatch.detach()
losses.append(loss_mbatch)

loss = sum(losses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,14 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the symmetric loss and cache gradients."""
loss = self.calculate_loss(reps)
loss.backward()
loss = self.calculate_loss(reps, with_backward=True)
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor:
"""Calculate the symmetric loss without caching gradients (for evaluation)."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)
Expand All @@ -214,6 +213,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
backward_loss: torch.Tensor = self.cross_entropy_loss(positive_scores.t(), labels[: len(positive_scores)])

loss_mbatch = (forward_loss + backward_loss) / 2
if with_backward:
loss_mbatch.backward()
loss_mbatch = loss_mbatch.detach()
losses.append(loss_mbatch)

loss = sum(losses) / len(losses)
Expand Down
21 changes: 17 additions & 4 deletions sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor, nn

Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(
self.matryoshka_weights = matryoshka_weights
self.n_dims_per_step = n_dims_per_step

def __call__(self, reps: list[list[Tensor]], *args) -> Tensor:
def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor:
dim_indices = range(len(self.matryoshka_dims))
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
dim_indices = random.sample(dim_indices, self.n_dims_per_step)
Expand All @@ -91,9 +92,21 @@ def __call__(self, reps: list[list[Tensor]], *args) -> Tensor:
dim = self.matryoshka_dims[idx]
weight = self.matryoshka_weights[idx]

truncated = [[shrink(r, dim) for r in rs] for rs in reps]
loss += weight * self.fn(truncated, *args)

truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps]
compute_gradients = torch.is_grad_enabled()
# we need to detach the truncated embeddings,
# otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation
if compute_gradients:
matryoshka_reps = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated]
else:
matryoshka_reps = truncated
loss += weight * self.fn(matryoshka_reps, *args, **kwargs)
# After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation
# the gradients must be multipied with the weights because otherwise the matryoshka weights are not considered in the backward pass
if compute_gradients:
for t_minibatch, d_minibatch in zip(truncated, matryoshka_reps):
for t, d in zip(t_minibatch, d_minibatch):
t.backward(weight * d.grad)
return loss


Expand Down
Loading