From 1bc8d0aace1dd27c9021cc55a6cd4b5f2fb31085 Mon Sep 17 00:00:00 2001 From: Jeremy D <115047575+bmosaicml@users.noreply.github.com> Date: Mon, 15 Jan 2024 16:38:44 -0500 Subject: [PATCH] Add custom stopping criteria to ICL generate tasks (#2800) * add custome gen kwargs and stopping on eos token * modify test * modify test * finish * finish * finish * finish * finish pr * implement early stop * add tesT * fix bug * bug fix * add keys * diff split * fix typo * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix conditional import * add nlp metrics * remove code gen changes * fix nits --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- .../in_context_learning_evaluation.py | 167 ++++++++++-------- composer/datasets/utils.py | 60 ++++++- composer/metrics/nlp.py | 14 +- .../test_in_context_learning_datasets.py | 17 ++ tests/metrics/test_nlp_metrics.py | 6 +- 5 files changed, 184 insertions(+), 80 deletions(-) diff --git a/composer/datasets/in_context_learning_evaluation.py b/composer/datasets/in_context_learning_evaluation.py index c514356524..294bb1b2ba 100644 --- a/composer/datasets/in_context_learning_evaluation.py +++ b/composer/datasets/in_context_learning_evaluation.py @@ -16,6 +16,7 @@ from composer.core import DataSpec from composer.core.data_spec import _default_split_batch, _split_list +from composer.datasets.utils import stop_sequences_criteria from composer.utils import MissingConditionalImportError, dist, get_file if TYPE_CHECKING: @@ -139,21 +140,21 @@ def _read_dataset(self, dataset: Dataset) -> List[Dict[str, str]]: }) return result - def __init__( - self, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, - destination_path: str, - question_prelimiter: str, - fewshot_random_seed: int, - cot_delimiter: str = '', - ): + def __init__(self, + dataset_uri: str, + tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, + example_delimiter: str, + continuation_delimiter: str, + destination_path: str, + question_prelimiter: str, + fewshot_random_seed: int, + cot_delimiter: str = '', + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True): if tokenizer.eos_token_id is None: raise ValueError('`InContextLearningQATaskDataset` tokenizer must have non-null `eos_token_id`') try: @@ -166,6 +167,8 @@ def __init__( if dist.get_local_rank() == 0: get_file(dataset_uri, destination_path, overwrite=True) dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) + self.early_stopping_criteria = early_stopping_criteria + self.do_normalization = do_normalization self.samples = self._read_dataset(dataset) # pyright: ignore[reportGeneralTypeIssues] self.samples = strip_data(self.samples) self.tokenizer = tokenizer @@ -301,17 +304,26 @@ def collate_fn(self, data): # We will search for the answer within the portion of the model response # beginning with `cot_delimiter` cot_delimiter = sample['cot_delimiter'] - + stopping_criteria = None + if self.early_stopping_criteria: + if stop_sequences_criteria is None: # pyright: ignore [reportUnnecessaryComparison] + raise MissingConditionalImportError(extra_deps_group='nlp', + conda_package='transformers', + conda_channel='conda-forge') + stopping_criteria = stop_sequences_criteria(self.tokenizer, self.early_stopping_criteria, len(inputs)) batch = { 'input_ids': torch.stack(inputs), 'mode': 'generate', 'labels': answers, 'cot_delimiter': cot_delimiter, 'generation_length': self.max_answer_length, + 'stopping_criteria': self.early_stopping_criteria, + 'do_normalization': self.do_normalization, 'generation_kwargs': { 'pad_token_id': self.pad_tok_id, 'use_cache': True, - 'eos_token_id': self.tokenizer.eos_token_id + 'stopping_criteria': stopping_criteria, + 'eos_token_id': self.tokenizer.eos_token_id, } } @@ -325,7 +337,9 @@ def split_batch(self, batch: Any, microbatch_size: int): # Don't split kwargs that don't change # Normally split torch tensors # List split lists of strings - no_split = ['mode', 'generation_length', 'generation_kwargs', 'cot_delimiter'] + no_split = [ + 'mode', 'generation_length', 'generation_kwargs', 'cot_delimiter', 'do_normalization', 'stopping_criteria' + ] normal_split = ['input_ids', 'attention_mask'] list_split = ['labels'] chunked = {} @@ -341,7 +355,7 @@ def split_batch(self, batch: Any, microbatch_size: int): raise ValueError(f'Unexpected key {k}') num_chunks = len(chunked['input_ids']) for k, v in batch.items(): - if isinstance(v, (int, float, str, bool, dict)): + if k in no_split: chunked[k] = [v] * num_chunks return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] @@ -1169,23 +1183,24 @@ def split_batch(self, batch: Any, microbatch_size: int): def build_icl_dataloader( - icl_task_type: str, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], - batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str, # e.g. '' - destination_path: str, - question_prelimiter: str = '', # e.g. 'Question: ' - cot_delimiter: str = '', - fewshot_random_seed: int = 1234, - pass_at_k: int = 1, - generations_per_sample: int = 1, -) -> DataSpec: + icl_task_type: str, + dataset_uri: str, + tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], + batch_size: int, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, # e.g. 'translate english to french:' + example_delimiter: str, # e.g. '\n' + continuation_delimiter: str, # e.g. '' + destination_path: str, + question_prelimiter: str = '', # e.g. 'Question: ' + cot_delimiter: str = '', + fewshot_random_seed: int = 1234, + pass_at_k: int = 1, + generations_per_sample: int = 1, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True) -> DataSpec: if icl_task_type == 'multiple_choice': dataset = InContextLearningMultipleChoiceTaskDataset(dataset_uri, tokenizer, @@ -1236,7 +1251,9 @@ def build_icl_dataloader( destination_path=destination_path, question_prelimiter=question_prelimiter, fewshot_random_seed=fewshot_random_seed, - cot_delimiter=cot_delimiter) + cot_delimiter=cot_delimiter, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization) effective_batchsize = batch_size elif icl_task_type == 'code_evaluation': dataset = InContextLearningCodeEvalDataset(dataset_uri, @@ -1335,7 +1352,9 @@ def get_icl_task_dataloader( pass_at_k: int = 1, generations_per_sample: int = 1, cot_delimiter: str = '', - has_categories: bool = False) -> Union[DataSpec, Dict[str, DataSpec]]: + has_categories: bool = False, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True) -> Union[DataSpec, Dict[str, DataSpec]]: """This constructs a dataloader (or dataloaders if has_categories is True) capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below: >>> dl = get_icl_task_dataloader( @@ -1388,41 +1407,41 @@ def get_icl_task_dataloader( categories = sorted(output_files.keys()) for category in categories: partition_uri = output_files[category] - result_dls[category] = build_icl_dataloader( - icl_task_type, - partition_uri, - tokenizer, - batch_size, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - partition_uri + '_tmp', - question_prelimiter, - cot_delimiter, - fewshot_random_seed, - pass_at_k, - generations_per_sample, - ) + result_dls[category] = build_icl_dataloader(icl_task_type, + partition_uri, + tokenizer, + batch_size, + max_seq_len, + pad_tok_id, + num_fewshot, + prompt_string, + example_delimiter, + continuation_delimiter, + partition_uri + '_tmp', + question_prelimiter, + cot_delimiter, + fewshot_random_seed, + pass_at_k, + generations_per_sample, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization) return result_dls else: - return build_icl_dataloader( - icl_task_type, - dataset_uri, - tokenizer, - batch_size, - max_seq_len, - pad_tok_id, - num_fewshot, - prompt_string, - example_delimiter, - continuation_delimiter, - destination_path, - question_prelimiter, - cot_delimiter, - fewshot_random_seed, - pass_at_k, - generations_per_sample, - ) + return build_icl_dataloader(icl_task_type, + dataset_uri, + tokenizer, + batch_size, + max_seq_len, + pad_tok_id, + num_fewshot, + prompt_string, + example_delimiter, + continuation_delimiter, + destination_path, + question_prelimiter, + cot_delimiter, + fewshot_random_seed, + pass_at_k, + generations_per_sample, + early_stopping_criteria=early_stopping_criteria, + do_normalization=do_normalization) diff --git a/composer/datasets/utils.py b/composer/datasets/utils.py index 9f6b2aac4e..431a860900 100644 --- a/composer/datasets/utils.py +++ b/composer/datasets/utils.py @@ -6,7 +6,7 @@ import logging import textwrap import warnings -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -166,3 +166,61 @@ def add_vision_dataset_transform(dataset: VisionDataset, transform: Callable, is else: dataset.transform = transforms.Compose([dataset.transform, transform]) log.warning(transform_added_logstring) + + +try: + import transformers + + class MultiTokenEOSCriteria(transformers.StoppingCriteria): + """Criteria to stop on the specified multi-token sequence. + Slightly modified from: https://github.com/EleutherAI/lm-evaluation-harness/blob/78545d42f2ca95c6fe0ed220d456eeb94f4485e9/lm_eval/utils.py#L614-L649 + """ + + def __init__( + self, + stop_sequence: str, + tokenizer: transformers.PreTrainedTokenizer, + batch_size: int, + ) -> None: + self.done_tracker = [False] * batch_size + self.stop_sequence = stop_sequence + self.stop_sequence_ids = tokenizer.encode(stop_sequence, add_special_tokens=False) + + # we look back for 2 more tokens than it takes to encode our stop sequence + # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` + # and we don't want to mistakenly not stop a generation because our + # (string) stop sequence was output in a different tokenization + + # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, + # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized + self.stop_sequence_id_len = len(self.stop_sequence_ids) + 2 + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool: + # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence + lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:] + + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) + for i, done in enumerate(self.done_tracker): + if i >= len(lookback_tokens_batch): + # The last batch of a dataset may be smaller than `batch_size` + # Automatically set those indices in the done_tracker to True + # since those indices don't show up in the current batch + self.done_tracker[i] = True + break + elif not done: + self.done_tracker[i] = self.stop_sequence in lookback_tokens_batch[i] + return False not in self.done_tracker + + def stop_sequences_criteria( + tokenizer: transformers.PreTrainedTokenizer, + stop_sequences: List[str], + batch_size: int, + ) -> transformers.StoppingCriteriaList: + return transformers.StoppingCriteriaList([ + *[MultiTokenEOSCriteria(sequence, tokenizer, batch_size) for sequence in stop_sequences], + ]) + +except ImportError as e: + stop_sequences_criteria = None # pyright: ignore [reportGeneralTypeIssues] + MultiTokenEOSCriteria = None # pyright: ignore [reportGeneralTypeIssues] diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index a3111db0aa..dd4d665678 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -270,13 +270,23 @@ def update(self, outputs: List[str], labels: List[List[str]], batch: Optional[Di if batch is None: batch = {} cot_delimiter = batch.get('cot_delimiter', '') + do_normalization = batch.get('do_normalization', True) + stopping_criteria = batch.get('stopping_criteria', None) for sample_output, sample_labels in zip(outputs, labels): final_answer = sample_output + + if stopping_criteria is not None and len(stopping_criteria) > 0: + final_answer = re.split('|'.join(stopping_criteria), final_answer)[0] + if cot_delimiter is not None and len(cot_delimiter) > 0: final_answer = final_answer.split(cot_delimiter)[-1] - cleaned_final_answer = self.normalize_answer(final_answer) - cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels} + if do_normalization: + cleaned_final_answer = self.normalize_answer(final_answer) + cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels} + else: + cleaned_final_answer = final_answer + cleaned_sample_labels = set(sample_labels) if any(cleaned_final_answer.startswith(label) for label in cleaned_sample_labels): self.correct += torch.tensor(1.0) diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index 2e9a461fcf..ec7df306d6 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -18,6 +18,7 @@ from composer.datasets.in_context_learning_evaluation import (InContextLearningCodeEvalDataset, _get_fewshot_sample_idxs, _make_padded_input, get_icl_task_dataloader) +from composer.datasets.utils import MultiTokenEOSCriteria from composer.loggers import InMemoryLogger from composer.metrics import (InContextLearningCodeEvalAccuracy, InContextLearningLMAccuracy, InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy) @@ -66,6 +67,22 @@ def test_fewshot_sample_idxs_randomness(): assert rng_1_sample_2 != rng_3_sample_2 +def test_stop_sequences_criteria(tiny_gpt2_tokenizer): + pytest.importorskip('transformers') + eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) + seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids'] + seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + seq1 = [50257] * (len(seq2) - len(seq1)) + seq1 + input_ids = torch.tensor([seq1, seq2]) + assert not eos_criteria(input_ids, None) + + eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) + seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] + input_ids = torch.tensor([seq1, seq2]) + assert eos_criteria(input_ids, None) + + def test_batch_padding_logic(tiny_gpt2_tokenizer): continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids'] context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 75735fd839..9a3fa6760d 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -238,12 +238,12 @@ def test_in_context_learning_qa_accuracy(): def test_in_context_learning_qa_cot_accuracy(): outputs = [ - 'chain of thought ### Correct but then some more text', 'Incorrect', - 'chain of thought ### the CORREct with weird casing and spacing', + 'chain of thought ### Correct but then some more text\n\nanother chain of thought ### Incorrect answer this time', + 'Incorrect', 'chain of thought ### the CORREct with weird casing and spacing', 'incorrect chain of thought delimiter ## Correct but wrong delimiter' ] labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct'], ['correct']] - batch = {'cot_delimiter': ' ### ', 'labels': labels} + batch = {'cot_delimiter': ' ### ', 'labels': labels, 'do_normalization': True, 'stopping_criteria': '\n\n'} metric = InContextLearningQAAccuracy() metric.update(outputs, labels, batch)