diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index c6f5af15ca..b966c918c5 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -148,10 +148,25 @@ def _get_write_mode(name: str) -> str: raise ValueError(f'{name} does not end with a valid tarfile extension.') +def _is_rng_key(key: str, value: tuple) -> bool: + """Check if the key is an RNG key. + + We expect the RNG key to be of the form 'rng.{rank}.cuda|torch|python|numpy'. + This function ensures that we don't accidentally pick up other keys. + """ + starts_with_rng = key.startswith('rng') + ends_with_expected = key.endswith(('cuda', 'torch', 'python', 'numpy')) + three_parts = isinstance(value, tuple) and len(value) == 3 + if starts_with_rng and ends_with_expected and three_parts: + return True + + return False + + def _get_num_ranks_that_saved_rng(metadata: Metadata): rng_inds = [] for field_name, field_value in metadata.planner_data.items(): - if 'rng' in field_name: + if _is_rng_key(field_name, field_value): _, rng_rank_index, _ = field_value rng_inds.append(rng_rank_index) rng_inds = set(rng_inds) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 82629d245b..c2e4929535 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -35,6 +35,7 @@ _COMPOSER_STATES_FILENAME, PartialFilePath, _ensure_valid_checkpoint, + _is_rng_key, _write_checkpoint_file, glob_filter, ) @@ -130,6 +131,23 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0): assert all(keys_in) or not any(keys_in) +@pytest.mark.parametrize( + 'key,value,expected_result', + [ + ('rng.0.cuda', ('rng', '0', 'cuda'), True), + ('rng.0.torch', ('rng', '0', 'torch'), True), + ('rng.0.numpy', ('rng', '0', 'numpy'), True), + ('rng.0.python', ('rng', '0', 'python'), True), + ('rng.0', ('rng', '0'), False), + ('test.test.rng', ('test', 'test', 'rng'), False), + ('test.rng.test', ('test', 'rng', 'test'), False), + ('test.notatuple.test', 0, False), + ], +) +def test_is_rng_key(key: str, value: tuple, expected_result: bool): + assert _is_rng_key(key, value) == expected_result + + @pytest.mark.parametrize( 'remove_field_paths,filter_params', [