diff --git a/lm_eval/base.py b/lm_eval/base.py index 1ea798159f..63eb38efc8 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -1,5 +1,7 @@ import abc -from typing import Iterable +from typing import Iterable, Optional + +import promptsource import numpy as np import random import re @@ -12,6 +14,7 @@ import torch import torch.nn.functional as F +from lm_eval import metrics from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval import utils from abc import abstractmethod @@ -24,17 +27,17 @@ def __init__(self): @abstractmethod def loglikelihood(self, requests): """Compute log-likelihood of generating a continuation from a context. - Downstream tasks should attempt to use loglikelihood instead of other + Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. :param requests: list A list of pairs (context, continuation) context: str - Context string. Implementations of LM must be able to handle an + Context string. Implementations of LM must be able to handle an empty context string. continuation: str - The continuation over which log likelihood will be calculated. If - there is a word boundary, the space should be in the continuation. + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. For example, context="hello" continuation=" world" is correct. :return: list A list of pairs (logprob, isgreedy) @@ -97,7 +100,7 @@ def greedy_until(self, requests): context: str Context string until: [str] - The string sequences to generate until. These string sequences + The string sequences to generate until. These string sequences may each span across multiple tokens, or may be part of one token. :return: list A list of strings continuation @@ -118,7 +121,6 @@ def set_cache_hook(self, cache_hook): class BaseLM(LM): - @property @abstractmethod def eot_token_id(self): @@ -145,13 +147,16 @@ def device(self): pass @abstractmethod - def tok_encode(self, string: str): pass - + def tok_encode(self, string: str): + pass + @abstractmethod - def tok_decode(self, tokens: Iterable[int]): pass + def tok_decode(self, tokens: Iterable[int]): + pass @abstractmethod - def _model_generate(self, context, max_length, eos_token_id): pass + def _model_generate(self, context, max_length, eos_token_id): + pass @abstractmethod def _model_call(self, inps): @@ -187,23 +192,30 @@ def loglikelihood_rolling(self, requests): # TODO: automatic batch size detection for vectorization loglikelihoods = [] - for string, in tqdm(requests): - rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows( - token_list=self.tok_encode(string), - prefix_token=self.eot_token_id, - max_seq_len=self.max_length, - context_len=1, - ))) + for (string,) in tqdm(requests): + rolling_token_windows = list( + map( + utils.make_disjoint_window, + utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.eot_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) rolling_token_windows = [(None,) + x for x in rolling_token_windows] # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for # that - string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True) - + string_nll = self._loglikelihood_tokens( + rolling_token_windows, disable_tqdm=True + ) + # discard is_greedy string_nll = [x[0] for x in string_nll] - + string_nll = sum(string_nll) loglikelihoods.append(string_nll) @@ -223,10 +235,12 @@ def _collate(x): toks = x[1] + x[2] return -len(toks), tuple(toks) - + # TODO: automatic (variable) batch size detection for vectorization reord = utils.Reorderer(requests, _collate) - for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): + for chunk in utils.chunks( + tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size + ): inps = [] cont_toks_list = [] inplens = [] @@ -252,44 +266,60 @@ def _collate(x): # when too long to fit in context, truncate from the left inp = torch.tensor( - (context_enc + continuation_enc)[-(self.max_length+1):][:-1], - dtype=torch.long + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, ).to(self.device) - inplen, = inp.shape + (inplen,) = inp.shape cont = continuation_enc # since in _collate we make sure length is descending, the longest is always the first one. - padding_length = padding_length if padding_length is not None else inplen + padding_length = ( + padding_length if padding_length is not None else inplen + ) # pad length from seq to padding_length - inp = torch.cat([ - inp, # [seq] - torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] - ], dim=0) + inp = torch.cat( + [ + inp, # [seq] + torch.zeros(padding_length - inplen, dtype=torch.long).to( + inp.device + ), # [padding_length - seq] + ], + dim=0, + ) inps.append(inp.unsqueeze(0)) # [1, padding_length] cont_toks_list.append(cont) inplens.append(inplen) batched_inps = torch.cat(inps, dim=0) # [batch, padding_length - multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab] + multi_logits = F.log_softmax( + self._model_call(batched_inps), dim=-1 + ).cpu() # [batch, padding_length, vocab] - for (cache_key, _, _), logits, inp, inplen, cont_toks \ - in zip(chunk, multi_logits, inps, inplens, cont_toks_list): + for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( + chunk, multi_logits, inps, inplens, cont_toks_list + ): # Slice to original seq length contlen = len(cont_toks) - logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab] + logits = logits[inplen - contlen : inplen].unsqueeze( + 0 + ) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) - cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] + cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze( + 0 + ) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() - logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( + -1 + ) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) @@ -301,9 +331,9 @@ def _collate(x): res.append(answer) return reord.get_original(res) - + def greedy_until(self, requests): - # TODO: implement fully general `until` that handles untils that are + # TODO: implement fully general `until` that handles untils that are # multiple tokens or that span multiple tokens correctly # TODO: extract to TokenizedLM? @@ -312,29 +342,46 @@ def greedy_until(self, requests): def _collate(x): toks = self.tok_encode(x[0]) return len(toks), x[0] - + reord = utils.Reorderer(requests, _collate) - for context, until in tqdm(reord.get_reordered()): - if isinstance(until, str): - until = [until] + for context, request_args in tqdm(reord.get_reordered()): + stopping_criteria = request_args["stopping_criteria"] + max_generation_length = request_args["max_generation_length"] - primary_until, = self.tok_encode(until[0]) - - context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device) + assert isinstance(stopping_criteria, str) or stopping_criteria is None + assert ( + isinstance(max_generation_length, int) or max_generation_length is None + ) - cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until) + until = [stopping_criteria] + primary_until = self.tok_encode(until[0]) + context_enc = torch.tensor( + [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] + ).to(self.device) - s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:]) + if max_generation_length is None: + max_length = context_enc.shape[1] + self.max_gen_toks + else: + max_length = min( + max_generation_length, context_enc.shape[1] + self.max_gen_toks + ) + cont = self._model_generate( + context_enc, + max_length, + torch.tensor(primary_until), + ) + + s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) for term in until: s = s.split(term)[0] - + # partial caching self.cache_hook.add_partial("greedy_until", (context, until), s) - + res.append(s) - + return reord.get_original(res) @@ -383,7 +430,7 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None): self._fewshot_docs = None def download(self, data_dir=None, cache_dir=None, download_mode=None): - """ Downloads and returns the task dataset. + """Downloads and returns the task dataset. Override this method to download the dataset from a custom API. :param data_dir: str @@ -412,7 +459,7 @@ def download(self, data_dir=None, cache_dir=None, download_mode=None): name=self.DATASET_NAME, data_dir=data_dir, cache_dir=cache_dir, - download_mode=download_mode + download_mode=download_mode, ) @abstractmethod @@ -478,22 +525,22 @@ def doc_to_target(self, doc): @abstractmethod def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of + """Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LM. :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param ctx: str - The context string, generated by fewshot_context. This includes the natural + The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question - part of the document for `doc`. + part of the document for `doc`. """ pass @abstractmethod def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -507,7 +554,7 @@ def process_results(self, doc, results): def aggregation(self): """ :returns: {str: [metric_score] -> float} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metric scores """ pass @@ -516,22 +563,26 @@ def aggregation(self): def higher_is_better(self): """ :returns: {str: bool} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ pass def fewshot_description(self): import warnings + warnings.warn( "`fewshot_description` will be removed in futures versions. Pass " "any custom descriptions to the `evaluate` function instead.", - DeprecationWarning) + DeprecationWarning, + ) return "" @utils.positional_deprecated - def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): - """ Returns a fewshot context string that is made up of a prepended description + def fewshot_context( + self, doc, num_fewshot, provide_description=None, rnd=None, description=None + ): + """Returns a fewshot context string that is made up of a prepended description (if provided), the `num_fewshot` number of examples, and an appended prompt example. :param doc: str @@ -548,7 +599,9 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, :returns: str The fewshot context. """ - assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" + assert ( + rnd is not None + ), "A `random.Random` generator argument must be provided to `rnd`" assert not provide_description, ( "The `provide_description` arg will be removed in future versions. To prepend " "a custom description to the context, supply the corresponding string via the " @@ -556,7 +609,9 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, ) if provide_description is not None: # nudge people to not specify it at all - print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + print( + "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" + ) description = description + "\n\n" if description else "" @@ -569,31 +624,229 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, else: if self._fewshot_docs is None: self._fewshot_docs = list( - self.validation_docs() if self.has_validation_docs() else self.test_docs() + self.validation_docs() + if self.has_validation_docs() + else self.test_docs() ) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) - # get rid of the doc that's the one we're evaluating, if it's in the fewshot fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] - labeled_examples = "\n\n".join( - [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex] - ) + "\n\n" + labeled_examples = ( + "\n\n".join( + [ + self.doc_to_text(doc) + self.doc_to_target(doc) + for doc in fewshotex + ] + ) + + "\n\n" + ) example = self.doc_to_text(doc) return description + labeled_examples + example -class MultipleChoiceTask(Task): +class PromptSourceTask(Task): + """These are the metrics from promptsource that we have + added default behavior for. If you want to add default behavior for a new metric, + update the functions below. If you want to use one of the following metrics, + *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`. + """ + + CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"]) + + def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None): + super().__init__(data_dir, cache_dir, download_mode) + self.prompt = prompt + + def stopping_criteria(self) -> Optional[str]: + """Denote where the generation should end. + + For example, for coqa, this is '\nQ:' and for drop '.'. + + By default, its None, meaning to generate up to max or EOT, whichever comes first. + """ + return None + + def max_generation_length(self) -> Optional[int]: + """Denote where the max length of the generation if it is obvious from the task.""" + return None + + def invalid_doc_for_prompt(self, doc) -> bool: + """Some prompts may not work for some documents.""" + if ( + # generate_paraphrase for mrpc + # This generation prompt assumes a positive example. We filter out the negative examples. + # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7 + # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88 + ( + self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5" + or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f" + ) + and doc["label"] == 0 + ): + return True + return False + + def doc_to_target(self, doc) -> str: + """NOTE: In the future, this may return Union[str, List[str]].""" + _, target = self.prompt.apply(doc) + return f" {target}" + + def doc_to_text(self, doc) -> str: + text, _ = self.prompt.apply(doc) + return text + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + _requests = [] + answer_choices_list = self.prompt.get_answer_choices_list(doc) + + if answer_choices_list: + # If answer_choices_list, then this is a ranked choice prompt. + for answer_choice in answer_choices_list: + ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") + _requests.append(ll_answer_choice) + else: + # If not, then this is a generation prompt. + # NOTE: In the future, target will be a list of strings. + request_args = { + "stopping_criteria": self.stopping_criteria(), + "max_generation_length": self.max_generation_length(), + } + cont_request = rf.greedy_until(ctx, request_args) + _requests.append(cont_request) + + return _requests + + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + target = self.doc_to_target(doc).strip() + answer_choices_list = self.prompt.get_answer_choices_list(doc) + if answer_choices_list: + # If answer_choices_list, then this is a ranked choice prompt. + # NOTE: In the future, target will be a list of strings. + # For now, we can assume there will be only 1 target, but its possible + # that this not the case so we should check for that. + + pred = answer_choices_list[np.argmax(results)] + out = {} + + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "Accuracy": + out["acc"] = pred == target + # TODO: Add metrics here. + return out + else: + # If not, then this is a generation prompt. + # NOTE: In the future, target will be a list of strings. + pred = results[0].strip() + out = {} + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "BLEU": + out["bleu"] = (target, pred) + if metric == "ROUGE": + # TODO: This computes all rouge sub-metrics. Find a generic + # way to handle user specified rouge sub-metrics to avoid extra + # compute. + rouge_scores = metrics.rouge(target, pred) + # Flatten rouge score dict. + rouge_scores = utils.flatten(rouge_scores) + # Merge all the rouge-type scores into the `out` dict. + out = {**out, **rouge_scores} + print(out) + return out + + def higher_is_better(self): + out = {} + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "Accuracy": + out["acc"] = True + if metric == "BLEU": + out["bleu"] = True + if metric == "ROUGE": + # TODO: Find a generic way to handle user specified rouge metrics. + out["rouge1_precision"] = True + out["rouge1_recall"] = True + out["rouge1_fmeasure"] = True + + out["rouge2_precision"] = True + out["rouge2_recall"] = True + out["rouge2_fmeasure"] = True + + out["rougeL_precision"] = True + out["rougeL_recall"] = True + out["rougeL_fmeasure"] = True + + out["rougeLsum_precision"] = True + out["rougeLsum_recall"] = True + out["rougeLsum_fmeasure"] = True + return out + def aggregation(self): + out = {} + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "Accuracy": + out["acc"] = mean + if metric == "BLEU": + out["bleu"] = metrics.bleu + if metric == "ROUGE": + # TODO: Find a generic way to handle user specified rouge metrics. + out["rouge1_precision"] = mean + out["rouge1_recall"] = mean + out["rouge1_fmeasure"] = mean + + out["rouge2_precision"] = mean + out["rouge2_recall"] = mean + out["rouge2_fmeasure"] = mean + + out["rougeL_precision"] = mean + out["rougeL_recall"] = mean + out["rougeL_fmeasure"] = mean + + out["rougeLsum_precision"] = mean + out["rougeLsum_recall"] = mean + out["rougeLsum_fmeasure"] = mean + return out + + +class MultipleChoiceTask(Task): def doc_to_target(self, doc): - return " " + doc['choices'][doc['gold']] + return " " + doc["choices"][doc["gold"]] def construct_requests(self, doc, ctx): lls = [ - rf.loglikelihood(ctx, " {}".format(choice))[0] - for choice in doc['choices'] + rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"] ] return lls @@ -601,21 +854,21 @@ def construct_requests(self, doc, ctx): def process_results(self, doc, results): gold = doc["gold"] - acc = 1. if np.argmax(results) == gold else 0. + acc = 1.0 if np.argmax(results) == gold else 0.0 completion_len = np.array([float(len(i)) for i in doc["choices"]]) - acc_norm = 1. if np.argmax(results / completion_len) == gold else 0. + acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 return { "acc": acc, "acc_norm": acc_norm, } - + def higher_is_better(self): return { "acc": True, "acc_norm": True, } - + def aggregation(self): return { "acc": mean, @@ -624,7 +877,6 @@ def aggregation(self): class PerplexityTask(Task, abc.ABC): - def has_training_docs(self): return False @@ -632,9 +884,15 @@ def fewshot_examples(self, k, rnd): assert k == 0 return [] - def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): - assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks." - assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`." + def fewshot_context( + self, doc, num_fewshot, provide_description=None, rnd=None, description=None + ): + assert ( + num_fewshot == 0 + ), "The number of fewshot examples must be 0 for perplexity tasks." + assert ( + rnd is not None + ), "A `random.Random` generator argument must be provided to `rnd`." assert not provide_description, ( "The `provide_description` arg will be removed in future versions. To prepend " "a custom description to the context, supply the corresponding string via the " @@ -642,7 +900,9 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, ) if provide_description is not None: # nudge people to not specify it at all - print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + print( + "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" + ) return "" @@ -665,7 +925,7 @@ def construct_requests(self, doc, ctx): return req def process_results(self, doc, results): - loglikelihood, = results + (loglikelihood,) = results words = self.count_words(doc) bytes_ = self.count_bytes(doc) return { @@ -687,23 +947,23 @@ def count_bytes(cls, doc): @classmethod def count_words(cls, doc): - """ Downstream tasks with custom word boundaries should override this! """ + """Downstream tasks with custom word boundaries should override this!""" return len(re.split(r"\s+", doc)) def hash_args(attr, args): dat = json.dumps([attr] + list(args)) - return hashlib.sha256(dat.encode('utf-8')).hexdigest() + return hashlib.sha256(dat.encode("utf-8")).hexdigest() class CacheHook: def __init__(self, cachinglm): - if cachinglm is None: + if cachinglm is None: self.dbdict = None return self.dbdict = cachinglm.dbdict - + def add_partial(self, attr, req, res): if self.dbdict is None: return @@ -733,7 +993,7 @@ def __getattr__(self, attr): def fn(requests): res = [] remaining_reqs = [] - + # figure out which ones are cached and which ones are new for req in requests: hsh = hash_args(attr, req) @@ -746,7 +1006,7 @@ def fn(requests): else: res.append(None) remaining_reqs.append(req) - + # actually run the LM on the requests that do not have cached results rem_res = getattr(self.lm, attr)(remaining_reqs) @@ -764,41 +1024,48 @@ def fn(requests): self.dbdict.commit() return res + return fn - + def get_cache_hook(self): return CacheHook(self) REQUEST_RETURN_LENGTHS = { - 'loglikelihood': 2, - 'greedy_until': None, - 'loglikelihood_rolling': None, + "loglikelihood": 2, + "greedy_until": None, + "loglikelihood_rolling": None, } class Request: def __init__(self, request_type, args, index=None): if request_type not in REQUEST_RETURN_LENGTHS.keys(): - raise NotImplementedError('The request type {} is not implemented!'.format(request_type)) + raise NotImplementedError( + "The request type {} is not implemented!".format(request_type) + ) self.request_type = request_type self.args = args self.index = index - + def __iter__(self): if REQUEST_RETURN_LENGTHS[self.request_type] is None: - raise IndexError('This request type does not return multiple arguments!') + raise IndexError("This request type does not return multiple arguments!") for i in range(REQUEST_RETURN_LENGTHS[self.request_type]): yield Request(self.request_type, self.args, i) - + def __getitem__(self, i): if REQUEST_RETURN_LENGTHS[self.request_type] is None: - raise IndexError('This request type does not return multiple arguments!') + raise IndexError("This request type does not return multiple arguments!") return Request(self.request_type, self.args, i) - + def __eq__(self, other): - return self.request_type == other.request_type and self.args == other.args and self.index == other.index + return ( + self.request_type == other.request_type + and self.args == other.args + and self.index == other.index + ) def __repr__(self): return f"Req_{self.request_type}{self.args}[{self.index}]\n" @@ -808,6 +1075,7 @@ class RequestFactory: def __getattr__(self, attr): def fn(*args): return Request(attr, args) + return fn diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 22a8aca26e..efeb5d2178 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -2,25 +2,38 @@ import itertools import pathlib import random + import lm_eval.metrics import lm_eval.models import lm_eval.tasks import lm_eval.base +import promptsource import numpy as np + +from promptsource.templates import DatasetTemplates from lm_eval.utils import positional_deprecated, run_task_tests @positional_deprecated -def simple_evaluate(model, model_args=None, tasks=[], - num_fewshot=0, batch_size=None, device=None, - no_cache=False, limit=None, bootstrap_iters=100000, - description_dict=None, check_integrity=False): +def simple_evaluate( + model, + model_args=None, + tasks=[], + num_fewshot=0, + batch_size=None, + device=None, + no_cache=False, + limit=None, + bootstrap_iters=100000, + description_dict=None, + check_integrity=False, +): """Instantiate and evaluate a model on a list of tasks. :param model: Union[str, LM] Name of model or LM object, see lm_eval.models.get_model :param model_args: Optional[str] - String arguments for each model class, see LM.create_from_arg_string. + String arguments for each model class, see LM.create_from_arg_string. Ignored if `model` argument is a LM object. :param tasks: list[Union[str, Task]] List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. @@ -37,7 +50,7 @@ def simple_evaluate(model, model_args=None, tasks=[], :param bootstrap_iters: Number of iterations for bootstrap statistics :param description_dict: dict[str, str] - Dictionary of custom task descriptions of the form: `task_name: description` + Dictionary of custom task descriptions of the form: `task_name: description` :param check_integrity: bool Whether to run the relevant part of the test suite for the tasks :return @@ -49,20 +62,28 @@ def simple_evaluate(model, model_args=None, tasks=[], assert tasks != [], "No tasks specified" if isinstance(model, str): - if model_args is None: model_args = "" - lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { - 'batch_size': batch_size, 'device': device - }) + if model_args is None: + model_args = "" + lm = lm_eval.models.get_model(model).create_from_arg_string( + model_args, {"batch_size": batch_size, "device": device} + ) else: assert isinstance(model, lm_eval.base.LM) lm = model + # TODO: Hard-code turning off cache while testing. Remove once testing is completed. + no_cache = True if not no_cache: lm = lm_eval.base.CachingLM( - lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' + lm, + "lm_cache/" + + model + + "_" + + model_args.replace("=", "-").replace(",", "_").replace("/", "-") + + ".db", ) - - task_dict = lm_eval.tasks.get_task_dict(tasks) + + task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks) if check_integrity: run_task_tests(task_list=tasks) @@ -72,7 +93,7 @@ def simple_evaluate(model, model_args=None, tasks=[], task_dict=task_dict, num_fewshot=num_fewshot, limit=limit, - description_dict=description_dict + description_dict=description_dict, ) # add info about the model and few shot config @@ -85,14 +106,22 @@ def simple_evaluate(model, model_args=None, tasks=[], "no_cache": no_cache, "limit": limit, "bootstrap_iters": bootstrap_iters, - "description_dict": description_dict + "description_dict": description_dict, } return results @positional_deprecated -def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): +def evaluate( + lm, + task_dict, + provide_description=None, + num_fewshot=0, + limit=None, + bootstrap_iters=100000, + description_dict=None, +): """Instantiate and evaluate a model on a list of tasks. :param lm: obj @@ -108,7 +137,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, :param bootstrap_iters: Number of iterations for bootstrap statistics :param description_dict: dict[str, str] - Dictionary of custom task descriptions of the form: `task_name: description` + Dictionary of custom task descriptions of the form: `task_name: description` :return Dictionary of results """ @@ -118,12 +147,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, assert not provide_description # not implemented. if provide_description is not None: # nudge people to not specify it at all - print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + print( + "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" + ) task_dict_items = [ (name, task) for name, task in task_dict.items() - if(task.has_validation_docs() or task.has_test_docs()) + if (task.has_validation_docs() or task.has_test_docs()) ] results = collections.defaultdict(dict) @@ -141,8 +172,12 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, docs = {} # get lists of each type of request - for task_name, task in task_dict_items: - versions[task_name] = task.VERSION + for task_prompt_name, task in task_dict_items: + # if task.is_generation_task(): + # print(f"WARNING: Skipping generation prompt {task.prompt.name}.") + # continue + + versions[task_prompt_name] = task.VERSION # default to test doc, fall back to val doc if validation unavailable # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point if task.has_test_docs(): @@ -158,15 +193,19 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, rnd.seed(42) rnd.shuffle(task_docs) - description = description_dict[task_name] if description_dict and task_name in description_dict else "" + description = ( + description_dict[task_prompt_name] + if description_dict and task_prompt_name in description_dict + else "" + ) for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): - docs[(task_name, doc_id)] = doc + if task.invalid_doc_for_prompt(doc): + continue + + docs[(task_prompt_name, doc_id)] = doc ctx = task.fewshot_context( - doc=doc, - num_fewshot=num_fewshot, - rnd=rnd, - description=description + doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description ) reqs = task.construct_requests(doc, ctx) if not isinstance(reqs, (list, tuple)): @@ -175,7 +214,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, requests[req.request_type].append(req) # i: index in requests for a single task instance # doc_id: unique id that we can get back to a doc using `docs` - requests_origin[req.request_type].append((i, task_name, doc, doc_id)) + requests_origin[req.request_type].append( + (i, task_prompt_name, doc, doc_id) + ) # all responses for each (task, doc) process_res_queue = collections.defaultdict(list) @@ -189,43 +230,49 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, print("Running", reqtype, "requests") resps = getattr(lm, reqtype)([req.args for req in reqs]) - resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] + resps = [ + x if req.index is None else x[req.index] for x, req in zip(resps, reqs) + ] + + for resp, (i, task_prompt_name, doc, doc_id) in zip( + resps, requests_origin[reqtype] + ): + process_res_queue[(task_prompt_name, doc_id)].append((i, resp)) - for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): - process_res_queue[(task_name, doc_id)].append((i, resp)) - vals = collections.defaultdict(list) # unpack results and sort back in order and return control to Task - for (task_name, doc_id), requests in process_res_queue.items(): + for (task_prompt_name, doc_id), requests in process_res_queue.items(): requests.sort(key=lambda x: x[0]) requests = [x[1] for x in requests] - task = task_dict[task_name] - doc = docs[(task_name, doc_id)] + task = task_dict[task_prompt_name] + doc = docs[(task_prompt_name, doc_id)] metrics = task.process_results(doc, requests) for metric, value in metrics.items(): - vals[(task_name, metric)].append(value) - + vals[(task_prompt_name, metric)].append(value) + # aggregate results - for (task_name, metric), items in vals.items(): - task = task_dict[task_name] - results[task_name][metric] = task.aggregation()[metric](items) + for (task_prompt_name, metric), items in vals.items(): + task_name, prompt_name = task_prompt_name.split("+") + results[task_prompt_name]["task_name"] = task_name + results[task_prompt_name]["prompt_name"] = prompt_name + task = task_dict[task_prompt_name] + results[task_prompt_name][metric] = task.aggregation()[metric](items) # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # so we run them less iterations. still looking for a cleaner way to do this stderr = lm_eval.metrics.stderr_for_metric( metric=task.aggregation()[metric], - bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, + bootstrap_iters=min(bootstrap_iters, 1000) + if metric in ["bleu", "chrf", "ter"] + else bootstrap_iters, ) if stderr is not None: - results[task_name][metric + "_stderr"] = stderr(items) - - return { - "results": dict(results), - "versions": dict(versions) - } + results[task_prompt_name][metric + "_stderr"] = stderr(items) + + return {"results": dict(results), "versions": dict(versions)} def make_table(result_dict): @@ -234,22 +281,50 @@ def make_table(result_dict): md_writer = MarkdownTableWriter() latex_writer = LatexTableWriter() - md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] - latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] + md_writer.headers = ["Task", "Prompt", "Version", "Metric", "Value", "", "Stderr"] + latex_writer.headers = [ + "Task", + "Prompt", + "Version", + "Metric", + "Value", + "", + "Stderr", + ] values = [] - for k, dic in result_dict["results"].items(): version = result_dict["versions"][k] for m, v in dic.items(): if m.endswith("_stderr"): continue - + if "_name" in m: + continue if m + "_stderr" in dic: se = dic[m + "_stderr"] - values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) + values.append( + [ + dic["task_name"], + dic["prompt_name"], + version, + m, + "%.4f" % v, + "±", + "%.4f" % se, + ] + ) else: - values.append([k, version, m, '%.4f' % v, '', '']) + values.append( + [ + dic["task_name"], + dic["prompt_name"], + version, + m, + "%.4f" % v, + "", + "", + ] + ) k = "" version = "" md_writer.value_matrix = values diff --git a/lm_eval/metrics.py b/lm_eval/metrics.py index 05fad59ff3..ba91e0c2ee 100644 --- a/lm_eval/metrics.py +++ b/lm_eval/metrics.py @@ -1,8 +1,10 @@ +import typing import math from collections.abc import Iterable import numpy as np import sacrebleu +from rouge_score import rouge_scorer import sklearn.metrics import random @@ -184,6 +186,74 @@ def _sacreformat(refs, preds): return refs, preds + +def rouge( + refs: typing.List[str], + pred: str, + rouge_types: typing.List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"] +): + """ ROUGE with multi-reference support + + Implementation based on GEM-metrics: + https://github.com/GEM-benchmark/GEM-metrics/blob/431a8174bd6b3637e8d6118bfad2983e39e99733/gem_metrics/rouge.py + + :param refs: + A `list` of reference `str`s. + :param pred: + A single prediction `str`s. + """ + + # Add newlines between sentences to correctly compute `rougeLsum`. + if "rougeLsum" in rouge_types: + # TODO: Adapt this to handle languages that do not support sentence endings by `.`. + # See GEM-metrics implementation with lang specific `nltk` tokenizers to + # split sentences. + pred = pred.replace(".", ".\n") + refs = [ref.replace(".", ".\n") for ref in refs] + + scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) + # ROUGE multi-ref jackknifing + if len(refs) > 1: + cur_scores = [scorer.score(ref, pred) for ref in refs] + + # get best score for all leave-one-out sets + best_scores = [] + for leave in range(len(refs)): + cur_scores_leave_one = [ + cur_scores[s] for s in range(len(refs)) if s != leave + ] + best_scores.append( + { + rouge_type: max( + [s[rouge_type] for s in cur_scores_leave_one], + key=lambda s: s.fmeasure, + ) + for rouge_type in rouge_types + } + ) + # average the leave-one-out bests to produce the final score + score = { + rouge_type: rouge_scorer.scoring.Score( + np.mean([b[rouge_type].precision for b in best_scores]), + np.mean([b[rouge_type].recall for b in best_scores]), + np.mean([b[rouge_type].fmeasure for b in best_scores]), + ) + for rouge_type in rouge_types + } + else: + score = scorer.score(refs[0], pred) + # convert the named tuples to plain nested dicts + score = { + rouge_type: { + "precision": score[rouge_type].precision, + "recall": score[rouge_type].recall, + "fmeasure": score[rouge_type].fmeasure, + } + for rouge_type in rouge_types + } + return score + + # stderr stuff class _bootstrap_internal: diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index a12f68a513..6b31a9e633 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -1,11 +1,18 @@ from . import gpt2 +from . import gptj from . import gpt3 +from . import t5 +from . import t0 from . import dummy MODEL_REGISTRY = { "hf": gpt2.HFLM, "gpt2": gpt2.GPT2LM, + "gptj": gptj.GPTJLM, "gpt3": gpt3.GPT3LM, + "t5": t5.T5LM, + "mt5": t5.T5LM, + "t0": t0.T0LM, "dummy": dummy.DummyLM, } diff --git a/lm_eval/models/gpt2.py b/lm_eval/models/gpt2.py index a2214d39b1..2e73adf3a7 100644 --- a/lm_eval/models/gpt2.py +++ b/lm_eval/models/gpt2.py @@ -4,8 +4,15 @@ class HFLM(BaseLM): - - def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): + def __init__( + self, + device="cuda", + pretrained="gpt2", + revision="main", + subfolder=None, + tokenizer=None, + batch_size=1, + ): super().__init__() assert isinstance(device, str) @@ -15,28 +22,47 @@ def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder= if device: self._device = torch.device(device) else: - self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) # TODO: update this to be less of a hack once subfolder is fixed in HF self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( - pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") + pretrained, + revision=revision + ("/" + subfolder if subfolder is not None else ""), ).to(self.device) self.gpt2.eval() # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 self.tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) + pretrained if tokenizer is None else tokenizer, + revision=revision, + subfolder=subfolder, + ) - assert isinstance(self.tokenizer, ( - transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, - transformers.T5Tokenizer, transformers.T5TokenizerFast, - )), "this tokenizer has not been checked for compatibility yet!" + assert isinstance( + self.tokenizer, + ( + transformers.GPT2Tokenizer, + transformers.GPT2TokenizerFast, + transformers.T5Tokenizer, + transformers.T5TokenizerFast, + ), + ), "this tokenizer has not been checked for compatibility yet!" self.vocab_size = self.tokenizer.vocab_size - if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): - assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \ - self.tokenizer.encode('hello\n\nhello') + if isinstance( + self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast) + ): + assert self.tokenizer.encode("hello\n\nhello") == [ + 31373, + 198, + 198, + 31373, + ], self.tokenizer.encode("hello\n\nhello") # multithreading and batching self.batch_size_per_gpu = batch_size # todo: adaptive batch size @@ -75,7 +101,7 @@ def device(self): def tok_encode(self, string: str): return self.tokenizer.encode(string, add_special_tokens=False) - + def tok_decode(self, tokens): return self.tokenizer.decode(tokens) @@ -89,14 +115,42 @@ def _model_call(self, inps): """ with torch.no_grad(): return self.gpt2(inps)[0][:, :, :50257] - - def _model_generate(self, context, max_length, eos_token_id): + + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(self.tokenizer.eos_token) + ]) + + def _model_generate(self, context, max_length, stopping_criteria_ids): + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) return self.gpt2.generate( - context, - max_length=max_length, - eos_token_id=eos_token_id, - do_sample=False + context, + max_length=max_length, + stopping_criteria=stopping_criteria, + do_sample=False, ) + # for backwards compatibility diff --git a/lm_eval/models/gptj.py b/lm_eval/models/gptj.py new file mode 100644 index 0000000000..398ae03053 --- /dev/null +++ b/lm_eval/models/gptj.py @@ -0,0 +1,119 @@ +import transformers +import torch +from lm_eval.base import BaseLM + + +class GPTJLM(BaseLM): + def __init__( + self, + device="cuda", + batch_size=1, + ): + super().__init__() + + assert isinstance(device, str) + assert isinstance(batch_size, int) + + if device: + self._device = torch.device(device) + else: + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + + pretrained = "EleutherAI/gpt-j-6B" + self.gptj = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device) + self.gptj.eval() + + # pretrained tokenizer for neo is broken for now so just hard-coding this to gptj + self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained) + self.vocab_size = self.tokenizer.vocab_size + + # multithreading and batching + self.batch_size_per_gpu = batch_size # todo: adaptive batch size + + # TODO: fix multi-gpu + # gpus = torch.cuda.device_count() + # if gpus > 1: + # self.gptj = nn.DataParallel(self.gptj) + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + try: + return self.gptj.config.n_ctx + except AttributeError: + # gptneoconfig doesn't have n_ctx apparently + return self.gptj.config.max_position_embeddings + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + # TODO: fix multi-gpu + return self.batch_size_per_gpu # * gpus + + @property + def device(self): + # TODO: fix multi-gpu + return self._device + + def tok_encode(self, string: str): + return self.tokenizer.encode(string, add_special_tokens=False) + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def _model_call(self, inps): + """ + inps: a torch tensor of shape [batch, sequence] + the size of sequence may vary from call to call + + returns: a torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model + """ + with torch.no_grad(): + return self.gptj(inps)[0][:, :, :50257] + + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(self.tokenizer.eos_token) + ]) + + def _model_generate(self, context, max_length, stopping_criteria_ids): + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) + return self.gptj.generate( + context, + max_length=max_length, + stopping_criteria=stopping_criteria, + do_sample=False, + ) diff --git a/lm_eval/models/t0.py b/lm_eval/models/t0.py new file mode 100644 index 0000000000..db7ba582a7 --- /dev/null +++ b/lm_eval/models/t0.py @@ -0,0 +1,161 @@ +import transformers +import torch +import torch.nn as nn +import torch.nn.functional as F +from lm_eval.base import LM +from lm_eval import utils +from tqdm import tqdm +import numpy as np +import math + +class T0LM(LM): + MAX_GEN_TOKS = 256 + MAX_INP_LENGTH = 512 + VOCAB_SIZE = 32100 + EOT_TOKEN_ID = 1 + + def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1): + super().__init__() + if device: + self.device = torch.device(device) + else: + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + print(pretrained) + self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained) + self.t0.eval() + + if parallelize == "True": + print(parallelize) + self.t0.parallelize() + self.device = torch.device('cuda:0') + else: + self.t0.to(self.device) + + self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained) + self.max_length = self.MAX_INP_LENGTH + + self.batch_size = int(batch_size) + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config={}): + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + def loglikelihood(self, requests): + res = [] + for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)): + + inputs, targets = zip(*chunk) + + inputs_tok = self.tokenizer( + list(inputs), + max_length=self.max_length, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in inputs_tok: + inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :] + + targets_tok = self.tokenizer( + list(targets), + max_length=self.MAX_GEN_TOKS, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in targets_tok: + targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :] + + with torch.no_grad(): + outputs = self.t0( + **inputs_tok, + labels=targets_tok["input_ids"] + ) + + log_softmaxes = F.log_softmax(outputs.logits, dim=-1) + + output_iterator = zip( + chunk, + log_softmaxes, + targets_tok["input_ids"], + targets_tok["attention_mask"], + ) + for cache_key, log_softmax, target_tok, target_mask in output_iterator: + length = target_mask.sum() + log_softmax = log_softmax[:length] + target_tok = target_tok[:length] + greedy_tokens = log_softmax.argmax(dim=-1) + max_equal = (greedy_tokens == target_tok).all() + target_logits = torch.gather( + log_softmax, 1, target_tok.unsqueeze(-1) + ).squeeze(-1) + answer = (float(target_logits.sum()), bool(max_equal)) + + if cache_key is not None: + self.cache_hook.add_partial("loglikelihood", cache_key, answer) + + res.append(answer) + + return res + + def loglikelihood_rolling(self, requests): + raise NotImplementedError + + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(self.tokenizer.eos_token) + ]) + + def greedy_until(self, requests): + + res = [] + + for context, until in tqdm(requests): + if isinstance(until, str): until = [until] + + context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids + + stopping_criteria_ids = self.tokenizer.encode(until[0]) + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) + + cont = self.t0.generate( + context_enc, + max_length=self.MAX_GEN_TOKS, + stopping_criteria=stopping_criteria, + do_sample=False + ) + + s = self.tokenizer.decode(cont[0].tolist()) + + self.cache_hook.add_partial("greedy_until", (context, until), s) + + res.append(s) + + return res \ No newline at end of file diff --git a/lm_eval/models/t5.py b/lm_eval/models/t5.py new file mode 100644 index 0000000000..e6d99f870f --- /dev/null +++ b/lm_eval/models/t5.py @@ -0,0 +1,161 @@ +import transformers +import torch +import torch.nn as nn +import torch.nn.functional as F +from lm_eval.base import LM +from lm_eval import utils +from tqdm import tqdm +import numpy as np +import math + +class T5LM(LM): + MAX_GEN_TOKS = 256 + MAX_INP_LENGTH = 512 + VOCAB_SIZE = 32128 + EOT_TOKEN_ID = 1 + + def __init__(self, device='cuda', parallelize=False, pretrained='t5', batch_size=1): + super().__init__() + if device: + self.device = torch.device(device) + else: + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + print(pretrained) + self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained) + self.t5.eval() + + if parallelize == "True": + print(parallelize) + self.t5.parallelize() + self.device = torch.device('cuda:0') + else: + self.t5.to(self.device) + + self.tokenizer = transformers.T5TokenizerFast.from_pretrained(pretrained) + self.max_length = self.MAX_INP_LENGTH + + self.batch_size = int(batch_size) + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config={}): + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + def loglikelihood(self, requests): + res = [] + for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)): + + inputs, targets = zip(*chunk) + + inputs_tok = self.tokenizer( + list(inputs), + max_length=self.max_length, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in inputs_tok: + inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :] + + targets_tok = self.tokenizer( + list(targets), + max_length=self.MAX_GEN_TOKS, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in targets_tok: + targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :] + + with torch.no_grad(): + outputs = self.t5( + **inputs_tok, + labels=targets_tok["input_ids"] + ) + + log_softmaxes = F.log_softmax(outputs.logits, dim=-1) + + output_iterator = zip( + chunk, + log_softmaxes, + targets_tok["input_ids"], + targets_tok["attention_mask"], + ) + for cache_key, log_softmax, target_tok, target_mask in output_iterator: + length = target_mask.sum() + log_softmax = log_softmax[:length] + target_tok = target_tok[:length] + greedy_tokens = log_softmax.argmax(dim=-1) + max_equal = (greedy_tokens == target_tok).all() + target_logits = torch.gather( + log_softmax, 1, target_tok.unsqueeze(-1) + ).squeeze(-1) + answer = (float(target_logits.sum()), bool(max_equal)) + + if cache_key is not None: + self.cache_hook.add_partial("loglikelihood", cache_key, answer) + + res.append(answer) + + return res + + def loglikelihood_rolling(self, requests): + raise NotImplementedError + + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(self.tokenizer.eos_token) + ]) + + def greedy_until(self, requests): + + res = [] + + for context, until in tqdm(requests): + if isinstance(until, str): until = [until] + + context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids + + stopping_criteria_ids = self.tokenizer.encode(until[0]) + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) + + cont = self.t5.generate( + context_enc, + max_length=self.MAX_GEN_TOKS, + stopping_criteria=stopping_criteria, + do_sample=False + ) + + s = self.tokenizer.decode(cont[0].tolist()) + + self.cache_hook.add_partial("greedy_until", (context, until), s) + + res.append(s) + + return res \ No newline at end of file diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index 4e6a8b87fa..87c2a97af4 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -1,3 +1,4 @@ +from promptsource.templates import DatasetTemplates from pprint import pprint from typing import List, Union @@ -51,6 +52,9 @@ from . import asdiv from . import gsm8k from . import storycloze +from . import hans + +# from . import e2e_nlg_cleaned ######################################## # Translation tasks @@ -58,8 +62,8 @@ # 6 total gpt3_translation_benchmarks = { - "wmt14": ['en-fr', 'fr-en'], # French - "wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian + "wmt14": ["en-fr", "fr-en"], # French + "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian } @@ -67,7 +71,7 @@ selected_translation_benchmarks = { **gpt3_translation_benchmarks, "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), - "iwslt17": ['en-ar', 'ar-en'] # Arabic + "iwslt17": ["en-ar", "ar-en"], # Arabic } # 319 total @@ -91,7 +95,7 @@ "rte": glue.RTE, "qnli": glue.QNLI, "qqp": glue.QQP, - #"stsb": glue.STSB, # not implemented yet + # "stsb": glue.STSB, # not implemented yet "sst": glue.SST, "wnli": glue.WNLI, # SuperGLUE @@ -102,34 +106,27 @@ "record": superglue.ReCoRD, "wic": superglue.WordsInContext, "wsc": superglue.SGWinogradSchemaChallenge, - # Order by benchmark/genre? "coqa": coqa.CoQA, "drop": drop.DROP, "lambada": lambada.LAMBADA, "lambada_cloze": lambada_cloze.LAMBADA_cloze, - # multilingual lambada **lambada_multilingual.construct_tasks(), - "wikitext": wikitext.WikiText, # "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix - "piqa": piqa.PiQA, "prost": prost.PROST, "mc_taco": mc_taco.MCTACO, - # Science related - "pubmedqa" : pubmedqa.Pubmed_QA, - "sciq" : sciq.SciQ, - + "pubmedqa": pubmedqa.Pubmed_QA, + "sciq": sciq.SciQ, + # "e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned, "qasper": qasper.QASPER, - - "qa4mre_2011" : qa4mre.QA4MRE_2011, - "qa4mre_2012" : qa4mre.QA4MRE_2012, - "qa4mre_2013" : qa4mre.QA4MRE_2013, - + "qa4mre_2011": qa4mre.QA4MRE_2011, + "qa4mre_2012": qa4mre.QA4MRE_2012, + "qa4mre_2013": qa4mre.QA4MRE_2013, "triviaqa": triviaqa.TriviaQA, "arc_easy": arc.ARCEasy, "arc_challenge": arc.ARCChallenge, @@ -140,7 +137,7 @@ "squad2": squad.SQuAD2, "race": race.RACE, # "naturalqs": naturalqs.NaturalQs, # not implemented yet - "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es + "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa_es": headqa.HeadQAEs, "headqa_en": headqa.HeadQAEn, "mathqa": mathqa.MathQA, @@ -150,21 +147,18 @@ "anli_r1": anli.ANLIRound1, "anli_r2": anli.ANLIRound2, "anli_r3": anli.ANLIRound3, - + "hans": hans.HANS, "ethics_cm": hendrycks_ethics.EthicsCM, "ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_virtue": hendrycks_ethics.EthicsVirtue, - - "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, - "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, - + "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, + "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, # dialogue "mutual": mutual.MuTual, "mutual_plus": mutual.MuTualPlus, - # math "math_algebra": hendrycks_math.MathAlgebra, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability, @@ -175,7 +169,6 @@ "math_precalc": hendrycks_math.MathPrecalculus, "math_asdiv": asdiv.Asdiv, "gsm8k": gsm8k.GradeSchoolMath8K, - # arithmetic "arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus, @@ -189,22 +182,18 @@ "arithmetic_1dc": arithmetic.Arithmetic1DComposite, # TODO Perhaps make these groups of tasks # e.g. anli, arithmetic, openai_translations, harness_translations - # hendrycksTest (57 tasks) **hendrycks_test.create_all_tasks(), - # e.g. wmt14-fr-en **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), # chef's selection, mostly wmt20 **translation.create_tasks_from_benchmarks(selected_translation_benchmarks), - # Word Scrambling and Manipulation Tasks "anagrams1": unscramble.Anagrams1, "anagrams2": unscramble.Anagrams2, "cycle_letters": unscramble.CycleLetters, "random_insertion": unscramble.RandomInsertion, "reversed_words": unscramble.ReversedWords, - # Pile "pile_arxiv": pile.PileArxiv, "pile_books3": pile.PileBooks3, @@ -228,7 +217,6 @@ "pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_wikipedia": pile.PileWikipedia, "pile_youtubesubtitles": pile.PileYoutubeSubtitles, - # BLiMP "blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, @@ -297,7 +285,6 @@ "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, - # Requires manual download of data. # "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2018": storycloze.StoryCloze2018, @@ -321,19 +308,51 @@ def get_task_name_from_object(task_object): for name, class_ in TASK_REGISTRY.items(): if class_ is task_object: return name - + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting - return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ + return ( + task_object.EVAL_HARNESS_NAME + if hasattr(task_object, "EVAL_HARNESS_NAME") + else type(task_object).__name__ + ) def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): task_name_dict = { task_name: get_task(task_name)() - for task_name in task_name_list if isinstance(task_name, str) + for task_name in task_name_list + if isinstance(task_name, str) } task_name_from_object_dict = { get_task_name_from_object(task_object): task_object - for task_object in task_name_list if not isinstance(task_object, str) + for task_object in task_name_list + if not isinstance(task_object, str) } assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) return {**task_name_dict, **task_name_from_object_dict} + + +def get_task_dict_promptsource(task_name_list: List[str]): + """Loads a task instance for each prompt written for that task.""" + task_name_dict = {} + + for task_name in task_name_list: + assert isinstance(task_name, str) + + # Static version of the Task Use this to get HF dataset path / name. + static_task_obj = get_task(task_name) + # Create the proper task name arg for DatasetTemplates. + sub_task = ( + f"/{static_task_obj.DATASET_NAME}" if static_task_obj.DATASET_NAME else "" + ) + ps_task_name = f"{static_task_obj.DATASET_PATH}{sub_task}" + + task_prompts = DatasetTemplates(ps_task_name) + for prompt_name in task_prompts.all_template_names: + prompt = task_prompts[prompt_name] + # NOTE: We choose a sep that can be easily split. + task_name_dict[f"{task_name}+{prompt_name}"] = get_task(task_name)( + prompt=prompt + ) + + return task_name_dict diff --git a/lm_eval/tasks/anli.py b/lm_eval/tasks/anli.py index 2f7a763b92..59eee05f9c 100644 --- a/lm_eval/tasks/anli.py +++ b/lm_eval/tasks/anli.py @@ -10,7 +10,7 @@ Homepage: "https://github.com/facebookresearch/anli" """ import numpy as np -from lm_eval.base import rf, Task +from lm_eval.base import rf, PromptSourceTask from lm_eval.metrics import mean @@ -30,7 +30,7 @@ """ -class ANLIBase(Task): +class ANLIBase(PromptSourceTask): VERSION = 0 DATASET_PATH = "anli" DATASET_NAME = None @@ -59,51 +59,6 @@ def test_docs(self): if self.has_test_docs(): return self.dataset["test_r" + str(self.SPLIT)] - def doc_to_text(self, doc): - # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning - # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly - # appended onto the question, with no "Answer:" or even a newline. Do we *really* - # want to do it exactly as OA did? - return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:' - - def doc_to_target(self, doc): - # True = entailment - # False = contradiction - # Neither = neutral - return " " + ["True", "Neither", "False"][doc['label']] - - def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_neither, _ = rf.loglikelihood(ctx, " Neither") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_neither, ll_false - - def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of - the metric for that one document - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param results: - The results of the requests created in construct_requests. - """ - gold = doc["label"] - pred = np.argmax(results) - return { - "acc": pred == gold - } - def aggregation(self): """ :returns: {str: [float] -> float} diff --git a/lm_eval/tasks/arithmetic.py b/lm_eval/tasks/arithmetic.py index 8f783feb23..21da60a79e 100644 --- a/lm_eval/tasks/arithmetic.py +++ b/lm_eval/tasks/arithmetic.py @@ -58,10 +58,11 @@ def doc_to_target(self, doc): def construct_requests(self, doc, ctx): ll, is_prediction = rf.loglikelihood(ctx, doc["completion"]) - return is_prediction + return ll, is_prediction def process_results(self, doc, results): - is_prediction, = results + print(results) + results = results return { "acc": is_prediction } diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index f6c9983384..0fbd23112e 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -12,7 +12,7 @@ import inspect import transformers.data.metrics.squad_metrics as squad_metrics import lm_eval.datasets.coqa.coqa -from lm_eval.base import Task, rf, mean +from lm_eval.base import PromptSourceTask, Task, rf, mean from itertools import zip_longest @@ -28,9 +28,9 @@ """ -class CoQA(Task): +class CoQA(PromptSourceTask): VERSION = 1 - DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa) + DATASET_PATH = "coqa" DATASET_NAME = None def has_training_docs(self): @@ -51,44 +51,21 @@ def validation_docs(self): def test_docs(self): pass - def doc_to_text(self, doc): - # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} - # and a question qi, the task is to predict the answer ai - doc_text = doc["story"] + '\n\n' - for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai - question = f"Q: {q}\n\n" - answer = f"A: {a}\n\n" if a is not None else "A:" - doc_text += question + answer - return doc_text - - @classmethod - def get_answers(cls, doc, turn_id): - # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). - answers = [] - answer_forturn = doc["answers"]["input_text"][turn_id - 1] - answers.append(answer_forturn) - - additional_answers = doc.get("additional_answers") - if additional_answers: - for key in additional_answers: - additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1] - if additional_answer_for_turn.lower() not in map(str.lower, answers): - answers.append(additional_answer_for_turn) - return answers - - @classmethod - def get_answer_choice(self, raw_text): - # Function maps answers to CoQA answer categories - # ~ 1/5 of the CoQA answers are Yes/No - # ~ 2/3 of the CoQA answers are span-based - # (answers overlap with the passage ignoring punctuation and case mismatch) - if raw_text == "unknown": - return '0' - if squad_metrics.normalize_answer(raw_text) == "yes": - return '1' - if squad_metrics.normalize_answer(raw_text) == "no": - return '2' - return '3' # Not a yes/no question + # @classmethod + # def get_answers(cls, doc, turn_id): + # # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). + # answers = [] + # answer_forturn = doc["answers"]["input_text"][turn_id - 1] + # answers.append(answer_forturn) + # additional_answers = doc.get("additional_answers") + # if additional_answers: + # for key in additional_answers: + # additional_answer_for_turn = additional_answers[key]["input_text"][ + # turn_id - 1 + # ] + # if additional_answer_for_turn.lower() not in map(str.lower, answers): + # answers.append(additional_answer_for_turn) + # return answers @staticmethod def compute_scores(gold_list, pred): @@ -98,40 +75,40 @@ def compute_scores(gold_list, pred): em_sum = 0.0 if len(gold_list) > 1: for i in range(len(gold_list)): - gold_answers = gold_list[0:i] + gold_list[i + 1:] + gold_answers = gold_list[0:i] + gold_list[i + 1 :] # predictions compared against (n) golds and take maximum - em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) + em_sum += max( + squad_metrics.compute_exact(a, pred) for a in gold_answers + ) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) else: em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list) - return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} + return { + "em": em_sum / max(1, len(gold_list)), + "f1": f1_sum / max(1, len(gold_list)), + } - def doc_to_target(self, doc, turnid=None): - # Default to prediction of last turn. - if turnid is None: - turnid = len(doc["questions"]["input_text"]) - raw_text = doc['answers']["input_text"][turnid - 1] - return " " + raw_text + def stopping_criteria(self): + return "\n\n" - def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. + # def construct_requests(self, doc, ctx): + # """Uses RequestFactory to construct Requests and returns an iterable of + # Requests which will be sent to the LM. - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, ['\nQ:']) - return cont_request + # :param doc: + # The document as returned from training_docs, validation_docs, or test_docs. + # :param ctx: str + # The context string, generated by fewshot_context. This includes the natural + # language description, as well as the few shot examples, and the question + # part of the document for `doc`. + # """ + # return cont_request def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -139,15 +116,25 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - turn_id = len(doc["questions"]["input_text"]) - gold_list = self.get_answers(doc, turn_id) - pred = results[0].strip().split('\n')[0] - - scores = self.compute_scores(gold_list, pred) + target = self.doc_to_target(doc).strip() + pred = results[0].strip().split("\n")[0] + print("*" * 80) + print(f"DOC: {doc}") + # print(f"PS: {self.prompt.apply(doc)}") + print(f"TEXT: {self.doc_to_text(doc)}") + print(f"TARGET: {target} END TARGET") + print(f"PRED: {pred} END PRED") + print("*" * 80) + + # turn_id = len(doc["questions"]["input_text"]) + # gold_list = self.get_answers(doc, turn_id) + + # TODO: Add HF metrics mapped from promptsource metadata. + scores = self.compute_scores([target], pred) return { - "f1": scores['f1'], - "em": scores['em'], + "f1": scores["f1"], + "em": scores["em"], } def higher_is_better(self): diff --git a/lm_eval/tasks/drop.py b/lm_eval/tasks/drop.py index 82ca790d81..ff78c76b59 100644 --- a/lm_eval/tasks/drop.py +++ b/lm_eval/tasks/drop.py @@ -18,7 +18,7 @@ import string import lm_eval.datasets.drop.drop from scipy.optimize import linear_sum_assignment -from lm_eval.base import Task, rf +from lm_eval.base import PromptSourceTask, rf from lm_eval.metrics import mean @@ -37,9 +37,9 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) -class DROP(Task): +class DROP(PromptSourceTask): VERSION = 1 - DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop) + DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop) DATASET_NAME = None def has_training_docs(self): @@ -52,46 +52,13 @@ def has_test_docs(self): return False def training_docs(self): - if self._training_docs is None: - self._training_docs = list(map(self._process_doc, self.dataset["train"])) - return self._training_docs + # if self._training_docs is None: + # self._training_docs = list() + # return self._training_docs + return self.dataset["train"] def validation_docs(self): - return map(self._process_doc, self.dataset["validation"]) - - def _process_doc(self, doc): - return { - "id": doc["query_id"], - "passage": doc["passage"], - "question": doc["question"], - "answers": self.get_answers(doc), - } - - @classmethod - def get_answers(cls, qa): - def _flatten_validated_answers(validated_answers): - """ Flattens a dict of lists of validated answers. - {"number": ['1', '8'], ...} - -> [{"number": ['1'], ...}, {"number": ['8'], ...}] - """ - vas = [] - for i in range(len(validated_answers["number"])): - vas.append({ - "number": validated_answers["number"][i], - "date": validated_answers["date"][i], - "spans": validated_answers["spans"][i], - }) - return vas - answers = [] - answers_set = set() - candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"]) - for candidate in candidates: - answer = cls.parse_answer(candidate) - if answer in answers_set: - continue - answers_set.add(answer) - answers.append(answer) - return answers + return self.dataset["validation"] @classmethod def parse_answer(cls, answer): @@ -100,29 +67,33 @@ def parse_answer(cls, answer): return (str(answer["number"]),) if answer["spans"] != []: return tuple(answer["spans"]) - return (" ".join([answer["date"]["day"], - answer["date"]["month"], - answer["date"]["year"]]).strip(),) + return ( + " ".join( + [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]] + ).strip(), + ) - def doc_to_text(self, doc): - return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" + # def doc_to_text(self, doc): + # return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" - def doc_to_target(self, doc): - return " " + ", ".join(doc["answers"][0]) + # def doc_to_target(self, doc): + # return " " + ", ".join(doc["answers"][0]) - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. + # def construct_requests(self, doc, ctx): + # """Uses RequestFactory to construct Requests and returns an iterable of + # Requests which will be sent to the LM. - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - conts = [rf.greedy_until(ctx, ["."])] - return conts + # :param doc: + # The document as returned from training_docs, validation_docs, or test_docs. + # :param ctx: str + # The context string, generated by fewshot_context. This includes the natural + # language description, as well as the few shot examples, and the question + # part of the document for `doc`. + # """ + # conts = [rf.greedy_until(ctx, ["."])] + # return conts + def stopping_criteria(self): + return "." def process_results(self, doc, results): """Take a single document and the LM results and evaluates, returning a @@ -134,7 +105,21 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - preds, golds = results, doc["answers"] + + pred = results[0].strip() + target = self.doc_to_target(doc).strip() + + print("*" * 80) + print(f"DOC: {doc}") + print(f"PS: {self.prompt.apply(doc)}") + print(f"TEXT: {self.doc_to_text(doc)}") + print(f"TARGET: {target} END TARGET") + print(f"PRED: {pred} END PRED") + print("*" * 80) + + preds = [pred] + golds = [target] + max_em = 0 max_f1 = 0 for gold_answer in golds: @@ -142,10 +127,7 @@ def process_results(self, doc, results): if gold_answer[0].strip(): max_em = max(max_em, exact_match) max_f1 = max(max_f1, f1_score) - return { - "em": max_em, - "f1": max_f1 - } + return {"em": max_em, "f1": max_f1} def get_metrics(self, predicted, gold): """ @@ -158,7 +140,9 @@ def get_metrics(self, predicted, gold): predicted_bags = self._answer_to_bags(predicted) gold_bags = self._answer_to_bags(gold) - if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): + if set(predicted_bags[0]) == set(gold_bags[0]) and len( + predicted_bags[0] + ) == len(gold_bags[0]): exact_match = 1.0 else: exact_match = 0.0 @@ -190,7 +174,9 @@ def _align_bags(self, predicted, gold): for gold_index, gold_item in enumerate(gold): for pred_index, pred_item in enumerate(predicted): if self._match_numbers_if_present(gold_item, pred_item): - scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item) + scores[gold_index, pred_index] = self._compute_f1( + pred_item, gold_item + ) row_ind, col_ind = linear_sum_assignment(-scores) max_scores = np.zeros([max(len(gold), len(predicted))]) @@ -256,7 +242,11 @@ def _tokenize(self, text): def _normalize(self, answer): tokens = [ - self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower())))) + self._white_space_fix( + self._remove_articles( + self._fix_number(self._remove_punc(token.lower())) + ) + ) for token in self._tokenize(answer) ] tokens = [token for token in tokens if token.strip()] @@ -269,10 +259,7 @@ def aggregation(self): A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metrics """ - return { - "em": mean, - "f1": mean - } + return {"em": mean, "f1": mean} def higher_is_better(self): """ @@ -280,7 +267,4 @@ def higher_is_better(self): A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ - return { - "em": True, - "f1": True - } + return {"em": True, "f1": True} diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index 410396d462..8914db88dd 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -14,7 +14,7 @@ Homepage: https://gluebenchmark.com/ """ import numpy as np -from lm_eval.base import rf, Task +from lm_eval.base import PromptSourceTask, rf, Task from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno from lm_eval.utils import general_detokenize @@ -45,7 +45,7 @@ # Single-Sentence Tasks -class CoLA(Task): +class CoLA(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "cola" @@ -67,37 +67,20 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) - - def doc_to_target(self, doc): - return " {}".format({1: "yes", 0: "no"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " yes") - ll_false, _ = rf.loglikelihood(ctx, " no") - return ll_true, ll_false - - def process_results(self, doc, results): - ll_true, ll_false = results - pred = ll_true > ll_false - gold = doc["label"] - return { - "mcc": (gold, pred) - } + # def process_results(self, doc, results): + # answer_choices_list = self.prompt.get_answer_choices_list(doc) + # pred = np.argmax(results) + # target = answer_choices_list.index(self.doc_to_target(doc).strip()) + # return {"mcc": (target, pred)} - def higher_is_better(self): - return { - "mcc": True - } + # def higher_is_better(self): + # return {"mcc": True} - def aggregation(self): - return { - "mcc": matthews_corrcoef - } + # def aggregation(self): + # return {"mcc": matthews_corrcoef} -class SST(Task): +class SST(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "sst2" @@ -119,42 +102,11 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format( - general_detokenize(doc["sentence"]), - ) - - def doc_to_target(self, doc): - return " {}".format({1: "positive", 0: "negative"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_positive, _ = rf.loglikelihood(ctx, " positive") - ll_negative, _ = rf.loglikelihood(ctx, " negative") - return ll_positive, ll_negative - - def process_results(self, doc, results): - ll_positive, ll_negative = results - pred = ll_positive > ll_negative - gold = doc["label"] - return { - "acc": pred == gold - } - - def higher_is_better(self): - return { - "acc": True - } - - def aggregation(self): - return { - "acc": mean - } - # Inference Tasks -class MNLI(Task): +class MNLI(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "mnli" @@ -181,41 +133,6 @@ def test_docs(self): if self.has_test_docs(): return self.dataset["test_matched"] - def doc_to_text(self, doc): - return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( - doc["premise"], - doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'), - ) - - def doc_to_target(self, doc): - # True = entailment - # False = contradiction - # Neither = neutral - return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_neither, _ = rf.loglikelihood(ctx, " Neither") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_neither, ll_false - - def process_results(self, doc, results): - gold = doc["label"] - pred = np.argmax(results) - return { - "acc": pred == gold - } - - def higher_is_better(self): - return { - "acc": True - } - - def aggregation(self): - return { - "acc": mean - } - class MNLIMismatched(MNLI): VERSION = 0 @@ -229,7 +146,7 @@ def test_docs(self): return self.dataset["test_mismatched"] -class QNLI(Task): +class QNLI(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "qnli" @@ -251,42 +168,8 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( - doc["question"], - doc["sentence"], - ) - - def doc_to_target(self, doc): - # True = entailment - # False = not entailment - return " {}".format({0: "yes", 1: "no"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, " yes") - ll_no, _ = rf.loglikelihood(ctx, " no") - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - pred = ll_no > ll_yes - gold = doc["label"] - return { - "acc": pred == gold - } - - def higher_is_better(self): - return { - "acc": True - } - - def aggregation(self): - return { - "acc": mean - } - -class WNLI(Task): +class WNLI(PromptSourceTask): VERSION = 1 DATASET_PATH = "glue" DATASET_NAME = "wnli" @@ -301,49 +184,13 @@ def has_test_docs(self): return False def training_docs(self): - if self._training_docs is None: - self._training_docs = list(self.dataset["train"]) - return self._training_docs + return self.dataset["train"] def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: {} True or False?\nAnswer:".format( - doc["sentence1"], - doc["sentence2"], - ) - - def doc_to_target(self, doc): - # True = entailment - # False = not_entailment - return " {}".format({0: "False", 1: "True"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_false - - def process_results(self, doc, results): - ll_true, ll_false = results - pred = ll_true > ll_false - gold = doc["label"] - return { - "acc": pred == gold - } - - def higher_is_better(self): - return { - "acc": True - } - - def aggregation(self): - return { - "acc": mean - } - -class RTE(Task): +class RTE(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "rte" @@ -365,45 +212,17 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: {} True or False?\nAnswer:".format( - doc["sentence1"], - doc["sentence2"], - ) - - def doc_to_target(self, doc): - # 0 = entailment - # 1 = not_entailment - return " {}".format({0: "True", 1: "False"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_false - - def process_results(self, doc, results): - ll_true, ll_false = results - pred = ll_false > ll_true - gold = doc["label"] - return { - "acc": pred == gold - } - def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} # Similarity and Paraphrase Tasks -class MRPC(Task): +class MRPC(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "mrpc" @@ -417,6 +236,9 @@ def has_validation_docs(self): def has_test_docs(self): return False + def stopping_criteria(self): + return "\n" + def training_docs(self): if self._training_docs is None: self._training_docs = list(self.dataset["train"]) @@ -425,43 +247,8 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format( - general_detokenize(doc["sentence1"]), - general_detokenize(doc["sentence2"]), - ) - - def doc_to_target(self, doc): - return " {}".format(yesno(doc["label"])) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, " yes") - ll_no, _ = rf.loglikelihood(ctx, " no") - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - pred = ll_yes > ll_no - return { - "acc": pred == gold, - "f1": (gold, pred), - } - def higher_is_better(self): - return { - "acc": True, - "f1": True - } - - def aggregation(self): - return { - "acc": mean, - "f1": f1_score - } - - -class QQP(Task): +class QQP(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "qqp" @@ -483,41 +270,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format( - doc["question1"], - doc["question2"], - ) - - def doc_to_target(self, doc): - return " {}".format(yesno(doc["label"])) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, " yes") - ll_no, _ = rf.loglikelihood(ctx, " no") - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - pred = ll_yes > ll_no - return { - "acc": pred == gold, - "f1": (gold, pred), - } - - def higher_is_better(self): - return { - "acc": True, - "f1": True - } - - def aggregation(self): - return { - "acc": mean, - "f1": f1_score - } - class STSB(Task): VERSION = 0 @@ -554,22 +306,22 @@ def doc_to_target(self, doc): return " {}".format(doc["label"]) def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of + """Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LM. :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param ctx: str - The context string, generated by fewshot_context. This includes the natural + The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question - part of the document for `doc`. + part of the document for `doc`. """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') - + raise NotImplementedError("Evaluation not implemented") + def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -578,22 +330,22 @@ def process_results(self, doc, results): The results of the requests created in construct_requests. """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') + raise NotImplementedError("Evaluation not implemented") def aggregation(self): """ :returns: {str: [float] -> float} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metrics """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') + raise NotImplementedError("Evaluation not implemented") def higher_is_better(self): """ :returns: {str: bool} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') + raise NotImplementedError("Evaluation not implemented") diff --git a/lm_eval/tasks/hans.py b/lm_eval/tasks/hans.py new file mode 100644 index 0000000000..d740763870 --- /dev/null +++ b/lm_eval/tasks/hans.py @@ -0,0 +1,61 @@ +""" +Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference +https://arxiv.org/abs/1902.01007 + +A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), +which contains many examples where the heuristics fail. + +Homepage: https://github.com/tommccoy1/hans +""" +from lm_eval.base import PromptSourceTask + + +_CITATION = """ +@inproceedings{mccoy-etal-2019-right, + title = "Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference", + author = "McCoy, Tom and + Pavlick, Ellie and + Linzen, Tal", + booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", + month = jul, + year = "2019", + address = "Florence, Italy", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/P19-1334", + doi = "10.18653/v1/P19-1334", + pages = "3428--3448", + abstract = "A machine learning system can score well on a given test set by relying on heuristics that are effective for frequent example types but break down in more challenging cases. We study this issue within natural language inference (NLI), the task of determining whether one sentence entails another. We hypothesize that statistical NLI models may adopt three fallible syntactic heuristics: the lexical overlap heuristic, the subsequence heuristic, and the constituent heuristic. To determine whether models have adopted these heuristics, we introduce a controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), which contains many examples where the heuristics fail. We find that models trained on MNLI, including BERT, a state-of-the-art model, perform very poorly on HANS, suggesting that they have indeed adopted these heuristics. We conclude that there is substantial room for improvement in NLI systems, and that the HANS dataset can motivate and measure progress in this area.", +} +""" + + +class HANS(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "hans" + DATASET_NAME = None + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return False + + def training_docs(self): + if self.has_training_docs(): + # We cache training documents in `self._training_docs` for faster + # few-shot processing. If the data is too large to fit in memory, + # return the training data as a generator instead of a list. + if self._training_docs is None: + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + return self.dataset["test"] diff --git a/lm_eval/tasks/race.py b/lm_eval/tasks/race.py index f19211793a..3645f357ab 100644 --- a/lm_eval/tasks/race.py +++ b/lm_eval/tasks/race.py @@ -12,7 +12,7 @@ import collections import datasets import numpy as np -from lm_eval.base import rf, Task +from lm_eval.base import PromptSourceTask, rf from lm_eval.metrics import mean @@ -34,13 +34,13 @@ def __rrshift__(self, other): return list(map(self.f, other)) -class RACE(Task): +class RACE(PromptSourceTask): VERSION = 1 DATASET_PATH = "race" DATASET_NAME = "high" cache = {} - letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} + letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3} def has_training_docs(self): return True @@ -51,83 +51,92 @@ def has_validation_docs(self): def has_test_docs(self): return True - def _collate_data(self, set): - if set in self.cache: - return self.cache[set] - # One big issue with HF's implementation of this dataset: it makes a - # separate document for each question; meanwhile, in the GPT3 paper it - # is shown that one document is made per passage. - - r = collections.defaultdict(list) - for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]: - r[item['article']].append(item) - - res = list(r.values() >> each(lambda x: { - 'article': x[0]['article'], - 'problems': x >> each(lambda y: { - 'question': y['question'], - 'answer': y['answer'], - 'options': y['options'], - }) - })) - - self.cache[set] = res - return res + # def _collate_data(self, set): + # if set in self.cache: + # return self.cache[set] + # # One big issue with HF's implementation of this dataset: it makes a + # # separate document for each question; meanwhile, in the GPT3 paper it + # # is shown that one document is made per passage. + + # r = collections.defaultdict(list) + # for item in datasets.load_dataset( + # path=self.DATASET_PATH, name=self.DATASET_NAME + # )[set]: + # r[item["article"]].append(item) + + # res = list( + # r.values() + # >> each( + # lambda x: { + # "article": x[0]["article"], + # "problems": x + # >> each( + # lambda y: { + # "question": y["question"], + # "answer": y["answer"], + # "options": y["options"], + # } + # ), + # } + # ) + # ) + + # self.cache[set] = res + # return res def training_docs(self): - return self._collate_data("train") + return self.dataset["train"] def validation_docs(self): - return self._collate_data("validation") + return self.dataset["validation"] def test_docs(self): - return self._collate_data("test") + return self.dataset["test"] @classmethod def get_answer_option(cls, problem): - answer = cls.letter_to_num[problem['answer']] - return problem['options'][answer] + answer = cls.letter_to_num[problem["answer"]] + return problem["options"][answer] @classmethod def last_problem(cls, doc): - return doc['problems'][-1] - - def doc_to_text(self, doc): - text = 'Article: ' + doc['article'] + '\n\n' - for problem in doc['problems'][:-1]: - if problem['question'][-6:] == ' _ .': - text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' - else: - question = 'Question: ' + problem['question'] + '\n' - answer = 'Answer: ' + self.get_answer_option(problem) + '\n' - text += question + answer - text += self.last_problem(doc)['question'] - return text - - def doc_to_target(self, doc): - return " " + self.get_answer_option(self.last_problem(doc)) - - def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - problem = self.last_problem(doc) - ll_choices = [ - rf.loglikelihood(ctx, " " + problem['options'][i])[0] - for i in range(4) - ] - return ll_choices + return doc["problems"][-1] + + # def doc_to_text(self, doc): + # text = 'Article: ' + doc['article'] + '\n\n' + # for problem in doc['problems'][:-1]: + # if problem['question'][-6:] == ' _ .': + # text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' + # else: + # question = 'Question: ' + problem['question'] + '\n' + # answer = 'Answer: ' + self.get_answer_option(problem) + '\n' + # text += question + answer + # text += self.last_problem(doc)['question'] + # return text + + # def doc_to_target(self, doc): + # return " " + self.get_answer_option(self.last_problem(doc)) + + # def construct_requests(self, doc, ctx): + # """Uses RequestFactory to construct Requests and returns an iterable of + # Requests which will be sent to the LM. + + # :param doc: + # The document as returned from training_docs, validation_docs, or test_docs. + # :param ctx: str + # The context string, generated by fewshot_context. This includes the natural + # language description, as well as the few shot examples, and the question + # part of the document for `doc`. + # """ + # problem = self.last_problem(doc) + # ll_choices = [ + # rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4) + # ] + # return ll_choices def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -135,28 +144,24 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - gold = self.letter_to_num[self.last_problem(doc)['answer']] + # + gold = self.letter_to_num[self.doc_to_target(doc)] + # gold = self.letter_to_num[self.last_problem(doc)["answer"]] pred = np.argmax(results) - return { - "acc": int(pred == gold) - } + return {"acc": int(pred == gold)} def aggregation(self): """ :returns: {str: [float] -> float} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metrics """ - return { - "acc": mean - } + return {"acc": mean} def higher_is_better(self): """ :returns: {str: bool} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ - return { - "acc": True - } + return {"acc": True} diff --git a/lm_eval/tasks/superglue.py b/lm_eval/tasks/superglue.py index e4b9bfff6a..667dc54271 100644 --- a/lm_eval/tasks/superglue.py +++ b/lm_eval/tasks/superglue.py @@ -12,7 +12,7 @@ import numpy as np import sklearn import transformers.data.metrics.squad_metrics as squad_metrics -from lm_eval.base import rf, Task +from lm_eval.base import rf, PromptSourceTask from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno from lm_eval.utils import general_detokenize @@ -32,7 +32,7 @@ """ -class BoolQ(Task): +class BoolQ(PromptSourceTask): VERSION = 1 DATASET_PATH = "super_glue" DATASET_NAME = "boolq" @@ -54,41 +54,8 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:" - - def doc_to_target(self, doc): - return " " + yesno(doc['label']) - def construct_requests(self, doc, ctx): - - ll_yes, _ = rf.loglikelihood(ctx, ' yes') - ll_no, _ = rf.loglikelihood(ctx, ' no') - - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - - acc = 1. if (ll_yes > ll_no) == gold else 0. - - return { - "acc": acc - } - - def higher_is_better(self): - return { - "acc": True - } - - def aggregation(self): - return { - "acc": mean - } - - -class CommitmentBank(Task): +class CommitmentBank(PromptSourceTask): VERSION = 1 DATASET_PATH = "super_glue" DATASET_NAME = "cb" @@ -110,40 +77,15 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( - doc["premise"], - doc["hypothesis"], - ) - - def doc_to_target(self, doc): - # True = entailment - # False = contradiction - # Neither = neutral - return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, ' True') - ll_false, _ = rf.loglikelihood(ctx, ' False') - ll_neither, _ = rf.loglikelihood(ctx, ' Neither') - - return ll_true, ll_false, ll_neither - def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) - acc = 1. if pred == gold else 0. + acc = 1.0 if pred == gold else 0.0 + + return {"acc": acc, "f1": (pred, gold)} - return { - "acc": acc, - "f1": (pred, gold) - } - def higher_is_better(self): - return { - "acc": True, - "f1": True - } + return {"acc": True, "f1": True} @classmethod def cb_multi_fi(cls, items): @@ -155,7 +97,7 @@ def cb_multi_fi(cls, items): f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2) avg_f1 = mean([f11, f12, f13]) return avg_f1 - + def aggregation(self): return { "acc": mean, @@ -163,7 +105,7 @@ def aggregation(self): } -class Copa(Task): +class Copa(PromptSourceTask): VERSION = 0 DATASET_PATH = "super_glue" DATASET_NAME = "copa" @@ -185,53 +127,25 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - # Drop the period - connector = { - "cause": "because", - "effect": "therefore", - }[doc["question"]] - return doc["premise"].strip()[:-1] + f" {connector}" - - def doc_to_target(self, doc): - correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"] - # Connect the sentences - return " " + self.convert_choice(correct_choice) - - def construct_requests(self, doc, ctx): - choice1 = " " + self.convert_choice(doc["choice1"]) - choice2 = " " + self.convert_choice(doc["choice2"]) - - ll_choice1, _ = rf.loglikelihood(ctx, choice1) - ll_choice2, _ = rf.loglikelihood(ctx, choice2) - - return ll_choice1, ll_choice2 - def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) - acc = 1. if pred == gold else 0. + acc = 1.0 if pred == gold else 0.0 + + return {"acc": acc} - return { - "acc": acc - } - def higher_is_better(self): - return { - "acc": True - } - + return {"acc": True} + def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} @staticmethod def convert_choice(choice): return choice[0].lower() + choice[1:] -class MultiRC(Task): +class MultiRC(PromptSourceTask): VERSION = 1 DATASET_PATH = "super_glue" DATASET_NAME = "multirc" @@ -253,45 +167,19 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" - - def doc_to_target(self, doc): - return " " + self.format_answer(answer=doc["answer"], label=doc["label"]) - - @staticmethod - def format_answer(answer, label): - label_str = "yes" if label else "no" - return f"{answer}\nIs the answer correct? {label_str}" - - def construct_requests(self, doc, ctx): - true_choice = self.format_answer(answer=doc["answer"], label=True) - false_choice = self.format_answer(answer=doc["answer"], label=False) - - ll_true_choice, _ = rf.loglikelihood(ctx, f' {true_choice}') - ll_false_choice, _ = rf.loglikelihood(ctx, f' {false_choice}') - - return ll_true_choice, ll_false_choice - def process_results(self, doc, results): ll_true_choice, ll_false_choice = results pred = ll_true_choice > ll_false_choice - return { - "acc": (pred, doc) - } - + return {"acc": (pred, doc)} + def higher_is_better(self): - return { - "acc": True - } - + return {"acc": True} + def aggregation(self): - return { - "acc": acc_all - } + return {"acc": acc_all} -class ReCoRD(Task): +class ReCoRD(PromptSourceTask): VERSION = 0 DATASET_PATH = "super_glue" DATASET_NAME = "record" @@ -311,56 +199,31 @@ def training_docs(self): if self._training_docs is None: self._training_docs = [] for doc in self.dataset["train"]: - self._training_docs.append(self._process_doc(doc)) + self._training_docs.append(doc) return self._training_docs def validation_docs(self): # See: training_docs for doc in self.dataset["validation"]: - yield self._process_doc(doc) - - @classmethod - def _process_doc(cls, doc): - return { - "passage": doc["passage"], - "query": doc["query"], - "entities": sorted(list(set(doc["entities"]))), - "answers": sorted(list(set(doc["answers"]))), - } - - def doc_to_text(self, doc): - initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") - text = initial_text + "\n\n" - for highlight in highlights: - text += f" - {highlight}.\n" - return text - - @classmethod - def format_answer(cls, query, entity): - return f' - {query}'.replace("@placeholder", entity) - - def doc_to_target(self, doc): - # We only output the first correct entity in a doc - return self.format_answer(query=doc["query"], entity=doc["answers"][0]) - - def construct_requests(self, doc, ctx): - requests = [ - rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity)) - for entity in doc["entities"] - ] - return requests + yield doc def process_results(self, doc, results): # ReCoRD's evaluation is actually deceptively simple: # - Pick the maximum likelihood prediction entity # - Evaluate the accuracy and token F1 PER EXAMPLE # - Average over all examples + + # TODO (jon-tow): Look at result max_idx = np.argmax(np.array([result[0] for result in results])) prediction = doc["entities"][max_idx] gold_label_set = doc["answers"] - f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set) - em = metric_max_over_ground_truths(squad_metrics.compute_exact, prediction, gold_label_set) + f1 = metric_max_over_ground_truths( + squad_metrics.compute_f1, prediction, gold_label_set + ) + em = metric_max_over_ground_truths( + squad_metrics.compute_exact, prediction, gold_label_set + ) return { "f1": f1, @@ -380,7 +243,7 @@ def aggregation(self): } -class WordsInContext(Task): +class WordsInContext(PromptSourceTask): VERSION = 0 DATASET_PATH = "super_glue" DATASET_NAME = "wic" @@ -402,50 +265,19 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \ - " two sentences above?\nAnswer:".format( - doc["sentence1"], - doc["sentence2"], - doc["sentence1"][doc["start1"]:doc["end1"]], - ) - - def doc_to_target(self, doc): - return " {}".format({0: "no", 1: "yes"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, ' yes') - ll_no, _ = rf.loglikelihood(ctx, ' no') - - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - - acc = 1. if (ll_yes > ll_no) == gold else 0. - - return { - "acc": acc - } - def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} -class SGWinogradSchemaChallenge(Task): +class SGWinogradSchemaChallenge(PromptSourceTask): VERSION = 0 # Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # binary version of the task. DATASET_PATH = "super_glue" - DATASET_NAME = "wsc" + DATASET_NAME = "wsc.fixed" def has_training_docs(self): return True @@ -461,56 +293,15 @@ def training_docs(self): if self._training_docs is None: # GPT-3 Paper's format only uses positive examples for fewshot "training" self._training_docs = [ - doc for doc in - self.dataset["train"] - if doc["label"] + doc for doc in self.dataset["train"] if doc["label"] ] return self._training_docs def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - raw_passage = doc["text"] - # NOTE: HuggingFace span indices are word-based not character-based. - pre = " ".join(raw_passage.split()[:doc["span2_index"]]) - post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:] - passage = general_detokenize(pre + " *{}*".format(doc['span2_text']) + post) - noun = doc["span1_text"] - pronoun = doc["span2_text"] - text = ( - f"Passage: {passage}\n" - + f"Question: In the passage above, does the pronoun \"*{pronoun}*\" refer to \"*{noun}*\"?\n" - + "Answer:" - ) - return text - - def doc_to_target(self, doc): - return " " + yesno(doc['label']) - - def construct_requests(self, doc, ctx): - - ll_yes, _ = rf.loglikelihood(ctx, ' yes') - ll_no, _ = rf.loglikelihood(ctx, ' no') - - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - - acc = 1. if (ll_yes > ll_no) == gold else 0. - - return { - "acc": acc - } - def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} diff --git a/lm_eval/utils.py b/lm_eval/utils.py index e331283866..b508db96f8 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -146,6 +146,19 @@ def get_original(self, newarr): return res + +def flatten(d, parent_key='', sep='_'): + # From: https://stackoverflow.com/a/6027615 + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + def positional_deprecated(fn): """ A decorator to nudge users into passing only keyword args (`kwargs`) to the diff --git a/scripts/write_out.py b/scripts/write_out.py index 2039d3934f..0f1ee354e8 100644 --- a/scripts/write_out.py +++ b/scripts/write_out.py @@ -30,7 +30,7 @@ def main(): task_names = tasks.ALL_TASKS else: task_names = args.tasks.split(",") - task_dict = tasks.get_task_dict(task_names) + task_dict = tasks.get_task_dict_promptsource(task_names) description_dict = {} if args.description_dict_path: diff --git a/setup.py b/setup.py index 692d090872..c33c62b1c8 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,12 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires=">=3.6", install_requires=[ + "promptsource @ git+https://github.com/bigscience-workshop/promptsource@eval-hackathon", + "wrapt", + "nltk", + "jinja2", "black", "datasets==2.0.0", "click>=7.1", @@ -42,9 +46,9 @@ "openai==0.6.4", "jieba==0.42.1", "nagisa==0.2.7", - "bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt" + "bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", ], dependency_links=[ "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", - ] + ], ) diff --git a/templates/new_task.py b/templates/new_task.py new file mode 100644 index 0000000000..fb3a3c5090 --- /dev/null +++ b/templates/new_task.py @@ -0,0 +1,128 @@ +# TODO: Remove all TODO comments once the implementation is complete. +""" +TODO: Add the Paper Title on this line. +TODO: Add the paper's PDF URL (preferrably from arXiv) on this line. + +TODO: Write a Short Description of the task. + +Homepage: TODO: Add the URL to the task's Homepage here. +""" +from lm_eval.base import PromptSourceTask + + +# TODO: Add the BibTeX citation for the task. +_CITATION = """ +""" + + +# TODO: Replace `NewTask` with the name of your Task. +class NewTask(PromptSourceTask): + VERSION = 0 + # TODO: Add the `DATASET_PATH` string. This will be the name of the `Task` + # dataset as denoted in HuggingFace `datasets`. + DATASET_PATH = "" + # TODO: Add the `DATASET_NAME` string. This is the name of a subset within + # `DATASET_PATH`. If there aren't specific subsets you need, leave this as `None`. + DATASET_NAME = None + + def has_training_docs(self): + # TODO: Fill in the return with `True` if the Task has training data; else `False`. + return False + + def has_validation_docs(self): + # TODO: Fill in the return with `True` if the Task has validation data; else `False`. + return False + + def has_test_docs(self): + # TODO: Fill in the return with `True` if the Task has test data; else `False`. + return False + + def training_docs(self): + if self.has_training_docs(): + # We cache training documents in `self._training_docs` for faster + # few-shot processing. If the data is too large to fit in memory, + # return the training data as a generator instead of a list. + if self._training_docs is None: + # TODO: Return the training document generator from `self.dataset`. + # If you need to process the data, `map` over the documents with + # the custom procesing function, `self._process_doc`. E.g. + # `map(self._process_doc, self.dataset["validation"])` + # In most case you can leave this as is unless the dataset split is + # named differently than the default `"train"`. + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + # TODO: Return the validation document generator from `self.dataset`. + # If you need to process the data, `map` over the documents with the + # custom procesing function, `self._process_doc`. E.g. + # `map(self._process_doc, self.dataset["validation"])` + # In most case you can leave this as is unless the dataset split is + # named differently than the default `"validation"`. + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + # TODO: Return the test document generator from `self.dataset`. + # If you need to process the data, `map` over the documents with the + # custom processing function, `self._process_doc`. E.g. + # `map(self._process_doc, self.dataset["test"])` + # In most case you can leave this as is unless the dataset split is + # named differently than the default `"test"`. + return self.dataset["test"] + + def stopping_criteria(self): + # TODO: Denote the string where the generation should be split. + # For example, for `coqa`, this is '\nQ:' and for `drop` '.'. + # NOTE: You may delete this function if the task does not required generation. + return None + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or + test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # TODO: Construct your language model requests with the request factory, `rf`, + # and return them as an iterable. + return [] + + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + # TODO: For each (sub)metric in the task evaluation, add a key-value pair + # with the metric name as key and the corresponding metric result as value + # for the current `doc`. + return {} + + def aggregation(self): + """ + :returns: {str: [metric_score] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metric scores + """ + # TODO: For each (sub)metric in the task evaluation, add a key-value pair + # with the metric name as key and an aggregation function as value which + # determines how to combine results from each document in the dataset. + # Check `lm_eval.metrics` to find built-in aggregation functions. + return {} + + def higher_is_better(self): + # TODO: For each (sub)metric in the task evaluation, add a key-value pair + # with the metric name as key and a `bool` value determining whether or + # not higher values of that metric are deemed better. + return {} \ No newline at end of file diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py new file mode 100644 index 0000000000..898d1b96ad --- /dev/null +++ b/tests/test_gpt2.py @@ -0,0 +1,33 @@ +import random +import lm_eval.models as models +import pytest +import torch +from transformers import StoppingCriteria + + +@pytest.mark.parametrize( + "eos_token,test_input,expected", + [ + ("not", "i like", "i like to say that I'm not"), + ("say that", "i like", "i like to say that"), + ("great", "big science is", "big science is a great"), + ("<|endoftext|>", "big science has", "big science has been done in the past, but it's not the same as the science of the") + ] +) +def test_stopping_criteria(eos_token, test_input, expected): + random.seed(42) + torch.random.manual_seed(42) + + device = "cuda" if torch.cuda.is_available() else "cpu" + gpt2 = models.get_model("gpt2")(device=device) + + context = torch.tensor([gpt2.tokenizer.encode(test_input)]) + stopping_criteria_ids = gpt2.tokenizer.encode(eos_token) + + generations = gpt2._model_generate( + context, + max_length=20, + stopping_criteria_ids=stopping_criteria_ids + ) + generations = gpt2.tokenizer.decode(generations[0]) + assert generations == expected