Skip to content

Commit

Permalink
mv DeepSpeedEngine param_names dict init post _configure_distributed_…
Browse files Browse the repository at this point in the history
…model (#4803)

In some backends, when params are being moved from host to device, they
might changed their python object id(), which uses a the key in the
param_names dictionary. in such case this dict might become invalid.

Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
nelyahu and mrwyattii authored Dec 15, 2023
1 parent faa00b1 commit 449e454
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,6 @@ def __init__(
# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
Expand All @@ -261,6 +258,9 @@ def __init__(
# Configure distributed model
self._configure_distributed_model(model)

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

self._get_model_parameters()

see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
Expand Down

0 comments on commit 449e454

Please sign in to comment.