You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed checkpoint saving is suspiciously slow in some tests, so I decided to investigate.
Checkpoint saving should be bottlenecked by hardware (disk write speed), but turns out it's not. There seems to be something in torch save / safetensors that makes serialization slow.
On H100 NVMe, we should have more than 2 GiB/s of write speed, but torch.save and safetensors.torch.save_file give less than a third of that.
I think the impact isn't too big for distributed checkpoints because multiple processes are saving at the same time, but it definitely makes exports and conversions slower than they should be.
This problem is hard to solve because it's in other libraries, but maybe we can find some kind of hack to speed things up, like using multiple serialization processes.
🔄 Steps to Reproduce
Third-party library benchmarks
Running this code:
importcontextlibimportpathlibimportshutilimporttorch, timeimportnumpyasnpimportsafetensors.torchsize=2**30a=torch.ones(size, dtype=torch.uint8, device="cuda")
b=a.cpu()
c=b.numpy()
dir=pathlib.Path("tmp")
#dir=pathlib.Path("/mnt/workspace/tmp/ckpt")ifdir.is_dir():
shutil.rmtree(dir)
dir.mkdir(exist_ok=True)
@contextlib.contextmanagerdefmeasure_time(name):
start=time.time()
yieldstop=time.time()
print(f"{name}: {stop-start:.3f}s, {size/2**20/(stop-start):.3f} MiB/s")
withmeasure_time("torch save from gpu"):
torch.save(a, (dir/"torch.pt").open("wb"))
withmeasure_time("torch save from cpu"):
torch.save(b, (dir/"torch1.pt").open("wb"))
withmeasure_time("numpy save"):
np.save((dir/"np").open("wb"), c)
withmeasure_time("safetensors save from gpu"):
safetensors.torch.save_file({"a":a},dir/"c.safetensors")
withmeasure_time("safetensors save from cpu"):
safetensors.torch.save_file({"a":a},dir/"c1.safetensors")
withmeasure_time("safetensors serialize in-memory from cpu"):
d=safetensors.torch.save({"a":b})
withmeasure_time("safetensors serialize in-memory from gpu"):
d=safetensors.torch.save({"a":a})
withmeasure_time("Write to disk"):
with (dir/"d.safetensors").open("wb") asf:
f.write(d)
We get:
torch save from gpu: 1.338s, 765.159 MiB/s
torch save from cpu: 1.641s, 623.839 MiB/s
numpy save: 0.536s, 1909.272 MiB/s
safetensors save from gpu: 1.558s, 657.230 MiB/s
safetensors save from cpu: 2.307s, 443.877 MiB/s
safetensors serialize in-memory from cpu: 1.658s, 617.791 MiB/s
safetensors serialize in-memory from gpu: 2.092s, 489.483 MiB/s
Write to disk: 0.497s, 2062.189 MiB/s
I also tried with remote storage (commented path), it's a bit slower (~1.5 GiB/s) but follows the same pattern.
Fast-LLM benchmarks
Running the Mistral-7B 4-node benchmark on H100 nodes, with added checkpoint and export:
2024-10-25 16:14:10,223 [Rank 00] Saving checkpoint at iteration 100
2024-10-25 16:14:17,987 [Rank 00] Saved checkpoint to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/checkpoints/100
2024-10-25 16:14:17,989 [Rank 00] Saving export at iteration 100
2024-10-25 16:14:17,999 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_0.safetensors
2024-10-25 16:14:31,599 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_1.safetensors
2024-10-25 16:14:40,670 [Rank 00] Saving index to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict.safetensors.index.json
2024-10-25 16:14:40,679 [Rank 00] Saved export to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100
Checkpoint is saved in 7.76 s, and each shard is 2.53 GiB, so write speed is 334 MiB/s/process, or 2670 MiB/s/node. This seems reasonable, but I don't know what speed we should be able to get in theory.
The first export file is saved in 13.60 s and is 8.01 GiB, so write speed is 603 MiB/s. (2nd file has 618 MiB/s). This is at least 4x too slow.
📜 Environment Information
DGX-H100, saving locally (NVMe).
The text was updated successfully, but these errors were encountered:
Distributed checkpoints files are just a single big tensor, so the safetensors and torch formats are overkill. We can get away with a pure binary format like numpy, which has near-optimal write speed and can be converted to/from torch without copy. I got 987 MiB/s with this including the cpu-gpu conversion, which is still way faster than safetensors. We could do better by buffering to overlap the transfer and save, but it doesn't matter because we'll want async checkpoints anyway.
So I propose the new format for distributed checkpoints in v0.2:
Remove the metadata from saved shards, keep only the tensor itself.
Use metadata.yaml exclusively for checkpoint metadata.
Use numpy format to save distributed shards.
This means dropping backward compatibility, but we were already planning to do that. For the future I think we should always prevent loading distributed checkpoints with a different version, and enforce the state_sict format for long-term storage.
I'd rather address this in 0.3 and put #25 in 0.2, mostly because #26 affects the internal checkpoint format, whereas #25 is user facing and improves usability signficantly. Let's discuss next week.
🐞 Describe the Bug
I noticed checkpoint saving is suspiciously slow in some tests, so I decided to investigate.
Checkpoint saving should be bottlenecked by hardware (disk write speed), but turns out it's not. There seems to be something in torch save / safetensors that makes serialization slow.
On H100 NVMe, we should have more than 2 GiB/s of write speed, but
torch.save
andsafetensors.torch.save_file
give less than a third of that.I think the impact isn't too big for distributed checkpoints because multiple processes are saving at the same time, but it definitely makes exports and conversions slower than they should be.
This problem is hard to solve because it's in other libraries, but maybe we can find some kind of hack to speed things up, like using multiple serialization processes.
🔄 Steps to Reproduce
Third-party library benchmarks
Running this code:
We get:
I also tried with remote storage (commented path), it's a bit slower (~1.5 GiB/s) but follows the same pattern.
Fast-LLM benchmarks
Running the Mistral-7B 4-node benchmark on H100 nodes, with added checkpoint and export:
Checkpoint is saved in 7.76 s, and each shard is 2.53 GiB, so write speed is 334 MiB/s/process, or 2670 MiB/s/node. This seems reasonable, but I don't know what speed we should be able to get in theory.
The first export file is saved in 13.60 s and is 8.01 GiB, so write speed is 603 MiB/s. (2nd file has 618 MiB/s). This is at least 4x too slow.
📜 Environment Information
DGX-H100, saving locally (NVMe).
The text was updated successfully, but these errors were encountered: