Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster and more memory-efficient implementation of zero_to_fp32 #6658

Merged
merged 12 commits into from
Nov 18, 2024

Conversation

xu-song
Copy link
Contributor

@xu-song xu-song commented Oct 23, 2024

It is a faster and more memory-efficient implementation of zero_to_fp32.

The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B).

# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)

How does it work?

  1. Lazy loading: Load checkpoint with mmap=True, thus the weights are mmaped rather than loading all the storages into memory.
  2. Lazy merge: GatheredTensor contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only when tensor.contiguous() is called, it starts to load related weights to memory and merge into a single tensor.
  3. Release memory in time: Save checkpoints shard by shard, and release the memory once a shard is saved.

Throughout the process, only one shard of tensors are keeped in memory.

How much benefit in speed and memory ?

Experiments were conducted on a linux host with 1TB of memory. Here is a detailed comparision

world size peak memory(GB) elapsed time(h:mm:ss)
llama3-8B(old->new) 8 90 -> 41 0:02:17 -> 0:01:10
llama2-13B(old->new) 8 146 -> 54 0:02:30 -> 0:01:47
llama2-70B(old->new) 16 789 -> 159 0:20:47 -> 0:20:45
qwen1.5-110B(old->new) 32 OOM -> 217 ? -> 0:34:21
llama3-405B(old->new) 192 OOM -> 262 ? -> 2:09:59

You can reproduce with the following scripts

# 1. install requirments
apt-get install time
# 2. prepare zero-3 checkpoints
# 3. convert zero to fp32 checkpoints
/usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization
  • memory: Theoretically, this PR reduces the memory cost from 2M to (1/n)M, where M is the memory cost of the full weights, n is num_shards.
  • speed: The speed gain mainly comes from avoiding extra tensor copying. The benifit may be slight.

Impl history

  • v1 : a hf_hub compatible approach.
    It has been discarded due to the controversial implementation of data_ptr().
  • v2: a simple approach with torch.empty

@tjruwase tjruwase requested review from tohtana and removed request for awan-10 October 23, 2024 16:31
@tjruwase
Copy link
Contributor

@xylian86, FYI

@tjruwase
Copy link
Contributor

@xu-song, just to clarify, we greatly appreciate this PR. The memory and speed benefits are very useful. My only concern are the HF_Hub related changes, so hopefully those can be clarified.

Can you please add the observed speed and memory benefits of this optimizations? Such details are generally useful for readers to better appreciate the value. Thanks!

@xu-song
Copy link
Contributor Author

xu-song commented Oct 25, 2024

@xu-song, just to clarify, we greatly appreciate this PR. The memory and speed benefits are very useful. My only concern are the HF_Hub related changes, so hopefully those can be clarified.

Can you please add the observed speed and memory benefits of this optimizations? Such details are generally useful for readers to better appreciate the value. Thanks!

@tjruwase Is there any alternative approach to sharding torch state_dict?

If any, the compatible feature to huggingface_hub.split_torch_state_dict_into_shards can be discarded.

@tjruwase
Copy link
Contributor

@tjruwase Is there any alternative approach to sharding torch state_dict?

If any, the compatible feature to huggingface_hub.split_torch_state_dict_into_shards can be discarded.

Sorry, but I am a bit confused about the objective of this PR. The goal of zero_to_fp32 is to create a consolidated checkpoint state from the sharded checkpoints of ZeRO-* training, so I don't understand why state_dict sharding is a consideration here.

It seems that there are two parts of this PR.

  1. Speed and memory optimizations
  2. HF_hub compatibility involving state_dict sharding

Am I correct?

@xu-song
Copy link
Contributor Author

xu-song commented Oct 26, 2024

1. Yes, HF_hub compatibility involves state_dict sharding.

It seems that there are two parts of this PR.

  1. Speed and memory optimizations
  2. HF_hub compatibility involving state_dict sharding

Am I correct?

2. Besides, our implementation exactly follows the the goal of zero_to_fp32. As the document says

Note: this approach may not work if your application doesn’t have sufficient free CPU memory and you may need to use the offline approach using the zero_to_fp32.py script that is saved with the checkpoint.

By default, get_fp32_state_dict_from_zero_checkpoint return state_dict with torch.Tensor (large-memory).
If OOM, use zero_to_fp32.py (memory-efficient).

3. new impl

v1
To save memory, we pass pseudo tensor to split_torch_state_dict_into_shards.

state_dict_split = split_torch_state_dict_into_shards(state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size)

v2

-        state_dict_split = split_torch_state_dict_into_shards(state_dict,
+        mock_state_dict = {name: torch.empty(tensor.shape, dtype=tensor.dtype) for name, tensor in state_dict.items()}
+        state_dict_split = split_torch_state_dict_into_shards(mock_state_dict,
                                                              filename_pattern=filename_pattern,
                                                              max_shard_size=max_shard_size)

Convert pseudo tensor to torch.tensor before callingsplit_torch_state_dict_into_shards.
It is memory-free and perfectly works for split_torch_state_dict_into_shards.

@xu-song xu-song closed this Oct 27, 2024
@xu-song xu-song reopened this Oct 27, 2024
@tjruwase
Copy link
Contributor

tjruwase commented Oct 28, 2024

Besides, our implementation exactly follows the the goal of zero_to_fp32. As the document says

Apologies for not being clear. The reason that I referred to that doc was to show that zero_to_fp32 is meant to generate a consolidated checkpoint and not a sharded one. Does this clarify the intention of zero_to_fp32?

@@ -483,6 +530,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
- ``lazy_merge``: a more memory-efficient feature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provide brief description of why more memory-efficient, and perhaps mention important usage concepts like pseudo tensors and contiguous().

Also, please add a unit test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a good place for a unit test:

state_dict = get_fp32_state_dict_from_zero_checkpoint(filename, tag="checkpoint")

Copy link
Contributor Author

@xu-song xu-song Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i am working on it.

@xu-song
Copy link
Contributor Author

xu-song commented Oct 29, 2024

zero_to_fp32 is meant to generate a consolidated checkpoint and not a sharded one.

@tjruwase Sorry, you may confuse the objective of zero_to_fp32 and get_fp32_state_dict_from_zero_checkpoint. zero_to_fp32can generate sharded checkpoints, but get_fp32_state_dict_from_zero_checkpoint has nothing to do with sharding and checkpointing.

if your application doesn’t have sufficient free CPU memory and you may need to use the offline approach using the zero_to_fp32.py script that is saved with the checkpoint.

If your point is no-sharding is allowed, then ignore and close this pr.

@tjruwase
Copy link
Contributor

If your point is no-sharding is allowed, then ignore and close this pr.

No, my point is not no-sharding is allowed. Like I said earlier, this PR is very useful, and we are appreciative. I notice that new PR does not contain HF_hub compatibility logic, and so my concern is addressed. Thanks for that.

@xu-song xu-song requested a review from loadams as a code owner November 6, 2024 06:15
@xu-song
Copy link
Contributor Author

xu-song commented Nov 6, 2024

A unit test and more comments have been added. Thanks

@xu-song
Copy link
Contributor Author

xu-song commented Nov 7, 2024

formatting and unit test issue have been resolved.

@tjruwase
Copy link
Contributor

tjruwase commented Nov 9, 2024

formatting and unit test issue have been resolved.
@xu-song, apologies that our CI issues is delayed merge. Thanks for your patience.

@loadams loadams enabled auto-merge November 18, 2024 18:16
@loadams loadams added this pull request to the merge queue Nov 18, 2024
Merged via the queue into microsoft:master with commit dd40269 Nov 18, 2024
10 checks passed
@NicholasCao
Copy link

I have a question that's a bit off-topic. Why is it necessary to convert zero to fp32 instead of directly converting it to fp16 or bf16? I don't understand this. Is there a zero_to_bf16.py script available?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants