Skip to content

Commit

Permalink
Feature/lambada evaluator (mosaicml#1845)
Browse files Browse the repository at this point in the history
* new branch

* unittest multi gpu

* reimplement lambada

* update

* pyright

* change naming of file

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
bmosaicml and dakinggg authored Jan 9, 2023
1 parent 5a51b79 commit 62bf8ba
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 12 deletions.
6 changes: 3 additions & 3 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class DataSpec:
... )
Args:
dataloader (Iterable): The dataloader, which can be any iterable that yields batches.
dataloader (Union[Iterable, torch.utils.data.DataLoader]): The dataloader, which can be any iterable that yields batches.
num_samples (int, optional): The total number of samples in an epoch, across all ranks. This field is used by
the :class:`.Timestamp` (training progress tracker). If not specified, then ``len(dataloader.dataset)`` is
Expand Down Expand Up @@ -214,15 +214,15 @@ class DataSpec:

def __init__(
self,
dataloader: Iterable,
dataloader: Union[Iterable, torch.utils.data.DataLoader],
num_samples: Optional[int] = None,
num_tokens: Optional[int] = None,
device_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, int], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], int]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], int]] = None,
) -> None:
self.dataloader = dataloader
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
self.num_tokens = num_tokens
self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms
self.split_batch = _default_split_batch if split_batch is None else split_batch
Expand Down
141 changes: 141 additions & 0 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
# This code is based on the implementation in https://github.com/EleutherAI/lm-evaluation-harness/blob/8c048e266a22a1c85ccbdb0c209ac712e4f39989/lm_eval/base.py#L221-L330

from typing import Union

import torch
import transformers
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset

from composer.core import DataSpec
from composer.utils import dist
from composer.utils.file_helpers import get_file

__all__ = ['InContextLearningLMTaskDataset', 'get_lm_task_dataloader']


class InContextLearningLMTaskDataset(Dataset):
"""A dataset that construct batches for in-context learning language modeling evaluation
Args:
dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend
supported by :meth:`composer.utils.maybe_create_object_store_from_uri`.
tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches
batch_size (int): Size of a batch used for eval
max_seq_len (int): The sequence length expected by the model
eos_tok_id (int): The special token reserved for padding the ends of batches
destination_path (str): Temporary path to store downloaded datasets
"""

def __init__(
self,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
max_seq_len: int,
eos_tok_id: int,
destination_path: str = 'icl_lm_task.json',
):
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
self.encoded_dataset = list(
dataset.map(lambda examples: {
'continuation': tokenizer(examples['continuation']),
'context': tokenizer(examples['context']),
}))
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.eos_tok_id = eos_tok_id

def __getitem__(self, index):
return self.encoded_dataset[index]

def __len__(self):
return len(self.encoded_dataset)

def collate_fn(self, data):
inputs = []
continuation_indices = []
for data_pair in data:
context, continuation = data_pair['context'], data_pair['continuation']

context_enc = context['input_ids']
continuation_enc = continuation['input_ids']
continuation_span = torch.tensor(range(len(context_enc), len(context_enc) + len(continuation_enc)))

inp = torch.tensor(
(context_enc + continuation_enc
)[-(self.max_seq_len + 1):], # trim from the left if context + continuation are larger than max_seq_len
dtype=torch.long,
)
(inp_len,) = inp.shape

# pad length from seq to padding_length
inp = torch.cat(
[
inp, # [seq]
torch.LongTensor((self.max_seq_len - inp_len) * [self.eos_tok_id]),
],
dim=0,
)

inputs.append(inp)
continuation_indices.append(continuation_span)

batch = {
'input_ids': torch.stack(inputs),
'continuation_indices': continuation_indices,
'mode': 'lm_task',
'labels': torch.stack(inputs),
}

batch['attention_mask'] = ~(batch['input_ids'] == self.eos_tok_id)
return batch

def get_num_samples_in_batch(self, batch) -> int:
return batch['input_ids'].shape[0]

def update_metric(self, metric, batch, output_logits, labels):
metric.update(batch, output_logits, labels)


def get_lm_task_dataloader(dataset_uri: str, tokenizer: Union[transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast], batch_size: int,
max_seq_len: int, eos_tok_id: int) -> DataSpec:
"""This constructs a dataloader capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below:
>>> dl = get_lm_task_dataloader(dataset_uri, tokenizer, 2, max_seq_len=2048, eos_tok_id=tokenizer.eos_token_id)
>>> eval_evaluator = Evaluator(
... label="lambada",
... dataloader=dl,
... metric_names=['InContextLearningLMAccuracy']
... )
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_evaluator,
... optimizers=optimizer,
... max_duration="1ep",
... )
Args:
dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend
supported by :meth:`composer.utils.maybe_create_object_store_from_uri`.
tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches
batch_size (int): Size of a batch used for eval
max_seq_len (int): The sequence length expected by the model
eos_tok_id (int): The special token reserved for padding the ends of batches
Returns:
DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided.
"""
dataset = InContextLearningLMTaskDataset(dataset_uri, tokenizer, max_seq_len, eos_tok_id)
sampler = dist.get_sampler(dataset, drop_last=False, shuffle=True)
return DataSpec(DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=dataset.collate_fn,
),
get_num_samples_in_batch=dataset.get_num_samples_in_batch)
8 changes: 7 additions & 1 deletion composer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from composer.metrics.map import MAP
from composer.metrics.metrics import CrossEntropy, Dice, LossMetric, MIoU
from composer.metrics.nlp import BinaryF1Score, HFCrossEntropy, LanguageCrossEntropy, MaskedAccuracy, Perplexity
from composer.metrics.nlp import (BinaryF1Score, HFCrossEntropy, InContextLearningLMAccuracy, LanguageCrossEntropy,
MaskedAccuracy, Perplexity)

__all__ = [
'MAP',
Expand All @@ -18,4 +19,9 @@
'HFCrossEntropy',
'LanguageCrossEntropy',
'MaskedAccuracy',
'InContextLearningLMAccuracy',
]

METRIC_DEFAULT_CTORS = {
'InContextLearningLMAccuracy': InContextLearningLMAccuracy,
}
51 changes: 50 additions & 1 deletion composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

from composer.loss import soft_cross_entropy

__all__ = ['Perplexity', 'BinaryF1Score', 'HFCrossEntropy', 'LanguageCrossEntropy', 'MaskedAccuracy']
__all__ = [
'Perplexity', 'InContextLearningLMAccuracy', 'BinaryF1Score', 'HFCrossEntropy', 'LanguageCrossEntropy',
'MaskedAccuracy'
]


class MaskedAccuracy(Metric):
Expand Down Expand Up @@ -231,3 +234,49 @@ def compute(self) -> Tensor:
"""Returns torch.exp() of the LanguageCrossEntropyLoss."""
avg_loss = super().compute()
return torch.exp(avg_loss)


class InContextLearningLMAccuracy(Metric):
r"""Computes accuracy for In-context learning (ICL) language modeling (LM) tasks.
ICL LM tasks consist of some number of example language modeling tasks (referred to as the 'context'), followed by a test task where the model must correctly predict all the tokens
following tokens in some passage (referred to as the 'continuation').
For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. Note: it doesn't matter
whether the model correctly predicts the context tokens.
Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->`
Continuation: `green`
Adds metric state variables:
correct (float): The number of examples where the model correctly predicted the whole continuation.
total (float): The number of total examples seen.
Args:
dist_sync_on_step (bool, optional): Synchronize metric state across processes at
each forward() before returning the value at the step. Default: ``False``.
"""

# Make torchmetrics call update only once
full_state_update = False

def __init__(self, dist_sync_on_step: bool = False):
# state from multiple processes
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum')

def update(self, batch: dict, output_logits: torch.Tensor, labels: torch.Tensor):
targets = torch.roll(labels, shifts=-1)
targets[:, -1] = -100
for batch_idx, cont_idx in enumerate(batch['continuation_indices']):
cont_tok_pred = output_logits[batch_idx].index_select(dim=0, index=cont_idx - 1).argmax(dim=-1)
cont_tok_targ = targets[batch_idx].index_select(dim=0, index=cont_idx - 1)

self.correct += (cont_tok_pred == cont_tok_targ).all().int()
self.total += torch.tensor(1)

def compute(self):
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
return self.correct.float() / self.total
24 changes: 19 additions & 5 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import inspect
import json
import logging
import tempfile
Expand All @@ -16,6 +17,7 @@
import torch
from torchmetrics import Metric

from composer.metrics import METRIC_DEFAULT_CTORS, InContextLearningLMAccuracy
from composer.models.base import ComposerModel
from composer.utils import get_file
from composer.utils.import_helpers import MissingConditionalImportError, import_object
Expand Down Expand Up @@ -86,14 +88,14 @@ def __init__(self,

self.use_logits = use_logits

self.train_metrics = None
self.val_metrics = None
self.train_metrics: Optional[Dict] = None
self.val_metrics: Optional[Dict] = None

if metrics:
self.train_metrics = {metric.__class__.__name__: metric for metric in metrics}
self.val_metrics = {metric.__class__.__name__: metric for metric in metrics}

self.labels = None # set in eval_forward() if exists
self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists

@staticmethod
def hf_from_composer_checkpoint(
Expand Down Expand Up @@ -257,6 +259,7 @@ def hf_from_composer_checkpoint(
def forward(self, batch):
if isinstance(batch, dict) or isinstance(batch, UserDict):
# Further input validation is left to the huggingface forward call
batch = {k: v for k, v in batch.items() if k in inspect.getfullargspec(self.model.forward).args}
output = self.model(**batch) # type: ignore (thirdparty)
else:
raise ValueError(
Expand All @@ -273,7 +276,7 @@ def loss(self, outputs, batch):

def eval_forward(self, batch, outputs: Optional[Any] = None):
output = outputs if outputs else self.forward(batch)
if self.use_logits:
if self.use_logits or batch.get('mode', None) == 'lm_task':
self.labels = batch.pop('labels')
if self.config.use_return_dict:
output = output['logits']
Expand All @@ -296,7 +299,11 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]:
return metrics if metrics else {}

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
metric.update(outputs, self.labels)
if isinstance(metric, InContextLearningLMAccuracy):
assert self.labels is not None
metric.update(batch, outputs, self.labels)
else:
metric.update(outputs, self.labels)

def get_metadata(self):
model_output = {}
Expand Down Expand Up @@ -335,3 +342,10 @@ def get_metadata(self):
'content': tokenizer_file_content
}
return {'model': model_output, 'tokenizer': tokenizer_output}

def add_eval_metrics(self, evaluator):
evaluator_metrics = {m: METRIC_DEFAULT_CTORS[m]() for m in evaluator.metric_names}
if self.val_metrics is not None:
self.val_metrics.update(evaluator_metrics)
else:
self.val_metrics = evaluator_metrics
1 change: 0 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,6 @@ def __init__(
ensure_evaluator(evaluator, default_metric_names=model_metric_names)
for evaluator in ensure_tuple(eval_dataloader)
]

# match metric names to model metrics
self.state.eval_metrics = {
evaluator.label: _filter_metrics(eval_metrics, evaluator.metric_names) for evaluator in evaluators
Expand Down
Loading

0 comments on commit 62bf8ba

Please sign in to comment.