From 2f91a64a348b0f745ab83f66acdad2a07082cc14 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:40:26 -0800 Subject: [PATCH] Combine filters into one, to avoid datasets error (#729) --- llmfoundry/data/finetuning/tasks.py | 46 +++++++++++------------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 67a27ac239..6ba6ad96c8 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -363,43 +363,31 @@ def dataset_mapper(example: Dict): desc='Tokenizing dataset', ) - def filter_long_prompts(example: Dict) -> bool: - return len(example['input_ids']) < max_seq_len + pad_token_id = tokenizer.pad_token_id - prompt_length_filtered_dataset = tokenized_dataset.filter( - filter_long_prompts, + def filter_long_or_empty_examples(example: Dict) -> bool: + less_than_max_seq_len = len(example['input_ids']) < max_seq_len + non_empty_input = len(example['input_ids']) > 0 + non_empty_labels = len(example['labels']) > 0 + non_padding_response = any( + token_id != pad_token_id for token_id in example['labels']) + return (less_than_max_seq_len and non_empty_input and + non_empty_labels and non_padding_response) + + filtered_dataset = tokenized_dataset.filter( + filter_long_or_empty_examples, num_proc=num_cpus_to_use, desc='Filtering out long prompts', ) - examples_removed = len(tokenized_dataset) - len( - prompt_length_filtered_dataset) + examples_removed = len(tokenized_dataset) - len(filtered_dataset) if examples_removed > 0: warnings.warn( - f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.' + f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + + + 'the prompt or response was empty, or the response was all padding tokens.' ) - pad_token_id = tokenizer.pad_token_id - - def filter_empty_examples(example: Dict) -> bool: - return len(example['input_ids']) > 0 and len( - example['labels']) > 0 and any( - token_id != pad_token_id for token_id in example['labels']) - - empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter( - filter_empty_examples, - num_proc=num_cpus_to_use, - desc='Filtering out empty examples') - - log.debug('Done tokenizing and filtering examples.') - - empty_examples_removed = len(prompt_length_filtered_dataset) - len( - empty_examples_dropped_dataset) - if empty_examples_removed > 0: - warnings.warn( - f'Dropped {empty_examples_removed} examples where the prompt or response was empty, ' - + 'or the response was only padding tokens.') - # Now local rank 0 indicates to the other ranks that it is done if dist.get_local_rank() == 0: log.debug('Local rank 0 finished data prep') @@ -414,7 +402,7 @@ def filter_empty_examples(example: Dict) -> bool: os.remove(signal_file_path) log.debug('All ranks finished data prep') - return empty_examples_dropped_dataset + return filtered_dataset def build_from_streaming(self, *args: Any, **kwargs: Any) -> StreamingFinetuningDataset: