diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 1e6b1c9c6f6..27aeedc5a8d 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1540,7 +1540,7 @@ def get_state_dict_offloaded_model(model: nn.Module): placeholders.add(name + f".{key}") continue params = module_state_dict[key] - state_dict[name + f".{key}"] = params + state_dict[name + f".{key}"] = params.to("cpu") # move buffers to cpu for key in placeholders.copy(): if key in state_dict: placeholders.remove(key) @@ -1923,7 +1923,7 @@ def align_module_device(module: torch.nn.Module, execution_device: Optional[torc module._hf_hook.execution_device = original_device elif execution_device is not None: - devices = {name: param.device for name, param in module.named_parameters()} + devices = {name: param.device for name, param in module.named_parameters(recurse=False)} try: for name in devices: set_module_tensor_to_device(module, name, execution_device) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 41ce475c6de..4d01c7ef514 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -43,6 +43,7 @@ convert_file_size_to_int, find_tied_parameters, get_balanced_memory, + get_state_dict_offloaded_model, infer_auto_device_map, load_checkpoint_in_model, load_state_dict, @@ -66,6 +67,15 @@ def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) +class NestedModelForTest(nn.Module): + def __init__(self): + super().__init__() + self.model = ModelForTest() + + def forward(self, x): + return self.model(x) + + class LinearWithNonPersistentBuffers(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {"device": device, "dtype": dtype} @@ -788,6 +798,19 @@ def test_convert_file_size(self): with self.assertRaises(ValueError): convert_file_size_to_int("-1GB") + def test_get_state_dict_offloaded_model(self): + for model_cls in (ModelForTest, NestedModelForTest): + model = model_cls() + execution_device = torch.device(torch_device) + original_state_dict = model.state_dict() + + cpu_offload(model, execution_device=execution_device) + state_dict = get_state_dict_offloaded_model(model) + + assert original_state_dict.keys() == state_dict.keys() + for key in original_state_dict: + assert torch.equal(original_state_dict[key], state_dict[key]) + def test_align_module_device_simple(self): model = ModelForTest() execution_device = torch.device(torch_device) @@ -834,3 +857,13 @@ def test_align_module_device_offloaded(self): assert model.linear1.weight.device == offload_device assert model.batchnorm.weight.device == offload_device assert model.linear2.weight.device == offload_device + + def test_align_module_device_offloaded_nested(self): + model = NestedModelForTest() + execution_device = torch.device(torch_device) + align_device = torch.device("cpu") + cpu_offload(model, execution_device=execution_device) + for module in model.modules(): + with align_module_device(module, align_device): + for param in model.parameters(recurse=False): + assert param.device == align_device