Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference api eval wrapper #494

Merged
merged 75 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
7e8511b
add subset num batches
bmosaicml Jul 14, 2023
059c43e
add subset num batches
bmosaicml Jul 14, 2023
75c455c
remove tiktoken
bmosaicml Jul 16, 2023
f028ad8
remove openai import
bmosaicml Jul 17, 2023
06fa54a
remove bad line
bmosaicml Jul 17, 2023
3a139b2
foo
bmosaicml Jul 17, 2023
56a2c88
add training callback
bmosaicml Aug 1, 2023
e16e86b
modify yamls
bmosaicml Aug 1, 2023
8341a76
implement train
bmosaicml Aug 1, 2023
6ff5cc5
fix indexing to get most recent eval result
bmosaicml Aug 1, 2023
06560d5
finish
bmosaicml Aug 2, 2023
9e07ece
Merge branch 'main' into enable_gauntlet_training
bmosaicml Aug 2, 2023
989f61a
finish
bmosaicml Aug 2, 2023
4c316f1
finish
bmosaicml Aug 2, 2023
7de1b8c
finish
bmosaicml Aug 2, 2023
8a77e88
finish
bmosaicml Aug 2, 2023
61d682a
Merge branch 'main' into enable_gauntlet_training
bmosaicml Aug 9, 2023
6b2116d
foo
bmosaicml Aug 9, 2023
33d3165
foo
bmosaicml Aug 9, 2023
85c2641
working on debugging changeS
bmosaicml Aug 12, 2023
1b3944f
[wip] removing logger dependency from model gauntlet
bmosaicml Aug 14, 2023
309570d
remove logger from eval
bmosaicml Aug 14, 2023
850bc8e
remove logger from eval
bmosaicml Aug 14, 2023
82cee97
remove logger from eval
bmosaicml Aug 14, 2023
fe2c141
Merge branch 'main' into enable_gauntlet_training
bmosaicml Aug 15, 2023
df170de
debug
bmosaicml Aug 16, 2023
c20ee09
debug
bmosaicml Aug 16, 2023
f23a1ad
debug
bmosaicml Aug 16, 2023
7865e83
debug
bmosaicml Aug 16, 2023
96210f0
fix
bmosaicml Aug 16, 2023
669a770
finish?
bmosaicml Aug 16, 2023
b269552
Merge branch 'main' into enable_gauntlet_training
bmosaicml Aug 16, 2023
6819b43
fix bug
bmosaicml Aug 16, 2023
03f80d9
merge main
bmosaicml Aug 16, 2023
a4f981a
fix bug
bmosaicml Aug 16, 2023
d18cef8
Revert "ignore empty outputs"
bmosaicml Aug 16, 2023
c882f01
fix pyright
bmosaicml Aug 16, 2023
9fd4ae1
fix pyright
bmosaicml Aug 16, 2023
f174bb6
update versions
bmosaicml Aug 17, 2023
c54509b
merge
bmosaicml Aug 30, 2023
1b80362
fix
bmosaicml Aug 30, 2023
c89c118
merge updates
bmosaicml Aug 30, 2023
e286dbd
remove info from yamls
bmosaicml Aug 30, 2023
79ca900
remove load in 8bit
bmosaicml Aug 30, 2023
f613b3b
Merge branch 'main' into add_openai_wrapper
bmosaicml Sep 13, 2023
2157759
address comments
bmosaicml Sep 13, 2023
4bc1534
address comments
bmosaicml Sep 14, 2023
b1bc1e3
address comments
bmosaicml Sep 14, 2023
984cdbd
add monkeypatch
bmosaicml Sep 14, 2023
c7daee0
Merge branch 'main' into add_openai_wrapper
bmosaicml Sep 14, 2023
3108ea5
add back in bsz
bmosaicml Sep 14, 2023
b5e5e2f
add back in bsz
bmosaicml Sep 14, 2023
0cb29ab
add openai reqs
bmosaicml Sep 14, 2023
57b684f
remove branch
bmosaicml Sep 14, 2023
29da49d
fix conditional import
bmosaicml Sep 15, 2023
1a9a77e
Update llmfoundry/models/inference_api_wrapper/interface.py
bmosaicml Sep 15, 2023
2bd933a
Update llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
bmosaicml Sep 15, 2023
77f5279
Update llmfoundry/models/inference_api_wrapper/interface.py
bmosaicml Sep 15, 2023
359c893
fix comments
bmosaicml Sep 15, 2023
4a41efd
fix comments
bmosaicml Sep 15, 2023
618ec6f
Merge branch 'main' into add_openai_wrapper
bmosaicml Sep 15, 2023
8f155fa
fix comments
bmosaicml Sep 15, 2023
3d4d0da
Merge branch 'add_openai_wrapper' of github.com:mosaicml/llm-foundry …
bmosaicml Sep 15, 2023
6cff092
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
e7df76f
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
dc7cef5
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
02187d0
Merge branch 'main' into add_openai_wrapper
dakinggg Sep 16, 2023
7638244
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
e11d316
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
735f719
pyright ignore
bmosaicml Sep 16, 2023
f2b02dd
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
9801af9
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
e809971
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
2bf35e5
Update tests/test_inference_api_eval_wrapper.py
dakinggg Sep 16, 2023
eaaebf2
Merge branch 'main' into add_openai_wrapper
bmosaicml Sep 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llmfoundry/models/inference_api_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAITokenizerWrapper)

__all__ = [
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
'OpenAITokenizerWrapper',
'InferenceAPIEvalWrapper',
]
110 changes: 110 additions & 0 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional

import torch
from composer.core.types import Batch
from composer.metrics import InContextLearningMetric
from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.models import ComposerModel
from torchmetrics import Metric
from transformers import AutoTokenizer


class InferenceAPIEvalWrapper(ComposerModel):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = tokenizer
self.labels = None
# set up training and eval metrics
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
self.eval_metrics = {
metric.__class__.__name__: metric for metric in eval_metrics
}
super().__init__()

def get_metrics(self, is_train: bool = False):
if is_train:
raise NotImplementedError(
'You cannot use inference wrappers for training')
else:
metrics = self.eval_metrics

return metrics if metrics else {}

def get_next_token_logit_tensor(self,
prompt: str) -> Optional[torch.Tensor]:
raise NotImplementedError

def rebatch(self, batch: Batch):
# default is a no-op, but Chat API modifies these
return batch

def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
# be returned from eval_forward
output_logits_batch = []
for tokens, cont_idxs in zip(batch['input_ids'],
batch['continuation_indices']):

seqlen = tokens.shape[0]
tokens = tokens.tolist()
cont_idxs = cont_idxs.tolist()
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]]),
num_classes=self.tokenizer.vocab_size)
for i in range(len(expected_cont_tokens)):
# decode one token at a time
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] +
expected_cont_tokens[0:i])
next_logit_tensor = self.get_next_token_logit_tensor(prompt)
if next_logit_tensor is None:
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
continue
output_logits = torch.cat(
[output_logits,
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
batch = self.rebatch(batch)
self.labels = batch.pop('labels')
self.labels[:, :-1] = self.labels[:, 1:].clone()
self.labels[:, -1] = -100
if isinstance(metric, InContextLearningMetric) and batch.get(
'mode', None) == 'icl_task':
assert self.labels is not None
metric.update(batch, outputs, self.labels)
else:
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task'
)

def forward(self):
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
"Inference API wrapper doesn't support forward")

def loss(self):
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("Inference API wrapper doesn't support loss")
Loading