Skip to content

Commit

Permalink
Use huggingface_hub helper function to split state dict (huggingface#…
Browse files Browse the repository at this point in the history
…31091)

* shard saving from hf hub

* index = None

* fix tests

* indent
  • Loading branch information
SunMarc authored and zucchini-nlp committed Jun 14, 2024
1 parent 2002625 commit d4cd9d0
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"fugashi>=1.0",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.0,<1.0",
"huggingface-hub>=0.23.2,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.0,<1.0",
"huggingface-hub": "huggingface-hub>=0.23.2,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
Expand Down
24 changes: 20 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from zipfile import is_zipfile

import torch
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss, Identity
Expand Down Expand Up @@ -362,6 +363,10 @@ def shard_checkpoint(
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
The name of the model save file.
"""
logger.warning(
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
"split_torch_state_dict_into_shards from huggingface_hub library"
)
max_shard_size = convert_file_size_to_int(max_shard_size)

sharded_state_dicts = [{}]
Expand Down Expand Up @@ -2618,7 +2623,17 @@ def save_pretrained(
else:
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME

shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Save index if sharded
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}

# Clean the folder from a previous save
for filename in os.listdir(save_directory):
Expand All @@ -2634,14 +2649,15 @@ def save_pretrained(
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shards.keys()
and filename not in state_dict_split.filename_to_tensors.keys()
and is_main_process
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

# Save the model
for shard_file, shard in shards.items():
for shard_file, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
# remake shard with onloaded parameters if necessary
if module_map:
if accelerate_version < version.parse("0.31"):
Expand Down Expand Up @@ -2680,7 +2696,7 @@ def save_pretrained(
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

Expand Down
7 changes: 2 additions & 5 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def test_checkpoint_sharding_local_bin(self):

with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
for max_size in ["50kB", "100kB", "200kB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)

# Get each shard file and its size
Expand All @@ -686,10 +686,7 @@ def test_checkpoint_sharding_local_bin(self):

# Check a file is bigger than max_size only when it has a single weight
for shard_file, size in shard_to_size.items():
if max_size.endswith("kiB"):
max_size_int = int(max_size[:-3]) * 2**10
else:
max_size_int = int(max_size[:-2]) * 10**3
max_size_int = int(max_size[:-2]) * 10**3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
Expand Down

0 comments on commit d4cd9d0

Please sign in to comment.