Skip to content

Commit

Permalink
Reapply #1389 (#1561)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Sep 30, 2024
1 parent bdc58b3 commit 30cdd67
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 39 deletions.
56 changes: 20 additions & 36 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,42 +575,26 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(
finetune_dir,
f'.node_{dist.get_node_rank()}_local_rank0_completed',
)
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{name}/{split}{ext}' for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}',
) from e
else:
log.debug(
f'Could not find {name}, looking for another extension',
)
continue

os.makedirs(os.path.dirname(signal_file_path), exist_ok=True)
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished trying to download the dataset
dist.barrier()

# clean up signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
with dist.busy_wait_for_local_rank_zero(finetune_dir):
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{name}/{split}{ext}'
for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}',
) from e
else:
log.debug(
f'Could not find {name}, looking for another extension',
)
continue
dist.barrier()
break
return finetune_dir
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def build_from_hf(
if tokenizer is None:
raise ValueError('A tokenizer must be provided.')

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'
signal_file_path = dist.get_node_signal_file_name()

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
# Once local rank 0 is done, the datasets are all cached on disk, and all other ranks
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def build_inner_model(
f'init_device="{init_device}" must be either "cpu" or "meta".',
)

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
signal_file_path = dist.get_node_signal_file_name()
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def build_tokenizer(
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'
signal_file_path = dist.get_node_signal_file_name()

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
Expand Down

0 comments on commit 30cdd67

Please sign in to comment.