Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

safetensor/mmap memory leak when per-layer weights are converted do other dtypes #34366

Open
2 of 4 tasks
Qubitium opened this issue Oct 24, 2024 · 18 comments
Open
2 of 4 tasks

Comments

@Qubitium
Copy link
Contributor

Qubitium commented Oct 24, 2024

System Info

While working on GTPQModel which does gptq quantization of hf models and load each layer on to gpu, quantize, and then move layer back to cpu for vram reduction, we noticed a huge cpu memory leak == to layer weight of dtype when moving the layer from from dtypes. The layer stays in cpu memory, leaks, and we are unable to free it. The memory stays until the program ends. The leak happens happens when we do the dtype conversion on cpu or to gpu.

Is this is a internal memory leak or are we doing something wrong or have the wrong expectation to how transformers/torch handles cpu tensor memory?

Reproducing code on cpu only (to gpu has same bug). No gpu is nesssary, just load the model as bfloat16 and do dtype transitions and observe memory leak.

Who can help?

@ArthurZucker @SunMarc @MekkCyber

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Env:

AMD Zen3
Ubuntu: 22.04

Name: torch
Version: 2.5.0

Name: accelerate
Version: 1.0.1

Name: transformers
Version: 4.45.2 
import gc

import torch
from memory_profiler import memory_usage
from transformers import AutoModelForCausalLM

def mem(msg: str):
    gc.collect()
    m = memory_usage()[0]
    mm = f"{msg}. memory_usage: {m} MiB"
    print(mm)

MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"

mem("load model before")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16, 
    device_map="cpu",
)
mem("load model after")
print("model", model)

for i in range(0, 32):
    mem("to float32 before") # ref point: each layer is ~400MB in bfloat16
    model.model.layers[i].to(torch.float32)
    mem("to float32 after").  # <--- +1200MB ram == leak 400MB (1200MB - 800MB (float32)). 

Run the above code and watch the cpu memory usage grow linearly after each loop by 1200MB instead of expected 800MB (400MB leak per layer equal to the size of the layer in bfloat16 before conversion). Gc() does not help.

Expected behavior

Constant memory equal to model weights/dtype combo.

UPDATE: Looks the leak is isolated to model/layers loaded as torch.bfloat16. No memory leak observed if model/layer is first loaded as torch.float16 or torch.float32 and conversion to other dtypes.

@Qubitium Qubitium added the bug label Oct 24, 2024
@Qubitium Qubitium changed the title Cpu memory leak when layer/weights loaded in non-fp32 datatypes are converted do other dtypes. Cpu memory leak when layer/weights loaded in bfloat16 datatypes are converted do other dtypes. Oct 24, 2024
@ArthurZucker ArthurZucker added Core: Modeling Internals of the library; Models. contributions-welcome labels Oct 24, 2024
@byi8220
Copy link
Contributor

byi8220 commented Oct 24, 2024

Is this is a internal memory leak

Looks that way to me. IMHO it looks like a pretty devastating one. 50%-100% RAM usage increase! All because you loaded the model as one data type, and then converted the weights to another (a perfectly normal thing to do).

Randomly saw this issue and thought I'd take a small look at it. Not sure if I'm on the right track, but there's a few things I've noticed:

  1. I'm not sure if it's an issue with Tensor.to() itself*, since a first attempt at a pytorch only reproduction behaved as expected.
  2. The model used in the repro, meta-llama/Llama-3.1-8B-Instruct specifies bfloat16 as its torch_dtype in config.json, and has safetensor parameters in bf16.
  3. This issue persists if I replace the model_dtype with "auto"
  4. This issue persists even if if I run the model once (putting the entire thing in memory if it wasn't already) before the conversions

*It might be, actually. Messing around with it, checking for memory usage right here in Module.convert shows each parameter passed into the convert fn is leaving with 2x the memory we expect. I'm not sure why I can only observe this happening with transformers, still.

Based on that, I replaced MODEL_ID in the reproduction with nlpcloud/instruct-gpt-j-fp16 instead.

# First pass outputs when converting `nlpcloud/instruct-gpt-j-fp16`
# Convert torch.float16->torch.bfloat16 before: 447.11 MiB, after: 1215.24 MiB, diff: 768.12 MiB
# Convert torch.float16->torch.float32 before: 445.13 MiB, after: 1597.51 MiB, diff: 1152.38 MiB
# Convert torch.bfloat16->torch.float16 before: 11986.68 MiB, after: 11986.80 MiB, diff: 0.12 MiB
# Convert torch.bfloat16->torch.float32 before: 11986.41 MiB, after: 12370.47 MiB, diff: 384.06 MiB

# Final memory usage after torch.float16->torch.bfloat16: 21952.75 MiB
# Final memory usage after torch.bfloat16->torch.float16: 11986.54 MiB

Since this model was uploaded with its weights in fp16, and I'm seeing the exact same behavior in reverse, I'm not convinced this is a model or dtype specific issue.

My initial guess to what's going on is that loading the model's weights in its original tensor dtype vs loading them in a different dtype are going through two different code paths:

  1. When specifying a torch_dtype that is not the model's original type specified in it's config.json, everything is eagerly eagerly converted/loaded into memory and the model works as intended from then on.
  2. When specifying torch_dtype as the original type, somehow there's something that is holding onto the original weight tensors and doesn't let go of it. (Maybe check where it's first loaded into memory?)

tl;dr: Maybe the bug is due to differences in code between loading a pretrained model with the original dtype vs with a new dtype (which needs to be converted)?

The observations line up with this at least. If I'm on the right track hopefully it's something really obvious, but this seems like it might be an awful bug to hunt down.

Not sure if safetensors has any involvement here and not sure if the possibility of an underlying bug in pytorch can be ruled out yet either.

@Qubitium
Copy link
Contributor Author

Qubitium commented Oct 25, 2024

Updated original test to manually releasing individual layer weights model.model.layers[i] (set to None) after dtype convert, releasing model.model.layers and releasing entire model loop ends.

  1. [LEAK] Releasing the model.model.layers[i] does not release the leaked memory.
  2. [LEAK] Releasing the model.model.layers object does not release the leaked memory. Wow.
  3. [NO LEAK] Releasing the entire model.model object does release the leaked memory.
  4. [NO LEAK] Releasing the entire model object does release the leaked memory.

Based on this, the leakage is between model.model and model.model.layers scope holding on to layer weights beyond the layers[] object?

import gc

import torch
from memory_profiler import memory_usage
from transformers import AutoModelForCausalLM

def mem(msg: str):
    gc.collect()
    m = memory_usage(-1, interval=0.1, timeout=1)[0]
    mm = f"{msg}. memory_usage: {m} MiB"
    print(mm)

MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
mem("load model before")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)
mem("load model after")
# print("model", model)

for i in range(0, 32):
    mem("layer[{i}] to float32 before")
    model.model.layers[i].to(torch.float32) # [LEAKING full layers[i] before .to() conversion]
    mem("layer[{i}] to float32 after")

    mem("del layer[{i}] before")
    model.model.layers[i] = None # [LEAK NOT FIXED]
    mem("del layer[{i}] after")

mem("del model.model.layers before")
del model.model.layers # [LEAK NOT FIXED]
mem("del model.model.layers after")

mem("del model.model before")
del model.model # [LEAK FIXED]
mem("del model.model after")

mem("del model before")
del model # [LEAK FIXED]
mem("del model after")

@SunMarc
Copy link
Member

SunMarc commented Oct 25, 2024

fHey @Qubitium, thanks for the report and the reproducer. I hope that the following explanation can shred light on what's happening with the memory.

In transformers, we use a torch trick (mmap + assign = True) to make loading really fast ! Basically, the weights are mmap and we assign them to the parameters/buffer in the model (no copy involved). I'll leave this pytorch doc for reference.

When we set the same torch_dtype as the original model, what happens is that the weights are not really loaded on the cpu. This is why, when we set torch_dtype = bfloat16, the first output is only 1GB, the weights are not on the cpu yet. Also, this makes loading with the same dtype is really fast.

to float32 before. memory_usage: 1002.46875 MiB

Note: The first forward pass will be a bit slower if you perform it on cpu since we need to move the weights to the cpu. If you move the model to CUDA, you should also see the increase of memory.

When we don't specify the same dtype as the original dtype, we actually have to convert the weights, so the weights have to be on the cpu and can't stay on the disk. This is why when you specify torch_dtype = float16, the first output is around 16GB and that we see increments of 400MB when a layer gets converted.

to float32 before. memory_usage: 16299.421875 MiB

If you check the final memory for both, you will see that they are approximately the same.

to float32 after. memory_usage: 31472.51171875 MiB -> bfloat16
to float32 after. memory_usage: 30131.54296875 MiB -> float16

Related PR for more context

As for why we are not able to release the memory, I don't have an answer yet.

@byi8220
Copy link
Contributor

byi8220 commented Oct 25, 2024

@SunMarc Could you give a code example which demonstrates this? Specifically,

If you check the final memory for both, you will see that they are approximately the same.

What do you mean by the "final memory" in this situation? Do you mean the memory once the tensors are finally being accessed for something?

If this is the case, unless I'm doing something wrong, my reproduction below appears to still be holding on to memory when I'm starting with bfloat16:

import gc
import torch
from memory_profiler import memory_usage
from transformers.models.llama.modeling_llama import LlamaForCausalLM

# FROM_DTYPE=bfloat16 leaks, all other values don't. 
FROM_DTYPE = torch.bfloat16
TO_DTYPE = torch.float32
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_ID_1B = "meta-llama/Llama-3.2-1B-Instruct" # Smaller model = Actually can leak and not crash on my machine

def mem():
    gc.collect()
    m = memory_usage()[0]
    return m

m_before = mem()
print(f"load model before: {m_before:.2f} MiB")

model = LlamaForCausalLM.from_pretrained(
    MODEL_ID_1B,
    torch_dtype=FROM_DTYPE, 
    device_map="cpu",
)

print(type(model))

m_after = mem()

print(f"load model after: {m_after:.2f} MiB")
print(f"{TO_DTYPE} layer before: {m_before:.2f} MiB, after: {m_after:.2f} MiB, diff: {m_after - m_before:.2f} MiB")

for i in range(0, len(model.model.layers)):
    m_before = mem()
    model.model.layers[i].to(TO_DTYPE)
    m_after = mem() 
    print(f"{TO_DTYPE} layer before: {m_before:.2f} MiB, after: {m_after:.2f} MiB, diff: {m_after - m_before:.2f} MiB")

m_final = mem()
print(f"Memory utilization of converted model: {m_final:.2f} MiB")

# Do some ops to force the model to actually be realized in RAM
# This repro actually can't do a forward pass, since messing with the layers is messy
print("Dummy output", sum([torch.mean(p) for p in model.parameters()]))

m_final = mem()
print(f"Final memory usage of model: {m_final:.2f} MiB")

As for why we are not able to release the memory, I don't have an answer yet.

This might be an issue that specifically arises when you're directly messing around with internals. Simply calling model.to(torch.float32) on a loaded model seems to behave well.

P.S.: I tried recreating this behavior with plain old pytorch by saving, loading, and tampering with weights like we're doing here, but I can't get this leak to happen with what I've written.

@byi8220
Copy link
Contributor

byi8220 commented Oct 25, 2024

@Qubitium

Based on this, the leakage is between model.model and model.model.layers scope holding on to layer weights beyond the layers[] object?

That seems to be the case. Maybe the PR @SunMarc linked gives some hints. It could be due to an interaction caused by passing in assign_to_params_buffers, which IIUC finds itself in https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2295

If there is a leak, hopefully it's something simple enough and entirely fixable in python.

@Qubitium

This comment was marked as outdated.

@byi8220
Copy link
Contributor

byi8220 commented Oct 25, 2024

@Qubitium

Not even sure where assign_to_params_buffers is used.

I think this is being sent as args to pytorch's _load_from_state_dict function (Yes, accessing another codebase's private functions, very evil).

level maybe whatever low level buffers are binding to the model context even if the buffer is actually on the layer level

This logic makes sense, the main goal IMO should be trying to find what is being bound and where it's binding to. I'm having a hard time identifying the buffers in question. I didn't get very far with trying to objgraph the layers/parameters.

But I feel like it has to either be the underlying param.data tensor, or something that holds onto it that is leaking. When I tried inspecting heap alloc in torch.nn.Module, it seems that every call to Module.to() ended up with double the memory usage as expected.

Do you have a specific part of the code you're suspecting?

Ideally, we could objgraph the entirety of model.model and try to find something something in that mess that's out of line, if we're pretty confident it's scoped to something with the lifetime of the model itself. (Maybe something like deleting all the layers and investigating what's left?)

@Qubitium
Copy link
Contributor Author

Qubitium commented Oct 26, 2024

@byi8220

Do you have a specific part of the code you're suspecting?

  1. We know if mmap buffers are used then pytorch must call unmap. All the languages I have seen with mmap calls unmap. There is no such thing as auto-release by itself, unless it has helpers such as ContextManagers.
  2. So any objects with mmap must hook into the gc() as finalizer related code so unmap can happen
  3. If 1/2 is true, then the mapped tensors must call unmap code before gc manually or hook into python object finalizer
  4. Python's ContextManager + with is a semi-auto (magic) code also can act like a finally/finalizer to me.
  5. If gc finalizer is used to unmap, where is the code? If ContextManager is used, how is unmap operation hooked? Is ContextManager being manually held within mode.model after the of scope of with?

Just my naive thoughts on what may cause the leaks and I think it is directly related to SunMarc's ref PR. I heve never looked at Python gc, finalizer, or looked inside Pytorch code so just curious at how everything is glued magically together in the Python world.

@Qubitium
Copy link
Contributor Author

We are currently bypassing this bug by doing following monkeypatch at gptqmodel. Obviously the downside is now all the model layer weights are not lazy mmaped but loaded all into memory on init.

def check_support_param_buffer_assignment(*args, **kwargs):
    return False


# Fix cpu memory leak.
# See https://github.com/huggingface/transformers/issues/34366
modeling_utils.check_support_param_buffer_assignment = check_support_param_buffer_assignment

@byi8220
Copy link
Contributor

byi8220 commented Oct 26, 2024

I heve never looked at Python gc, finalizer, or looked inside Pytorch code so just curious at how everything is glued magically together in the Python world.

Hah, I'm pretty new to this as well. This seems like a pretty involved bug. A lot of the normal memory leak detection strategies are falling short (probably because Pytorch is doing all of this stuff in C++ bindings, not python, and the trace gets lost?).

Good callout on the finalizers, maybe tracing weakrefs/finalizes/del calls would give a hint as to what is happening.

I'm inclined to agree with your logic. The fact that deleting the model frees the memory strongly supports this.

And if you were able to workaround this issue by bypassing check_support_param_buffer_assignment, that suggests that either 1. somehow the function check_support_param_buffer_assignment or transformers._load_state_dict_into_model is somehow binding the state (not sure how, but not ruling it out), or what I think is more likely, 2. that how transformers is calling module._load_from_state_dict(*args) (pytorch code) is leading to this.

If I can somehow reproduce this issue without using transformers at all, this would at least absolve transformers as the root cause of this issue and I could raise a bug there.

Yeah, I think cutting down to a much smaller model, deleting it, then tracing what happens to a model when it gets deleted would help narrow down the cause. How exactly do that in python, I'd have to figure that one out 🙃

@byi8220
Copy link
Contributor

byi8220 commented Oct 28, 2024

Not sure if I'm on the right track, but I've noticed some a couple of interesting things while hunting this down.

  1. Firstly, although deleting the model layers in llama 3.1 8B instruct didn't free the memory, I managed to get the leak to go away if we delete all of the other modules within model:
# Removing all the other modules from LlamaForCausalLM.model leads to no leak (but obviously breaks the model in other ways)
del model.model.embed_tokens
del model.model.norm 
del model.model.rotary_emb
layers_to_keep = 1
for i in range(layers_to_keep, len(model.model.layers)):
    model.model.layers[i] = None
gc.collect()
print("Remaining modules:", model.model._modules) # Confirm we got rid of everything except our desired layers

Actually, if you mess around with layers_to_keep, you can get this model to properly free memory even without deleting the other modules.

  1. Secondly, even if you don't delete any layers, sometimes the model does free memory, but less than what it is leaking. If we use your original reproduction to convert from bfloat16 to float16, we get the following memory allocation pattern:
# bfloat16 to float16. 800MB = 400MB float16 + 400MB bfloat16
before: 434.54 MiB, after: 1266.29 MiB, diff: 831.75 MiB
before: 1266.29 MiB, after: 2097.54 MiB, diff: 831.25 MiB
before: 2097.54 MiB, after: 2929.29 MiB, diff: 831.75 MiB
before: 2929.29 MiB, after: 3760.79 MiB, diff: 831.50 MiB
before: 3760.79 MiB, after: 4592.41 MiB, diff: 831.62 MiB
before: 4592.41 MiB, after: 5423.66 MiB, diff: 831.25 MiB
before: 5423.66 MiB, after: 6255.04 MiB, diff: 831.38 MiB
before: 6255.04 MiB, after: 7086.66 MiB, diff: 831.62 MiB
before: 7086.66 MiB, after: 7918.54 MiB, diff: 831.88 MiB
before: 7918.54 MiB, after: 8750.29 MiB, diff: 831.75 MiB
before: 8750.29 MiB, after: 9581.66 MiB, diff: 831.38 MiB
before: 9581.66 MiB, after: 10413.41 MiB, diff: 831.75 MiB
before: 10413.41 MiB, after: 11244.91 MiB, diff: 831.50 MiB
before: 11244.91 MiB, after: 12075.91 MiB, diff: 831.00 MiB
before: 12075.91 MiB, after: 12907.66 MiB, diff: 831.75 MiB
before: 12907.66 MiB, after: 13739.54 MiB, diff: 831.88 MiB
before: 13739.54 MiB, after: 14571.29 MiB, diff: 831.75 MiB
before: 14571.29 MiB, after: 15403.04 MiB, diff: 831.75 MiB
before: 15403.04 MiB, after: 16234.41 MiB, diff: 831.38 MiB
before: 16234.41 MiB, after: 17066.16 MiB, diff: 831.75 MiB
- before: 17066.16 MiB, after: 13129.53 MiB, diff: -3936.63 MiB # Free
before: 13129.53 MiB, after: 13961.03 MiB, diff: 831.50 MiB
before: 13961.03 MiB, after: 14792.65 MiB, diff: 831.62 MiB
before: 14792.65 MiB, after: 15554.78 MiB, diff: 762.12 MiB
before: 15554.78 MiB, after: 16385.53 MiB, diff: 830.75 MiB
before: 16385.53 MiB, after: 17217.53 MiB, diff: 832.00 MiB
before: 17217.53 MiB, after: 18048.90 MiB, diff: 831.38 MiB
before: 18048.90 MiB, after: 18880.40 MiB, diff: 831.50 MiB
before: 18880.40 MiB, after: 19711.53 MiB, diff: 831.12 MiB
before: 19711.53 MiB, after: 20540.40 MiB, diff: 828.88 MiB
before: 20540.40 MiB, after: 21371.90 MiB, diff: 831.50 MiB
- before: 21263.40 MiB, after: 17492.47 MiB, diff: -3770.93 MiB # Free

What we can see here is that the memory is being freed at some points, but not it's not keeping up with the amount of new allocations.

  1. Finally, if after we perform the partial conversions which cause the leak, if we convert the entire model afterwards, the leak appears to go away.
# Tail output of https://gist.github.com/byi8220/81940a962f38244fb511f6457c617c3e)
load model after: 15755.20 MiB
torch.float16 layer before: 393.54 MiB, after: 15755.20 MiB, diff: 15361.66 MiB # Original model is loaded into memory
torch.float16 layer before: 15755.20 MiB, after: 16164.08 MiB, diff: 408.88 MiB # +400 MB leak from conversion
torch.float16 layer before: 16164.08 MiB, after: 16580.08 MiB, diff: 416.00 MiB # +400 MB leak from conversion
torch.float16 layer before: 16580.08 MiB, after: 16995.83 MiB, diff: 415.75 MiB # +400 MB leak from conversion
torch.float16 layer before: 16995.83 MiB, after: 17411.95 MiB, diff: 416.12 MiB # +400 MB leak from conversion
Memory utilization of partially converted model: 17411.95 MiB
Memory utilization of fully converted model before: 17411.95 MiB, after: 15746.70 MiB, diff: -1665.26 MiB # Freed ~1600 MB
Final memory usage of model: 15746.70 MiB

These were all consistently done on my machine, with conversions of meta-llama/Llama-3.1-8B-Instruct from bfloat16 to float16. Feel free to confirm if these results are consistent on your end.

So what the hell does any of this mean? Honestly, I can't pinpoint a root cause. However, my theory based on these observations is that the bound state preventing otherwise "useless" data from being freed is being held on to by other layers/modules in the model.

If that is true, then why that is the case is an open problem. Maybe the same memory buffer or manager object is responsible for multiple layers (and can only free once all it's responsibilities are gone?), or maybe since each layer depends on the previous, it's holding some references to stuff from dependents/dependencies? But then if this is true, why can't I reproduce it in vanilla pytorch code?

I have no idea, but at least it feels like a lead.

@byi8220
Copy link
Contributor

byi8220 commented Oct 29, 2024

@Qubitium @SunMarc I might have a theory for why this is happening.

Profiling the reproduction's heap usage with memray shows that ~15.0 GiB of heap memory comes from 4 mmap allocations at at::MapAllocator::MapAllocator(at::WithFd, c10::basic_string_view<char>, int, int, unsigned long).

The reproduction model (https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/tree/main) has sharded its weights into 4 safetensor shards. These could be the objects that are stuck in memory. E.g. even if model.model.layers[0] is converted and no longer references memory in model-00001-of-00004.safetensors, if other objects are referencing that file (such as model.embed_tokens), the entire file must remain in memory.

Imo, this lines up with everything that has been observed so far, so this might be it. If so, I'm not exactly sure what a fix would be, besides having to split the model into smaller or more carefully constructed shards.

@Qubitium
Copy link
Contributor Author

Qubitium commented Oct 29, 2024

@byi8220 I think you found it. It makes sense, if we can find the code in torch/safetensor loading that does the mmap veifying that the module/layer weights are actually reading a single (or per saftensor file) model scoped mmap slice. I had assumed, mostly likely wrong now based on youe tests, that each layer/module was a separate mmap slice.

@byi8220
Copy link
Contributor

byi8220 commented Oct 29, 2024

@Qubitium

It makes sense, if we can find the code in torch/safetensor loading that does the mmap

Here is a screenshot containing the stack trace leading up to the allocated memory that I am talking about:

TreeApp_2024-10-29T11_39_24_931288

Note that we see 4 allocations made by _load_pretrained_model, corresponding to the 4 safetensor files that are opened. (Ignore the the broken typesetting in pic below, seems to be a bug with saving screenshots)

TreeApp_numalloc

If you are willing to install memray, you can run and confirm you see the same results.

pip install memray
memray run -o memray-results.bin --native [program.py]
memray tree memray-results.bin

The last python call in this trace occurs at: https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L313

This appears to be the python handle for opening a safetensor file. In my repro, I verified the filename argument corresponded with the safetensor shards.

Anything deeper and it goes into Rust code in safetensors, and the code at the very bottom of this trace is Pytorch's internal C++ library, ATen, which I believe ultimately instantiates a MapAllocator object in C++: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/MapAllocator.cpp#L64

I am completely unfamiliar with ATen, and not very familiar with POSIX file handling conventions, so I can't comment on it's behavior down there.

I know most of this is fairly high-level and handwavy, but I think it's strong evidence for what's happening. If you're not convinced we could try to do more formal memory tracing to pinpoint every single mmap alloc/dealloc. I've never really done that before, so I don't know if that's going to be quite a lot of work or not.

If we're feeling extra lazy, we could probably just confirm this by constructing highly sharded model checkpoints, loading them, and seeing memory usage is correlated with sharding.

@byi8220
Copy link
Contributor

byi8220 commented Oct 29, 2024

Just quickly tried out limiting the saved checkpoint shard sizes to 400MB with model.save_pretrained("memleak/model", max_shard_size="400MB"), and then tried the repo on that saved local model. In that case memory use is as expected although performance seems to suffer by having so many shards. So it almost certainly has to do with the safetensor files.

As for a fix? I can't think of an elegant solution, but thoughts off the top of my head:

  1. Give users more control over checkpoint sharding, so that their relevant weights are close to each other. (Doable, but user unfriendly. The least intrusive option by far, since it doesn't change any existing functionality)
  2. Have some bookkeeping and python level tracking. (This seems extremely complicated to implement.)
  3. Some very low level memory mapping or file level magic (Likely way out of scope for transformers, we're probably getting into system level pytorch internals)

None of these feel particularly great to me.

@Qubitium Qubitium changed the title Cpu memory leak when layer/weights loaded in bfloat16 datatypes are converted do other dtypes. safetensor/mmap memory leak when per-layer weights are converted do other dtypes Nov 3, 2024
@Qubitium

This comment was marked as off-topic.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Qubitium
Copy link
Contributor Author

Not stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants