From 485f7e539d9f34426d1ca6395701ef48d26c67c5 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 24 Oct 2023 22:30:36 -0700 Subject: [PATCH] Torch 2.1 Resumption Support (#2665) * v1 * filter cuda rng * remove unused import * remove prints * assert --- composer/core/state.py | 1 + composer/trainer/mosaic_fsdp_utils.py | 1 - composer/trainer/trainer.py | 2 +- composer/utils/reproducibility.py | 10 +++++++++- tests/test_state.py | 1 + tests/trainer/test_fsdp_checkpoint.py | 1 + 6 files changed, 13 insertions(+), 3 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 69cd1df8af..dbdba40170 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -823,6 +823,7 @@ def _get_state_metadata(self) -> Dict[str, Any]: """ metadata_dict = {} metadata_dict['composer_env_info'] = get_composer_env_dict() + metadata_dict['torch_version'] = torch.__version__ metadata_dict['device'] = self.device.name metadata_dict['precision'] = self.precision.value metadata_dict['world_size'] = dist.get_world_size() diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 40ca0ae160..08061978e1 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -27,7 +27,6 @@ from torch.distributed.utils import _replace_by_prefix from composer.core import Precision -from composer.utils import dist, using_torch_2 if TYPE_CHECKING: if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse( diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 78c3f21bd4..2063b8f341 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2001,7 +2001,7 @@ def _train_loop(self) -> None: self._spin_dataloaders_to_cur_epoch() if self.state.timestamp.batch_in_epoch == 0 and self._rng_state is not None: - # only restore the rng state here if the step in the current epoch is zero. + # Only restore the rng state here if the step in the current epoch is zero. reproducibility.load_rng_state(self._rng_state) self._rng_state = None diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py index e6555616ac..0895b530d9 100644 --- a/composer/utils/reproducibility.py +++ b/composer/utils/reproducibility.py @@ -215,7 +215,15 @@ def load_rng_state(rng_state_dicts: List[Dict[str, Any]]): log.debug('Restoring the RNG state') if is_cuda_available and has_cuda_rng_state: - torch.cuda.set_rng_state(rng_state_dict['cuda']) + try: + torch.cuda.set_rng_state(rng_state_dict['cuda']) + except RuntimeError as e: + if 'RNG state is wrong size' in str(e): + warnings.warn('The CUDA RNG state could not be loaded from the checkpoint, ' + 'likely because a different version of torch was used to save the ' + 'checkpoint. Skipping loading the CUDA RNG state.') + else: + raise e if is_cuda_available and not has_cuda_rng_state: warnings.warn( diff --git a/tests/test_state.py b/tests/test_state.py index 2660cc25ab..1734e08227 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -150,6 +150,7 @@ def test_composer_metadata_in_state_dict(tmp_path, request: pytest.FixtureReques assert expected_env_info_keys == actual_env_info_keys assert loaded_state_dict['metadata']['composer_env_info']['composer_version'] == composer.__version__ + assert loaded_state_dict['metadata']['torch_version'] == torch.__version__ assert loaded_state_dict['metadata']['device'] == 'cpu' assert loaded_state_dict['metadata']['precision'] == 'amp_fp16' assert loaded_state_dict['metadata']['world_size'] == 1 diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 50e62e9381..8d3568ca05 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -338,6 +338,7 @@ def test_fsdp_mixed_with_sync(world_size, tmp_path: pathlib.Path, sync_module_st r'ignore:MosaicMLLogger is not in the state_dict. Its state will not be restored.:UserWarning')) ]) @pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*The CUDA RNG state could not be loaded.*:UserWarning') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='requires PyTorch 1.13 or higher') def test_fsdp_load_old_checkpoint(world_size, tmp_path: pathlib.Path, precision: str, sharding_strategy: str,