Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA OOM error with supposedly good enough specs according to memory stats output #2071

Open
sionhan opened this issue Nov 26, 2024 · 4 comments

Comments

@sionhan
Copy link

sionhan commented Nov 26, 2024

Hello!

I am currently trying to fine tune using lora a Llama 3.1 70B Nemotron Instruct LLM by tweaking a bit the Llama 3.1 70B lora configs.
According to the memory stats required by torchtune, it would be around 18GiB per GPU which 4090s should be able to handle with some room to spare for any additional charge if I understood correctly. However I still get cuda oom despite lowering every possible option to minimal vram requirements in the finetuning recipe.

Finetuning configs parameters
Copied the Llama 3.1 70B lora configs and adapted it to the Nemotron HF model. It mentioned being able to be ran on 8 gpus.
I used the original Llama 3.1 tokenizer.model as I assumed the tokenizer was the same after reading the config.json of the Nemotron HF model.

  • Every memory reducing parameters was set to True: torch compile, offloading...
  • Saving the merged files and not the adapter weights only
  • Batch size 1, gradient 1, epoch 1
  • lora rank lowered from 16 to 4
  • 200 entries chat dataset with short user prompt and gpt output so I could lower the max seq len to 512
  • The optimizer is still AdamW. I read that there might be a problem with torch's AdamW optimizer, tried to use regular Adam but same results.

Environment
Cuda: 12.4
Torch: 2.5.1
Torchtune: 0.4.0
Specs: 8 x 4090 instance
RAM: 192 GiB
Set PYTORCH_CUDA_ALLOC_CONF to expendable_segments:True.
I also ran nvidia-smi to check the gpus charges, and everything looked correct (equally distributed, all gpus with active processes).

Command ran
tune run --nproc_per_node 8 lora_finetune_distributed --config ./my_config.yaml

Error message (basically a CUDA oom error on every GPU)

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 2315079729. Local seed is seed + rank = 2315079729 + 0
Writing logs to Llama-3.1-Nemotron-70B-Instruct/lora-llama3_1-finetune-output/log_1732589837.txt
INFO:torchtune.utils._logging:FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Compiling model layers with torch.compile...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 396.81 secs
INFO:torchtune.utils._logging:Memory stats after model init:
        GPU peak memory allocation: 18.49 GiB
        GPU peak memory reserved: 18.69 GiB
        GPU peak memory active: 18.49 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Compiling loss with torch.compile...
INFO:torchtune.utils._logging:Loss is initialized.
Generating train split: 200 examples [00:00, 1096.69 examples/s]
[rank1]: Traceback (most recent call last):
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 942, in <module>
[rank1]:     sys.exit(recipe_main())
[rank1]:              ^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank1]:     sys.exit(recipe_main(conf))
[rank1]:              ^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 937, in recipe_main
[rank1]:     recipe.train()
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 807, in train
[rank1]:     logits = self._model(**batch)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank1]:     return inner()
[rank1]:            ^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1769, in inner
[rank1]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank1]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 66, in fsdp_hook_wrapper
[rank1]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 232, in _pre_forward
[rank1]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 301, in pre_forward
[rank1]:     self.wait_for_unshard()
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 257, in wait_for_unshard
[rank1]:     foreach_all_gather_copy_out(
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 265, in foreach_all_gather_copy_out
[rank1]:     fsdp_param.init_all_gather_outputs(
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 409, in init_all_gather_outputs
[rank1]:     torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 1 has a total capacity of 23.43 GiB of which 210.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 22.41 GiB is allocated by PyTorch, and 206.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank7]: Traceback (most recent call last):
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 942, in <module>
[rank7]:     sys.exit(recipe_main())
[rank7]:              ^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank7]:     sys.exit(recipe_main(conf))
[rank7]:              ^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 937, in recipe_main
[rank7]:     recipe.train()
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 807, in train
[rank7]:     logits = self._model(**batch)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank7]:     return inner()
[rank7]:            ^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1769, in inner
[rank7]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank7]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 66, in fsdp_hook_wrapper
[rank7]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank7]:     return fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 232, in _pre_forward
[rank7]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank7]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 301, in pre_forward
[rank7]:     self.wait_for_unshard()
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 257, in wait_for_unshard
[rank7]:     foreach_all_gather_copy_out(
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 265, in foreach_all_gather_copy_out
[rank7]:     fsdp_param.init_all_gather_outputs(
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 409, in init_all_gather_outputs
[rank7]:     torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
[rank7]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 7 has a total capacity of 23.43 GiB of which 210.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 22.41 GiB is allocated by PyTorch, and 206.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank4]: Traceback (most recent call last):
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 942, in <module>
[rank4]:     sys.exit(recipe_main())
[rank4]:              ^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank4]:     sys.exit(recipe_main(conf))
[rank4]:              ^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 937, in recipe_main
[rank4]:     recipe.train()
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 807, in train
[rank4]:     logits = self._model(**batch)
[rank4]:              ^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank4]:     return self._call_impl(*args, **kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank4]:     return inner()
[rank4]:            ^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1769, in inner
[rank4]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank4]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 66, in fsdp_hook_wrapper
[rank4]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank4]:     return fn(*args, **kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 232, in _pre_forward
[rank4]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank4]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 301, in pre_forward
[rank4]:     self.wait_for_unshard()
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 257, in wait_for_unshard
[rank4]:     foreach_all_gather_copy_out(
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank4]:     return func(*args, **kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 265, in foreach_all_gather_copy_out
[rank4]:     fsdp_param.init_all_gather_outputs(
[rank4]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 409, in init_all_gather_outputs
[rank4]:     torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
[rank4]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 4 has a total capacity of 23.43 GiB of which 210.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 22.41 GiB is allocated by PyTorch, and 206.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank5]: Traceback (most recent call last):
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 942, in <module>
[rank5]:     sys.exit(recipe_main())
[rank5]:              ^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank5]:     sys.exit(recipe_main(conf))
[rank5]:              ^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 937, in recipe_main
[rank5]:     recipe.train()
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 807, in train
[rank5]:     logits = self._model(**batch)
[rank5]:              ^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank5]:     return inner()
[rank5]:            ^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1769, in inner
[rank5]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank5]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 66, in fsdp_hook_wrapper
[rank5]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank5]:     return fn(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 232, in _pre_forward
[rank5]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank5]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 301, in pre_forward
[rank5]:     self.wait_for_unshard()
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 257, in wait_for_unshard
[rank5]:     foreach_all_gather_copy_out(
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank5]:     return func(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 265, in foreach_all_gather_copy_out
[rank5]:     fsdp_param.init_all_gather_outputs(
[rank5]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 409, in init_all_gather_outputs
[rank5]:     torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
[rank5]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 5 has a total capacity of 23.43 GiB of which 210.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 22.41 GiB is allocated by PyTorch, and 206.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank2]: Traceback (most recent call last):
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 942, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 937, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 807, in train
[rank2]:     logits = self._model(**batch)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1769, in inner
[rank2]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank2]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 66, in fsdp_hook_wrapper
[rank2]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 232, in _pre_forward
[rank2]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 301, in pre_forward
[rank2]:     self.wait_for_unshard()
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 257, in wait_for_unshard
[rank2]:     foreach_all_gather_copy_out(
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 265, in foreach_all_gather_copy_out
[rank2]:     fsdp_param.init_all_gather_outputs(
[rank2]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 409, in init_all_gather_outputs
[rank2]:     torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
[rank2]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 2 has a total capacity of 23.43 GiB of which 210.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 22.41 GiB is allocated by PyTorch, and 206.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank6]: Traceback (most recent call last):
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 942, in <module>
[rank6]:     sys.exit(recipe_main())
[rank6]:              ^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank6]:     sys.exit(recipe_main(conf))
[rank6]:              ^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 937, in recipe_main
[rank6]:     recipe.train()
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 807, in train
[rank6]:     logits = self._model(**batch)
[rank6]:              ^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank6]:     return self._call_impl(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank6]:     return inner()
[rank6]:            ^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1769, in inner
[rank6]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank6]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 66, in fsdp_hook_wrapper
[rank6]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank6]:     return fn(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 232, in _pre_forward
[rank6]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank6]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 301, in pre_forward
[rank6]:     self.wait_for_unshard()
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 257, in wait_for_unshard
[rank6]:     foreach_all_gather_copy_out(
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank6]:     return func(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 265, in foreach_all_gather_copy_out
[rank6]:     fsdp_param.init_all_gather_outputs(
[rank6]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 409, in init_all_gather_outputs
[rank6]:     torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
[rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 6 has a total capacity of 23.43 GiB of which 210.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 22.41 GiB is allocated by PyTorch, and 206.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
[rank3]: Traceback (most recent call last):
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 942, in <module>
[rank3]:     sys.exit(recipe_main())
[rank3]:              ^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank3]:     sys.exit(recipe_main(conf))
[rank3]:              ^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 937, in recipe_main
[rank3]:     recipe.train()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py", line 807, in train
[rank3]:     logits = self._model(**batch)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank3]:     return inner()
[rank3]:            ^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1769, in inner
[rank3]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank3]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 66, in fsdp_hook_wrapper
[rank3]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 232, in _pre_forward
[rank3]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 301, in pre_forward
[rank3]:     self.wait_for_unshard()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 257, in wait_for_unshard
[rank3]:     foreach_all_gather_copy_out(
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3]:     return func(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 265, in foreach_all_gather_copy_out
[rank3]:     fsdp_param.init_all_gather_outputs(
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 409, in init_all_gather_outputs
[rank3]:     torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
[rank3]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 3 has a total capacity of 23.43 GiB of which 210.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 22.41 GiB is allocated by PyTorch, and 206.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W1125 19:04:39.212000 2476 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2609 closing signal SIGTERM
W1125 19:04:39.216000 2476 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2610 closing signal SIGTERM
W1125 19:04:39.217000 2476 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2611 closing signal SIGTERM
W1125 19:04:39.218000 2476 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2612 closing signal SIGTERM
W1125 19:04:39.221000 2476 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2613 closing signal SIGTERM
W1125 19:04:39.222000 2476 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2615 closing signal SIGTERM
W1125 19:04:39.222000 2476 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2616 closing signal SIGTERM
E1125 19:04:47.474000 2476 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 5 (pid: 2614) of binary: /usr/bin/python3.12
Traceback (most recent call last):
  File "/usr/local/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/usr/local/lib/python3.12/dist-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/usr/local/lib/python3.12/dist-packages/torchtune/_cli/run.py", line 206, in _run_cmd
    self._run_distributed(args, is_builtin=is_builtin)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torchtune/_cli/run.py", line 95, in _run_distributed
    run(args)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/usr/local/lib/python3.12/dist-packages/recipes/lora_finetune_distributed.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-11-25_19:04:39
  host      : bc5153ad6aec
  rank      : 5 (local_rank: 5)
  exitcode  : 1 (pid: 2614)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
@joecummings
Copy link
Contributor

I'll take a look! Can you provide your full config definition, too?

@joecummings
Copy link
Contributor

Just updating with my initial findings. I managed to get it to work two ways under 24 GB of VRAM:

  1. With fsdp_cpu_offload=True, which offloads gradients and parameters to CPU when not in use. It's very effective, but slowwwww. You can see in this screenshot that peak memory allocation is ~11 GiB, but the tok/sec is ~50.
Screenshot 2024-11-26 at 1 06 11 PM
  1. With QLoRA (just set model._component_=torchtune.models.llama3_1.qlora_llama3_1_70b). Peak memory allocation is ~13 GiB, but tok/sec ~200.
Screenshot 2024-11-26 at 1 14 54 PM

The problem is that even though the model can be loaded in ~18/19 GiB of memory, the activations, gradients, and optimizer states still take up some memory. I'm going to do some full profiling to see what exactly is the culprit here and if there's anything we can do for a vanilla LoRA config.

@felipemello1
Copy link
Contributor

felipemello1 commented Nov 26, 2024

hey @sionhan , when dealing with large models + lora, two things will mostly consume your memory: activations and model parameters.

The optimizer state and gradients won't matter much because you are not finetuning many parameters.

To tackle the model: QLoRA will be your best friend, because we quantize all the weights. The only other flag that will help you with this is fsdp_cpu_offloading, like Joe shared, but this is slow.

To tackle the activations, the main thing you can do is like you said: activation checkpointing, activation offloading and compile. If that doesnt work, then your other knob is to reduce tokenizer.max_seq_len. Reducing the rank and number of finetuned layers helps a bit, but i don't think that it is worth to go <rank=8.

Make sure to have dataset.packed=True for higher tokens per second.

@joecummings
Copy link
Contributor

joecummings commented Nov 26, 2024

Aha! I found one more possible way to do it. It turns out in our default config we are not sharding the token embedding and and output which are very large. If we add this, it should be right under the 24GiB threshold!!! The tok/sec are still a little slower but not too bad.

Screenshot 2024-11-26 at 7 03 28 PM

I'll update our LoRA recipe accordingly (#2072) and then you should be able to install nightlies or from source to take advantage of this. Note: I did these experiments will all the other tricks activated as well - activation offloading, compile, packed=True, max_seq_len=512.

Also caveat that it's very close to the 24GiB range and I artificially constrained my memory to test this; however, the only true test is seeing if this will actually work on your 4090. LMK how it goes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants