From c2f5742d5d15e26b510bead331b35a82258b6c44 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:01:42 -0800 Subject: [PATCH] Run HF dataset processing on local rank 0 first (#716) --- llmfoundry/data/finetuning/tasks.py | 40 ++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index edbfcc28c7..3673a48217 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -38,6 +38,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from typing import Any, Callable, Dict, List, Optional, Union import datasets as hf_datasets +from composer.utils import dist from omegaconf import DictConfig from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase @@ -332,6 +333,16 @@ def build_from_hf( preprocessing_fn = self.get_preprocessing_fn_from_str( proto_preprocessing_fn, dataset_name) + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' + + # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. + # Once local rank 0 is done, the datasets are all cached on disk, and all other ranks + # can just read them. + if dist.get_local_rank() != 0: + log.debug('Waiting for local_rank 0 to finish data prep') + with dist.local_rank_zero_download_and_wait(signal_file_path): + pass + dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs) def dataset_mapper(example: Dict): @@ -340,7 +351,8 @@ def dataset_mapper(example: Dict): return _tokenize_formatted_example(example, tokenizer) detected_cpu_count = os.cpu_count() or 1 - num_cpus_to_use = max(1, detected_cpu_count - 4) + detected_cpus_with_margin = detected_cpu_count - 8 + num_cpus_to_use = max(1, detected_cpus_with_margin) columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( @@ -348,10 +360,12 @@ def dataset_mapper(example: Dict): batched=False, remove_columns=columns_to_remove, num_proc=num_cpus_to_use, + desc='Tokenizing dataset', ) prompt_length_filtered_dataset = tokenized_dataset.filter( lambda example: len(example['input_ids']) < max_seq_len, num_proc=num_cpus_to_use, + desc='Filtering out long prompts', ) examples_removed = len(tokenized_dataset) - len( @@ -361,10 +375,16 @@ def dataset_mapper(example: Dict): f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.' ) + pad_token_id = tokenizer.pad_token_id empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter( lambda example: len(example['input_ids']) > 0 and len(example[ - 'labels']) > 0 and any(token_id != tokenizer.pad_token_id - for token_id in example['labels'])) + 'labels']) > 0 and any(token_id != pad_token_id + for token_id in example['labels']), + num_proc=num_cpus_to_use, + desc='Filtering out empty examples') + + log.debug('Done tokenizing and filtering examples.') + empty_examples_removed = len(prompt_length_filtered_dataset) - len( empty_examples_dropped_dataset) if empty_examples_removed > 0: @@ -372,6 +392,20 @@ def dataset_mapper(example: Dict): f'Dropped {empty_examples_removed} examples where the prompt or response was empty, ' + 'or the response was only padding tokens.') + # Now local rank 0 indicates to the other ranks that it is done + if dist.get_local_rank() == 0: + log.debug('Local rank 0 finished data prep') + with open(signal_file_path, 'wb') as f: + f.write(b'local_rank0_completed_data_prep') + + # All ranks sync up at this barrier, having completed data processing + dist.barrier() + + # Last, local rank 0 cleans up the signal file + if dist.get_local_rank() == 0: + os.remove(signal_file_path) + + log.debug('All ranks finished data prep') return empty_examples_dropped_dataset def build_from_streaming(self, *args: Any,