diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index 2d59ff4e13..c2d24e9c11 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -179,7 +179,7 @@ class DataSpec: ... ) Args: - dataloader (Iterable): The dataloader, which can be any iterable that yields batches. + dataloader (Union[Iterable, torch.utils.data.DataLoader]): The dataloader, which can be any iterable that yields batches. num_samples (int, optional): The total number of samples in an epoch, across all ranks. This field is used by the :class:`.Timestamp` (training progress tracker). If not specified, then ``len(dataloader.dataset)`` is @@ -214,7 +214,7 @@ class DataSpec: def __init__( self, - dataloader: Iterable, + dataloader: Union[Iterable, torch.utils.data.DataLoader], num_samples: Optional[int] = None, num_tokens: Optional[int] = None, device_transforms: Optional[Callable[[Batch], Batch]] = None, @@ -222,7 +222,7 @@ def __init__( get_num_samples_in_batch: Optional[Callable[[Batch], int]] = None, get_num_tokens_in_batch: Optional[Callable[[Batch], int]] = None, ) -> None: - self.dataloader = dataloader + self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader self.num_tokens = num_tokens self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms self.split_batch = _default_split_batch if split_batch is None else split_batch diff --git a/composer/datasets/in_context_learning_evaluation.py b/composer/datasets/in_context_learning_evaluation.py new file mode 100644 index 0000000000..26580b9720 --- /dev/null +++ b/composer/datasets/in_context_learning_evaluation.py @@ -0,0 +1,141 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 +# This code is based on the implementation in https://github.com/EleutherAI/lm-evaluation-harness/blob/8c048e266a22a1c85ccbdb0c209ac712e4f39989/lm_eval/base.py#L221-L330 + +from typing import Union + +import torch +import transformers +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset + +from composer.core import DataSpec +from composer.utils import dist +from composer.utils.file_helpers import get_file + +__all__ = ['InContextLearningLMTaskDataset', 'get_lm_task_dataloader'] + + +class InContextLearningLMTaskDataset(Dataset): + """A dataset that construct batches for in-context learning language modeling evaluation + + Args: + dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend + supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches + batch_size (int): Size of a batch used for eval + max_seq_len (int): The sequence length expected by the model + eos_tok_id (int): The special token reserved for padding the ends of batches + destination_path (str): Temporary path to store downloaded datasets + """ + + def __init__( + self, + dataset_uri: str, + tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], + max_seq_len: int, + eos_tok_id: int, + destination_path: str = 'icl_lm_task.json', + ): + get_file(dataset_uri, destination_path, overwrite=True) + dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) + self.encoded_dataset = list( + dataset.map(lambda examples: { + 'continuation': tokenizer(examples['continuation']), + 'context': tokenizer(examples['context']), + })) + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + self.eos_tok_id = eos_tok_id + + def __getitem__(self, index): + return self.encoded_dataset[index] + + def __len__(self): + return len(self.encoded_dataset) + + def collate_fn(self, data): + inputs = [] + continuation_indices = [] + for data_pair in data: + context, continuation = data_pair['context'], data_pair['continuation'] + + context_enc = context['input_ids'] + continuation_enc = continuation['input_ids'] + continuation_span = torch.tensor(range(len(context_enc), len(context_enc) + len(continuation_enc))) + + inp = torch.tensor( + (context_enc + continuation_enc + )[-(self.max_seq_len + 1):], # trim from the left if context + continuation are larger than max_seq_len + dtype=torch.long, + ) + (inp_len,) = inp.shape + + # pad length from seq to padding_length + inp = torch.cat( + [ + inp, # [seq] + torch.LongTensor((self.max_seq_len - inp_len) * [self.eos_tok_id]), + ], + dim=0, + ) + + inputs.append(inp) + continuation_indices.append(continuation_span) + + batch = { + 'input_ids': torch.stack(inputs), + 'continuation_indices': continuation_indices, + 'mode': 'lm_task', + 'labels': torch.stack(inputs), + } + + batch['attention_mask'] = ~(batch['input_ids'] == self.eos_tok_id) + return batch + + def get_num_samples_in_batch(self, batch) -> int: + return batch['input_ids'].shape[0] + + def update_metric(self, metric, batch, output_logits, labels): + metric.update(batch, output_logits, labels) + + +def get_lm_task_dataloader(dataset_uri: str, tokenizer: Union[transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast], batch_size: int, + max_seq_len: int, eos_tok_id: int) -> DataSpec: + """This constructs a dataloader capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below: + + >>> dl = get_lm_task_dataloader(dataset_uri, tokenizer, 2, max_seq_len=2048, eos_tok_id=tokenizer.eos_token_id) + >>> eval_evaluator = Evaluator( + ... label="lambada", + ... dataloader=dl, + ... metric_names=['InContextLearningLMAccuracy'] + ... ) + >>> trainer = Trainer( + ... model=model, + ... train_dataloader=train_dataloader, + ... eval_dataloader=eval_evaluator, + ... optimizers=optimizer, + ... max_duration="1ep", + ... ) + + Args: + dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend + supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. + tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches + batch_size (int): Size of a batch used for eval + max_seq_len (int): The sequence length expected by the model + eos_tok_id (int): The special token reserved for padding the ends of batches + + Returns: + DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. + """ + dataset = InContextLearningLMTaskDataset(dataset_uri, tokenizer, max_seq_len, eos_tok_id) + sampler = dist.get_sampler(dataset, drop_last=False, shuffle=True) + return DataSpec(DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + collate_fn=dataset.collate_fn, + ), + get_num_samples_in_batch=dataset.get_num_samples_in_batch) diff --git a/composer/metrics/__init__.py b/composer/metrics/__init__.py index bff3f7b1c3..4b366e69b0 100644 --- a/composer/metrics/__init__.py +++ b/composer/metrics/__init__.py @@ -5,7 +5,8 @@ from composer.metrics.map import MAP from composer.metrics.metrics import CrossEntropy, Dice, LossMetric, MIoU -from composer.metrics.nlp import BinaryF1Score, HFCrossEntropy, LanguageCrossEntropy, MaskedAccuracy, Perplexity +from composer.metrics.nlp import (BinaryF1Score, HFCrossEntropy, InContextLearningLMAccuracy, LanguageCrossEntropy, + MaskedAccuracy, Perplexity) __all__ = [ 'MAP', @@ -18,4 +19,9 @@ 'HFCrossEntropy', 'LanguageCrossEntropy', 'MaskedAccuracy', + 'InContextLearningLMAccuracy', ] + +METRIC_DEFAULT_CTORS = { + 'InContextLearningLMAccuracy': InContextLearningLMAccuracy, +} diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index fe9271a719..334ccee5cc 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -10,7 +10,10 @@ from composer.loss import soft_cross_entropy -__all__ = ['Perplexity', 'BinaryF1Score', 'HFCrossEntropy', 'LanguageCrossEntropy', 'MaskedAccuracy'] +__all__ = [ + 'Perplexity', 'InContextLearningLMAccuracy', 'BinaryF1Score', 'HFCrossEntropy', 'LanguageCrossEntropy', + 'MaskedAccuracy' +] class MaskedAccuracy(Metric): @@ -231,3 +234,49 @@ def compute(self) -> Tensor: """Returns torch.exp() of the LanguageCrossEntropyLoss.""" avg_loss = super().compute() return torch.exp(avg_loss) + + +class InContextLearningLMAccuracy(Metric): + r"""Computes accuracy for In-context learning (ICL) language modeling (LM) tasks. + + ICL LM tasks consist of some number of example language modeling tasks (referred to as the 'context'), followed by a test task where the model must correctly predict all the tokens + following tokens in some passage (referred to as the 'continuation'). + + For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. Note: it doesn't matter + whether the model correctly predicts the context tokens. + + Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->` + Continuation: `green` + + Adds metric state variables: + correct (float): The number of examples where the model correctly predicted the whole continuation. + total (float): The number of total examples seen. + + Args: + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: ``False``. + """ + + # Make torchmetrics call update only once + full_state_update = False + + def __init__(self, dist_sync_on_step: bool = False): + # state from multiple processes + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, batch: dict, output_logits: torch.Tensor, labels: torch.Tensor): + targets = torch.roll(labels, shifts=-1) + targets[:, -1] = -100 + for batch_idx, cont_idx in enumerate(batch['continuation_indices']): + cont_tok_pred = output_logits[batch_idx].index_select(dim=0, index=cont_idx - 1).argmax(dim=-1) + cont_tok_targ = targets[batch_idx].index_select(dim=0, index=cont_idx - 1) + + self.correct += (cont_tok_pred == cont_tok_targ).all().int() + self.total += torch.tensor(1) + + def compute(self): + assert isinstance(self.correct, Tensor) + assert isinstance(self.total, Tensor) + return self.correct.float() / self.total diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index f7e043d6ac..3043f671e9 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -5,6 +5,7 @@ from __future__ import annotations +import inspect import json import logging import tempfile @@ -16,6 +17,7 @@ import torch from torchmetrics import Metric +from composer.metrics import METRIC_DEFAULT_CTORS, InContextLearningLMAccuracy from composer.models.base import ComposerModel from composer.utils import get_file from composer.utils.import_helpers import MissingConditionalImportError, import_object @@ -86,14 +88,14 @@ def __init__(self, self.use_logits = use_logits - self.train_metrics = None - self.val_metrics = None + self.train_metrics: Optional[Dict] = None + self.val_metrics: Optional[Dict] = None if metrics: self.train_metrics = {metric.__class__.__name__: metric for metric in metrics} self.val_metrics = {metric.__class__.__name__: metric for metric in metrics} - self.labels = None # set in eval_forward() if exists + self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists @staticmethod def hf_from_composer_checkpoint( @@ -257,6 +259,7 @@ def hf_from_composer_checkpoint( def forward(self, batch): if isinstance(batch, dict) or isinstance(batch, UserDict): # Further input validation is left to the huggingface forward call + batch = {k: v for k, v in batch.items() if k in inspect.getfullargspec(self.model.forward).args} output = self.model(**batch) # type: ignore (thirdparty) else: raise ValueError( @@ -273,7 +276,7 @@ def loss(self, outputs, batch): def eval_forward(self, batch, outputs: Optional[Any] = None): output = outputs if outputs else self.forward(batch) - if self.use_logits: + if self.use_logits or batch.get('mode', None) == 'lm_task': self.labels = batch.pop('labels') if self.config.use_return_dict: output = output['logits'] @@ -296,7 +299,11 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]: return metrics if metrics else {} def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None: - metric.update(outputs, self.labels) + if isinstance(metric, InContextLearningLMAccuracy): + assert self.labels is not None + metric.update(batch, outputs, self.labels) + else: + metric.update(outputs, self.labels) def get_metadata(self): model_output = {} @@ -335,3 +342,10 @@ def get_metadata(self): 'content': tokenizer_file_content } return {'model': model_output, 'tokenizer': tokenizer_output} + + def add_eval_metrics(self, evaluator): + evaluator_metrics = {m: METRIC_DEFAULT_CTORS[m]() for m in evaluator.metric_names} + if self.val_metrics is not None: + self.val_metrics.update(evaluator_metrics) + else: + self.val_metrics = evaluator_metrics diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 4153b63700..c7973ec4ac 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1212,7 +1212,6 @@ def __init__( ensure_evaluator(evaluator, default_metric_names=model_metric_names) for evaluator in ensure_tuple(eval_dataloader) ] - # match metric names to model metrics self.state.eval_metrics = { evaluator.label: _filter_metrics(eval_metrics, evaluator.metric_names) for evaluator in evaluators diff --git a/tests/datasets/local_data/lambada_small.jsonl b/tests/datasets/local_data/lambada_small.jsonl new file mode 100644 index 0000000000..8aef66f319 --- /dev/null +++ b/tests/datasets/local_data/lambada_small.jsonl @@ -0,0 +1,16 @@ +{"context": "With Tristran's next step he was standing beside a lake, and the candlelight shone brightly on the water; and then he was walking through the mountains, through lonely crags, where the candlelight was reflected in the eyes of the creatures of the high snows; and then he was walking through the clouds, which, while not entirely substantial, still supported his weight in comfort; and then, holding tightly to his candle, he was underground, and the candlelight glinted back at him from the wet cave walls; now he was in the mountains once more; and then he was on a road through wild forest, and he glimpsed a chariot being pulled by two goats, being driven by a woman in a red dress who looked, for the glimpse he got of her, the way Boadicea was drawn in his history books; and another step and he was in a leafy glen, and he could hear the chuckle of water as it splashed and sang its way into a small brook.\n\nHe took another step, but he was still in the", "continuation": " glen"} +{"context": "Todd replied: No I thought you looked familiar but I can’t recall. The stranger told Todd: I’m Enoch; we met in your dream. Todd looked back again, this time he realized it really was Enoch; Todd stopped on the side of the road, leaned back and tried to see if he was dreaming. When Enoch said: No Todd you’re not", "continuation": " dreaming"} +{"context": "The Librarian thumbed through the bundle of pages, stopping on the final sheet and began reading, “It is our conclusion that much of the work that is currently done in the Library can be out-sourced to contractors, particularly non-skill specific work such as shelving, stacking...”\nLucy gulped and Gillian began to open her mouth to protest again, but the Librarian carried on regardless, his voice becoming louder in order to drown out any potentially dissenting voices, “... blah, blah, blah. It is our recommendation that a downsizing of the non-essential and part-time members of staff would bring instant economy of scale benefits and would allow for the implementation of a new middle management structure.”\n“You mean sacrifice the troops to pay for the generals,” said", "continuation": " Gillian"} +{"context": "He was small, even for a dwarf, and his poor taste in sorcerous robes contrasted awkwardly with D’jebee’s elegant attire; her long, diaphanous gown and his chemical-stained, star-spangled robe clashed almost as much as her vacuous expression alongside his own visage, alive as it was with cunning and a twisted intelligence.\n\nD’jebee sighed with boredom.\n\n‘What is it, my love?’ Poldanyelz oozed with ersatz concern.\n\n‘I’m bored,’ D’jebee complained undiplomatically. ‘No one ever comes here. I never see anyone except you.’\n\nA shuffling from the main arch alerted her to the inaccuracy of her", "continuation": " statement"} +{"context": "But I digresseth, for it was the manner of Corgley’s death which did draw suche greate attention from ye populace, for he was walking as normale down the center streete, when suddenly a gashe of large proportiones did appear upon his neck and arms, and he did vanish in a fountain of bloode, from which his bodie could not be founde. The terrible power of ye Daemon of Gorey’s Hollow has become apparent, and it seems as if his power doth extend beyond the Hollow itselfe.\n\nReinhouer flipped ahead three more pages, to the next entry:\n\nThree more have met their endes at my request at the handes of ye Daemon of Gorey’s", "continuation": " Hollow"} +{"context": "There had been himself, Gillian Dawson, the assistant librarian in charge of acquisitions, Chege Gomez, who did something unspecific in the archive section, and a new employee whose name Art had yet to discover, but who turned out to be called Lucy something and who stacked shelves and ran various errands for the deputy librarian, all nervously seated around in a circle, in the surprisingly comfortable chairs in the Librarian’s spacious office, awaiting the arrival of the Main Man. The delay allowed plenty of time for idle speculation.\n“What do you think he wants?” asked Chege, fidgeting nervously in his chair like a cat on hot bricks.\n“It’s not going to be good news,” said", "continuation": " Gillian"} +{"context": "The girls all stopped what they were doing, which had been preparing-in various states of distress or, in Sophronia's case, delight-to try their own versions of the somersault, and began patting about for handkerchiefs.\n\n\"What did I tell you yesterday? A lady always has her handkerchief on her person. A handkerchief is endlessly useful. Not only is it a communication device, but it can also be dropped as a distraction, scented with various perfumes and noxious gases for discombobulation, used to wipe the forehead of a gentleman, or even bandage a wound, and, of course, you may dab at the eyes or nose if it is still", "continuation": " clean"} +{"context": "\"Oh yes,\" he added, seeing Laurence's surprise, \"the Tsar decided he would not throw a good army after bad, in the end, and perhaps that he did not want to spend the rest of his life as a French prisoner; there is an armistice, and they are negotiating a treaty in Warsaw, the two Emperors, as the best of friends.\" He gave a bark of laughter. \"So you see, they may not bother getting us out; by the end of this month I may be a citoyen myself.\"\n\nHe had only just escaped the final destruction of Prince Hohenlohe's corps, having been ordered to Danzig by courier to secure the fortress against just such a", "continuation": " siege"} +{"context": "I realize that this probably isn’t the best time but I want to hear Nikolai’s explanations, and to be honest I’m exhausted and I could use a little rest, which Nikolai’s tale will give us, before we go any further.”\n“Of course your…” Tares caught himself in time, “Slade. We will listen to what the necromancer has to say if such is your wish.”\n“I do wish. Now Nikolai, you better start talking.” Nikolai could see that whatever tentative trust had been building between Slade and himself had been partially eroded, thanks to", "continuation": " Tares"} +{"context": "The marines could perceive the young people to be friendly (the marines themselves were only 19-year-olds), but, at a distance of about 30 yards, it was as if a warning shot rang out—some of the marines made as if to reach for their weapons. Mills started from the midst of the group—dark spots under his arms marked where the perspiration had soaked into his shirt.\n\"None of us know how we got here,\" Ray explained, taking on a leadership role. He had, after all, saved John Henry's life, so, although he was younger than John Henry, he had a certain serendipitous good fortune going for him which plucked him out to be", "continuation": " leader"} +{"context": "We’re supposed to be staking out this Gordo’s,” he jerked his impossibly cleft chin at the downtown Gordo’s “restaurant” before which they stood, “so I might as well get something here.”\nA man dressed as a ninja and wreathed in flame ran past, trailing smoke.\n“A foul omen!” said Goodspeed.\n“That happens,” said Duke. “A lot.”\n“The spirits in this place are dark indeed.”\n“That’s the Hamwiches,” said", "continuation": " Duke"} +{"context": "And the moon would have told him more, and perhaps she did, but the moon became the glimmer of moonlight on water far below him, and then he became aware of a small spider walking across his face, and of a crick in his neck, and he raised a hand and brushed the spider carefully from his cheek, and the morning sun was in his eyes and the world was gold and green.\n\n\"You were dreaming,\" said a young woman's voice from somewhere above him. The voice was gentle and oddly accented. He could hear leaves rustle in the copper beech tree overhead.\n\n\"Yes,\" he said, to whoever was in the tree, \"I was", "continuation": " dreaming"} +{"context": "My eyes fly open, and I feel they're engulfed in the Ult L-E as I recite a poem as if someone else is controlling me,\n“Though the clouds darken the sun,\nand the rain becomes tainted,\nalways know there will be\na love that will not die.\nThough hope seems a distant memory,\nand human machines walk the land,\nknow no one can destroy\na love that will not die.”\n\n“What are you babbling about?” the Rogue asks.\nI surface from my unconscious state, and I sit up, stand, walk to the PPK, pick up the gun, and aim it for the", "continuation": " Rogue"} +{"context": "I had no doubt had we been inside, and not standing in front of the house when the van with the rest of the guys in the band pulled up, that they would have interrupted something a lot more intimate than us kissing.\n\nSomeone honked the horn and Jet pulled away. He left a little bite to remember him by and now, instead of being angry, those oh-so-pretty eyes with that gold halo just looked sad.\n\n\"Bye, Ayd.\"\n\nI had to hold back tears. I put his shaking fingers to my mouth, like maybe I could hold him there, keep him with me forever, and whispered back, \"Bye,", "continuation": " Jet"} +{"context": "I look over at her, and they are gone.\n\nTHE FIRST NIGHT IT HAPPENED\n\nThe first night it happened, I followed them into the strip mall parking lot. They were all stuffed into a silver-gray Honda-all thousand of them. This was back in November. Charlie had only been dead two months then.\n\nOne minute I was sitting on the side of a country road, taking shots of Smirnoff and counting my tips before I went back to the store to close, the next minute I was in the middle of a science fiction movie, complete with a jet-powered Honda Civic and a thousand translucent zombielike beings who looked like", "continuation": " Charlie"} +{"context": "Oh, it was petty of her, she knew, to fault them for their infinite ambition when she was the one left filing papers-again-but Alice also knew without any doubt at all that each of them would happily stab the other in the back and trample all over the bleeding body to get ahead. Like some other people...\n\nAs she gathered up her papers and retreated to her attic, Alice wondered again how she could have been so wrong about Ella. Of all her friends, she would never have expected her to be the one to let her down-Cassie, in an episode of single-minded selfishness, perhaps; Flora, out of thoughtlessness; but", "continuation": " Ella"} diff --git a/tests/datasets/test_in_context_learning_eval_datasets.py b/tests/datasets/test_in_context_learning_eval_datasets.py new file mode 100644 index 0000000000..442ee538fe --- /dev/null +++ b/tests/datasets/test_in_context_learning_eval_datasets.py @@ -0,0 +1,48 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +from torch.utils.data import DataLoader + +from composer.core import Evaluator +from composer.datasets.in_context_learning_evaluation import get_lm_task_dataloader +from composer.loggers import InMemoryLogger +from composer.models.gpt2 import create_gpt2 +from composer.trainer import Trainer +from tests.common import device, world_size + + +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = tiny_gpt2_tokenizer + dataset_uri = f'{local_data}/{dataset_uri}' + dl = get_lm_task_dataloader(dataset_uri, tokenizer, 2, max_seq_len=2048, eos_tok_id=tokenizer.eos_token_id) + + assert isinstance(dl.dataloader, DataLoader) # pyright + assert 'input_ids' in next(dl.dataloader._get_iterator()) + assert 'attention_mask' in next(dl.dataloader._get_iterator()) + assert 'continuation_indices' in next(dl.dataloader._get_iterator()) + assert 'labels' in next(dl.dataloader._get_iterator()) + assert 'mode' in next(dl.dataloader._get_iterator()) + + +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +@world_size(1, 2) +@device('cpu', 'gpu') +def test_lm_task_evaluation(device, world_size, dataset_uri, tiny_gpt2_tokenizer): + in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + dataset_uri = f'{local_data}/{dataset_uri}' + tokenizer = tiny_gpt2_tokenizer + dl = get_lm_task_dataloader(dataset_uri, tokenizer, 2, max_seq_len=2048, eos_tok_id=tokenizer.eos_token_id) + evaluator = Evaluator(label='lambada', dataloader=dl, metric_names=['InContextLearningLMAccuracy']) + model = create_gpt2(use_pretrained=False, pretrained_model_name='EleutherAI/gpt-neo-125M') + model.add_eval_metrics(evaluator) + trainer = Trainer(model=model, max_duration='1ep', loggers=in_memory_logger) + trainer.eval(eval_dataloader=evaluator) + assert 'metrics/lambada/InContextLearningLMAccuracy' in in_memory_logger.data.keys() + assert in_memory_logger.data['metrics/lambada/InContextLearningLMAccuracy'][0][1].item() == 0 diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 95075097ce..42065c2dc2 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -7,7 +7,8 @@ import torch from torch.nn.functional import cross_entropy -from composer.metrics.nlp import BinaryF1Score, HFCrossEntropy, LanguageCrossEntropy, MaskedAccuracy, Perplexity +from composer.metrics.nlp import (BinaryF1Score, HFCrossEntropy, InContextLearningLMAccuracy, LanguageCrossEntropy, + MaskedAccuracy, Perplexity) @pytest.mark.parametrize('ignore_index', [-100]) @@ -200,3 +201,28 @@ def test_perplexity(): perplexity = perplexity_metric.compute() assert torch.equal(torch.exp(ce), perplexity) + + +def test_in_context_learning_lm_accuracy(tiny_gpt2_tokenizer): + contexts = ['The dog is', 'I love to eat', 'I hate', 'The weather is'] + continuations = [' furry', ' pie', ' long lines', ' snowy'] + pad = tiny_gpt2_tokenizer.pad_token_id + inputs = [ + tiny_gpt2_tokenizer(context)['input_ids'] + tiny_gpt2_tokenizer(continuation)['input_ids'] + for context, continuation in zip(contexts, continuations) + ] + inputs = torch.tensor([input + [pad] * (2048 - len(input)) for input in inputs]) + + cont_idxs = [] + for context, continuation in zip(contexts, continuations): + start = len(tiny_gpt2_tokenizer(context)['input_ids']) + end = start + len(tiny_gpt2_tokenizer(continuation)['input_ids']) + cont_idxs.append(torch.tensor(list(range(start, end)))) + + batch = {'continuation_indices': cont_idxs, 'labels': inputs, 'input_ids': inputs} + logits = torch.nn.functional.one_hot(torch.roll(inputs, shifts=-1), num_classes=pad + 1) + logits[2] = logits[1].clone() # make one of the answers incorrect + metric = InContextLearningLMAccuracy() + metric.update(batch, logits, batch['labels']) + + assert metric.compute() == 0.75