From 75c9ef37c93a8716f2ea0b3908bad90be013e3a5 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 30 Sep 2024 13:38:09 -0700 Subject: [PATCH] reapply --- llmfoundry/data/finetuning/dataloader.py | 56 +++++++++--------------- llmfoundry/data/finetuning/tasks.py | 2 +- llmfoundry/models/hf/hf_base.py | 2 +- llmfoundry/utils/builders.py | 2 +- 4 files changed, 23 insertions(+), 39 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 69051a2d51..612b8d6385 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -575,42 +575,26 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file - signal_file_path = os.path.join( - finetune_dir, - f'.node_{dist.get_node_rank()}_local_rank0_completed', - ) - if dist.get_local_rank() == 0: - try: - get_file(path=name, destination=destination, overwrite=True) - except FileNotFoundError as e: - if extension == SUPPORTED_EXTENSIONS[-1]: - files_searched = [ - f'{name}/{split}{ext}' for ext in SUPPORTED_EXTENSIONS - ] - raise FileNotFoundError( - f'Could not find a file with any of ' + \ - f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \ - f'at {files_searched}', - ) from e - else: - log.debug( - f'Could not find {name}, looking for another extension', - ) - continue - - os.makedirs(os.path.dirname(signal_file_path), exist_ok=True) - with open(signal_file_path, 'wb') as f: - f.write(b'local_rank0_completed_download') - - # Avoid the collective call until the local rank zero has finished trying to download the dataset - # so that we don't timeout for large downloads. This syncs all processes on the node - with dist.local_rank_zero_download_and_wait(signal_file_path): - # Then, wait to ensure every node has finished trying to download the dataset - dist.barrier() - - # clean up signal file - if dist.get_local_rank() == 0: - os.remove(signal_file_path) + with dist.busy_wait_for_local_rank_zero(finetune_dir): + if dist.get_local_rank() == 0: + try: + get_file(path=name, destination=destination, overwrite=True) + except FileNotFoundError as e: + if extension == SUPPORTED_EXTENSIONS[-1]: + files_searched = [ + f'{name}/{split}{ext}' + for ext in SUPPORTED_EXTENSIONS + ] + raise FileNotFoundError( + f'Could not find a file with any of ' + \ + f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \ + f'at {files_searched}', + ) from e + else: + log.debug( + f'Could not find {name}, looking for another extension', + ) + continue dist.barrier() break return finetune_dir diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e8f6484ef2..e099ffe14a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -877,7 +877,7 @@ def build_from_hf( if tokenizer is None: raise ValueError('A tokenizer must be provided.') - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' + signal_file_path = dist.get_node_signal_file_name() # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. # Once local rank 0 is done, the datasets are all cached on disk, and all other ranks diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py index d193e1067f..2ec9bbaa98 100644 --- a/llmfoundry/models/hf/hf_base.py +++ b/llmfoundry/models/hf/hf_base.py @@ -356,7 +356,7 @@ def build_inner_model( f'init_device="{init_device}" must be either "cpu" or "meta".', ) - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed' + signal_file_path = dist.get_node_signal_file_name() if dist.get_local_rank() == 0: with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_download') diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 687b21b46d..ae04b68ee5 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -494,7 +494,7 @@ def build_tokenizer( os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' + signal_file_path = dist.get_node_signal_file_name() if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: