Skip to content

Commit

Permalink
Reduce system memory usage during checkpoint loading/saving (#8694)
Browse files Browse the repository at this point in the history
* avoid duplicate optimizer state dict fix

Signed-off-by: jiemingz <[email protected]>

* load checkpoint directly to GPU

Signed-off-by: jiemingz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tensorstore import

Signed-off-by: jiemingz <[email protected]>

* fix isort

Signed-off-by: jiemingz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix loading with torch dist ckpt

Signed-off-by: jiemingz <[email protected]>

---------

Signed-off-by: jiemingz <[email protected]>
Co-authored-by: jiemingz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
4 people authored Mar 28, 2024
1 parent 417aa51 commit e64b222
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
24 changes: 17 additions & 7 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
make_sharded_optimizer_tensor,
optim_state_to_sharding_state,
)
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer
Expand Down Expand Up @@ -254,7 +255,7 @@ def configure_ddp(self):
else:
super().configure_ddp()

def optimizer_sharded_state_dict(self):
def optimizer_sharded_state_dict(self, unsharded_optim_state=None):
"""
Sharded state dictionary for an MainParamsOptimizerWrapper.
Used to save and load the optimizer state when training with distributed_checkpoint.
Expand All @@ -274,7 +275,7 @@ def optimizer_sharded_state_dict(self):
}

if isinstance(optimizer, MegatronDistributedFusedAdam):
return optimizer.sharded_state_dict(model_sharded_state_dict)
return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state)
elif not isinstance(optimizer, MainParamsOptimizerWrapper):
# Regular optimizer, e.g. Adam or FusedAdam
init_optimizer_states(optimizer)
Expand Down Expand Up @@ -337,9 +338,14 @@ def save_checkpoint(
hasattr(self.lightning_module, 'sharded_state_dict')
and self.lightning_module.sharded_state_dict() is not None
):
assert (
len(checkpoint['optimizer_states']) == 1
), "Currently only support checkpointing 1 distributed optimizer per time!"
# converts the optimizer states to their sharded equivalents
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]

sharded_optim_state = self.optimizer_sharded_state_dict(
unsharded_optim_state=checkpoint['optimizer_states'][0]
)
checkpoint['optimizer_states'] = [sharded_optim_state]
# dist_checkpointing expects a directory so we will name the directory
# using the path with the file extension removed
checkpoint_dir = ckpt_to_dir(filepath)
Expand Down Expand Up @@ -437,9 +443,13 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
checkpoint['state_dict'] = sharded_state_dict
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]

checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path)

checkpoint = self._fix_tensors_device(checkpoint)
if self.torch_dist_ckpt:
sharded_strategy = ('torch_dist', 1)
else:
sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True)
checkpoint = dist_checkpointing.load(
sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_path, sharded_strategy=sharded_strategy
)

return checkpoint

Expand Down
5 changes: 3 additions & 2 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,9 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA
# Handle any remaining dtype conversions
super()._check_params_shard_dtypes(params_buckets)

def sharded_state_dict(self, model_sharded_state_dict):
optimizer_state_dict = self.state_dict()
def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None):
if optimizer_state_dict is None:
optimizer_state_dict = self.state_dict()

id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(),
Expand Down

0 comments on commit e64b222

Please sign in to comment.