Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GradCache + MNRL: Go beyond GPU-memory limit for MNRL #1759

Merged
merged 8 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from __future__ import annotations
from contextlib import nullcontext
from functools import partial
import torch
from torch import nn, Tensor
from torch.utils.checkpoint import get_device_states, set_device_states
from typing import Iterable, Dict, Iterator, List, Optional, Tuple
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import tqdm


class RandContext:
"""
Random-state context manager class. Reference: https://github.com/luyug/GradCache.

This class will back up the pytorch's random state during initialization. Then when the context is activated,
the class will set up the random state with the backed-up one.
"""

def __init__(self, *tensors):
self.fwd_cpu_state = torch.get_rng_state()
self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)

def __enter__(self):
self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices, enabled=True)
self._fork.__enter__()
torch.set_rng_state(self.fwd_cpu_state)
set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)

def __exit__(self, exc_type, exc_val, exc_tb):
self._fork.__exit__(exc_type, exc_val, exc_tb)
self._fork = None


def _backward_hook(
grad_output: Tensor,
sentence_features: Iterable[Dict[str, Tensor]],
loss_obj: CachedMultipleNegativesRankingLoss,
):
"""A backward hook to backpropagate the cached gradients mini-batch by mini-batch."""
assert loss_obj.cache is not None
assert loss_obj.random_states is not None
with torch.enable_grad():
for sentence_feature, grad, random_states in zip(sentence_features, loss_obj.cache, loss_obj.random_states):
for (reps_mb, _), grad_mb in zip(
loss_obj.embed_minibatch_iter(
sentence_feature=sentence_feature,
with_grad=True,
copy_random_state=False,
random_states=random_states,
),
grad,
):
surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output
surrogate.backward()


class CachedMultipleNegativesRankingLoss(nn.Module):
"""
Boosted version of MultipleNegativesRankingLoss (https://arxiv.org/pdf/1705.00652.pdf) by GradCache (https://arxiv.org/pdf/2101.06983.pdf).

Constrastive learning (here our MNRL loss) with in-batch negatives is usually hard to work with large batch sizes due to (GPU) memory limitation.
Even with batch-scaling methods like gradient-scaling, it cannot work either. This is because the in-batch negatives make the data points within
the same batch non-independent and thus the batch cannot be broke down into mini-batches. GradCache is a smart way to solve this problem.
It achieves the goal by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches.
As a result, memory of constant size (e.g. that works with batch size = 32) can now process much larger batches (e.g. 65536).

In detail:
(1) It first does a quick embedding step without gradients/computation graphs to get all the embeddings;
(2) Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings;
(3) A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain.

Notes: All steps are done with mini-batches. In the original implementation of GradCache, (2) is not done in mini-batches and
requires a lot memory when batch size large. One drawback is about the speed. GradCache will sacrifice around 20% computation time according to the paper.

Example:

from sentence_transformers import SentenceTransformer
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1024) # Here we can try much larger batch sizes!
train_loss = losses.CachedMultipleNegativesRankingLoss(model=model, mini_batch_size: int = 32)
"""

def __init__(
self,
model: SentenceTransformer,
scale: float = 20.0,
similarity_fct: callable[[Tensor, Tensor], Tensor] = util.cos_sim,
mini_batch_size: int = 32,
show_progress_bar: bool = False,
):
"""
:param model: SentenceTransformer model
:param scale: Output of similarity function is multiplied by scale value
:param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
"""
super(CachedMultipleNegativesRankingLoss, self).__init__()
self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.mini_batch_size = mini_batch_size
self.cache: Optional[List[List[Tensor]]] = None
self.random_states: Optional[List[List[RandContext]]] = None
self.show_progress_bar = show_progress_bar

def embed_minibatch(
self,
sentence_feature: Dict[str, Tensor],
begin: int,
end: int,
with_grad: bool,
copy_random_state: bool,
random_state: Optional[RandContext] = None,
) -> Tuple[Tensor, Optional[RandContext]]:
"""Do forward pass on a minibatch of the input features and return corresponding embeddings."""
grad_context = nullcontext if with_grad else torch.no_grad
random_state_context = nullcontext() if random_state is None else random_state
sentence_feature_minibatch = {k: v[begin:end] for k, v in sentence_feature.items()}
with random_state_context:
with grad_context():
random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None
reps = self.model(sentence_feature_minibatch)["sentence_embedding"] # (mbsz, hdim)
return reps, random_state

def embed_minibatch_iter(
self,
sentence_feature: Dict[str, Tensor],
with_grad: bool,
copy_random_state: bool,
random_states: Optional[List[RandContext]] = None,
) -> Iterator[Tuple[Tensor, Optional[RandContext]]]:
"""Do forward pass on all the minibatches of the input features and yield corresponding embeddings."""
input_ids: Tensor = sentence_feature["input_ids"]
bsz, _ = input_ids.shape
for i, b in enumerate(
tqdm.trange(
0,
bsz,
self.mini_batch_size,
desc="Embed mini-batches",
disable=not self.show_progress_bar,
)
):
e = b + self.mini_batch_size
reps, random_state = self.embed_minibatch(
sentence_feature=sentence_feature,
begin=b,
end=e,
with_grad=with_grad,
copy_random_state=copy_random_state,
random_state=None if random_states is None else random_states[i],
)
yield reps, random_state # reps: (mbsz, hdim)

def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]]) -> Tensor:
"""Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)

batch_size = len(embeddings_a)
labels = torch.tensor(
range(batch_size), dtype=torch.long, device=embeddings_a.device
) # (bsz, (1 + nneg) * bsz) Example a[i] should match with b[i]
losses: List[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())

loss = sum(losses).requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor) -> Tensor:
# Step (1): A quick embedding step without gradients/computation graphs to get all the embeddings
reps = []
self.random_states = [] # Copy random states to guarantee exact reproduction of the embeddings during the second forward pass, i.e. step (3)
for sentence_feature in sentence_features:
reps_mbs = []
random_state_mbs = []
for reps_mb, random_state in self.embed_minibatch_iter(
sentence_feature=sentence_feature,
with_grad=False,
copy_random_state=True,
):
reps_mbs.append(reps_mb.detach().requires_grad_())
random_state_mbs.append(random_state)
reps.append(reps_mbs)
self.random_states.append(random_state_mbs)

# Step (2): Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings
loss = self.calculate_loss_and_cache_gradients(reps)

# Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain
loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self))
return loss

def get_config_dict(self):
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
2 changes: 2 additions & 0 deletions sentence_transformers/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .TripletLoss import TripletDistanceMetric, TripletLoss
from .MarginMSELoss import MarginMSELoss
from .MSELoss import MSELoss
from .CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss
from .ContrastiveLoss import SiameseDistanceMetric, ContrastiveLoss
from .ContrastiveTensionLoss import (
ContrastiveTensionLoss,
Expand Down Expand Up @@ -32,6 +33,7 @@
"MSELoss",
"ContrastiveLoss",
"SiameseDistanceMetric",
"CachedMultipleNegativesRankingLoss",
"ContrastiveTensionLoss",
"ContrastiveTensionLossInBatchNegatives",
"ContrastiveTensionDataLoader",
Expand Down
148 changes: 148 additions & 0 deletions tests/test_cmnrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from contextlib import nullcontext
from typing import List
import pytest
from sentence_transformers import SentenceTransformer, InputExample, losses
import tqdm
from transformers import set_seed
import torch
from torch.optim import Adam


@pytest.mark.parametrize(
["train_samples_mnrl", "train_samples_cmnrl", "same_grad", "scaler", "precision"],
[
(
[
InputExample(texts=[q, p, n])
for q, p, n in zip(
["aaa", "bbb", "ccc", "ddd", "eee"],
["aas", "bbs", "ccs", "dds", "ees"],
["xxx", "yyy", "zzz", "kkk", "fff"],
)
],
[
InputExample(texts=[q, p, n])
for q, p, n in zip(
["aaa", "bbb", "ccc", "ddd", "eee"],
["aas", "bbs", "ccs", "dds", "ees"],
["xxx", "yyy", "zzz", "kkk", "fff"],
)
],
True,
1.0,
1e-6,
),
(
[
InputExample(texts=[q, p, n])
for q, p, n in zip(
["adsa", "czx", "dsada"],
["b", "fas", "xcz"],
["c", "yyy", "asdas"],
)
],
[
InputExample(texts=[q, p, n])
for q, p, n in zip(
["aaa", "bbb", "ccc", "ddd", "eee"],
["aas", "bbs", "ccs", "dds", "ees"],
["xxx", "yyy", "zzz", "kkk", "fff"],
)
],
False,
1.0,
1e-6,
),
(
[
InputExample(texts=[q, p, n])
for q, p, n in zip(
["aaa", "bbb", "ccc", "ddd", "eee"],
["aas", "bbs", "ccs", "dds", "ees"],
["xxx", "yyy", "zzz", "kkk", "fff"],
)
],
[
InputExample(texts=[q, p, n])
for q, p, n in zip(
["aaa", "bbb", "ccc", "ddd", "eee"],
["aas", "bbs", "ccs", "dds", "ees"],
["xxx", "yyy", "zzz", "kkk", "fff"],
)
],
True,
1000.0,
1e-3,
),
],
)
def test_cmnrl_same_grad(
train_samples_mnrl: List[InputExample],
train_samples_cmnrl: List[InputExample],
same_grad: bool,
scaler: float,
precision: float,
):
# Given:
sbert = SentenceTransformer("distilbert-base-uncased")
sbert.to("cpu")
optimizer = Adam(sbert.parameters())
# train_samples_mnrl
# train_samples_cmnrl
# same_grad
# scaler # This simulates AMP scenarios
# precision

# When:
# First run with MNRL
set_seed(42)
optimizer.zero_grad()
loss_mnrl = losses.MultipleNegativesRankingLoss(sbert)
loss_mnrl_value: torch.Tensor = loss_mnrl.forward(*sbert.smart_batching_collate(train_samples_mnrl)) * scaler
loss_mnrl_value.backward()
grad_expected = {name: p.grad.clone() for name, p in loss_mnrl.named_parameters() if p.grad is not None}

# Then run with this cached version:
set_seed(42)
optimizer.zero_grad()
loss_cmnrl = losses.CachedMultipleNegativesRankingLoss(sbert, mini_batch_size=2)
loss_cmnrl_value = loss_cmnrl.forward(*sbert.smart_batching_collate(train_samples_cmnrl)) * scaler
loss_cmnrl_value.backward()
grad = {name: p.grad.clone() for name, p in loss_cmnrl.named_parameters() if p.grad is not None}

# Then:
if same_grad:
assert pytest.approx(loss_mnrl_value.item()) == loss_cmnrl_value.item()
else:
assert pytest.approx(loss_mnrl_value.item()) != loss_cmnrl_value.item()

nclose = 0
for name in tqdm.tqdm(grad_expected):
nclose += torch.allclose(grad[name], grad_expected[name], precision, precision)

if same_grad:
assert nclose == len(grad_expected)
else:
assert nclose != len(grad_expected)


@pytest.mark.parametrize("use_rand_context", [True, False])
def test_rand_context_working(use_rand_context: bool):
# Given:
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import (
RandContext,
)

a = torch.Tensor(1)
b = torch.Tensor(1)
random_state = RandContext(a, b) if use_rand_context else nullcontext()
expected = torch.rand(1000)
precision = 1e-6

# When:
with random_state:
# Then:
if use_rand_context:
assert torch.allclose(torch.rand(1000), expected, precision, precision)
else:
assert not torch.allclose(torch.rand(1000), expected, precision, precision)
Loading