From 31ea664dbf5356de2da10ccdf8bd533a4d48c27f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 16 Jan 2024 15:08:02 -0500 Subject: [PATCH] Add save_ignore_keys (#2868) * comment * add it * debug * add the keys * debug * debug * remove print statement * docs and tests * fix tests --------- Co-authored-by: Daniel King --- composer/callbacks/checkpoint_saver.py | 32 ++++++++++++++++++--- composer/trainer/trainer.py | 23 +++++++++++++++ composer/utils/checkpoint.py | 15 ++++++++-- tests/trainer/test_checkpoint.py | 39 ++++++++++++++++++++++++++ tests/utils/test_autolog_hparams.py | 1 + 5 files changed, 104 insertions(+), 6 deletions(-) diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index d3a9dfafc7..62fe55bb75 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -199,10 +199,6 @@ class CheckpointSaver(Callback): # noqa: D101 progress). It should return ``True`` if a checkpoint should be saved given the current state and event. - weights_only (bool): If ``True``, save only the model weights instead of the entire training state. - This parameter must be ``False`` when using DeepSpeed. Default: ``False``. - - num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints are removed first. Set to ``-1`` to keep all checkpoints locally. Default: ``-1``. @@ -214,6 +210,31 @@ class CheckpointSaver(Callback): # noqa: D101 This parameter only controls how many checkpoints are kept locally; checkpoints are not deleted from remote file systems. + weights_only (bool): If ``True``, save only the model weights instead of the entire training state. + This parameter must be ``False`` when using DeepSpeed. Default: ``False``. + + ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint, + which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list + of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch + uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2). + See :mod:`composer.core.state` for the structure of state_dict. + + Example 1: ``save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore + layer 1 weights and bias. + + Example 2: ``save_ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same + effect as the previous example if there was only 1 layer. + + Example 3: ``save_ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model. + + Example 4: ``save_ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when + saving the checkpoint. + + If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify + the state_dict before it is loaded. + + (default: ``None``) + Attributes: saved_checkpoints (List[Tuple[Timestamp, List[pathlib.Path]]]): The checkpoint timestamps and filepaths. @@ -243,6 +264,7 @@ def __init__( overwrite: bool = False, num_checkpoints_to_keep: int = -1, weights_only: bool = False, + ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None, ): folder = str(folder) filename = str(filename) @@ -267,6 +289,7 @@ def __init__( self.all_saved_checkpoints_to_timestamp: Dict[str, Timestamp] = {} self.num_checkpoints_to_keep = num_checkpoints_to_keep self.weights_only = weights_only + self.ignore_keys = ignore_keys self.start_batch = None @@ -363,6 +386,7 @@ def _save_checkpoint(self, state: State, logger: Logger): state=state, filename=filename_with_placeholders, weights_only=self.weights_only, + ignore_keys=self.ignore_keys, ) log.debug(f'Checkpoint locally saved to {saved_path}') diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 71f27fa08b..2b9c9731a5 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -700,6 +700,27 @@ class Trainer: state. This parameter has no effect if ``save_folder`` is ``None``. (default: ``False``) .. seealso:: :class:`~.CheckpointSaver` + save_ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint, + which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list + of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch + uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2). + See :mod:`composer.core.state` for the structure of state_dict. + + Example 1: ``save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore + layer 1 weights and bias. + + Example 2: ``save_ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same + effect as the previous example if there was only 1 layer. + + Example 3: ``save_ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model. + + Example 4: ``save_ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when + saving the checkpoint. + + If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify + the state_dict before it is loaded. + + (default: ``None``) save_num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints are removed first. Set to ``-1`` to keep all checkpoints locally. (default: ``-1``) @@ -866,6 +887,7 @@ def __init__( save_overwrite: bool = False, save_interval: Union[str, int, Time, Callable[[State, Event], bool]] = '1ep', save_weights_only: bool = False, + save_ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None, save_num_checkpoints_to_keep: int = -1, save_metrics: bool = False, @@ -1150,6 +1172,7 @@ def __init__( latest_remote_file_name=latest_remote_file_name, overwrite=save_overwrite, weights_only=save_weights_only, + ignore_keys=save_ignore_keys, save_interval=save_interval, num_checkpoints_to_keep=save_num_checkpoints_to_keep, ) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 339e628f03..63e87f57fe 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -16,7 +16,7 @@ import warnings from importlib import import_module from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch from packaging import version @@ -938,6 +938,7 @@ def _save_checkpoint( save_filename: str, *, weights_only: bool = False, + ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None, ) -> Union[str, None]: # noqa: D103 is_deepspeed = is_model_deepspeed(state.model) @@ -957,6 +958,15 @@ def _save_checkpoint( 'rng': reproducibility.get_rng_state(), } + if ignore_keys: + # Filter provided list of key paths + if not callable(ignore_keys): + ignore_keys = glob_filter(ignore_keys) + # Call function to modify state_dict + ignore_keys(state_dict) + # Ensure state exists + state_dict['state'] = state_dict.get('state', {}) + if state.fsdp_sharded_state_dict_enabled: # To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top # level of the state dict because the load_sharded_optimizer_state_dict function @@ -1087,9 +1097,10 @@ def save_checkpoint( filename: str = 'ep{epoch}-ba{batch}-rank{rank}', *, weights_only: bool = False, + ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None, ) -> Union[str, None]: # noqa: D103 save_filename = get_save_filename(state, filename) - return _save_checkpoint(state, save_filename, weights_only=weights_only) + return _save_checkpoint(state, save_filename, weights_only=weights_only, ignore_keys=ignore_keys) save_checkpoint.__doc__ = f"""Checkpoint the training ``state``. diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index edebae455c..166d031b4e 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -285,6 +285,7 @@ def test_checkpoint_saver_properly_constructed(self, save_folder: str, expected_ 'weights_only': False, 'save_interval': '1ep', 'num_checkpoints_to_keep': -1, + 'ignore_keys': None, } expected_folder = expected_path.rstrip('/') if expected_path != '' else '.' mock_checkpoint_saver.assert_called_once_with(folder=expected_folder, **rest_of_checkpoint_saver_kwargs) @@ -790,6 +791,44 @@ def test_load_ignore_keys(self, load_ignore_keys, weights_equal, callbacks_equal assert trainer_1_rng_state is not None deep_compare(trainer_1_rng_state, trainer_2._rng_state) + @pytest.mark.parametrize('save_ignore_keys,weights_equal,callbacks_equal,rng_equal', [ + ['*', False, False, False], + ['state/model/*', False, True, True], + ['state/callbacks/*', True, False, True], + ['rng', True, True, False], + ]) + @pytest.mark.filterwarnings('ignore:.* is not in the state_dict.*:UserWarning') + def test_save_ignore_keys(self, save_ignore_keys, weights_equal, callbacks_equal, rng_equal): + + trainer_1 = self.get_trainer(save_folder='first', save_ignore_keys=[save_ignore_keys]) + trainer_1.fit() + trainer_1_rng_state = reproducibility.get_rng_state() + trainer_1.close() + + last_checkpoint = os.path.join('first', 'ep2.pt') + trainer_2 = self.get_trainer(load_path=last_checkpoint) + + # Check weights loaded properly + with contextlib.nullcontext() if weights_equal else pytest.raises(AssertionError): + self._assert_weights_equivalent( + trainer_1.state.model, + trainer_2.state.model, + ) + + # Check callbacks state + stateful_callbacks_equal = self._stateful_callbacks_equal( + trainer_1.state.callbacks, + trainer_2.state.callbacks, + ) + if callbacks_equal: + assert stateful_callbacks_equal + else: + assert not stateful_callbacks_equal + + if rng_equal: + assert trainer_1_rng_state is not None + deep_compare(trainer_1_rng_state, trainer_2._rng_state) + @pytest.mark.remote @device('cpu') @pytest.mark.parametrize('load_weights_only', [True, False]) diff --git a/tests/utils/test_autolog_hparams.py b/tests/utils/test_autolog_hparams.py index 7804d7bd80..773fbd2299 100644 --- a/tests/utils/test_autolog_hparams.py +++ b/tests/utils/test_autolog_hparams.py @@ -164,6 +164,7 @@ def test_extract_hparams_trainer(): 'save_overwrite': False, 'save_interval': '1ep', 'save_weights_only': False, + 'save_ignore_keys': None, 'save_num_checkpoints_to_keep': -1, 'save_metrics': False,