Skip to content

Commit

Permalink
Use mmap option to load_state_dict (#28331)
Browse files Browse the repository at this point in the history
Use mmap option to load_state_dict (#28331)
  • Loading branch information
Weiming Zhao authored Jan 10, 2024
1 parent 0f2f0c6 commit 701298d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from zipfile import is_zipfile

import torch
from packaging import version
Expand Down Expand Up @@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
map_location = "meta"
else:
map_location = "cpu"

return torch.load(checkpoint_file, map_location=map_location, weights_only=True)
extra_args = {}
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args)
except Exception as e:
try:
with open(checkpoint_file) as f:
Expand Down
50 changes: 49 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
from torch import nn

from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights
from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage


Expand Down Expand Up @@ -536,6 +536,54 @@ class CopyClass(base_class):
).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_torch_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__]

if isinstance(base_class, tuple):
base_class = base_class[0]

for model_class in self.all_model_classes:
if model_class == base_class:
continue

# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass

base_class_copy = CopyClass

# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights

model = model_class(config)
state_dict = model.state_dict()

def check_equal(loaded):
for key in state_dict.keys():
max_diff = torch.max(
state_dict()[key] ^ loaded[key]
if isinstance(state_dict[key], torch.BoolTensor)
else torch.abs(state_dict[key] - loaded[key])
).item()
self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin")
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True)
check_equal(load_state_dict(pt_checkpoint_path))
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False)
check_equal(load_state_dict(pt_checkpoint_path))

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down

0 comments on commit 701298d

Please sign in to comment.