diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 3af77fc3186416..2601c92cfb76df 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -600,24 +600,6 @@ def test_model_from_pretrained(self): model = BertModel.from_pretrained(model_name) self.assertIsNotNone(model) - @slow - def test_save_and_load_low_cpu_mem_usage(self): - with tempfile.TemporaryDirectory() as tmpdirname: - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model_to_save = model_class(config) - - model_to_save.save_pretrained(tmpdirname) - - model = model_class.from_pretrained( - tmpdirname, - low_cpu_mem_usage=True, - ) - - # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are - # any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error. - model.to(torch_device) - @slow @require_torch_accelerator def test_torchscript_device_change(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 32f6abcbe3aad1..c969c7c0d33b08 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -435,6 +435,23 @@ class CopyClass(model_class): max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_save_and_load_low_cpu_mem_usage(self): + with tempfile.TemporaryDirectory() as tmpdirname: + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model_to_save = model_class(config) + + model_to_save.save_pretrained(tmpdirname) + + model = model_class.from_pretrained( + tmpdirname, + low_cpu_mem_usage=True, + ) + + # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are + # any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error. + model.to(torch_device) + def test_fast_init_context_manager(self): # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ class MyClass(PreTrainedModel):