Skip to content

Commit

Permalink
fix backward pass for cached losses
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Brunnbauer committed Dec 3, 2024
1 parent e7641c8 commit 1bb592a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
8 changes: 6 additions & 2 deletions sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,15 @@ def embed_minibatch_iter(
def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
loss = self.calculate_loss(reps, reps_guided)
loss.backward()
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps]

return loss

def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
def calculate_loss(
self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]], with_backward: bool = False
) -> Tensor:
"""Generalized function to calculate the cross-entropy loss without caching gradients."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")
Expand Down Expand Up @@ -291,6 +292,9 @@ def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor
# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
if with_backward:
loss_mbatch.backward()
loss_mbatch = loss_mbatch.detach()
losses.append(loss_mbatch)

loss = sum(losses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,14 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
loss = self.calculate_loss(reps)
loss.backward()
loss = self.calculate_loss(reps, with_backward=True)
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor:
"""Calculate the cross-entropy loss. No need to cache the gradients."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)
Expand All @@ -241,6 +240,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
e = b + self.mini_batch_size
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
if with_backward:
loss_mbatch.backward()
loss_mbatch = loss_mbatch.detach()
losses.append(loss_mbatch)

loss = sum(losses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,14 @@ def embed_minibatch_iter(

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
"""Calculate the symmetric loss and cache gradients."""
loss = self.calculate_loss(reps)
loss.backward()
loss = self.calculate_loss(reps, with_backward=True)
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor:
"""Calculate the symmetric loss without caching gradients (for evaluation)."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)
Expand All @@ -214,6 +213,9 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
backward_loss: torch.Tensor = self.cross_entropy_loss(positive_scores.t(), labels[: len(positive_scores)])

loss_mbatch = (forward_loss + backward_loss) / 2
if with_backward:
loss_mbatch.backward()
loss_mbatch = loss_mbatch.detach()
losses.append(loss_mbatch)

loss = sum(losses) / len(losses)
Expand Down
15 changes: 11 additions & 4 deletions sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self.matryoshka_weights = matryoshka_weights
self.n_dims_per_step = n_dims_per_step

def __call__(self, reps: list[list[Tensor]], *args) -> Tensor:
def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor:
dim_indices = range(len(self.matryoshka_dims))
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
dim_indices = random.sample(dim_indices, self.n_dims_per_step)
Expand All @@ -91,9 +91,16 @@ def __call__(self, reps: list[list[Tensor]], *args) -> Tensor:
dim = self.matryoshka_dims[idx]
weight = self.matryoshka_weights[idx]

truncated = [[shrink(r, dim) for r in rs] for rs in reps]
loss += weight * self.fn(truncated, *args)

truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps]
# we need to detach the truncated embeddings,
# otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation
detached = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated]
loss += weight * self.fn(detached, *args, **kwargs)
# After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation
# the gradients must be multipied with the weights because otherwise the matryoshka weights are not considered in the backward pass
for t_minibatch, d_minibatch in zip(truncated, detached):
for t, d in zip(t_minibatch, d_minibatch):
t.backward(weight * d.grad)
return loss


Expand Down

0 comments on commit 1bb592a

Please sign in to comment.