From ce867790ceccead60e95862b19e0c827837764da Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 11 Dec 2024 18:22:35 +0000 Subject: [PATCH 1/3] fix loading with only state dict and config --- src/transformers/modeling_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f349847b1fd7a1..257776b0835be0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4017,8 +4017,8 @@ def from_pretrained( loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = list(state_dict.keys()) - - if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())): + #TODO: find better condition + if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) and pretrained_model_name_or_path is not None: # In case some weights need to be kept in float32 and accelerate is not installed, # we later on want to take the path where state_dict is not None, that is the one # that do not require accelerate. @@ -4674,7 +4674,7 @@ def _find_mismatched_keys( ) # For GGUF models `state_dict` is never set to None as the state dict is always small - if gguf_path: + if gguf_path or low_cpu_mem_usage: fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, From b35207432fabb78ac387df5ac6c8ddcfcef2e235 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 11 Dec 2024 18:31:35 +0000 Subject: [PATCH 2/3] style --- src/transformers/modeling_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 257776b0835be0..b2846f3dc92dac 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4017,8 +4017,12 @@ def from_pretrained( loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = list(state_dict.keys()) - #TODO: find better condition - if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) and pretrained_model_name_or_path is not None: + # TODO: find better condition + if ( + gguf_path is None + and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) + and pretrained_model_name_or_path is not None + ): # In case some weights need to be kept in float32 and accelerate is not installed, # we later on want to take the path where state_dict is not None, that is the one # that do not require accelerate. From 84eb5d7852dafe2da1674f3853e812d75b769398 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 12 Dec 2024 15:38:41 +0000 Subject: [PATCH 3/3] add tests --- src/transformers/modeling_utils.py | 1 - tests/utils/test_modeling_utils.py | 20 ++++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b2846f3dc92dac..99b502e3ff4c18 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4017,7 +4017,6 @@ def from_pretrained( loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = list(state_dict.keys()) - # TODO: find better condition if ( gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 458ddeee5ff8be..31c0d01af776ac 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1750,6 +1750,26 @@ def test_save_and_load_config_with_custom_generation(self): new_model.generate(random_ids, max_new_tokens=3) self.assertTrue(len(w) == 0) + def test_load_model_with_state_dict_only(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + state_dict = model.state_dict() + config = model.config + + model_loaded = BertModel.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict + ) + self.assertTrue(check_models_equal(model, model_loaded)) + + def test_load_model_with_state_dict_only_low_cpu_mem_usage(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + state_dict = model.state_dict() + config = model.config + + model_loaded = BertModel.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True + ) + self.assertTrue(check_models_equal(model, model_loaded)) + @slow @require_torch