diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 01cd0fcc9b..246c1b393b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1431,23 +1431,26 @@ def __init__( assert latest_remote_file_name is not None if self.state.fsdp_sharded_state_dict_enabled: ar_object_store = maybe_create_object_store_from_uri(save_folder) - # Symlink is on object store. + # Symlink is on object store if ar_object_store is not None: - with tempfile.TemporaryDirectory() as temp_dir: - local_symlink_file = str(Path(temp_dir) / Path('autoresume.symlink')) - formatted_latest_remote_file_name = format_name_with_dist(latest_remote_file_name, - self.state.run_name) + '.symlink' - rank0_formatted_latest_remote_file_name = dist.all_gather_object( - formatted_latest_remote_file_name)[0] - try: - ar_object_store.download_object(rank0_formatted_latest_remote_file_name, local_symlink_file) - with open(local_symlink_file, 'r') as f: - real_path = f.read() - log.debug(f'Read path {real_path} from symlink file') - autoresume_checkpoint_path = ar_object_store.get_uri(real_path) - except FileNotFoundError: - autoresume_checkpoint_path = None - # Symlink is local. + autoresume_checkpoint_path = None + if dist.get_global_rank() == 0: + with tempfile.TemporaryDirectory() as temp_dir: + local_symlink_file = str(Path(temp_dir) / Path('autoresume.symlink')) + symlink_file_name = format_name_with_dist(latest_remote_file_name, + self.state.run_name) + '.symlink' + try: + ar_object_store.download_object(symlink_file_name, local_symlink_file) + with open(local_symlink_file, 'r') as f: + real_path = f.read() + log.debug(f'Read path {real_path} from symlink file') + autoresume_checkpoint_path = ar_object_store.get_uri(real_path) + except FileNotFoundError: + pass + autoresume_path_list = [autoresume_checkpoint_path] + dist.broadcast_object_list(autoresume_path_list) + autoresume_checkpoint_path = autoresume_path_list[0] + # Symlink is local else: save_latest_filename = format_name_with_dist(save_latest_filename, self.state.run_name) rank0_save_latest_filename = dist.all_gather_object(save_latest_filename)[0] @@ -1460,7 +1463,7 @@ def __init__( else: autoresume_checkpoint_path = None - # Standard non-elastic codepath for autoresume. + # Standard non-elastic codepath for autoresume else: autoresume_checkpoint_path = self._get_autoresume_checkpoint( save_folder=save_folder,