Skip to content

Commit

Permalink
Torch 2.1 Resumption Support (mosaicml#2665)
Browse files Browse the repository at this point in the history
* v1

* filter cuda rng

* remove unused import

* remove prints

* assert
  • Loading branch information
mvpatel2000 authored Oct 25, 2023
1 parent 151fb45 commit 485f7e5
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ def _get_state_metadata(self) -> Dict[str, Any]:
"""
metadata_dict = {}
metadata_dict['composer_env_info'] = get_composer_env_dict()
metadata_dict['torch_version'] = torch.__version__
metadata_dict['device'] = self.device.name
metadata_dict['precision'] = self.precision.value
metadata_dict['world_size'] = dist.get_world_size()
Expand Down
1 change: 0 additions & 1 deletion composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from torch.distributed.utils import _replace_by_prefix

from composer.core import Precision
from composer.utils import dist, using_torch_2

if TYPE_CHECKING:
if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse(
Expand Down
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,7 @@ def _train_loop(self) -> None:
self._spin_dataloaders_to_cur_epoch()

if self.state.timestamp.batch_in_epoch == 0 and self._rng_state is not None:
# only restore the rng state here if the step in the current epoch is zero.
# Only restore the rng state here if the step in the current epoch is zero.
reproducibility.load_rng_state(self._rng_state)
self._rng_state = None

Expand Down
10 changes: 9 additions & 1 deletion composer/utils/reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,15 @@ def load_rng_state(rng_state_dicts: List[Dict[str, Any]]):
log.debug('Restoring the RNG state')

if is_cuda_available and has_cuda_rng_state:
torch.cuda.set_rng_state(rng_state_dict['cuda'])
try:
torch.cuda.set_rng_state(rng_state_dict['cuda'])
except RuntimeError as e:
if 'RNG state is wrong size' in str(e):
warnings.warn('The CUDA RNG state could not be loaded from the checkpoint, '
'likely because a different version of torch was used to save the '
'checkpoint. Skipping loading the CUDA RNG state.')
else:
raise e

if is_cuda_available and not has_cuda_rng_state:
warnings.warn(
Expand Down
1 change: 1 addition & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_composer_metadata_in_state_dict(tmp_path, request: pytest.FixtureReques
assert expected_env_info_keys == actual_env_info_keys
assert loaded_state_dict['metadata']['composer_env_info']['composer_version'] == composer.__version__

assert loaded_state_dict['metadata']['torch_version'] == torch.__version__
assert loaded_state_dict['metadata']['device'] == 'cpu'
assert loaded_state_dict['metadata']['precision'] == 'amp_fp16'
assert loaded_state_dict['metadata']['world_size'] == 1
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def test_fsdp_mixed_with_sync(world_size, tmp_path: pathlib.Path, sync_module_st
r'ignore:MosaicMLLogger is not in the state_dict. Its state will not be restored.:UserWarning'))
])
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The CUDA RNG state could not be loaded.*:UserWarning')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
def test_fsdp_load_old_checkpoint(world_size, tmp_path: pathlib.Path, precision: str, sharding_strategy: str,
Expand Down

0 comments on commit 485f7e5

Please sign in to comment.