Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix matryoshka cached mnr #3065

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def calculate_loss(self, reps: list[list[Tensor]]) -> Tensor:
loss = sum(losses)
return loss

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor, return_hook: bool = False) -> Tensor:
# Step (1): A quick embedding step without gradients/computation graphs to get all the embeddings
reps = []
self.random_states = [] # Copy random states to guarantee exact reproduction of the embeddings during the second forward pass, i.e. step (3)
Expand All @@ -287,11 +287,13 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor
loss = self.calculate_loss_and_cache_gradients(reps)

# Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain
loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self))
hook_handle = loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self))
else:
# If grad is not enabled (e.g. in evaluation), then we don't have to worry about the gradients or backward hook
loss = self.calculate_loss(reps)

if torch.is_grad_enabled() and return_hook:
return loss, hook_handle
return loss

def get_config_dict(self) -> dict[str, Any]:
Expand Down
45 changes: 44 additions & 1 deletion sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import random
import warnings
import torch
from collections.abc import Iterable
from typing import Any
from functools import partial

import torch.nn.functional as F
from torch import Tensor, nn
Expand Down Expand Up @@ -52,6 +54,39 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
return output


def _backward_hook(
grad_output: Tensor,
sentence_features: Iterable[dict[str, Tensor]],
loss_obj: CachedMultipleNegativesRankingLoss,
dim: int,
loss_obj_cache,
loss_obj_random_states
) -> None:
"""Customized from CachedMultipleNegativesRankingLoss."""
loss_obj.cache = loss_obj_cache
loss_obj.random_states = loss_obj_random_states
assert loss_obj.cache is not None
assert loss_obj.random_states is not None
original_forward = loss_obj.model.forward
decorated_forward = ForwardDecorator(original_forward)
decorated_forward.set_dim(dim)
loss_obj.model.forward = decorated_forward
with torch.enable_grad():
for sentence_feature, grad, random_states in zip(sentence_features, loss_obj.cache, loss_obj.random_states):
for (reps_mb, _), grad_mb in zip(
loss_obj.embed_minibatch_iter(
sentence_feature=sentence_feature,
with_grad=True,
copy_random_state=False,
random_states=random_states,
),
grad,
):
surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output
surrogate.backward()
loss_obj.model.forward = original_forward


class MatryoshkaLoss(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -150,7 +185,15 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor
dim = self.matryoshka_dims[idx]
weight = self.matryoshka_weights[idx]
decorated_forward.set_dim(dim)
loss += weight * self.loss(sentence_features, labels)

if torch.is_grad_enabled() and isinstance(self.loss, CachedMultipleNegativesRankingLoss):
loss_part, hook = self.loss(sentence_features, labels, return_hook=True)
# register our customized hook instead
hook.remove()
loss_part.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self.loss, dim=dim, loss_obj_cache=self.loss.cache, loss_obj_random_states=self.loss.random_states))
else:
loss_part = self.loss(sentence_features, labels)
loss += weight * loss_part
finally:
self.model.forward = original_forward
return loss
Expand Down
Loading