Skip to content

Commit

Permalink
Download symlink once (mosaicml#3043)
Browse files Browse the repository at this point in the history
* download symlink once

* lint
  • Loading branch information
mvpatel2000 authored Feb 21, 2024
1 parent 7855cc7 commit c2a2b7b
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,23 +1431,26 @@ def __init__(
assert latest_remote_file_name is not None
if self.state.fsdp_sharded_state_dict_enabled:
ar_object_store = maybe_create_object_store_from_uri(save_folder)
# Symlink is on object store.
# Symlink is on object store
if ar_object_store is not None:
with tempfile.TemporaryDirectory() as temp_dir:
local_symlink_file = str(Path(temp_dir) / Path('autoresume.symlink'))
formatted_latest_remote_file_name = format_name_with_dist(latest_remote_file_name,
self.state.run_name) + '.symlink'
rank0_formatted_latest_remote_file_name = dist.all_gather_object(
formatted_latest_remote_file_name)[0]
try:
ar_object_store.download_object(rank0_formatted_latest_remote_file_name, local_symlink_file)
with open(local_symlink_file, 'r') as f:
real_path = f.read()
log.debug(f'Read path {real_path} from symlink file')
autoresume_checkpoint_path = ar_object_store.get_uri(real_path)
except FileNotFoundError:
autoresume_checkpoint_path = None
# Symlink is local.
autoresume_checkpoint_path = None
if dist.get_global_rank() == 0:
with tempfile.TemporaryDirectory() as temp_dir:
local_symlink_file = str(Path(temp_dir) / Path('autoresume.symlink'))
symlink_file_name = format_name_with_dist(latest_remote_file_name,
self.state.run_name) + '.symlink'
try:
ar_object_store.download_object(symlink_file_name, local_symlink_file)
with open(local_symlink_file, 'r') as f:
real_path = f.read()
log.debug(f'Read path {real_path} from symlink file')
autoresume_checkpoint_path = ar_object_store.get_uri(real_path)
except FileNotFoundError:
pass
autoresume_path_list = [autoresume_checkpoint_path]
dist.broadcast_object_list(autoresume_path_list)
autoresume_checkpoint_path = autoresume_path_list[0]
# Symlink is local
else:
save_latest_filename = format_name_with_dist(save_latest_filename, self.state.run_name)
rank0_save_latest_filename = dist.all_gather_object(save_latest_filename)[0]
Expand All @@ -1460,7 +1463,7 @@ def __init__(
else:
autoresume_checkpoint_path = None

# Standard non-elastic codepath for autoresume.
# Standard non-elastic codepath for autoresume
else:
autoresume_checkpoint_path = self._get_autoresume_checkpoint(
save_folder=save_folder,
Expand Down

0 comments on commit c2a2b7b

Please sign in to comment.