From 12c7a870a9468a8d79d2267e5e9e3567b67a947e Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 6 Feb 2024 17:18:30 +0100 Subject: [PATCH] Revert "[WIP] Hard error when ignoring tensors." (#28898) Revert "[WIP] Hard error when ignoring tensors. (#27484)" This reverts commit 2da28c4b41bba23969a8afe97c3dfdcbc47a57dc. --- src/transformers/modeling_utils.py | 108 ++++------------------------- tests/test_modeling_utils.py | 20 ------ 2 files changed, 15 insertions(+), 113 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 71b8ac979ab7b8..dd19189332cf1e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,7 +29,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from zipfile import is_zipfile import torch @@ -570,65 +570,6 @@ def set_initialized_submodules(model, state_dict_keys): return not_initialized_submodules -def _end_ptr(tensor: torch.Tensor) -> int: - # extract the end of the pointer if the tensor is a slice of a bigger tensor - if tensor.nelement(): - stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() - else: - stop = tensor.data_ptr() - return stop - - -def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: - filtered_tensors = [] - for shared in tensors: - if len(shared) < 2: - filtered_tensors.append(shared) - continue - - areas = [] - for name in shared: - tensor = state_dict[name] - areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) - areas.sort() - - _, last_stop, last_name = areas[0] - filtered_tensors.append({last_name}) - for start, stop, name in areas[1:]: - if start >= last_stop: - filtered_tensors.append({name}) - else: - filtered_tensors[-1].add(name) - last_stop = stop - disjoint_tensors = [] - shared_tensors = [] - for tensors in filtered_tensors: - if len(tensors) == 1: - disjoint_tensors.append(tensors.pop()) - else: - shared_tensors.append(tensors) - return shared_tensors, disjoint_tensors - - -def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: - shared_tensors = [] - identical = [] - for shared in tensors: - if len(shared) < 2: - continue - - areas = collections.defaultdict(set) - for name in shared: - tensor = state_dict[name] - area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) - areas[area].add(name) - if len(areas) == 1: - identical.append(shared) - else: - shared_tensors.append(shared) - return shared_tensors, identical - - def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -2441,8 +2382,6 @@ def save_pretrained( # These are all the pointers of shared tensors. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} warn_names = set() - error_names = set() - to_delete_names = set() for names in shared_ptrs.values(): # Removing the keys which are declared as known duplicates on # load. This allows to make sure the name which is kept is consistent. @@ -2453,42 +2392,25 @@ def save_pretrained( if matches_pattern and name in state_dict: found += 1 if found < len(names): - to_delete_names.add(name) - # We are entering a place where the weights and the transformers configuration do NOT match. - shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) - # Those are actually tensor sharing but disjoint from each other, we can safely clone them - # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. - for name in disjoint_names: - state_dict[name] = state_dict[name].clone() - - # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. - # If the link between tensors was done at runtime then `from_pretrained` will not get - # the key back leading to random tensor. A proper warning will be shown - # during reload (if applicable), but since the file is not necessarily compatible with - # the config, better show a proper warning. - shared_names, identical_names = _find_identical(shared_names, state_dict) - # delete tensors that have identical storage - for inames in identical_names: - known = inames.intersection(to_delete_names) - for name in known: - del state_dict[name] - unknown = sorted(inames.difference(to_delete_names)) - for name in unknown[1:]: - del state_dict[name] - warn_names.add(name) - - error_names.update(shared_names) - + del state_dict[name] + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + found = 0 + for name in names: + if name in state_dict: + found += 1 + if found > 1: + del state_dict[name] + warn_names.add(name) if len(warn_names) > 0: logger.warning_once( f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", ) - if len(error_names) > 0: - raise RuntimeError( - f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", - ) - # Shard the model if it is too big. if not _hf_peft_config_loaded: weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f7878cb68d803d..cef56822dc3e95 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -257,26 +257,6 @@ def test_model_from_pretrained_subfolder(self): self.assertTrue(check_models_equal(model, model_loaded)) - def test_model_manually_shared_disjointed_tensors_optimum(self): - config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") - model = BertModel(config) - - # Let's fuse qkv - attn = model.encoder.layer[0].attention.self - q = attn.query.weight - k = attn.key.weight - v = attn.value.weight - # Force some shared storage - qkv = torch.stack([q, k, v], dim=0) - attn.query.weight = torch.nn.Parameter(qkv[0]) - attn.key.weight = torch.nn.Parameter(qkv[1]) - attn.value.weight = torch.nn.Parameter(qkv[2]) - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir) - model_loaded = BertModel.from_pretrained(tmp_dir) - - self.assertTrue(check_models_equal(model, model_loaded)) - def test_model_from_pretrained_subfolder_sharded(self): config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel(config)