Skip to content

Commit

Permalink
Add custom stopping criteria to ICL generate tasks (mosaicml#2800)
Browse files Browse the repository at this point in the history
* add custome gen kwargs and stopping on eos token

* modify test

* modify test

* finish

* finish

* finish

* finish

* finish pr

* implement early stop

* add tesT

* fix bug

* bug fix

* add keys

* diff split

* fix typo

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix conditional import

* add nlp metrics

* remove code gen changes

* fix nits

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
bmosaicml and dakinggg authored Jan 15, 2024
1 parent d497d8f commit 1bc8d0a
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 80 deletions.
167 changes: 93 additions & 74 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from composer.core import DataSpec
from composer.core.data_spec import _default_split_batch, _split_list
from composer.datasets.utils import stop_sequences_criteria
from composer.utils import MissingConditionalImportError, dist, get_file

if TYPE_CHECKING:
Expand Down Expand Up @@ -139,21 +140,21 @@ def _read_dataset(self, dataset: Dataset) -> List[Dict[str, str]]:
})
return result

def __init__(
self,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
question_prelimiter: str,
fewshot_random_seed: int,
cot_delimiter: str = '',
):
def __init__(self,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
question_prelimiter: str,
fewshot_random_seed: int,
cot_delimiter: str = '',
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True):
if tokenizer.eos_token_id is None:
raise ValueError('`InContextLearningQATaskDataset` tokenizer must have non-null `eos_token_id`')
try:
Expand All @@ -166,6 +167,8 @@ def __init__(
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
self.early_stopping_criteria = early_stopping_criteria
self.do_normalization = do_normalization
self.samples = self._read_dataset(dataset) # pyright: ignore[reportGeneralTypeIssues]
self.samples = strip_data(self.samples)
self.tokenizer = tokenizer
Expand Down Expand Up @@ -301,17 +304,26 @@ def collate_fn(self, data):
# We will search for the answer within the portion of the model response
# beginning with `cot_delimiter`
cot_delimiter = sample['cot_delimiter']

stopping_criteria = None
if self.early_stopping_criteria:
if stop_sequences_criteria is None: # pyright: ignore [reportUnnecessaryComparison]
raise MissingConditionalImportError(extra_deps_group='nlp',
conda_package='transformers',
conda_channel='conda-forge')
stopping_criteria = stop_sequences_criteria(self.tokenizer, self.early_stopping_criteria, len(inputs))
batch = {
'input_ids': torch.stack(inputs),
'mode': 'generate',
'labels': answers,
'cot_delimiter': cot_delimiter,
'generation_length': self.max_answer_length,
'stopping_criteria': self.early_stopping_criteria,
'do_normalization': self.do_normalization,
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
'stopping_criteria': stopping_criteria,
'eos_token_id': self.tokenizer.eos_token_id,
}
}

Expand All @@ -325,7 +337,9 @@ def split_batch(self, batch: Any, microbatch_size: int):
# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
no_split = ['mode', 'generation_length', 'generation_kwargs', 'cot_delimiter']
no_split = [
'mode', 'generation_length', 'generation_kwargs', 'cot_delimiter', 'do_normalization', 'stopping_criteria'
]
normal_split = ['input_ids', 'attention_mask']
list_split = ['labels']
chunked = {}
Expand All @@ -341,7 +355,7 @@ def split_batch(self, batch: Any, microbatch_size: int):
raise ValueError(f'Unexpected key {k}')
num_chunks = len(chunked['input_ids'])
for k, v in batch.items():
if isinstance(v, (int, float, str, bool, dict)):
if k in no_split:
chunked[k] = [v] * num_chunks
return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]

Expand Down Expand Up @@ -1169,23 +1183,24 @@ def split_batch(self, batch: Any, microbatch_size: int):


def build_icl_dataloader(
icl_task_type: str,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str, # e.g. ''
destination_path: str,
question_prelimiter: str = '', # e.g. 'Question: '
cot_delimiter: str = '',
fewshot_random_seed: int = 1234,
pass_at_k: int = 1,
generations_per_sample: int = 1,
) -> DataSpec:
icl_task_type: str,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str, # e.g. ''
destination_path: str,
question_prelimiter: str = '', # e.g. 'Question: '
cot_delimiter: str = '',
fewshot_random_seed: int = 1234,
pass_at_k: int = 1,
generations_per_sample: int = 1,
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True) -> DataSpec:
if icl_task_type == 'multiple_choice':
dataset = InContextLearningMultipleChoiceTaskDataset(dataset_uri,
tokenizer,
Expand Down Expand Up @@ -1236,7 +1251,9 @@ def build_icl_dataloader(
destination_path=destination_path,
question_prelimiter=question_prelimiter,
fewshot_random_seed=fewshot_random_seed,
cot_delimiter=cot_delimiter)
cot_delimiter=cot_delimiter,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization)
effective_batchsize = batch_size
elif icl_task_type == 'code_evaluation':
dataset = InContextLearningCodeEvalDataset(dataset_uri,
Expand Down Expand Up @@ -1335,7 +1352,9 @@ def get_icl_task_dataloader(
pass_at_k: int = 1,
generations_per_sample: int = 1,
cot_delimiter: str = '',
has_categories: bool = False) -> Union[DataSpec, Dict[str, DataSpec]]:
has_categories: bool = False,
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True) -> Union[DataSpec, Dict[str, DataSpec]]:
"""This constructs a dataloader (or dataloaders if has_categories is True) capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below:
>>> dl = get_icl_task_dataloader(
Expand Down Expand Up @@ -1388,41 +1407,41 @@ def get_icl_task_dataloader(
categories = sorted(output_files.keys())
for category in categories:
partition_uri = output_files[category]
result_dls[category] = build_icl_dataloader(
icl_task_type,
partition_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
partition_uri + '_tmp',
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
)
result_dls[category] = build_icl_dataloader(icl_task_type,
partition_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
partition_uri + '_tmp',
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization)
return result_dls
else:
return build_icl_dataloader(
icl_task_type,
dataset_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
destination_path,
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
)
return build_icl_dataloader(icl_task_type,
dataset_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
destination_path,
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization)
60 changes: 59 additions & 1 deletion composer/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import textwrap
import warnings
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -166,3 +166,61 @@ def add_vision_dataset_transform(dataset: VisionDataset, transform: Callable, is
else:
dataset.transform = transforms.Compose([dataset.transform, transform])
log.warning(transform_added_logstring)


try:
import transformers

class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence.
Slightly modified from: https://github.com/EleutherAI/lm-evaluation-harness/blob/78545d42f2ca95c6fe0ed220d456eeb94f4485e9/lm_eval/utils.py#L614-L649
"""

def __init__(
self,
stop_sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
batch_size: int,
) -> None:
self.done_tracker = [False] * batch_size
self.stop_sequence = stop_sequence
self.stop_sequence_ids = tokenizer.encode(stop_sequence, add_special_tokens=False)

# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization

# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
self.stop_sequence_id_len = len(self.stop_sequence_ids) + 2
self.tokenizer = tokenizer

def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:]

lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if i >= len(lookback_tokens_batch):
# The last batch of a dataset may be smaller than `batch_size`
# Automatically set those indices in the done_tracker to True
# since those indices don't show up in the current batch
self.done_tracker[i] = True
break
elif not done:
self.done_tracker[i] = self.stop_sequence in lookback_tokens_batch[i]
return False not in self.done_tracker

def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList([
*[MultiTokenEOSCriteria(sequence, tokenizer, batch_size) for sequence in stop_sequences],
])

except ImportError as e:
stop_sequences_criteria = None # pyright: ignore [reportGeneralTypeIssues]
MultiTokenEOSCriteria = None # pyright: ignore [reportGeneralTypeIssues]
14 changes: 12 additions & 2 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,23 @@ def update(self, outputs: List[str], labels: List[List[str]], batch: Optional[Di
if batch is None:
batch = {}
cot_delimiter = batch.get('cot_delimiter', '')
do_normalization = batch.get('do_normalization', True)
stopping_criteria = batch.get('stopping_criteria', None)
for sample_output, sample_labels in zip(outputs, labels):
final_answer = sample_output

if stopping_criteria is not None and len(stopping_criteria) > 0:
final_answer = re.split('|'.join(stopping_criteria), final_answer)[0]

if cot_delimiter is not None and len(cot_delimiter) > 0:
final_answer = final_answer.split(cot_delimiter)[-1]

cleaned_final_answer = self.normalize_answer(final_answer)
cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels}
if do_normalization:
cleaned_final_answer = self.normalize_answer(final_answer)
cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels}
else:
cleaned_final_answer = final_answer
cleaned_sample_labels = set(sample_labels)

if any(cleaned_final_answer.startswith(label) for label in cleaned_sample_labels):
self.correct += torch.tensor(1.0)
Expand Down
17 changes: 17 additions & 0 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from composer.datasets.in_context_learning_evaluation import (InContextLearningCodeEvalDataset,
_get_fewshot_sample_idxs, _make_padded_input,
get_icl_task_dataloader)
from composer.datasets.utils import MultiTokenEOSCriteria
from composer.loggers import InMemoryLogger
from composer.metrics import (InContextLearningCodeEvalAccuracy, InContextLearningLMAccuracy,
InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy)
Expand Down Expand Up @@ -66,6 +67,22 @@ def test_fewshot_sample_idxs_randomness():
assert rng_1_sample_2 != rng_3_sample_2


def test_stop_sequences_criteria(tiny_gpt2_tokenizer):
pytest.importorskip('transformers')
eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2)
seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids']
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq1 = [50257] * (len(seq2) - len(seq1)) + seq1
input_ids = torch.tensor([seq1, seq2])
assert not eos_criteria(input_ids, None)

eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2)
seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
input_ids = torch.tensor([seq1, seq2])
assert eos_criteria(input_ids, None)


def test_batch_padding_logic(tiny_gpt2_tokenizer):
continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids']
context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids']
Expand Down
6 changes: 3 additions & 3 deletions tests/metrics/test_nlp_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ def test_in_context_learning_qa_accuracy():

def test_in_context_learning_qa_cot_accuracy():
outputs = [
'chain of thought ### Correct but then some more text', 'Incorrect',
'chain of thought ### the CORREct with weird casing and spacing',
'chain of thought ### Correct but then some more text\n\nanother chain of thought ### Incorrect answer this time',
'Incorrect', 'chain of thought ### the CORREct with weird casing and spacing',
'incorrect chain of thought delimiter ## Correct but wrong delimiter'
]
labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct'], ['correct']]
batch = {'cot_delimiter': ' ### ', 'labels': labels}
batch = {'cot_delimiter': ' ### ', 'labels': labels, 'do_normalization': True, 'stopping_criteria': '\n\n'}
metric = InContextLearningQAAccuracy()
metric.update(outputs, labels, batch)

Expand Down

0 comments on commit 1bc8d0a

Please sign in to comment.