Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up checkpoint serialization #26

Open
jlamypoirier opened this issue Oct 25, 2024 · 3 comments
Open

Speed up checkpoint serialization #26

jlamypoirier opened this issue Oct 25, 2024 · 3 comments
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@jlamypoirier
Copy link
Collaborator

🐞 Describe the Bug

I noticed checkpoint saving is suspiciously slow in some tests, so I decided to investigate.

Checkpoint saving should be bottlenecked by hardware (disk write speed), but turns out it's not. There seems to be something in torch save / safetensors that makes serialization slow.

On H100 NVMe, we should have more than 2 GiB/s of write speed, but torch.save and safetensors.torch.save_file give less than a third of that.

I think the impact isn't too big for distributed checkpoints because multiple processes are saving at the same time, but it definitely makes exports and conversions slower than they should be.

This problem is hard to solve because it's in other libraries, but maybe we can find some kind of hack to speed things up, like using multiple serialization processes.

🔄 Steps to Reproduce

Third-party library benchmarks

Running this code:

import contextlib
import pathlib
import shutil

import torch, time
import numpy as np
import safetensors.torch


size=2**30
a=torch.ones(size, dtype=torch.uint8, device="cuda")
b=a.cpu()
c=b.numpy()

dir=pathlib.Path("tmp")
#dir=pathlib.Path("/mnt/workspace/tmp/ckpt")

if dir.is_dir():
    shutil.rmtree(dir)
dir.mkdir(exist_ok=True)

@contextlib.contextmanager
def measure_time(name):
    start = time.time()
    yield
    stop = time.time()
    print(f"{name}: {stop-start:.3f}s, {size/2**20/(stop-start):.3f} MiB/s")


with measure_time("torch save from gpu"):
    torch.save(a, (dir/"torch.pt").open("wb"))


with measure_time("torch save from cpu"):
    torch.save(b, (dir/"torch1.pt").open("wb"))


with measure_time("numpy save"):
    np.save((dir/"np").open("wb"), c)

with measure_time("safetensors save from gpu"):
    safetensors.torch.save_file({"a":a},dir/"c.safetensors")


with measure_time("safetensors save from cpu"):
    safetensors.torch.save_file({"a":a},dir/"c1.safetensors")


with measure_time("safetensors serialize in-memory from cpu"):
    d=safetensors.torch.save({"a":b})

with measure_time("safetensors serialize in-memory from gpu"):
    d=safetensors.torch.save({"a":a})


with measure_time("Write to disk"):
    with (dir / "d.safetensors").open("wb") as f:
        f.write(d)

We get:

torch save from gpu: 1.338s, 765.159 MiB/s
torch save from cpu: 1.641s, 623.839 MiB/s
numpy save: 0.536s, 1909.272 MiB/s
safetensors save from gpu: 1.558s, 657.230 MiB/s
safetensors save from cpu: 2.307s, 443.877 MiB/s
safetensors serialize in-memory from cpu: 1.658s, 617.791 MiB/s
safetensors serialize in-memory from gpu: 2.092s, 489.483 MiB/s
Write to disk: 0.497s, 2062.189 MiB/s

I also tried with remote storage (commented path), it's a bit slower (~1.5 GiB/s) but follows the same pattern.

Fast-LLM benchmarks

Running the Mistral-7B 4-node benchmark on H100 nodes, with added checkpoint and export:

2024-10-25 16:14:10,223 [Rank 00] Saving checkpoint at iteration 100
2024-10-25 16:14:17,987 [Rank 00] Saved checkpoint to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/checkpoints/100
2024-10-25 16:14:17,989 [Rank 00] Saving export at iteration 100
2024-10-25 16:14:17,999 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_0.safetensors
2024-10-25 16:14:31,599 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_1.safetensors
2024-10-25 16:14:40,670 [Rank 00] Saving index to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict.safetensors.index.json
2024-10-25 16:14:40,679 [Rank 00] Saved export to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100

Checkpoint is saved in 7.76 s, and each shard is 2.53 GiB, so write speed is 334 MiB/s/process, or 2670 MiB/s/node. This seems reasonable, but I don't know what speed we should be able to get in theory.

The first export file is saved in 13.60 s and is 8.01 GiB, so write speed is 603 MiB/s. (2nd file has 618 MiB/s). This is at least 4x too slow.

📜 Environment Information

DGX-H100, saving locally (NVMe).

@jlamypoirier jlamypoirier added the bug Something isn't working label Oct 25, 2024
@jlamypoirier jlamypoirier self-assigned this Oct 25, 2024
@jlamypoirier jlamypoirier added help wanted Extra attention is needed enhancement New feature or request and removed bug Something isn't working help wanted Extra attention is needed labels Oct 25, 2024
@jlamypoirier jlamypoirier changed the title Checkpoint serialization is slow Speed up checkpoint serialization Oct 25, 2024
@jlamypoirier
Copy link
Collaborator Author

Distributed checkpoints files are just a single big tensor, so the safetensors and torch formats are overkill. We can get away with a pure binary format like numpy, which has near-optimal write speed and can be converted to/from torch without copy. I got 987 MiB/s with this including the cpu-gpu conversion, which is still way faster than safetensors. We could do better by buffering to overlap the transfer and save, but it doesn't matter because we'll want async checkpoints anyway.

So I propose the new format for distributed checkpoints in v0.2:

  • Remove the metadata from saved shards, keep only the tensor itself.
  • Use metadata.yaml exclusively for checkpoint metadata.
  • Use numpy format to save distributed shards.

This means dropping backward compatibility, but we were already planning to do that. For the future I think we should always prevent loading distributed checkpoints with a different version, and enforce the state_sict format for long-term storage.

@jlamypoirier jlamypoirier mentioned this issue Oct 25, 2024
@tscholak tscholak added this to the 0.2.0 milestone Oct 25, 2024
@tscholak
Copy link
Collaborator

tscholak commented Oct 26, 2024

I'd rather address this in 0.3 and put #25 in 0.2, mostly because #26 affects the internal checkpoint format, whereas #25 is user facing and improves usability signficantly. Let's discuss next week.

@tscholak
Copy link
Collaborator

tscholak commented Nov 4, 2024

re: speed.
loading can be sped up, too: https://x.com/_philschmid/status/1853339965612060797

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants