Skip to content

Commit

Permalink
[tests] make test_model_parallelism device-agnostic (#30844)
Browse files Browse the repository at this point in the history
* enable on xpu

* fix style

* add comment and mps
  • Loading branch information
faaany authored May 24, 2024
1 parent 42d8dd8 commit 04c7c17
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
require_safetensors,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
Expand Down Expand Up @@ -3009,8 +3010,11 @@ def check_device_map_is_respected(self, model, device_map):
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
self.assertEqual(param.device, torch.device("meta"))
elif param_device in ["mps"]:
self.assertEqual(param.device, torch.device("mps"))
else:
self.assertEqual(param.device, torch.device(param_device))
# when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))

@require_accelerate
@mark.accelerate_tests
Expand Down Expand Up @@ -3129,7 +3133,7 @@ def test_cpu_offload(self):

@require_accelerate
@mark.accelerate_tests
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand All @@ -3155,7 +3159,6 @@ def test_model_parallelism(self):
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

torch.manual_seed(0)
Expand Down

0 comments on commit 04c7c17

Please sign in to comment.