-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
safetensor/mmap memory leak when per-layer weights are converted do other dtypes #34366
Comments
Looks that way to me. IMHO it looks like a pretty devastating one. 50%-100% RAM usage increase! All because you loaded the model as one data type, and then converted the weights to another (a perfectly normal thing to do). Randomly saw this issue and thought I'd take a small look at it. Not sure if I'm on the right track, but there's a few things I've noticed:
*It might be, actually. Messing around with it, checking for memory usage right here in Module.convert shows each parameter passed into the Based on that, I replaced MODEL_ID in the reproduction with # First pass outputs when converting `nlpcloud/instruct-gpt-j-fp16`
# Convert torch.float16->torch.bfloat16 before: 447.11 MiB, after: 1215.24 MiB, diff: 768.12 MiB
# Convert torch.float16->torch.float32 before: 445.13 MiB, after: 1597.51 MiB, diff: 1152.38 MiB
# Convert torch.bfloat16->torch.float16 before: 11986.68 MiB, after: 11986.80 MiB, diff: 0.12 MiB
# Convert torch.bfloat16->torch.float32 before: 11986.41 MiB, after: 12370.47 MiB, diff: 384.06 MiB
# Final memory usage after torch.float16->torch.bfloat16: 21952.75 MiB
# Final memory usage after torch.bfloat16->torch.float16: 11986.54 MiB Since this model was uploaded with its weights in fp16, and I'm seeing the exact same behavior in reverse, I'm not convinced this is a model or dtype specific issue. My initial guess to what's going on is that loading the model's weights in its original tensor dtype vs loading them in a different dtype are going through two different code paths:
tl;dr: Maybe the bug is due to differences in code between loading a pretrained model with the original dtype vs with a new dtype (which needs to be converted)? The observations line up with this at least. If I'm on the right track hopefully it's something really obvious, but this seems like it might be an awful bug to hunt down. Not sure if safetensors has any involvement here and not sure if the possibility of an underlying bug in pytorch can be ruled out yet either. |
Updated original test to manually releasing individual layer weights
Based on this, the leakage is between import gc
import torch
from memory_profiler import memory_usage
from transformers import AutoModelForCausalLM
def mem(msg: str):
gc.collect()
m = memory_usage(-1, interval=0.1, timeout=1)[0]
mm = f"{msg}. memory_usage: {m} MiB"
print(mm)
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
mem("load model before")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
mem("load model after")
# print("model", model)
for i in range(0, 32):
mem("layer[{i}] to float32 before")
model.model.layers[i].to(torch.float32) # [LEAKING full layers[i] before .to() conversion]
mem("layer[{i}] to float32 after")
mem("del layer[{i}] before")
model.model.layers[i] = None # [LEAK NOT FIXED]
mem("del layer[{i}] after")
mem("del model.model.layers before")
del model.model.layers # [LEAK NOT FIXED]
mem("del model.model.layers after")
mem("del model.model before")
del model.model # [LEAK FIXED]
mem("del model.model after")
mem("del model before")
del model # [LEAK FIXED]
mem("del model after") |
fHey @Qubitium, thanks for the report and the reproducer. I hope that the following explanation can shred light on what's happening with the memory. In transformers, we use a torch trick (mmap + When we set the same
Note: The first forward pass will be a bit slower if you perform it on cpu since we need to move the weights to the cpu. If you move the model to CUDA, you should also see the increase of memory. When we don't specify the same dtype as the original dtype, we actually have to convert the weights, so the weights have to be on the cpu and can't stay on the disk. This is why when you specify
If you check the final memory for both, you will see that they are approximately the same.
Related PR for more context As for why we are not able to release the memory, I don't have an answer yet. |
@SunMarc Could you give a code example which demonstrates this? Specifically,
What do you mean by the "final memory" in this situation? Do you mean the memory once the tensors are finally being accessed for something? If this is the case, unless I'm doing something wrong, my reproduction below appears to still be holding on to memory when I'm starting with bfloat16: import gc
import torch
from memory_profiler import memory_usage
from transformers.models.llama.modeling_llama import LlamaForCausalLM
# FROM_DTYPE=bfloat16 leaks, all other values don't.
FROM_DTYPE = torch.bfloat16
TO_DTYPE = torch.float32
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_ID_1B = "meta-llama/Llama-3.2-1B-Instruct" # Smaller model = Actually can leak and not crash on my machine
def mem():
gc.collect()
m = memory_usage()[0]
return m
m_before = mem()
print(f"load model before: {m_before:.2f} MiB")
model = LlamaForCausalLM.from_pretrained(
MODEL_ID_1B,
torch_dtype=FROM_DTYPE,
device_map="cpu",
)
print(type(model))
m_after = mem()
print(f"load model after: {m_after:.2f} MiB")
print(f"{TO_DTYPE} layer before: {m_before:.2f} MiB, after: {m_after:.2f} MiB, diff: {m_after - m_before:.2f} MiB")
for i in range(0, len(model.model.layers)):
m_before = mem()
model.model.layers[i].to(TO_DTYPE)
m_after = mem()
print(f"{TO_DTYPE} layer before: {m_before:.2f} MiB, after: {m_after:.2f} MiB, diff: {m_after - m_before:.2f} MiB")
m_final = mem()
print(f"Memory utilization of converted model: {m_final:.2f} MiB")
# Do some ops to force the model to actually be realized in RAM
# This repro actually can't do a forward pass, since messing with the layers is messy
print("Dummy output", sum([torch.mean(p) for p in model.parameters()]))
m_final = mem()
print(f"Final memory usage of model: {m_final:.2f} MiB")
This might be an issue that specifically arises when you're directly messing around with internals. Simply calling P.S.: I tried recreating this behavior with plain old pytorch by saving, loading, and tampering with weights like we're doing here, but I can't get this leak to happen with what I've written. |
That seems to be the case. Maybe the PR @SunMarc linked gives some hints. It could be due to an interaction caused by passing in If there is a leak, hopefully it's something simple enough and entirely fixable in python. |
This comment was marked as outdated.
This comment was marked as outdated.
I think this is being sent as args to pytorch's
This logic makes sense, the main goal IMO should be trying to find what is being bound and where it's binding to. I'm having a hard time identifying the buffers in question. I didn't get very far with trying to But I feel like it has to either be the underlying param.data tensor, or something that holds onto it that is leaking. When I tried inspecting heap alloc in Do you have a specific part of the code you're suspecting? Ideally, we could |
Just my naive thoughts on what |
We are currently bypassing this bug by doing following monkeypatch at gptqmodel. Obviously the downside is now all the model layer weights are not lazy mmaped but loaded all into memory on init. def check_support_param_buffer_assignment(*args, **kwargs):
return False
# Fix cpu memory leak.
# See https://github.com/huggingface/transformers/issues/34366
modeling_utils.check_support_param_buffer_assignment = check_support_param_buffer_assignment |
Hah, I'm pretty new to this as well. This seems like a pretty involved bug. A lot of the normal memory leak detection strategies are falling short (probably because Pytorch is doing all of this stuff in C++ bindings, not python, and the trace gets lost?). Good callout on the finalizers, maybe tracing weakrefs/finalizes/del calls would give a hint as to what is happening. I'm inclined to agree with your logic. The fact that deleting the model frees the memory strongly supports this. And if you were able to workaround this issue by bypassing If I can somehow reproduce this issue without using Yeah, I think cutting down to a much smaller model, deleting it, then tracing what happens to a model when it gets deleted would help narrow down the cause. How exactly do that in python, I'd have to figure that one out 🙃 |
Not sure if I'm on the right track, but I've noticed some a couple of interesting things while hunting this down.
# Removing all the other modules from LlamaForCausalLM.model leads to no leak (but obviously breaks the model in other ways)
del model.model.embed_tokens
del model.model.norm
del model.model.rotary_emb
layers_to_keep = 1
for i in range(layers_to_keep, len(model.model.layers)):
model.model.layers[i] = None
gc.collect()
print("Remaining modules:", model.model._modules) # Confirm we got rid of everything except our desired layers Actually, if you mess around with
# bfloat16 to float16. 800MB = 400MB float16 + 400MB bfloat16
before: 434.54 MiB, after: 1266.29 MiB, diff: 831.75 MiB
before: 1266.29 MiB, after: 2097.54 MiB, diff: 831.25 MiB
before: 2097.54 MiB, after: 2929.29 MiB, diff: 831.75 MiB
before: 2929.29 MiB, after: 3760.79 MiB, diff: 831.50 MiB
before: 3760.79 MiB, after: 4592.41 MiB, diff: 831.62 MiB
before: 4592.41 MiB, after: 5423.66 MiB, diff: 831.25 MiB
before: 5423.66 MiB, after: 6255.04 MiB, diff: 831.38 MiB
before: 6255.04 MiB, after: 7086.66 MiB, diff: 831.62 MiB
before: 7086.66 MiB, after: 7918.54 MiB, diff: 831.88 MiB
before: 7918.54 MiB, after: 8750.29 MiB, diff: 831.75 MiB
before: 8750.29 MiB, after: 9581.66 MiB, diff: 831.38 MiB
before: 9581.66 MiB, after: 10413.41 MiB, diff: 831.75 MiB
before: 10413.41 MiB, after: 11244.91 MiB, diff: 831.50 MiB
before: 11244.91 MiB, after: 12075.91 MiB, diff: 831.00 MiB
before: 12075.91 MiB, after: 12907.66 MiB, diff: 831.75 MiB
before: 12907.66 MiB, after: 13739.54 MiB, diff: 831.88 MiB
before: 13739.54 MiB, after: 14571.29 MiB, diff: 831.75 MiB
before: 14571.29 MiB, after: 15403.04 MiB, diff: 831.75 MiB
before: 15403.04 MiB, after: 16234.41 MiB, diff: 831.38 MiB
before: 16234.41 MiB, after: 17066.16 MiB, diff: 831.75 MiB
- before: 17066.16 MiB, after: 13129.53 MiB, diff: -3936.63 MiB # Free
before: 13129.53 MiB, after: 13961.03 MiB, diff: 831.50 MiB
before: 13961.03 MiB, after: 14792.65 MiB, diff: 831.62 MiB
before: 14792.65 MiB, after: 15554.78 MiB, diff: 762.12 MiB
before: 15554.78 MiB, after: 16385.53 MiB, diff: 830.75 MiB
before: 16385.53 MiB, after: 17217.53 MiB, diff: 832.00 MiB
before: 17217.53 MiB, after: 18048.90 MiB, diff: 831.38 MiB
before: 18048.90 MiB, after: 18880.40 MiB, diff: 831.50 MiB
before: 18880.40 MiB, after: 19711.53 MiB, diff: 831.12 MiB
before: 19711.53 MiB, after: 20540.40 MiB, diff: 828.88 MiB
before: 20540.40 MiB, after: 21371.90 MiB, diff: 831.50 MiB
- before: 21263.40 MiB, after: 17492.47 MiB, diff: -3770.93 MiB # Free What we can see here is that the memory is being freed at some points, but not it's not keeping up with the amount of new allocations.
# Tail output of https://gist.github.com/byi8220/81940a962f38244fb511f6457c617c3e)
load model after: 15755.20 MiB
torch.float16 layer before: 393.54 MiB, after: 15755.20 MiB, diff: 15361.66 MiB # Original model is loaded into memory
torch.float16 layer before: 15755.20 MiB, after: 16164.08 MiB, diff: 408.88 MiB # +400 MB leak from conversion
torch.float16 layer before: 16164.08 MiB, after: 16580.08 MiB, diff: 416.00 MiB # +400 MB leak from conversion
torch.float16 layer before: 16580.08 MiB, after: 16995.83 MiB, diff: 415.75 MiB # +400 MB leak from conversion
torch.float16 layer before: 16995.83 MiB, after: 17411.95 MiB, diff: 416.12 MiB # +400 MB leak from conversion
Memory utilization of partially converted model: 17411.95 MiB
Memory utilization of fully converted model before: 17411.95 MiB, after: 15746.70 MiB, diff: -1665.26 MiB # Freed ~1600 MB
Final memory usage of model: 15746.70 MiB These were all consistently done on my machine, with conversions of meta-llama/Llama-3.1-8B-Instruct from bfloat16 to float16. Feel free to confirm if these results are consistent on your end. So what the hell does any of this mean? Honestly, I can't pinpoint a root cause. However, my theory based on these observations is that the bound state preventing otherwise "useless" data from being freed is being held on to by other layers/modules in the model. If that is true, then why that is the case is an open problem. Maybe the same memory buffer or manager object is responsible for multiple layers (and can only free once all it's responsibilities are gone?), or maybe since each layer depends on the previous, it's holding some references to stuff from dependents/dependencies? But then if this is true, why can't I reproduce it in vanilla pytorch code? I have no idea, but at least it feels like a lead. |
@Qubitium @SunMarc I might have a theory for why this is happening. Profiling the reproduction's heap usage with memray shows that ~15.0 GiB of heap memory comes from 4 The reproduction model (https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/tree/main) has sharded its weights into 4 safetensor shards. These could be the objects that are stuck in memory. E.g. even if Imo, this lines up with everything that has been observed so far, so this might be it. If so, I'm not exactly sure what a fix would be, besides having to split the model into smaller or more carefully constructed shards. |
@byi8220 I think you found it. It makes sense, if we can find the code in torch/safetensor loading that does the mmap veifying that the module/layer weights are actually reading a single (or per saftensor file) model scoped mmap slice. I had assumed, mostly likely wrong now based on youe tests, that each layer/module was a separate mmap slice. |
Here is a screenshot containing the stack trace leading up to the allocated memory that I am talking about: Note that we see 4 allocations made by _load_pretrained_model, corresponding to the 4 safetensor files that are opened. (Ignore the the broken typesetting in pic below, seems to be a bug with saving screenshots) If you are willing to install
The last python call in this trace occurs at: https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L313 This appears to be the python handle for opening a safetensor file. In my repro, I verified the Anything deeper and it goes into Rust code in safetensors, and the code at the very bottom of this trace is Pytorch's internal C++ library, ATen, which I believe ultimately instantiates a MapAllocator object in C++: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/MapAllocator.cpp#L64 I am completely unfamiliar with ATen, and not very familiar with POSIX file handling conventions, so I can't comment on it's behavior down there. I know most of this is fairly high-level and handwavy, but I think it's strong evidence for what's happening. If you're not convinced we could try to do more formal memory tracing to pinpoint every single mmap alloc/dealloc. I've never really done that before, so I don't know if that's going to be quite a lot of work or not. If we're feeling extra lazy, we could probably just confirm this by constructing highly sharded model checkpoints, loading them, and seeing memory usage is correlated with sharding. |
Just quickly tried out limiting the saved checkpoint shard sizes to 400MB with As for a fix? I can't think of an elegant solution, but thoughts off the top of my head:
None of these feel particularly great to me. |
This comment was marked as off-topic.
This comment was marked as off-topic.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Not stale. |
System Info
While working on GTPQModel which does gptq quantization of hf models and load each layer on to gpu, quantize, and then move layer back to cpu for vram reduction, we noticed a huge cpu memory leak == to layer weight of dtype when moving the layer from from dtypes. The layer stays in cpu memory, leaks, and we are unable to free it. The memory stays until the program ends. The leak happens happens when we do the dtype conversion on cpu or to gpu.
Is this is a internal memory leak or are we doing something wrong or have the wrong expectation to how transformers/torch handles cpu tensor memory?
Reproducing code on cpu only (to gpu has same bug). No gpu is nesssary, just load the model as
bfloat16
and do dtype transitions and observe memory leak.Who can help?
@ArthurZucker @SunMarc @MekkCyber
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Env:
Run the above code and watch the cpu memory usage grow linearly after each loop by 1200MB instead of expected 800MB (400MB leak per layer equal to the size of the layer in
bfloat16
before conversion). Gc() does not help.Expected behavior
Constant memory equal to model weights/dtype combo.
UPDATE: Looks the leak is isolated to model/layers loaded as
torch.bfloat16
. No memory leak observed if model/layer is first loaded astorch.float16
ortorch.float32
and conversion to other dtypes.The text was updated successfully, but these errors were encountered: