Skip to content

Commit

Permalink
Combine filters into one, to avoid datasets error (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 10, 2023
1 parent d2ddb83 commit 2f91a64
Showing 1 changed file with 17 additions and 29 deletions.
46 changes: 17 additions & 29 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,43 +363,31 @@ def dataset_mapper(example: Dict):
desc='Tokenizing dataset',
)

def filter_long_prompts(example: Dict) -> bool:
return len(example['input_ids']) < max_seq_len
pad_token_id = tokenizer.pad_token_id

prompt_length_filtered_dataset = tokenized_dataset.filter(
filter_long_prompts,
def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(
prompt_length_filtered_dataset)
examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)

pad_token_id = tokenizer.pad_token_id

def filter_empty_examples(example: Dict) -> bool:
return len(example['input_ids']) > 0 and len(
example['labels']) > 0 and any(
token_id != pad_token_id for token_id in example['labels'])

empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
filter_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out empty examples')

log.debug('Done tokenizing and filtering examples.')

empty_examples_removed = len(prompt_length_filtered_dataset) - len(
empty_examples_dropped_dataset)
if empty_examples_removed > 0:
warnings.warn(
f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
+ 'or the response was only padding tokens.')

# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
Expand All @@ -414,7 +402,7 @@ def filter_empty_examples(example: Dict) -> bool:
os.remove(signal_file_path)

log.debug('All ranks finished data prep')
return empty_examples_dropped_dataset
return filtered_dataset

def build_from_streaming(self, *args: Any,
**kwargs: Any) -> StreamingFinetuningDataset:
Expand Down

0 comments on commit 2f91a64

Please sign in to comment.