-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory leak in SentenceTransformer.encode during the first ~10000 predictions #1795
Comments
@Dobiasd hi, did you have any solutions? i meet this issue either but i don't notice there is any upper bound, it always occupy as much memory as it can =.= |
No, I did not find a solution. My workaround is to have a memory limit on the affected Kubernetes pods, regularly have them be |
Hey @Dobiasd , have u tried to see if setting the inner torch models to |
@JoanFM No, I have not yet tried those things. Can you help me by showing me to do them? |
Try this: import random
import string
import psutil
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
def random_string(length: int) -> str:
return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))
model = model.eval()
import torch
with torch.no_grad():
print('iteration,memory_usage_in_MiB', flush=True)
for iteration in range(99999999):
model.encode([random_string(12345) for _ in range(200)])
memory_usage_in_MiB = psutil.Process().memory_info().rss / (1024 * 1024)
print(f'{iteration},{memory_usage_in_MiB}', flush=True) in theory, only one of the changes may be enough |
Thanks a lot! I just tested with your version, but sadly it's still leaking (output). 😐 |
Hey @Dobiasd, with this change, did you get the OOM in Kubernetes? |
No, so far I only tested with the minimal example (dockerized) as shown in my original post. |
but u get an OOM? |
The memory consumption grows and grows, as shown in the linked output. |
Having the same issue when training the model with teacher embeddings. After a while it grows so much that the container kills the process. |
Having the same issue. I used psutil to check the memory info, and I found the memory leak may occur in BertModel. Unfortunately, I have no idea how to determine which line leads to memory leak. Hope anyone who can help us! |
Looks like this happens when you pass an array of data for encoding. If you call model.encode many times with one element only (by using an outer loop) there's no memory spike. At least that's what my tests show. |
I recently stumbled upon this problem and based on @rossbg advice I partitioned the data before calling Batch function def batched(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)] Encode process # create SentenceTransformer and set max_seq_length
embedding_model = SentenceTransformer("indobenchmark/indobert-large-p2")
embedding_model.max_seq_length = 512
# prepare dataset and calculate total iteration for tqdm
dataset = []# Total data: 20336
embedding_chunks = []
max_batch = np.ceil(len(dataset) / 128)
# process each batch
for cb in tqdm.tqdm(batched(dataset, 128), total=max_batch):
embedding_chunks.append(embedding_model.encode(cb, batch_size=128))
# stack all embeddings into one
all_embeddings = np.vstack(embedding_chunks) Memory usage on Colab using V100 GPU Based on my experiments, you can fine tune the Ref: |
I test your method, it's still growing. code: # memleak.py
import random
import string
import psutil
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
def random_string(length: int) -> str:
return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))
print('iteration,memory_usage_in_MiB', flush=True)
for iteration in range(99999999):
a = [model.encode(random_string(12345)) for _ in range(200)] # <- CHANGE HERE
memory_usage_in_MiB = psutil.Process().memory_info().rss / (1024 * 1024)
print(f'{iteration:02d}, {memory_usage_in_MiB:.2f}', flush=True) output:
|
I have seen the same growing memory issue with same model 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' |
Hello! Although I can reproduce the results from this issue, the issues disappear if we change the type of input to just a bunch of words using NLTK: from nltk.corpus import words
import random
import string
import psutil
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
def random_words(length: int) -> str:
return " ".join(random.sample(words.words(), k=length))
print('iteration,memory_usage_in_MiB', flush=True)
for iteration in range(20):
model.encode([random_words(2000) for _ in range(200)])
memory_usage_in_MiB = psutil.Process().memory_info().rss / (1024 * 1024)
print(f'{iteration},{memory_usage_in_MiB}', flush=True)
In short, I'm struggling to see a real memory leak at this time. I'd love for you to prove me wrong, though - I would love to reduce memory issues for my users.
|
When using 'sentence-transformers/all-MiniLM-L6-v2', still using random_string(), the issues disappear too. It's very interesting. def random_string(length: int) -> str: iteration,memory_usage_in_MiB |
@tomaarsen Ah, I guess the "leak" disappears when using this fixed set of words (instead of random strings) because the set of possible tokens is limited that way. 👍 To give some context: I ran into this problem while we were processing chat messages from a large online (international) user base. These users tend to produce so many different words (including typos, etc.) that the memory usage does not stop growing at a reasonable amount. |
It may be memory grew for me by using random strings with very long words |
Interesting. I would love to avoid the memory issues with these odd edge cases as well. I remember a similar case where someone tried to do sentence segmentation on Wikipedia edits, but it would sometimes stop working - it ended up being caused by someone who edited a sequence of "aaaaaaa..." with a length of 10k, and the segmenter couldn't handle that 😄 |
Any news on this? I get cuda GPU out of memory when looping encoding text at some point in the loop, even if I do reset torch's cache:
|
@0xtotem
|
Hi @tomaarsen, my GPU has 16GB of VRAM and I am processing text that may embed code. For context, I'm encoding pieces in a loop and storing the result on disk (cache) for later usage. |
Hey @tomaarsen, I'm having this issue too. I'm using an A100 PCIe with 80gb of VRAM with the My data consists of subtitles, so it could be a tokenization issue. I haven't checked the data for cleanliness or rating, these are the first 128ish rows. (I split them based on number of chars, each file should be around 600k). from sentence_transformers import SentenceTransformer
model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
import json
with open('./test.json') as handler:
data_one = json.load(handler)
with open('./test2.json') as handler:
data_two = json.load(handler)
batches = [data_one, data_two]
for batch in batches:
embeddings = model.encode([episode["caption"] for episode in batch], batch_size=32, show_progress_bar=True) Usually, the first iteration works perfectly, it OOMs when it reaches the second file. If we switch the order of the files, same result. It works on the first file, and OOMs on the second one. I've tried the torch.cuda.empty_cache, gc.collect, running it with no_grad, running each batch in a subprocess, but still no luck. Especially on the subprocesses, normally, memory frees up when I kill the interpreter, this doesn't happen with subprocesses (loading the model in a subprocess, encoding, and exiting the subprocess) - And I OOM. I'm ready to try fixes, thanks Tom! |
I reduced the batch_size to 16 and reduced the maximum number of characters per iteration to 300k. OOMed on the 3rd batch. Oh? It works when I skip batch 3, 4, 5. It OOMs on Batch 10 again, Edit: To note, after an OOM, I need to restart the interpreter. Else, it'll continue to OOM even when it otherwise wouldn't (on a freshly started interpreter). Could be a cleanup issue? |
+1 |
@TheOnlyWayUp I experimented with your code, and it seems like you're experiencing an OOM because of the really large maximum sequence length of 8192 tokens that the gte model allows. In Sentence Transformers, each batch is tokenized to the largest sequence in the batch, until the max_seq_length after which we truncate. In test1.json with a batch size of 32, you end up with these shapes for the inputs:
test2.json has these:
and test3.json has this one:
As you can imagine, each of these have wildly different memory requirements, and some of these ( I can't explain why it happens somewhat arbitrarily though. Either way, if you want to use the model with the massive sequence length, then you can create some extremely long dummy text and encode with that. That way you'll know for sure that the tokenizer is reaching the maximum sequence length. Then you can tune your batch size to the highest it'll go without OOM. Because your batch is as large as it can possibly be, this should be your upper limit of memory usage, e.g.: from sentence_transformers import SentenceTransformer
model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
data = ["a " * 100_000] * 1000
embeddings = model.encode(data, batch_size=32, show_progress_bar=True)
print(embeddings.shape) For me, this gives My real recommendation is to reduce the maximum sequence length used by the model. Keep in mind that it's currently trying to compress 8192 tokens into an embedding of 1024 values: it'll always lose a lot of context. I'm rather confident that you'll get roughly equivalent performance with a sequence length of 2048 or even 512 (plus, it'll be MUCH faster): from sentence_transformers import SentenceTransformer
import json
model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
model.max_seq_length = 2048
data = ["a " * 100_000] * 1000
embeddings = model.encode(data, batch_size=32, show_progress_bar=True)
print(embeddings.shape)
This is standard when getting a CUDA OOM, you always have to restart the interpreter then.
|
I just realized that this could be the issue: the number of tokens can suddenly be bigger than what you've seen before. You can verify whether this is the case by feeding the model with a bunch of tokens with your batch size, and see if that results in an OOM. E.g.: data = ["a " * 100_000] * 1000
embeddings = model.encode(data, batch_size=32, show_progress_bar=True)
|
@Dobiasd After some more digging, it looks like https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 uses a I think I'll open an issue on https://github.com/huggingface/tokenizers to see if this is expected behaviour or a bug.
|
Hey @tomaarsen, thank you for the detailed response! The code snippets are especially helpful, I'll use it to tune my batch-sizes in the future. The model card had a pytorch example, so I tried it just to check. This is what I stuck to using: import torch.nn.functional as F
import torch
from transformers import AutoModel, AutoTokenizer
model_path = 'Alibaba-NLP/gte-large-en-v1.5'
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, unpad_inputs=True, use_memory_efficient_attention=True).to(device)
def embed(texts):
with torch.inference_mode():
# Tokenize the input texts
batch_dict = tokenizer(texts, max_length=8192, padding=True, truncation=True, return_tensors='pt').to(device)
outputs = model(**batch_dict)
embeddings = outputs.last_hidden_state[:, 0]
return embeddings The snippet took my footprint from 78gb to 27gb. And no OOMs, so I was able to embed my dataset in its entirety (even with the weird shapes). It might be the options (memory_efficient_attention, I disabled mixed precision) and the use of Does this line up with everything mentioned in this thread? The VRAM drop was a pleasant surprise. Thanks! |
@TheOnlyWayUp Since Sentence Transformers v3.0.0 it's possible to pass kwargs to AutoModel used internally in Sentence Transformers, so I think you can reproduce your performance above with: from sentence_transformers import SentenceTransformer
model = SentenceTransformer("Alibaba-NLP/gte-large-en-v1.5", trust_remote_code=True, model_kwargs={"unpad_inputs": True, "use_memory_efficient_attention": True}) (But also, if it's not broke, don't fix it 😄 Just sharing the word about the new
|
The following minimal example repeatedly calls
SentenceTransformer.encode
on random strings of fixed length (12345
) and fixed number of strings (200
), and it records the memory usage.For the first ~50 calls (~10000 predictions), the memory usage grows enormously.
Output:
The larger the input strings, the higher the memory usage. But it always stops at this point.
I'm using no GPU, and the behavior can be reproduced with the following
Dockerfile
:Is this a memory leak or intended behavior?
The text was updated successfully, but these errors were encountered: