diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 758fe4d1fdf398..cfbd7c8d6aaf03 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1065,6 +1065,23 @@ def test_cached_files_are_used_when_internet_is_down(self): # This check we did call the fake head request mock_head.assert_called() + @require_accelerate + @mark.accelerate_tests + def test_save_model_with_device_map_cpu(self): + model_id = "hf-internal-testing/tiny-random-gpt2" + inputs = torch.tensor([[1, 2, 3]]) + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu") + output = model(inputs)[0] + model.save_pretrained( + tmp_dir, max_shard_size="200KB" + ) # model is 1.6MB, max shard size is allocated to cpu by default + saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map="cpu") + saved_model_output = saved_model(inputs)[0] + + self.assertTrue(torch.allclose(output, saved_model_output)) + @require_accelerate @mark.accelerate_tests @require_torch_accelerator @@ -1083,9 +1100,9 @@ def test_save_offloaded_model(self): # check_models_equal requires onloaded tensors model_id = "hf-internal-testing/tiny-random-gpt2" - onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu") + onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu").to(f"{torch_device}:0") inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0") - cpu_output = onloaded_model(inputs)[0] + output = onloaded_model(inputs)[0] with tempfile.TemporaryDirectory() as tmp_dir: offload_folder = os.path.join(tmp_dir, "offload") @@ -1099,7 +1116,7 @@ def test_save_offloaded_model(self): saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map) postsaved_output = saved_model(inputs)[0] - self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4)) + self.assertTrue(torch.allclose(output, presaved_output, atol=1e-4)) self.assertTrue(torch.allclose(presaved_output, postsaved_output)) @require_safetensors