diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c093a860960b60..79acf681c101fd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,6 +30,7 @@ from dataclasses import dataclass from functools import partial, wraps from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from zipfile import is_zipfile import torch from packaging import version @@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): map_location = "meta" else: map_location = "cpu" - - return torch.load(checkpoint_file, map_location=map_location, weights_only=True) + extra_args = {} + # mmap can only be used with files serialized with zipfile-based format. + if ( + isinstance(checkpoint_file, str) + and map_location != "meta" + and version.parse(torch.__version__) >= version.parse("2.1.0") + and is_zipfile(checkpoint_file) + ): + extra_args = {"mmap": True} + return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args) except Exception as e: try: with open(checkpoint_file) as f: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 23071b93d5fe3b..15c610563dd2fa 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -101,7 +101,7 @@ from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding - from transformers.modeling_utils import no_init_weights + from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -536,6 +536,54 @@ class CopyClass(base_class): ).item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_torch_save_load(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.__class__ not in MODEL_MAPPING: + return + base_class = MODEL_MAPPING[config.__class__] + + if isinstance(base_class, tuple): + base_class = base_class[0] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + # make a copy of model class to not break future tests + # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class + class CopyClass(base_class): + pass + + base_class_copy = CopyClass + + # make sure that all keys are expected for test + base_class_copy._keys_to_ignore_on_load_missing = [] + + # make init deterministic, but make sure that + # non-initialized weights throw errors nevertheless + base_class_copy._init_weights = _mock_init_weights + base_class_copy.init_weights = _mock_all_init_weights + + model = model_class(config) + state_dict = model.state_dict() + + def check_equal(loaded): + for key in state_dict.keys(): + max_diff = torch.max( + state_dict()[key] ^ loaded[key] + if isinstance(state_dict[key], torch.BoolTensor) + else torch.abs(state_dict[key] - loaded[key]) + ).item() + self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical") + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin") + torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True) + check_equal(load_state_dict(pt_checkpoint_path)) + torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False) + check_equal(load_state_dict(pt_checkpoint_path)) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()