Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GradCache + MNRL: Go beyond GPU-memory limit for MNRL #1759

Merged
merged 8 commits into from
Jan 16, 2024

Conversation

kwang2049
Copy link
Member

@kwang2049 kwang2049 commented Nov 16, 2022

What
Use the technique of GradCache to help MNRL loss go beyond the limit of GPU memory size. This PR adds a new corresponding loss class called CachedMultipleNegativeRankingLoss (CMNRL). A corresponding test is also included.

One can simply replace MNRL with CMNRL and do the training as usual. The batch size can then be set to very large ones, e.g. 65536. I am running an experiment on MSMARCO (batch size 65536, learning rate 6e-4, warmup steps 8, mini-batch size 64) and I found the GPU memory usage is only about 16GB, but it also slows down the training by around 2x.

How it works: (also documented in the class docstring)
(1) It first does a quick embedding step without gradients/computation graphs to get all the embeddings;
(2) Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings;
(3) A 2nd embedding step with gradients/computation graphs and connecting the cached gradients into the backward chain.

A small difference between this one and the original implementation is here all three steps are done with mini-batches, whereas only (1) and (3) are done so in the original one. I found this is very important to scale up to crazy batch sizes like 65536.

Why
MNRL or in-batch negative training is usually hard to scale up to large batch sizes, since the data points within the same batch are not independent of one another. However, recent work shows large batch sizes are the key to the performance of dense representations with such a training approach. This PR aims at solving such a problem and letting MNRL go beyond of GPU-memory limit by working around the chain rule and using mini-batches during training.

@kwang2049 kwang2049 requested a review from nreimers November 16, 2022 14:54
@kwang2049
Copy link
Member Author

kwang2049 commented Nov 21, 2022

Sanity check
By running https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder_mnrl.py with MNRL and this new CMNRL (i.e. keep other settings while changing only the loss class):

  • MNRL: 5h30min,
    • NDCG@1: 0.2182| NDCG@3: 0.3213| NDCG@5: 0.3569| NDCG@10: 0.3923| NDCG@100: 0.4481| NDCG@1000: 0.4622| MAP@1: 0.2119| MAP@3: 0.2928| MAP@5: 0.3128| MAP@10: 0.3278| MAP@100: 0.3392| MAP@1000: 0.3397| Recall@1: 0.2119| Recall@3: 0.3970| Recall@5: 0.4823| Recall@10: 0.5897| Recall@100: 0.8485| Recall@1000: 0.9576| P@1: 0.2182| P@3: 0.1369| P@5: 0.1001| P@10: 0.0615| P@100: 0.0090| P@1000: 0.0010| MRR@1: 0.2182| MRR@3: 0.2998| MRR@5: 0.3197| MRR@10: 0.3342| MRR@100: 0.3452| MRR@1000: 0.3457
  • CMNRL: 15h15min,
    • NDCG@1: 0.2195| NDCG@3: 0.3208| NDCG@5: 0.3580| NDCG@10: 0.3919| NDCG@100: 0.4479| NDCG@1000: 0.4622| MAP@1: 0.2130| MAP@3: 0.2928| MAP@5: 0.3137| MAP@10: 0.3280| MAP@100: 0.3394| MAP@1000: 0.3400| Recall@1: 0.2130| Recall@3: 0.3947| Recall@5: 0.4841| Recall@10: 0.5870| Recall@100: 0.8475| Recall@1000: 0.9580| P@1: 0.2195| P@3: 0.1363| P@5: 0.1006| P@10: 0.0612| P@100: 0.0089| P@1000: 0.0010| MRR@1: 0.2195| MRR@3: 0.3001| MRR@5: 0.3206| MRR@10: 0.3346| MRR@100: 0.3455| MRR@1000: 0.3460

So this implementation can yield nearly identical results. However, it takes ~x.2.77 time. The GPU is NVIDIA Tesla V100 SXM3 32 GB and the CUDA version is 11.6.

@tomaarsen
Copy link
Collaborator

Hello!

This looks fascinating! Your results are promising indeed. I would be curious to experiment with benchmarking MNRL vs CMNRL with a certain memory budget. CMNRL should then allow for a higher batch size & then also (ideally) result in superior performance, at the cost of training time.

I will try to invest some time into this work at the start of the new year.

  • Tom Aarsen

@kwang2049
Copy link
Member Author

Hello!

This looks fascinating! Your results are promising indeed. I would be curious to experiment with benchmarking MNRL vs CMNRL with a certain memory budget. CMNRL should then allow for a higher batch size & then also (ideally) result in superior performance, at the cost of training time.

I will try to invest some time into this work at the start of the new year.

  • Tom Aarsen

That would be awesome. Looking forward to your news!

tests/test_cmnrl.py Outdated Show resolved Hide resolved
@tomaarsen
Copy link
Collaborator

tomaarsen commented Jan 15, 2024

Hello!

I've taken care of the merge conflicts & did some experiments locally, showing that I could indeed increase the batch size. We all know that batch size is crucial in embedding models, so I have one more additional question: If this is indeed implemented correctly, does this approach/estimate the result of a higher batch size, or is this computationally equivalent to having a larger batch size?
I suppose a follow-up question is: Why would anyone ever use Multiple Negative Ranking Loss over Cached Multiple Negative Ranking Loss moving forward? Perhaps training time or sufficiently good hardware that an even higher batch size doesn't help?

Looking forward to your opinions. Also, I appreciate your inclusion of the tests! I think this is almost ready to merge :)

  • Tom Aarsen

@kwang2049
Copy link
Member Author

Hello!

I've taken care of the merge conflicts & did some experiments locally, showing that I could indeed increase the batch size. We all know that batch size is crucial in embedding models, so I have one more additional question: If this is indeed implemented correctly, does this approach/estimate the result of a higher batch size, or is this computationally equivalent to having a larger batch size? I suppose a follow-up question is: Why would anyone ever use Multiple Negative Ranking Loss over Cached Multiple Negative Ranking Loss moving forward? Perhaps training time or sufficiently good hardware that an even higher batch size doesn't help?

Looking forward to your opinions. Also, I appreciate your inclusion of the tests! I think this is almost ready to merge :)

  • Tom Aarsen

Thanks for your efforts!

For question 1: please refer to Fig 4 in https://arxiv.org/pdf/2010.08191.pdf for the previous experiments on the influence of batch size when using random negatives; and this CMNRL is computationally equivalent to the case when people simply switch to very large memory GPU (cf. the gradcache paper)

For question2: the main disadvantage is about the efficiency. You can see my reported numbers, it took x2.77 time.

tests/test_cmnrl.py Outdated Show resolved Hide resolved
@tomaarsen
Copy link
Collaborator

Thank you for clarifying! For reference, the 2.77x matches my findings, I experienced about a 2.45x increase in training time. That could still be totally worth it to reach that 4k or 8k negatives. I'll be merging this, and I'll include this as a key feature in the upcoming v2.3.0 release.

The release should be soon: essentially as soon as Nils gives me access to update sbert.net. Thanks a bunch for this very valuable addition - I'll be sharing it with some model builders that might be interested in it!

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 009869d into UKPLab:master Jan 16, 2024
9 checks passed
@kwang2049
Copy link
Member Author

Thank you for clarifying! For reference, the 2.77x matches my findings, I experienced about a 2.45x increase in training time. That could still be totally worth it to reach that 4k or 8k negatives. I'll be merging this, and I'll include this as a key feature in the upcoming v2.3.0 release.

The release should be soon: essentially as soon as Nils gives me access to update sbert.net. Thanks a bunch for this very valuable addition - I'll be sharing it with some model builders that might be interested in it!

  • Tom Aarsen

Thanks for your help! Glad to hear that the PR makes some contribution:). Looking forward to people's future feedback

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants