Skip to content

Commit

Permalink
Run HF dataset processing on local rank 0 first (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 6, 2023
1 parent ffb58f1 commit c2f5742
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from typing import Any, Callable, Dict, List, Optional, Union

import datasets as hf_datasets
from composer.utils import dist
from omegaconf import DictConfig
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -332,6 +333,16 @@ def build_from_hf(
preprocessing_fn = self.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name)

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
# Once local rank 0 is done, the datasets are all cached on disk, and all other ranks
# can just read them.
if dist.get_local_rank() != 0:
log.debug('Waiting for local_rank 0 to finish data prep')
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

def dataset_mapper(example: Dict):
Expand All @@ -340,18 +351,21 @@ def dataset_mapper(example: Dict):
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
num_cpus_to_use = max(1, detected_cpu_count - 4)
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)
prompt_length_filtered_dataset = tokenized_dataset.filter(
lambda example: len(example['input_ids']) < max_seq_len,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(
Expand All @@ -361,17 +375,37 @@ def dataset_mapper(example: Dict):
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
)

pad_token_id = tokenizer.pad_token_id
empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
lambda example: len(example['input_ids']) > 0 and len(example[
'labels']) > 0 and any(token_id != tokenizer.pad_token_id
for token_id in example['labels']))
'labels']) > 0 and any(token_id != pad_token_id
for token_id in example['labels']),
num_proc=num_cpus_to_use,
desc='Filtering out empty examples')

log.debug('Done tokenizing and filtering examples.')

empty_examples_removed = len(prompt_length_filtered_dataset) - len(
empty_examples_dropped_dataset)
if empty_examples_removed > 0:
warnings.warn(
f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
+ 'or the response was only padding tokens.')

# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_data_prep')

# All ranks sync up at this barrier, having completed data processing
dist.barrier()

# Last, local rank 0 cleans up the signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

log.debug('All ranks finished data prep')
return empty_examples_dropped_dataset

def build_from_streaming(self, *args: Any,
Expand Down

0 comments on commit c2f5742

Please sign in to comment.