From 04c7c176d7f70ec4b43c8c2a0327ff8d193f5c1d Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 24 May 2024 18:51:51 +0800 Subject: [PATCH] [tests] make `test_model_parallelism` device-agnostic (#30844) * enable on xpu * fix style * add comment and mps --- tests/test_modeling_common.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 20f5cf1ca2d713..30010cde9116dc 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -76,6 +76,7 @@ require_safetensors, require_torch, require_torch_gpu, + require_torch_multi_accelerator, require_torch_multi_gpu, require_torch_sdpa, slow, @@ -3009,8 +3010,11 @@ def check_device_map_is_respected(self, model, device_map): param_device = device_map[param_name] if param_device in ["cpu", "disk"]: self.assertEqual(param.device, torch.device("meta")) + elif param_device in ["mps"]: + self.assertEqual(param.device, torch.device("mps")) else: - self.assertEqual(param.device, torch.device(param_device)) + # when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu + self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}")) @require_accelerate @mark.accelerate_tests @@ -3129,7 +3133,7 @@ def test_cpu_offload(self): @require_accelerate @mark.accelerate_tests - @require_torch_multi_gpu + @require_torch_multi_accelerator def test_model_parallelism(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -3155,7 +3159,6 @@ def test_model_parallelism(self): new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) - self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0)