From 88745155ae4759e4d83684306218ab5c8acd1be2 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 15:36:46 -0400 Subject: [PATCH 01/40] Initial integration --- lm_eval/base.py | 316 ++++++++++++++++++++++++++------------ lm_eval/evaluator.py | 105 ++++++++----- lm_eval/tasks/__init__.py | 77 +++++----- lm_eval/tasks/coqa.py | 99 +++++------- lm_eval/tasks/drop.py | 72 +++++---- lm_eval/tasks/race.py | 129 ++++++++-------- 6 files changed, 480 insertions(+), 318 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 1ea798159f..c0207028a9 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -1,5 +1,6 @@ import abc from typing import Iterable + import numpy as np import random import re @@ -24,17 +25,17 @@ def __init__(self): @abstractmethod def loglikelihood(self, requests): """Compute log-likelihood of generating a continuation from a context. - Downstream tasks should attempt to use loglikelihood instead of other + Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. :param requests: list A list of pairs (context, continuation) context: str - Context string. Implementations of LM must be able to handle an + Context string. Implementations of LM must be able to handle an empty context string. continuation: str - The continuation over which log likelihood will be calculated. If - there is a word boundary, the space should be in the continuation. + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. For example, context="hello" continuation=" world" is correct. :return: list A list of pairs (logprob, isgreedy) @@ -97,7 +98,7 @@ def greedy_until(self, requests): context: str Context string until: [str] - The string sequences to generate until. These string sequences + The string sequences to generate until. These string sequences may each span across multiple tokens, or may be part of one token. :return: list A list of strings continuation @@ -118,7 +119,6 @@ def set_cache_hook(self, cache_hook): class BaseLM(LM): - @property @abstractmethod def eot_token_id(self): @@ -145,13 +145,16 @@ def device(self): pass @abstractmethod - def tok_encode(self, string: str): pass - + def tok_encode(self, string: str): + pass + @abstractmethod - def tok_decode(self, tokens: Iterable[int]): pass + def tok_decode(self, tokens: Iterable[int]): + pass @abstractmethod - def _model_generate(self, context, max_length, eos_token_id): pass + def _model_generate(self, context, max_length, eos_token_id): + pass @abstractmethod def _model_call(self, inps): @@ -187,23 +190,30 @@ def loglikelihood_rolling(self, requests): # TODO: automatic batch size detection for vectorization loglikelihoods = [] - for string, in tqdm(requests): - rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows( - token_list=self.tok_encode(string), - prefix_token=self.eot_token_id, - max_seq_len=self.max_length, - context_len=1, - ))) + for (string,) in tqdm(requests): + rolling_token_windows = list( + map( + utils.make_disjoint_window, + utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.eot_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) rolling_token_windows = [(None,) + x for x in rolling_token_windows] # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for # that - string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True) - + string_nll = self._loglikelihood_tokens( + rolling_token_windows, disable_tqdm=True + ) + # discard is_greedy string_nll = [x[0] for x in string_nll] - + string_nll = sum(string_nll) loglikelihoods.append(string_nll) @@ -223,10 +233,12 @@ def _collate(x): toks = x[1] + x[2] return -len(toks), tuple(toks) - + # TODO: automatic (variable) batch size detection for vectorization reord = utils.Reorderer(requests, _collate) - for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): + for chunk in utils.chunks( + tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size + ): inps = [] cont_toks_list = [] inplens = [] @@ -252,44 +264,60 @@ def _collate(x): # when too long to fit in context, truncate from the left inp = torch.tensor( - (context_enc + continuation_enc)[-(self.max_length+1):][:-1], - dtype=torch.long + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, ).to(self.device) - inplen, = inp.shape + (inplen,) = inp.shape cont = continuation_enc # since in _collate we make sure length is descending, the longest is always the first one. - padding_length = padding_length if padding_length is not None else inplen + padding_length = ( + padding_length if padding_length is not None else inplen + ) # pad length from seq to padding_length - inp = torch.cat([ - inp, # [seq] - torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] - ], dim=0) + inp = torch.cat( + [ + inp, # [seq] + torch.zeros(padding_length - inplen, dtype=torch.long).to( + inp.device + ), # [padding_length - seq] + ], + dim=0, + ) inps.append(inp.unsqueeze(0)) # [1, padding_length] cont_toks_list.append(cont) inplens.append(inplen) batched_inps = torch.cat(inps, dim=0) # [batch, padding_length - multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab] + multi_logits = F.log_softmax( + self._model_call(batched_inps), dim=-1 + ).cpu() # [batch, padding_length, vocab] - for (cache_key, _, _), logits, inp, inplen, cont_toks \ - in zip(chunk, multi_logits, inps, inplens, cont_toks_list): + for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( + chunk, multi_logits, inps, inplens, cont_toks_list + ): # Slice to original seq length contlen = len(cont_toks) - logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab] + logits = logits[inplen - contlen : inplen].unsqueeze( + 0 + ) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) - cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] + cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze( + 0 + ) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() - logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( + -1 + ) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) @@ -301,9 +329,9 @@ def _collate(x): res.append(answer) return reord.get_original(res) - + def greedy_until(self, requests): - # TODO: implement fully general `until` that handles untils that are + # TODO: implement fully general `until` that handles untils that are # multiple tokens or that span multiple tokens correctly # TODO: extract to TokenizedLM? @@ -312,29 +340,33 @@ def greedy_until(self, requests): def _collate(x): toks = self.tok_encode(x[0]) return len(toks), x[0] - + reord = utils.Reorderer(requests, _collate) for context, until in tqdm(reord.get_reordered()): if isinstance(until, str): until = [until] - primary_until, = self.tok_encode(until[0]) - - context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device) + (primary_until,) = self.tok_encode(until[0]) + + context_enc = torch.tensor( + [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] + ).to(self.device) - cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until) + cont = self._model_generate( + context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until + ) - s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:]) + s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) for term in until: s = s.split(term)[0] - + # partial caching self.cache_hook.add_partial("greedy_until", (context, until), s) - + res.append(s) - + return reord.get_original(res) @@ -383,7 +415,7 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None): self._fewshot_docs = None def download(self, data_dir=None, cache_dir=None, download_mode=None): - """ Downloads and returns the task dataset. + """Downloads and returns the task dataset. Override this method to download the dataset from a custom API. :param data_dir: str @@ -412,7 +444,7 @@ def download(self, data_dir=None, cache_dir=None, download_mode=None): name=self.DATASET_NAME, data_dir=data_dir, cache_dir=cache_dir, - download_mode=download_mode + download_mode=download_mode, ) @abstractmethod @@ -478,22 +510,22 @@ def doc_to_target(self, doc): @abstractmethod def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of + """Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LM. :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param ctx: str - The context string, generated by fewshot_context. This includes the natural + The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question - part of the document for `doc`. + part of the document for `doc`. """ pass @abstractmethod def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -507,7 +539,7 @@ def process_results(self, doc, results): def aggregation(self): """ :returns: {str: [metric_score] -> float} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metric scores """ pass @@ -516,22 +548,26 @@ def aggregation(self): def higher_is_better(self): """ :returns: {str: bool} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ pass def fewshot_description(self): import warnings + warnings.warn( "`fewshot_description` will be removed in futures versions. Pass " "any custom descriptions to the `evaluate` function instead.", - DeprecationWarning) + DeprecationWarning, + ) return "" @utils.positional_deprecated - def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): - """ Returns a fewshot context string that is made up of a prepended description + def fewshot_context( + self, doc, num_fewshot, provide_description=None, rnd=None, description=None + ): + """Returns a fewshot context string that is made up of a prepended description (if provided), the `num_fewshot` number of examples, and an appended prompt example. :param doc: str @@ -548,7 +584,9 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, :returns: str The fewshot context. """ - assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" + assert ( + rnd is not None + ), "A `random.Random` generator argument must be provided to `rnd`" assert not provide_description, ( "The `provide_description` arg will be removed in future versions. To prepend " "a custom description to the context, supply the corresponding string via the " @@ -556,7 +594,9 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, ) if provide_description is not None: # nudge people to not specify it at all - print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + print( + "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" + ) description = description + "\n\n" if description else "" @@ -569,7 +609,9 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, else: if self._fewshot_docs is None: self._fewshot_docs = list( - self.validation_docs() if self.has_validation_docs() else self.test_docs() + self.validation_docs() + if self.has_validation_docs() + else self.test_docs() ) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) @@ -577,23 +619,90 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, # get rid of the doc that's the one we're evaluating, if it's in the fewshot fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] - labeled_examples = "\n\n".join( - [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex] - ) + "\n\n" + labeled_examples = ( + "\n\n".join( + [ + self.doc_to_text(doc) + self.doc_to_target(doc) + for doc in fewshotex + ] + ) + + "\n\n" + ) example = self.doc_to_text(doc) return description + labeled_examples + example -class MultipleChoiceTask(Task): +class PromptSourceTask(Task): + def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None): + super().__init__(data_dir, cache_dir, download_mode) + self.prompt = prompt def doc_to_target(self, doc): - return " " + doc['choices'][doc['gold']] + _, target = prompt.apply(doc) + return f" {target}" + + def doc_to_text(self, doc): + text, _ = prompt.apply(doc) + return text + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + _requests = [] + + if self.prompt.metadata.choices_in_prompt: + for answer_choice in prompt.get_fixed_answer_choices_list(): + ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") + _requests.append(ll_answer_choice) + else: + # TODO(Albert): What is the stop symbol? Is it model specific? + ll_greedy, _ = rf.greedy_until(ctx, ["\nQ:"]) + _requests.append(ll_greedy) + + return _requests + + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + raise NotImplementedError( + "Implement process results using the `prompt.metadata.metrics`. See below." + ) + if self.prompt.metadata.choices_in_prompt: + for result, answer_choice in zip( + prompt.get_fixed_answer_choices_list(), results + ): + pass + else: + continuation = results + + # Map metric name to HF metric. + # TODO(Albert): What is Other? + metric_names = prompt.metadata.metrics + + +class MultipleChoiceTask(Task): + def doc_to_target(self, doc): + return " " + doc["choices"][doc["gold"]] def construct_requests(self, doc, ctx): lls = [ - rf.loglikelihood(ctx, " {}".format(choice))[0] - for choice in doc['choices'] + rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"] ] return lls @@ -601,21 +710,21 @@ def construct_requests(self, doc, ctx): def process_results(self, doc, results): gold = doc["gold"] - acc = 1. if np.argmax(results) == gold else 0. + acc = 1.0 if np.argmax(results) == gold else 0.0 completion_len = np.array([float(len(i)) for i in doc["choices"]]) - acc_norm = 1. if np.argmax(results / completion_len) == gold else 0. + acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 return { "acc": acc, "acc_norm": acc_norm, } - + def higher_is_better(self): return { "acc": True, "acc_norm": True, } - + def aggregation(self): return { "acc": mean, @@ -624,7 +733,6 @@ def aggregation(self): class PerplexityTask(Task, abc.ABC): - def has_training_docs(self): return False @@ -632,9 +740,15 @@ def fewshot_examples(self, k, rnd): assert k == 0 return [] - def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): - assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks." - assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`." + def fewshot_context( + self, doc, num_fewshot, provide_description=None, rnd=None, description=None + ): + assert ( + num_fewshot == 0 + ), "The number of fewshot examples must be 0 for perplexity tasks." + assert ( + rnd is not None + ), "A `random.Random` generator argument must be provided to `rnd`." assert not provide_description, ( "The `provide_description` arg will be removed in future versions. To prepend " "a custom description to the context, supply the corresponding string via the " @@ -642,7 +756,9 @@ def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, ) if provide_description is not None: # nudge people to not specify it at all - print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + print( + "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" + ) return "" @@ -665,7 +781,7 @@ def construct_requests(self, doc, ctx): return req def process_results(self, doc, results): - loglikelihood, = results + (loglikelihood,) = results words = self.count_words(doc) bytes_ = self.count_bytes(doc) return { @@ -687,23 +803,23 @@ def count_bytes(cls, doc): @classmethod def count_words(cls, doc): - """ Downstream tasks with custom word boundaries should override this! """ + """Downstream tasks with custom word boundaries should override this!""" return len(re.split(r"\s+", doc)) def hash_args(attr, args): dat = json.dumps([attr] + list(args)) - return hashlib.sha256(dat.encode('utf-8')).hexdigest() + return hashlib.sha256(dat.encode("utf-8")).hexdigest() class CacheHook: def __init__(self, cachinglm): - if cachinglm is None: + if cachinglm is None: self.dbdict = None return self.dbdict = cachinglm.dbdict - + def add_partial(self, attr, req, res): if self.dbdict is None: return @@ -733,7 +849,7 @@ def __getattr__(self, attr): def fn(requests): res = [] remaining_reqs = [] - + # figure out which ones are cached and which ones are new for req in requests: hsh = hash_args(attr, req) @@ -746,7 +862,7 @@ def fn(requests): else: res.append(None) remaining_reqs.append(req) - + # actually run the LM on the requests that do not have cached results rem_res = getattr(self.lm, attr)(remaining_reqs) @@ -764,41 +880,48 @@ def fn(requests): self.dbdict.commit() return res + return fn - + def get_cache_hook(self): return CacheHook(self) REQUEST_RETURN_LENGTHS = { - 'loglikelihood': 2, - 'greedy_until': None, - 'loglikelihood_rolling': None, + "loglikelihood": 2, + "greedy_until": None, + "loglikelihood_rolling": None, } class Request: def __init__(self, request_type, args, index=None): if request_type not in REQUEST_RETURN_LENGTHS.keys(): - raise NotImplementedError('The request type {} is not implemented!'.format(request_type)) + raise NotImplementedError( + "The request type {} is not implemented!".format(request_type) + ) self.request_type = request_type self.args = args self.index = index - + def __iter__(self): if REQUEST_RETURN_LENGTHS[self.request_type] is None: - raise IndexError('This request type does not return multiple arguments!') + raise IndexError("This request type does not return multiple arguments!") for i in range(REQUEST_RETURN_LENGTHS[self.request_type]): yield Request(self.request_type, self.args, i) - + def __getitem__(self, i): if REQUEST_RETURN_LENGTHS[self.request_type] is None: - raise IndexError('This request type does not return multiple arguments!') + raise IndexError("This request type does not return multiple arguments!") return Request(self.request_type, self.args, i) - + def __eq__(self, other): - return self.request_type == other.request_type and self.args == other.args and self.index == other.index + return ( + self.request_type == other.request_type + and self.args == other.args + and self.index == other.index + ) def __repr__(self): return f"Req_{self.request_type}{self.args}[{self.index}]\n" @@ -808,6 +931,7 @@ class RequestFactory: def __getattr__(self, attr): def fn(*args): return Request(attr, args) + return fn diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 22a8aca26e..66054cb210 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -6,21 +6,33 @@ import lm_eval.models import lm_eval.tasks import lm_eval.base +import promptsource import numpy as np + +from promptsource.templates import DatasetTemplates from lm_eval.utils import positional_deprecated, run_task_tests @positional_deprecated -def simple_evaluate(model, model_args=None, tasks=[], - num_fewshot=0, batch_size=None, device=None, - no_cache=False, limit=None, bootstrap_iters=100000, - description_dict=None, check_integrity=False): +def simple_evaluate( + model, + model_args=None, + tasks=[], + num_fewshot=0, + batch_size=None, + device=None, + no_cache=False, + limit=None, + bootstrap_iters=100000, + description_dict=None, + check_integrity=False, +): """Instantiate and evaluate a model on a list of tasks. :param model: Union[str, LM] Name of model or LM object, see lm_eval.models.get_model :param model_args: Optional[str] - String arguments for each model class, see LM.create_from_arg_string. + String arguments for each model class, see LM.create_from_arg_string. Ignored if `model` argument is a LM object. :param tasks: list[Union[str, Task]] List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. @@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[], :param bootstrap_iters: Number of iterations for bootstrap statistics :param description_dict: dict[str, str] - Dictionary of custom task descriptions of the form: `task_name: description` + Dictionary of custom task descriptions of the form: `task_name: description` :param check_integrity: bool Whether to run the relevant part of the test suite for the tasks :return @@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[], assert tasks != [], "No tasks specified" if isinstance(model, str): - if model_args is None: model_args = "" - lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { - 'batch_size': batch_size, 'device': device - }) + if model_args is None: + model_args = "" + lm = lm_eval.models.get_model(model).create_from_arg_string( + model_args, {"batch_size": batch_size, "device": device} + ) else: assert isinstance(model, lm_eval.base.LM) lm = model if not no_cache: lm = lm_eval.base.CachingLM( - lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' + lm, + "lm_cache/" + + model + + "_" + + model_args.replace("=", "-").replace(",", "_").replace("/", "-") + + ".db", ) - - task_dict = lm_eval.tasks.get_task_dict(tasks) + + task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks) if check_integrity: run_task_tests(task_list=tasks) @@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[], task_dict=task_dict, num_fewshot=num_fewshot, limit=limit, - description_dict=description_dict + description_dict=description_dict, ) # add info about the model and few shot config @@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[], "no_cache": no_cache, "limit": limit, "bootstrap_iters": bootstrap_iters, - "description_dict": description_dict + "description_dict": description_dict, } return results @positional_deprecated -def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): +def evaluate( + lm, + task_dict, + provide_description=None, + num_fewshot=0, + limit=None, + bootstrap_iters=100000, + description_dict=None, +): """Instantiate and evaluate a model on a list of tasks. :param lm: obj @@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, :param bootstrap_iters: Number of iterations for bootstrap statistics :param description_dict: dict[str, str] - Dictionary of custom task descriptions of the form: `task_name: description` + Dictionary of custom task descriptions of the form: `task_name: description` :return Dictionary of results """ @@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, assert not provide_description # not implemented. if provide_description is not None: # nudge people to not specify it at all - print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + print( + "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" + ) task_dict_items = [ (name, task) for name, task in task_dict.items() - if(task.has_validation_docs() or task.has_test_docs()) + if (task.has_validation_docs() or task.has_test_docs()) ] results = collections.defaultdict(dict) @@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, rnd.seed(42) rnd.shuffle(task_docs) - description = description_dict[task_name] if description_dict and task_name in description_dict else "" + description = ( + description_dict[task_name] + if description_dict and task_name in description_dict + else "" + ) for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): docs[(task_name, doc_id)] = doc ctx = task.fewshot_context( - doc=doc, - num_fewshot=num_fewshot, - rnd=rnd, - description=description + doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description ) reqs = task.construct_requests(doc, ctx) if not isinstance(reqs, (list, tuple)): @@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, print("Running", reqtype, "requests") resps = getattr(lm, reqtype)([req.args for req in reqs]) - resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] + resps = [ + x if req.index is None else x[req.index] for x, req in zip(resps, reqs) + ] for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): process_res_queue[(task_name, doc_id)].append((i, resp)) - + vals = collections.defaultdict(list) # unpack results and sort back in order and return control to Task @@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, metrics = task.process_results(doc, requests) for metric, value in metrics.items(): vals[(task_name, metric)].append(value) - + + task_name, prompt_name = task_name.split("+") + results[task_name]["task_name"] = task_name + results[task_name]["prompt_name"] = prompt_name + # aggregate results for (task_name, metric), items in vals.items(): task = task_dict[task_name] + results[task_name][metric] = task.aggregation()[metric](items) # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # so we run them less iterations. still looking for a cleaner way to do this stderr = lm_eval.metrics.stderr_for_metric( metric=task.aggregation()[metric], - bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, + bootstrap_iters=min(bootstrap_iters, 1000) + if metric in ["bleu", "chrf", "ter"] + else bootstrap_iters, ) if stderr is not None: results[task_name][metric + "_stderr"] = stderr(items) - - return { - "results": dict(results), - "versions": dict(versions) - } + + return {"results": dict(results), "versions": dict(versions)} def make_table(result_dict): @@ -247,9 +282,9 @@ def make_table(result_dict): if m + "_stderr" in dic: se = dic[m + "_stderr"] - values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) + values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se]) else: - values.append([k, version, m, '%.4f' % v, '', '']) + values.append([k, version, m, "%.4f" % v, "", ""]) k = "" version = "" md_writer.value_matrix = values diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index 4e6a8b87fa..a68b7cab4f 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -1,3 +1,5 @@ +from promptsource.templates import DatasetTemplates + from pprint import pprint from typing import List, Union @@ -58,8 +60,8 @@ # 6 total gpt3_translation_benchmarks = { - "wmt14": ['en-fr', 'fr-en'], # French - "wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian + "wmt14": ["en-fr", "fr-en"], # French + "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian } @@ -67,7 +69,7 @@ selected_translation_benchmarks = { **gpt3_translation_benchmarks, "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), - "iwslt17": ['en-ar', 'ar-en'] # Arabic + "iwslt17": ["en-ar", "ar-en"], # Arabic } # 319 total @@ -91,7 +93,7 @@ "rte": glue.RTE, "qnli": glue.QNLI, "qqp": glue.QQP, - #"stsb": glue.STSB, # not implemented yet + # "stsb": glue.STSB, # not implemented yet "sst": glue.SST, "wnli": glue.WNLI, # SuperGLUE @@ -102,34 +104,26 @@ "record": superglue.ReCoRD, "wic": superglue.WordsInContext, "wsc": superglue.SGWinogradSchemaChallenge, - # Order by benchmark/genre? "coqa": coqa.CoQA, "drop": drop.DROP, "lambada": lambada.LAMBADA, "lambada_cloze": lambada_cloze.LAMBADA_cloze, - # multilingual lambada **lambada_multilingual.construct_tasks(), - "wikitext": wikitext.WikiText, # "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix - "piqa": piqa.PiQA, "prost": prost.PROST, "mc_taco": mc_taco.MCTACO, - # Science related - "pubmedqa" : pubmedqa.Pubmed_QA, - "sciq" : sciq.SciQ, - + "pubmedqa": pubmedqa.Pubmed_QA, + "sciq": sciq.SciQ, "qasper": qasper.QASPER, - - "qa4mre_2011" : qa4mre.QA4MRE_2011, - "qa4mre_2012" : qa4mre.QA4MRE_2012, - "qa4mre_2013" : qa4mre.QA4MRE_2013, - + "qa4mre_2011": qa4mre.QA4MRE_2011, + "qa4mre_2012": qa4mre.QA4MRE_2012, + "qa4mre_2013": qa4mre.QA4MRE_2013, "triviaqa": triviaqa.TriviaQA, "arc_easy": arc.ARCEasy, "arc_challenge": arc.ARCChallenge, @@ -140,7 +134,7 @@ "squad2": squad.SQuAD2, "race": race.RACE, # "naturalqs": naturalqs.NaturalQs, # not implemented yet - "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es + "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa_es": headqa.HeadQAEs, "headqa_en": headqa.HeadQAEn, "mathqa": mathqa.MathQA, @@ -150,21 +144,17 @@ "anli_r1": anli.ANLIRound1, "anli_r2": anli.ANLIRound2, "anli_r3": anli.ANLIRound3, - "ethics_cm": hendrycks_ethics.EthicsCM, "ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_virtue": hendrycks_ethics.EthicsVirtue, - - "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, - "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, - + "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, + "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, # dialogue "mutual": mutual.MuTual, "mutual_plus": mutual.MuTualPlus, - # math "math_algebra": hendrycks_math.MathAlgebra, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability, @@ -175,7 +165,6 @@ "math_precalc": hendrycks_math.MathPrecalculus, "math_asdiv": asdiv.Asdiv, "gsm8k": gsm8k.GradeSchoolMath8K, - # arithmetic "arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus, @@ -189,22 +178,18 @@ "arithmetic_1dc": arithmetic.Arithmetic1DComposite, # TODO Perhaps make these groups of tasks # e.g. anli, arithmetic, openai_translations, harness_translations - # hendrycksTest (57 tasks) **hendrycks_test.create_all_tasks(), - # e.g. wmt14-fr-en **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), # chef's selection, mostly wmt20 **translation.create_tasks_from_benchmarks(selected_translation_benchmarks), - # Word Scrambling and Manipulation Tasks "anagrams1": unscramble.Anagrams1, "anagrams2": unscramble.Anagrams2, "cycle_letters": unscramble.CycleLetters, "random_insertion": unscramble.RandomInsertion, "reversed_words": unscramble.ReversedWords, - # Pile "pile_arxiv": pile.PileArxiv, "pile_books3": pile.PileBooks3, @@ -228,7 +213,6 @@ "pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_wikipedia": pile.PileWikipedia, "pile_youtubesubtitles": pile.PileYoutubeSubtitles, - # BLiMP "blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, @@ -297,7 +281,6 @@ "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, - # Requires manual download of data. # "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2018": storycloze.StoryCloze2018, @@ -321,19 +304,43 @@ def get_task_name_from_object(task_object): for name, class_ in TASK_REGISTRY.items(): if class_ is task_object: return name - + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting - return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ + return ( + task_object.EVAL_HARNESS_NAME + if hasattr(task_object, "EVAL_HARNESS_NAME") + else type(task_object).__name__ + ) def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): task_name_dict = { task_name: get_task(task_name)() - for task_name in task_name_list if isinstance(task_name, str) + for task_name in task_name_list + if isinstance(task_name, str) } task_name_from_object_dict = { get_task_name_from_object(task_object): task_object - for task_object in task_name_list if not isinstance(task_object, str) + for task_object in task_name_list + if not isinstance(task_object, str) } assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) return {**task_name_dict, **task_name_from_object_dict} + + +def get_task_dict_promptsource(task_name_list: List[str]): + """Loads a task instance for each prompt written for that task.""" + task_name_dict = {} + + for task_name in task_name_list: + assert isinstance(task_name, str) + task_prompts = DatasetTemplates(task_name) + + for prompt_name in task_prompts.all_template_names: + prompt = task_prompts[prompt_name] + # NOTE: We choose a sep that can be easily split. + task_name_dict[f"{task_name}+{prompt_name}"] = get_task(task_name)( + prompt=prompt + ) + + return task_name_dict diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index f6c9983384..f12c2c36de 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -51,44 +51,22 @@ def validation_docs(self): def test_docs(self): pass - def doc_to_text(self, doc): - # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} - # and a question qi, the task is to predict the answer ai - doc_text = doc["story"] + '\n\n' - for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai - question = f"Q: {q}\n\n" - answer = f"A: {a}\n\n" if a is not None else "A:" - doc_text += question + answer - return doc_text - - @classmethod - def get_answers(cls, doc, turn_id): - # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). - answers = [] - answer_forturn = doc["answers"]["input_text"][turn_id - 1] - answers.append(answer_forturn) - - additional_answers = doc.get("additional_answers") - if additional_answers: - for key in additional_answers: - additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1] - if additional_answer_for_turn.lower() not in map(str.lower, answers): - answers.append(additional_answer_for_turn) - return answers - - @classmethod - def get_answer_choice(self, raw_text): - # Function maps answers to CoQA answer categories - # ~ 1/5 of the CoQA answers are Yes/No - # ~ 2/3 of the CoQA answers are span-based - # (answers overlap with the passage ignoring punctuation and case mismatch) - if raw_text == "unknown": - return '0' - if squad_metrics.normalize_answer(raw_text) == "yes": - return '1' - if squad_metrics.normalize_answer(raw_text) == "no": - return '2' - return '3' # Not a yes/no question + # @classmethod + # def get_answers(cls, doc, turn_id): + # # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). + # answers = [] + # answer_forturn = doc["answers"]["input_text"][turn_id - 1] + # answers.append(answer_forturn) + + # additional_answers = doc.get("additional_answers") + # if additional_answers: + # for key in additional_answers: + # additional_answer_for_turn = additional_answers[key]["input_text"][ + # turn_id - 1 + # ] + # if additional_answer_for_turn.lower() not in map(str.lower, answers): + # answers.append(additional_answer_for_turn) + # return answers @staticmethod def compute_scores(gold_list, pred): @@ -98,40 +76,38 @@ def compute_scores(gold_list, pred): em_sum = 0.0 if len(gold_list) > 1: for i in range(len(gold_list)): - gold_answers = gold_list[0:i] + gold_list[i + 1:] + gold_answers = gold_list[0:i] + gold_list[i + 1 :] # predictions compared against (n) golds and take maximum - em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) + em_sum += max( + squad_metrics.compute_exact(a, pred) for a in gold_answers + ) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) else: em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list) - return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} - - def doc_to_target(self, doc, turnid=None): - # Default to prediction of last turn. - if turnid is None: - turnid = len(doc["questions"]["input_text"]) - raw_text = doc['answers']["input_text"][turnid - 1] - return " " + raw_text + return { + "em": em_sum / max(1, len(gold_list)), + "f1": f1_sum / max(1, len(gold_list)), + } def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of + """Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LM. :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param ctx: str - The context string, generated by fewshot_context. This includes the natural + The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question - part of the document for `doc`. + part of the document for `doc`. """ - cont_request = rf.greedy_until(ctx, ['\nQ:']) + cont_request = rf.greedy_until(ctx, ["\nQ:"]) return cont_request def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -139,15 +115,18 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - turn_id = len(doc["questions"]["input_text"]) - gold_list = self.get_answers(doc, turn_id) - pred = results[0].strip().split('\n')[0] + target = self.doc_to_target(doc).strip() + pred = results[0].strip().split("\n")[0] + + # turn_id = len(doc["questions"]["input_text"]) + # gold_list = self.get_answers(doc, turn_id) - scores = self.compute_scores(gold_list, pred) + # TODO: Add HF metrics mapped from promptsource metadata. + scores = self.compute_scores([target], pred) return { - "f1": scores['f1'], - "em": scores['em'], + "f1": scores["f1"], + "em": scores["em"], } def higher_is_better(self): diff --git a/lm_eval/tasks/drop.py b/lm_eval/tasks/drop.py index 82ca790d81..c179a23aff 100644 --- a/lm_eval/tasks/drop.py +++ b/lm_eval/tasks/drop.py @@ -70,21 +70,26 @@ def _process_doc(self, doc): @classmethod def get_answers(cls, qa): def _flatten_validated_answers(validated_answers): - """ Flattens a dict of lists of validated answers. + """Flattens a dict of lists of validated answers. {"number": ['1', '8'], ...} -> [{"number": ['1'], ...}, {"number": ['8'], ...}] """ vas = [] for i in range(len(validated_answers["number"])): - vas.append({ - "number": validated_answers["number"][i], - "date": validated_answers["date"][i], - "spans": validated_answers["spans"][i], - }) + vas.append( + { + "number": validated_answers["number"][i], + "date": validated_answers["date"][i], + "spans": validated_answers["spans"][i], + } + ) return vas + answers = [] answers_set = set() - candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"]) + candidates = [qa["answer"]] + _flatten_validated_answers( + qa["validated_answers"] + ) for candidate in candidates: answer = cls.parse_answer(candidate) if answer in answers_set: @@ -100,15 +105,17 @@ def parse_answer(cls, answer): return (str(answer["number"]),) if answer["spans"] != []: return tuple(answer["spans"]) - return (" ".join([answer["date"]["day"], - answer["date"]["month"], - answer["date"]["year"]]).strip(),) + return ( + " ".join( + [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]] + ).strip(), + ) - def doc_to_text(self, doc): - return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" + # def doc_to_text(self, doc): + # return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" - def doc_to_target(self, doc): - return " " + ", ".join(doc["answers"][0]) + # def doc_to_target(self, doc): + # return " " + ", ".join(doc["answers"][0]) def construct_requests(self, doc, ctx): """Uses RequestFactory to construct Requests and returns an iterable of @@ -134,7 +141,13 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - preds, golds = results, doc["answers"] + + pred = results[0].strip() + target = self.doc_to_target(doc).strip() + + preds = [pred] + golds = [target] + max_em = 0 max_f1 = 0 for gold_answer in golds: @@ -142,10 +155,7 @@ def process_results(self, doc, results): if gold_answer[0].strip(): max_em = max(max_em, exact_match) max_f1 = max(max_f1, f1_score) - return { - "em": max_em, - "f1": max_f1 - } + return {"em": max_em, "f1": max_f1} def get_metrics(self, predicted, gold): """ @@ -158,7 +168,9 @@ def get_metrics(self, predicted, gold): predicted_bags = self._answer_to_bags(predicted) gold_bags = self._answer_to_bags(gold) - if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): + if set(predicted_bags[0]) == set(gold_bags[0]) and len( + predicted_bags[0] + ) == len(gold_bags[0]): exact_match = 1.0 else: exact_match = 0.0 @@ -190,7 +202,9 @@ def _align_bags(self, predicted, gold): for gold_index, gold_item in enumerate(gold): for pred_index, pred_item in enumerate(predicted): if self._match_numbers_if_present(gold_item, pred_item): - scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item) + scores[gold_index, pred_index] = self._compute_f1( + pred_item, gold_item + ) row_ind, col_ind = linear_sum_assignment(-scores) max_scores = np.zeros([max(len(gold), len(predicted))]) @@ -256,7 +270,11 @@ def _tokenize(self, text): def _normalize(self, answer): tokens = [ - self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower())))) + self._white_space_fix( + self._remove_articles( + self._fix_number(self._remove_punc(token.lower())) + ) + ) for token in self._tokenize(answer) ] tokens = [token for token in tokens if token.strip()] @@ -269,10 +287,7 @@ def aggregation(self): A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metrics """ - return { - "em": mean, - "f1": mean - } + return {"em": mean, "f1": mean} def higher_is_better(self): """ @@ -280,7 +295,4 @@ def higher_is_better(self): A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ - return { - "em": True, - "f1": True - } + return {"em": True, "f1": True} diff --git a/lm_eval/tasks/race.py b/lm_eval/tasks/race.py index f19211793a..cd4e5490fd 100644 --- a/lm_eval/tasks/race.py +++ b/lm_eval/tasks/race.py @@ -40,7 +40,7 @@ class RACE(Task): DATASET_NAME = "high" cache = {} - letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} + letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3} def has_training_docs(self): return True @@ -59,17 +59,27 @@ def _collate_data(self, set): # is shown that one document is made per passage. r = collections.defaultdict(list) - for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]: - r[item['article']].append(item) - - res = list(r.values() >> each(lambda x: { - 'article': x[0]['article'], - 'problems': x >> each(lambda y: { - 'question': y['question'], - 'answer': y['answer'], - 'options': y['options'], - }) - })) + for item in datasets.load_dataset( + path=self.DATASET_PATH, name=self.DATASET_NAME + )[set]: + r[item["article"]].append(item) + + res = list( + r.values() + >> each( + lambda x: { + "article": x[0]["article"], + "problems": x + >> each( + lambda y: { + "question": y["question"], + "answer": y["answer"], + "options": y["options"], + } + ), + } + ) + ) self.cache[set] = res return res @@ -85,49 +95,48 @@ def test_docs(self): @classmethod def get_answer_option(cls, problem): - answer = cls.letter_to_num[problem['answer']] - return problem['options'][answer] + answer = cls.letter_to_num[problem["answer"]] + return problem["options"][answer] @classmethod def last_problem(cls, doc): - return doc['problems'][-1] - - def doc_to_text(self, doc): - text = 'Article: ' + doc['article'] + '\n\n' - for problem in doc['problems'][:-1]: - if problem['question'][-6:] == ' _ .': - text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' - else: - question = 'Question: ' + problem['question'] + '\n' - answer = 'Answer: ' + self.get_answer_option(problem) + '\n' - text += question + answer - text += self.last_problem(doc)['question'] - return text - - def doc_to_target(self, doc): - return " " + self.get_answer_option(self.last_problem(doc)) - - def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - problem = self.last_problem(doc) - ll_choices = [ - rf.loglikelihood(ctx, " " + problem['options'][i])[0] - for i in range(4) - ] - return ll_choices + return doc["problems"][-1] + + # def doc_to_text(self, doc): + # text = 'Article: ' + doc['article'] + '\n\n' + # for problem in doc['problems'][:-1]: + # if problem['question'][-6:] == ' _ .': + # text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' + # else: + # question = 'Question: ' + problem['question'] + '\n' + # answer = 'Answer: ' + self.get_answer_option(problem) + '\n' + # text += question + answer + # text += self.last_problem(doc)['question'] + # return text + + # def doc_to_target(self, doc): + # return " " + self.get_answer_option(self.last_problem(doc)) + + # def construct_requests(self, doc, ctx): + # """Uses RequestFactory to construct Requests and returns an iterable of + # Requests which will be sent to the LM. + + # :param doc: + # The document as returned from training_docs, validation_docs, or test_docs. + # :param ctx: str + # The context string, generated by fewshot_context. This includes the natural + # language description, as well as the few shot examples, and the question + # part of the document for `doc`. + # """ + # problem = self.last_problem(doc) + # ll_choices = [ + # rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4) + # ] + # return ll_choices def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -135,28 +144,24 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - gold = self.letter_to_num[self.last_problem(doc)['answer']] + # + gold = self.letter_to_num[self.doc_to_target(doc)] + # gold = self.letter_to_num[self.last_problem(doc)["answer"]] pred = np.argmax(results) - return { - "acc": int(pred == gold) - } + return {"acc": int(pred == gold)} def aggregation(self): """ :returns: {str: [float] -> float} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metrics """ - return { - "acc": mean - } + return {"acc": mean} def higher_is_better(self): """ :returns: {str: bool} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ - return { - "acc": True - } + return {"acc": True} From 7d282b5f11e8424842fce36a1ebb51f76006f7ae Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 15:40:33 -0400 Subject: [PATCH 02/40] Add PromptSourceTask to the updated tasks --- lm_eval/tasks/coqa.py | 4 ++-- lm_eval/tasks/drop.py | 4 ++-- lm_eval/tasks/race.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index f12c2c36de..58152c75ff 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -12,7 +12,7 @@ import inspect import transformers.data.metrics.squad_metrics as squad_metrics import lm_eval.datasets.coqa.coqa -from lm_eval.base import Task, rf, mean +from lm_eval.base import PromptSourceTask, rf, mean from itertools import zip_longest @@ -28,7 +28,7 @@ """ -class CoQA(Task): +class CoQA(PromptSourceTask): VERSION = 1 DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa) DATASET_NAME = None diff --git a/lm_eval/tasks/drop.py b/lm_eval/tasks/drop.py index c179a23aff..a075d2c68e 100644 --- a/lm_eval/tasks/drop.py +++ b/lm_eval/tasks/drop.py @@ -18,7 +18,7 @@ import string import lm_eval.datasets.drop.drop from scipy.optimize import linear_sum_assignment -from lm_eval.base import Task, rf +from lm_eval.base import PromptSourceTask, rf from lm_eval.metrics import mean @@ -37,7 +37,7 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) -class DROP(Task): +class DROP(PromptSourceTask): VERSION = 1 DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop) DATASET_NAME = None diff --git a/lm_eval/tasks/race.py b/lm_eval/tasks/race.py index cd4e5490fd..2dc23fa1b4 100644 --- a/lm_eval/tasks/race.py +++ b/lm_eval/tasks/race.py @@ -12,7 +12,7 @@ import collections import datasets import numpy as np -from lm_eval.base import rf, Task +from lm_eval.base import PromptSourceTask, rf from lm_eval.metrics import mean @@ -34,7 +34,7 @@ def __rrshift__(self, other): return list(map(self.f, other)) -class RACE(Task): +class RACE(PromptSourceTask): VERSION = 1 DATASET_PATH = "race" DATASET_NAME = "high" From 9484eecc419af07c23359825eb6ffef5b7ba70b4 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 15:58:28 -0400 Subject: [PATCH 03/40] Fix coqa --- lm_eval/base.py | 8 +++++--- lm_eval/evaluator.py | 39 ++++++++++++++++++++----------------- lm_eval/tasks/arithmetic.py | 5 +++-- lm_eval/tasks/coqa.py | 2 +- 4 files changed, 30 insertions(+), 24 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index c0207028a9..af98157172 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -1,6 +1,7 @@ import abc from typing import Iterable +import promptsource import numpy as np import random import re @@ -639,11 +640,12 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=Non self.prompt = prompt def doc_to_target(self, doc): - _, target = prompt.apply(doc) + _, target = self.prompt.apply(doc) return f" {target}" def doc_to_text(self, doc): - text, _ = prompt.apply(doc) + print(doc) + text, _ = self.prompt.apply(doc) return text def construct_requests(self, doc, ctx): @@ -660,7 +662,7 @@ def construct_requests(self, doc, ctx): _requests = [] if self.prompt.metadata.choices_in_prompt: - for answer_choice in prompt.get_fixed_answer_choices_list(): + for answer_choice in self.prompt.get_fixed_answer_choices_list(): ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") _requests.append(ll_answer_choice) else: diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 66054cb210..57cf1a9d5b 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -169,8 +169,10 @@ def evaluate( docs = {} # get lists of each type of request - for task_name, task in task_dict_items: - versions[task_name] = task.VERSION + for task_prompt_name, task in task_dict_items: + print(f"TASK PROMPT NAME: {task_prompt_name}") + + versions[task_prompt_name] = task.VERSION # default to test doc, fall back to val doc if validation unavailable # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point if task.has_test_docs(): @@ -187,13 +189,13 @@ def evaluate( rnd.shuffle(task_docs) description = ( - description_dict[task_name] - if description_dict and task_name in description_dict + description_dict[task_prompt_name] + if description_dict and task_prompt_name in description_dict else "" ) for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): - docs[(task_name, doc_id)] = doc + docs[(task_prompt_name, doc_id)] = doc ctx = task.fewshot_context( doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description ) @@ -204,7 +206,7 @@ def evaluate( requests[req.request_type].append(req) # i: index in requests for a single task instance # doc_id: unique id that we can get back to a doc using `docs` - requests_origin[req.request_type].append((i, task_name, doc, doc_id)) + requests_origin[req.request_type].append((i, task_prompt_name, doc, doc_id)) # all responses for each (task, doc) process_res_queue = collections.defaultdict(list) @@ -222,32 +224,33 @@ def evaluate( x if req.index is None else x[req.index] for x, req in zip(resps, reqs) ] - for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): - process_res_queue[(task_name, doc_id)].append((i, resp)) + for resp, (i, task_prompt_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): + process_res_queue[(task_prompt_name, doc_id)].append((i, resp)) vals = collections.defaultdict(list) # unpack results and sort back in order and return control to Task - for (task_name, doc_id), requests in process_res_queue.items(): + for (task_prompt_name, doc_id), requests in process_res_queue.items(): requests.sort(key=lambda x: x[0]) requests = [x[1] for x in requests] - task = task_dict[task_name] - doc = docs[(task_name, doc_id)] + task = task_dict[task_prompt_name] + doc = docs[(task_prompt_name, doc_id)] metrics = task.process_results(doc, requests) for metric, value in metrics.items(): - vals[(task_name, metric)].append(value) + vals[(task_prompt_name, metric)].append(value) + - task_name, prompt_name = task_name.split("+") - results[task_name]["task_name"] = task_name - results[task_name]["prompt_name"] = prompt_name # aggregate results - for (task_name, metric), items in vals.items(): + for (task_prompt_name, metric), items in vals.items(): + task_name, prompt_name = task_prompt_name.split("+") + results[task_prompt_name]["task_name"] = task_name + results[task_prompt_name]["prompt_name"] = prompt_name task = task_dict[task_name] - results[task_name][metric] = task.aggregation()[metric](items) + results[task_prompt_name][metric] = task.aggregation()[metric](items) # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # so we run them less iterations. still looking for a cleaner way to do this @@ -258,7 +261,7 @@ def evaluate( else bootstrap_iters, ) if stderr is not None: - results[task_name][metric + "_stderr"] = stderr(items) + results[task_prompt_name][metric + "_stderr"] = stderr(items) return {"results": dict(results), "versions": dict(versions)} diff --git a/lm_eval/tasks/arithmetic.py b/lm_eval/tasks/arithmetic.py index 8f783feb23..21da60a79e 100644 --- a/lm_eval/tasks/arithmetic.py +++ b/lm_eval/tasks/arithmetic.py @@ -58,10 +58,11 @@ def doc_to_target(self, doc): def construct_requests(self, doc, ctx): ll, is_prediction = rf.loglikelihood(ctx, doc["completion"]) - return is_prediction + return ll, is_prediction def process_results(self, doc, results): - is_prediction, = results + print(results) + results = results return { "acc": is_prediction } diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index 58152c75ff..3feb0c6f1d 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -12,7 +12,7 @@ import inspect import transformers.data.metrics.squad_metrics as squad_metrics import lm_eval.datasets.coqa.coqa -from lm_eval.base import PromptSourceTask, rf, mean +from lm_eval.base import PromptSourceTask, Task, rf, mean from itertools import zip_longest From 9f38846153ddade1844ca5b6bbaf3a43268fbd85 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 16:26:49 -0400 Subject: [PATCH 04/40] Fix task name to template creation --- lm_eval/tasks/__init__.py | 66 +++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index a68b7cab4f..af8d08957a 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -1,5 +1,4 @@ from promptsource.templates import DatasetTemplates - from pprint import pprint from typing import List, Union @@ -60,8 +59,8 @@ # 6 total gpt3_translation_benchmarks = { - "wmt14": ["en-fr", "fr-en"], # French - "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian + "wmt14": ['en-fr', 'fr-en'], # French + "wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian } @@ -69,7 +68,7 @@ selected_translation_benchmarks = { **gpt3_translation_benchmarks, "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), - "iwslt17": ["en-ar", "ar-en"], # Arabic + "iwslt17": ['en-ar', 'ar-en'] # Arabic } # 319 total @@ -93,7 +92,7 @@ "rte": glue.RTE, "qnli": glue.QNLI, "qqp": glue.QQP, - # "stsb": glue.STSB, # not implemented yet + #"stsb": glue.STSB, # not implemented yet "sst": glue.SST, "wnli": glue.WNLI, # SuperGLUE @@ -104,26 +103,34 @@ "record": superglue.ReCoRD, "wic": superglue.WordsInContext, "wsc": superglue.SGWinogradSchemaChallenge, + # Order by benchmark/genre? "coqa": coqa.CoQA, "drop": drop.DROP, "lambada": lambada.LAMBADA, "lambada_cloze": lambada_cloze.LAMBADA_cloze, + # multilingual lambada **lambada_multilingual.construct_tasks(), + "wikitext": wikitext.WikiText, # "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix + "piqa": piqa.PiQA, "prost": prost.PROST, "mc_taco": mc_taco.MCTACO, + # Science related - "pubmedqa": pubmedqa.Pubmed_QA, - "sciq": sciq.SciQ, + "pubmedqa" : pubmedqa.Pubmed_QA, + "sciq" : sciq.SciQ, + "qasper": qasper.QASPER, - "qa4mre_2011": qa4mre.QA4MRE_2011, - "qa4mre_2012": qa4mre.QA4MRE_2012, - "qa4mre_2013": qa4mre.QA4MRE_2013, + + "qa4mre_2011" : qa4mre.QA4MRE_2011, + "qa4mre_2012" : qa4mre.QA4MRE_2012, + "qa4mre_2013" : qa4mre.QA4MRE_2013, + "triviaqa": triviaqa.TriviaQA, "arc_easy": arc.ARCEasy, "arc_challenge": arc.ARCChallenge, @@ -134,7 +141,7 @@ "squad2": squad.SQuAD2, "race": race.RACE, # "naturalqs": naturalqs.NaturalQs, # not implemented yet - "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es + "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa_es": headqa.HeadQAEs, "headqa_en": headqa.HeadQAEn, "mathqa": mathqa.MathQA, @@ -144,17 +151,21 @@ "anli_r1": anli.ANLIRound1, "anli_r2": anli.ANLIRound2, "anli_r3": anli.ANLIRound3, + "ethics_cm": hendrycks_ethics.EthicsCM, "ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_virtue": hendrycks_ethics.EthicsVirtue, - "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, - "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, + + "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, + "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, + # dialogue "mutual": mutual.MuTual, "mutual_plus": mutual.MuTualPlus, + # math "math_algebra": hendrycks_math.MathAlgebra, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability, @@ -165,6 +176,7 @@ "math_precalc": hendrycks_math.MathPrecalculus, "math_asdiv": asdiv.Asdiv, "gsm8k": gsm8k.GradeSchoolMath8K, + # arithmetic "arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus, @@ -178,18 +190,22 @@ "arithmetic_1dc": arithmetic.Arithmetic1DComposite, # TODO Perhaps make these groups of tasks # e.g. anli, arithmetic, openai_translations, harness_translations + # hendrycksTest (57 tasks) **hendrycks_test.create_all_tasks(), + # e.g. wmt14-fr-en **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), # chef's selection, mostly wmt20 **translation.create_tasks_from_benchmarks(selected_translation_benchmarks), + # Word Scrambling and Manipulation Tasks "anagrams1": unscramble.Anagrams1, "anagrams2": unscramble.Anagrams2, "cycle_letters": unscramble.CycleLetters, "random_insertion": unscramble.RandomInsertion, "reversed_words": unscramble.ReversedWords, + # Pile "pile_arxiv": pile.PileArxiv, "pile_books3": pile.PileBooks3, @@ -213,6 +229,7 @@ "pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_wikipedia": pile.PileWikipedia, "pile_youtubesubtitles": pile.PileYoutubeSubtitles, + # BLiMP "blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, @@ -281,6 +298,7 @@ "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, + # Requires manual download of data. # "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2018": storycloze.StoryCloze2018, @@ -304,25 +322,19 @@ def get_task_name_from_object(task_object): for name, class_ in TASK_REGISTRY.items(): if class_ is task_object: return name - + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting - return ( - task_object.EVAL_HARNESS_NAME - if hasattr(task_object, "EVAL_HARNESS_NAME") - else type(task_object).__name__ - ) + return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): task_name_dict = { task_name: get_task(task_name)() - for task_name in task_name_list - if isinstance(task_name, str) + for task_name in task_name_list if isinstance(task_name, str) } task_name_from_object_dict = { get_task_name_from_object(task_object): task_object - for task_object in task_name_list - if not isinstance(task_object, str) + for task_object in task_name_list if not isinstance(task_object, str) } assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) return {**task_name_dict, **task_name_from_object_dict} @@ -334,8 +346,14 @@ def get_task_dict_promptsource(task_name_list: List[str]): for task_name in task_name_list: assert isinstance(task_name, str) - task_prompts = DatasetTemplates(task_name) + # Static version of the Task Use this to get HF dataset path / name. + static_task_obj = get_task(task_name) + # Create the proper task name arg for DatasetTemplates. + sub_task = f"/{static_task_obj.DATASET_NAME}" if static_task_obj.DATASET_NAME else "" + ps_task_name = f"{static_task_obj.DATASET_PATH}{sub_task}" + + task_prompts = DatasetTemplates(ps_task_name) for prompt_name in task_prompts.all_template_names: prompt = task_prompts[prompt_name] # NOTE: We choose a sep that can be easily split. From 2bfa451846844599af80b60215bfbe580650fb5d Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 18:15:58 -0400 Subject: [PATCH 05/40] Fix prompt source rank choice accuracy --- lm_eval/base.py | 28 +++++++++-------- lm_eval/evaluator.py | 9 ++---- lm_eval/tasks/coqa.py | 3 +- lm_eval/tasks/glue.py | 67 +++++++++-------------------------------- lm_eval/tasks/race.py | 70 +++++++++++++++++++++---------------------- scripts/write_out.py | 2 +- 6 files changed, 69 insertions(+), 110 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index af98157172..bc030675b7 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -644,7 +644,6 @@ def doc_to_target(self, doc): return f" {target}" def doc_to_text(self, doc): - print(doc) text, _ = self.prompt.apply(doc) return text @@ -661,13 +660,14 @@ def construct_requests(self, doc, ctx): """ _requests = [] - if self.prompt.metadata.choices_in_prompt: - for answer_choice in self.prompt.get_fixed_answer_choices_list(): + answer_choices_list = self.prompt.get_answer_choices_list(doc) + if answer_choices_list: + for answer_choice in answer_choices_list: ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") _requests.append(ll_answer_choice) else: # TODO(Albert): What is the stop symbol? Is it model specific? - ll_greedy, _ = rf.greedy_until(ctx, ["\nQ:"]) + ll_greedy = rf.greedy_until(ctx, ["\nQ:"]) _requests.append(ll_greedy) return _requests @@ -682,20 +682,22 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - raise NotImplementedError( - "Implement process results using the `prompt.metadata.metrics`. See below." - ) - if self.prompt.metadata.choices_in_prompt: - for result, answer_choice in zip( - prompt.get_fixed_answer_choices_list(), results - ): - pass + # raise NotImplementedError( + # "Implement process results using the `prompt.metadata.metrics`. See below." + # ) + target = self.doc_to_target(doc).strip() + answer_choices_list = self.prompt.get_answer_choices_list(doc) + if answer_choices_list: + pred = answer_choices_list[np.argmax(results)] + return { + "acc": pred == target + } else: continuation = results # Map metric name to HF metric. # TODO(Albert): What is Other? - metric_names = prompt.metadata.metrics + #metric_names = prompt.metadata.metrics class MultipleChoiceTask(Task): diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 57cf1a9d5b..1a9fee8e97 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -241,15 +241,12 @@ def evaluate( for metric, value in metrics.items(): vals[(task_prompt_name, metric)].append(value) - - # aggregate results for (task_prompt_name, metric), items in vals.items(): task_name, prompt_name = task_prompt_name.split("+") results[task_prompt_name]["task_name"] = task_name results[task_prompt_name]["prompt_name"] = prompt_name - task = task_dict[task_name] - + task = task_dict[task_prompt_name] results[task_prompt_name][metric] = task.aggregation()[metric](items) # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap @@ -276,13 +273,13 @@ def make_table(result_dict): latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] values = [] - for k, dic in result_dict["results"].items(): version = result_dict["versions"][k] for m, v in dic.items(): if m.endswith("_stderr"): continue - + if "_name" in m: + continue if m + "_stderr" in dic: se = dic[m + "_stderr"] values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se]) diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index 3feb0c6f1d..8a6cba72c8 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -30,7 +30,7 @@ class CoQA(PromptSourceTask): VERSION = 1 - DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa) + DATASET_PATH = "coqa" DATASET_NAME = None def has_training_docs(self): @@ -57,7 +57,6 @@ def test_docs(self): # answers = [] # answer_forturn = doc["answers"]["input_text"][turn_id - 1] # answers.append(answer_forturn) - # additional_answers = doc.get("additional_answers") # if additional_answers: # for key in additional_answers: diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index 410396d462..fdb49b8d29 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -14,7 +14,7 @@ Homepage: https://gluebenchmark.com/ """ import numpy as np -from lm_eval.base import rf, Task +from lm_eval.base import PromptSourceTask, rf, Task from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno from lm_eval.utils import general_detokenize @@ -286,7 +286,7 @@ def aggregation(self): } -class WNLI(Task): +class WNLI(PromptSourceTask): VERSION = 1 DATASET_PATH = "glue" DATASET_NAME = "wnli" @@ -301,37 +301,14 @@ def has_test_docs(self): return False def training_docs(self): - if self._training_docs is None: - self._training_docs = list(self.dataset["train"]) - return self._training_docs + # if self._training_docs is None: + # self._training_docs = list() + # return self._training_docs + return self.dataset["train"] def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: {} True or False?\nAnswer:".format( - doc["sentence1"], - doc["sentence2"], - ) - - def doc_to_target(self, doc): - # True = entailment - # False = not_entailment - return " {}".format({0: "False", 1: "True"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_false - - def process_results(self, doc, results): - ll_true, ll_false = results - pred = ll_true > ll_false - gold = doc["label"] - return { - "acc": pred == gold - } - def higher_is_better(self): return { "acc": True @@ -343,7 +320,7 @@ def aggregation(self): } -class RTE(Task): +class RTE(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "rte" @@ -365,29 +342,13 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: {} True or False?\nAnswer:".format( - doc["sentence1"], - doc["sentence2"], - ) - - def doc_to_target(self, doc): - # 0 = entailment - # 1 = not_entailment - return " {}".format({0: "True", 1: "False"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_false - - def process_results(self, doc, results): - ll_true, ll_false = results - pred = ll_false > ll_true - gold = doc["label"] - return { - "acc": pred == gold - } + # def process_results(self, doc, results): + # ll_true, ll_false = results + # pred = ll_false > ll_true + # gold = doc["label"] + # return { + # "acc": pred == gold + # } def higher_is_better(self): return { diff --git a/lm_eval/tasks/race.py b/lm_eval/tasks/race.py index 2dc23fa1b4..3645f357ab 100644 --- a/lm_eval/tasks/race.py +++ b/lm_eval/tasks/race.py @@ -51,47 +51,47 @@ def has_validation_docs(self): def has_test_docs(self): return True - def _collate_data(self, set): - if set in self.cache: - return self.cache[set] - # One big issue with HF's implementation of this dataset: it makes a - # separate document for each question; meanwhile, in the GPT3 paper it - # is shown that one document is made per passage. - - r = collections.defaultdict(list) - for item in datasets.load_dataset( - path=self.DATASET_PATH, name=self.DATASET_NAME - )[set]: - r[item["article"]].append(item) - - res = list( - r.values() - >> each( - lambda x: { - "article": x[0]["article"], - "problems": x - >> each( - lambda y: { - "question": y["question"], - "answer": y["answer"], - "options": y["options"], - } - ), - } - ) - ) - - self.cache[set] = res - return res + # def _collate_data(self, set): + # if set in self.cache: + # return self.cache[set] + # # One big issue with HF's implementation of this dataset: it makes a + # # separate document for each question; meanwhile, in the GPT3 paper it + # # is shown that one document is made per passage. + + # r = collections.defaultdict(list) + # for item in datasets.load_dataset( + # path=self.DATASET_PATH, name=self.DATASET_NAME + # )[set]: + # r[item["article"]].append(item) + + # res = list( + # r.values() + # >> each( + # lambda x: { + # "article": x[0]["article"], + # "problems": x + # >> each( + # lambda y: { + # "question": y["question"], + # "answer": y["answer"], + # "options": y["options"], + # } + # ), + # } + # ) + # ) + + # self.cache[set] = res + # return res def training_docs(self): - return self._collate_data("train") + return self.dataset["train"] def validation_docs(self): - return self._collate_data("validation") + return self.dataset["validation"] def test_docs(self): - return self._collate_data("test") + return self.dataset["test"] @classmethod def get_answer_option(cls, problem): diff --git a/scripts/write_out.py b/scripts/write_out.py index 2039d3934f..0f1ee354e8 100644 --- a/scripts/write_out.py +++ b/scripts/write_out.py @@ -30,7 +30,7 @@ def main(): task_names = tasks.ALL_TASKS else: task_names = args.tasks.split(",") - task_dict = tasks.get_task_dict(task_names) + task_dict = tasks.get_task_dict_promptsource(task_names) description_dict = {} if args.description_dict_path: From 34f591afde1e30632e38604887ce9df89b149c92 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 19:59:39 -0400 Subject: [PATCH 06/40] Add multiple tasks --- lm_eval/tasks/__init__.py | 2 + lm_eval/tasks/anli.py | 49 +--------- lm_eval/tasks/coqa.py | 34 ++++--- lm_eval/tasks/drop.py | 87 ++++++----------- lm_eval/tasks/glue.py | 110 +++------------------- lm_eval/tasks/superglue.py | 186 ++----------------------------------- 6 files changed, 81 insertions(+), 387 deletions(-) diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index af8d08957a..d07886b7df 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -52,6 +52,7 @@ from . import asdiv from . import gsm8k from . import storycloze +from . import e2e_nlg_cleaned ######################################## # Translation tasks @@ -124,6 +125,7 @@ # Science related "pubmedqa" : pubmedqa.Pubmed_QA, "sciq" : sciq.SciQ, + "e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned, "qasper": qasper.QASPER, diff --git a/lm_eval/tasks/anli.py b/lm_eval/tasks/anli.py index 2f7a763b92..59eee05f9c 100644 --- a/lm_eval/tasks/anli.py +++ b/lm_eval/tasks/anli.py @@ -10,7 +10,7 @@ Homepage: "https://github.com/facebookresearch/anli" """ import numpy as np -from lm_eval.base import rf, Task +from lm_eval.base import rf, PromptSourceTask from lm_eval.metrics import mean @@ -30,7 +30,7 @@ """ -class ANLIBase(Task): +class ANLIBase(PromptSourceTask): VERSION = 0 DATASET_PATH = "anli" DATASET_NAME = None @@ -59,51 +59,6 @@ def test_docs(self): if self.has_test_docs(): return self.dataset["test_r" + str(self.SPLIT)] - def doc_to_text(self, doc): - # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning - # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly - # appended onto the question, with no "Answer:" or even a newline. Do we *really* - # want to do it exactly as OA did? - return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:' - - def doc_to_target(self, doc): - # True = entailment - # False = contradiction - # Neither = neutral - return " " + ["True", "Neither", "False"][doc['label']] - - def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_neither, _ = rf.loglikelihood(ctx, " Neither") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_neither, ll_false - - def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of - the metric for that one document - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param results: - The results of the requests created in construct_requests. - """ - gold = doc["label"] - pred = np.argmax(results) - return { - "acc": pred == gold - } - def aggregation(self): """ :returns: {str: [float] -> float} diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index 8a6cba72c8..ee4fa31e50 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -67,6 +67,7 @@ def test_docs(self): # answers.append(additional_answer_for_turn) # return answers + @staticmethod def compute_scores(gold_list, pred): # tests for exact match and on the normalised answer (compute_exact) @@ -90,19 +91,21 @@ def compute_scores(gold_list, pred): "f1": f1_sum / max(1, len(gold_list)), } - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. + def eos_token(self): + return "\n" - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, ["\nQ:"]) - return cont_request + # def construct_requests(self, doc, ctx): + # """Uses RequestFactory to construct Requests and returns an iterable of + # Requests which will be sent to the LM. + + # :param doc: + # The document as returned from training_docs, validation_docs, or test_docs. + # :param ctx: str + # The context string, generated by fewshot_context. This includes the natural + # language description, as well as the few shot examples, and the question + # part of the document for `doc`. + # """ + # return cont_request def process_results(self, doc, results): """Take a single document and the LM results and evaluates, returning a @@ -116,6 +119,13 @@ def process_results(self, doc, results): """ target = self.doc_to_target(doc).strip() pred = results[0].strip().split("\n")[0] + print("*" * 80) + print(f"DOC: {doc}") +# print(f"PS: {self.prompt.apply(doc)}") + print(f"TEXT: {self.doc_to_text(doc)}") + print(f"TARGET: {target} END TARGET") + print(pred) + print("*" * 80) # turn_id = len(doc["questions"]["input_text"]) # gold_list = self.get_answers(doc, turn_id) diff --git a/lm_eval/tasks/drop.py b/lm_eval/tasks/drop.py index a075d2c68e..6e8ce20740 100644 --- a/lm_eval/tasks/drop.py +++ b/lm_eval/tasks/drop.py @@ -39,7 +39,7 @@ class DROP(PromptSourceTask): VERSION = 1 - DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop) + DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop) DATASET_NAME = None def has_training_docs(self): @@ -52,51 +52,13 @@ def has_test_docs(self): return False def training_docs(self): - if self._training_docs is None: - self._training_docs = list(map(self._process_doc, self.dataset["train"])) - return self._training_docs + # if self._training_docs is None: + # self._training_docs = list() + # return self._training_docs + return self.dataset["train"] def validation_docs(self): - return map(self._process_doc, self.dataset["validation"]) - - def _process_doc(self, doc): - return { - "id": doc["query_id"], - "passage": doc["passage"], - "question": doc["question"], - "answers": self.get_answers(doc), - } - - @classmethod - def get_answers(cls, qa): - def _flatten_validated_answers(validated_answers): - """Flattens a dict of lists of validated answers. - {"number": ['1', '8'], ...} - -> [{"number": ['1'], ...}, {"number": ['8'], ...}] - """ - vas = [] - for i in range(len(validated_answers["number"])): - vas.append( - { - "number": validated_answers["number"][i], - "date": validated_answers["date"][i], - "spans": validated_answers["spans"][i], - } - ) - return vas - - answers = [] - answers_set = set() - candidates = [qa["answer"]] + _flatten_validated_answers( - qa["validated_answers"] - ) - for candidate in candidates: - answer = cls.parse_answer(candidate) - if answer in answers_set: - continue - answers_set.add(answer) - answers.append(answer) - return answers + return self.dataset["validation"] @classmethod def parse_answer(cls, answer): @@ -117,19 +79,21 @@ def parse_answer(cls, answer): # def doc_to_target(self, doc): # return " " + ", ".join(doc["answers"][0]) - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - conts = [rf.greedy_until(ctx, ["."])] - return conts + # def construct_requests(self, doc, ctx): + # """Uses RequestFactory to construct Requests and returns an iterable of + # Requests which will be sent to the LM. + + # :param doc: + # The document as returned from training_docs, validation_docs, or test_docs. + # :param ctx: str + # The context string, generated by fewshot_context. This includes the natural + # language description, as well as the few shot examples, and the question + # part of the document for `doc`. + # """ + # conts = [rf.greedy_until(ctx, ["."])] + # return conts + def eos_token(self): + return "." def process_results(self, doc, results): """Take a single document and the LM results and evaluates, returning a @@ -145,6 +109,15 @@ def process_results(self, doc, results): pred = results[0].strip() target = self.doc_to_target(doc).strip() + print("*" * 80) + print(f"DOC: {doc}") + print(f"PS: {self.prompt.apply(doc)}") + print(f"TEXT: {self.doc_to_text(doc)}") + print(f"TARGET: {target} END TARGET") + print(pred) + print("*" * 80) + + preds = [pred] golds = [target] diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index fdb49b8d29..f640ae1759 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -45,7 +45,7 @@ # Single-Sentence Tasks -class CoLA(Task): +class CoLA(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "cola" @@ -67,23 +67,20 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) - - def doc_to_target(self, doc): - return " {}".format({1: "yes", 0: "no"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " yes") - ll_false, _ = rf.loglikelihood(ctx, " no") - return ll_true, ll_false + def process_results(self, doc, results): + answer_choices_list = self.prompt.get_answer_choices_list(doc) + pred = np.argmax(results) + target = answer_choices_list.index(self.doc_to_target(doc).strip()) + print("*" * 80) + print(f"DOC: {doc}") + print(f"TEXT: {self.doc_to_text(doc)}") + print(f"STRING TARGET: {self.doc_to_target(doc)} END TARGET") + print(f"TARGET: {target} END TARGET") + print(f"PRED: {pred}") + print("*" * 80) - def process_results(self, doc, results): - ll_true, ll_false = results - pred = ll_true > ll_false - gold = doc["label"] return { - "mcc": (gold, pred) + "mcc": (target, pred) } def higher_is_better(self): @@ -97,7 +94,7 @@ def aggregation(self): } -class SST(Task): +class SST(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "sst2" @@ -119,27 +116,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format( - general_detokenize(doc["sentence"]), - ) - - def doc_to_target(self, doc): - return " {}".format({1: "positive", 0: "negative"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_positive, _ = rf.loglikelihood(ctx, " positive") - ll_negative, _ = rf.loglikelihood(ctx, " negative") - return ll_positive, ll_negative - - def process_results(self, doc, results): - ll_positive, ll_negative = results - pred = ll_positive > ll_negative - gold = doc["label"] - return { - "acc": pred == gold - } - def higher_is_better(self): return { "acc": True @@ -154,7 +130,7 @@ def aggregation(self): # Inference Tasks -class MNLI(Task): +class MNLI(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "mnli" @@ -181,24 +157,6 @@ def test_docs(self): if self.has_test_docs(): return self.dataset["test_matched"] - def doc_to_text(self, doc): - return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( - doc["premise"], - doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'), - ) - - def doc_to_target(self, doc): - # True = entailment - # False = contradiction - # Neither = neutral - return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, " True") - ll_neither, _ = rf.loglikelihood(ctx, " Neither") - ll_false, _ = rf.loglikelihood(ctx, " False") - return ll_true, ll_neither, ll_false - def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) @@ -251,22 +209,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( - doc["question"], - doc["sentence"], - ) - - def doc_to_target(self, doc): - # True = entailment - # False = not entailment - return " {}".format({0: "yes", 1: "no"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, " yes") - ll_no, _ = rf.loglikelihood(ctx, " no") - return ll_yes, ll_no - def process_results(self, doc, results): ll_yes, ll_no = results pred = ll_no > ll_yes @@ -342,14 +284,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - # def process_results(self, doc, results): - # ll_true, ll_false = results - # pred = ll_false > ll_true - # gold = doc["label"] - # return { - # "acc": pred == gold - # } - def higher_is_better(self): return { "acc": True @@ -386,20 +320,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format( - general_detokenize(doc["sentence1"]), - general_detokenize(doc["sentence2"]), - ) - - def doc_to_target(self, doc): - return " {}".format(yesno(doc["label"])) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, " yes") - ll_no, _ = rf.loglikelihood(ctx, " no") - return ll_yes, ll_no - def process_results(self, doc, results): ll_yes, ll_no = results gold = doc["label"] diff --git a/lm_eval/tasks/superglue.py b/lm_eval/tasks/superglue.py index e4b9bfff6a..1ea6edde6b 100644 --- a/lm_eval/tasks/superglue.py +++ b/lm_eval/tasks/superglue.py @@ -12,7 +12,7 @@ import numpy as np import sklearn import transformers.data.metrics.squad_metrics as squad_metrics -from lm_eval.base import rf, Task +from lm_eval.base import rf, PromptSourceTask from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno from lm_eval.utils import general_detokenize @@ -32,7 +32,7 @@ """ -class BoolQ(Task): +class BoolQ(PromptSourceTask): VERSION = 1 DATASET_PATH = "super_glue" DATASET_NAME = "boolq" @@ -54,29 +54,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:" - - def doc_to_target(self, doc): - return " " + yesno(doc['label']) - - def construct_requests(self, doc, ctx): - - ll_yes, _ = rf.loglikelihood(ctx, ' yes') - ll_no, _ = rf.loglikelihood(ctx, ' no') - - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - - acc = 1. if (ll_yes > ll_no) == gold else 0. - - return { - "acc": acc - } - def higher_is_better(self): return { "acc": True @@ -88,7 +65,7 @@ def aggregation(self): } -class CommitmentBank(Task): +class CommitmentBank(PromptSourceTask): VERSION = 1 DATASET_PATH = "super_glue" DATASET_NAME = "cb" @@ -110,25 +87,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( - doc["premise"], - doc["hypothesis"], - ) - - def doc_to_target(self, doc): - # True = entailment - # False = contradiction - # Neither = neutral - return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_true, _ = rf.loglikelihood(ctx, ' True') - ll_false, _ = rf.loglikelihood(ctx, ' False') - ll_neither, _ = rf.loglikelihood(ctx, ' Neither') - - return ll_true, ll_false, ll_neither - def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) @@ -163,7 +121,7 @@ def aggregation(self): } -class Copa(Task): +class Copa(PromptSourceTask): VERSION = 0 DATASET_PATH = "super_glue" DATASET_NAME = "copa" @@ -185,28 +143,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - # Drop the period - connector = { - "cause": "because", - "effect": "therefore", - }[doc["question"]] - return doc["premise"].strip()[:-1] + f" {connector}" - - def doc_to_target(self, doc): - correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"] - # Connect the sentences - return " " + self.convert_choice(correct_choice) - - def construct_requests(self, doc, ctx): - choice1 = " " + self.convert_choice(doc["choice1"]) - choice2 = " " + self.convert_choice(doc["choice2"]) - - ll_choice1, _ = rf.loglikelihood(ctx, choice1) - ll_choice2, _ = rf.loglikelihood(ctx, choice2) - - return ll_choice1, ll_choice2 - def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) @@ -231,7 +167,7 @@ def convert_choice(choice): return choice[0].lower() + choice[1:] -class MultiRC(Task): +class MultiRC(PromptSourceTask): VERSION = 1 DATASET_PATH = "super_glue" DATASET_NAME = "multirc" @@ -253,26 +189,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" - - def doc_to_target(self, doc): - return " " + self.format_answer(answer=doc["answer"], label=doc["label"]) - - @staticmethod - def format_answer(answer, label): - label_str = "yes" if label else "no" - return f"{answer}\nIs the answer correct? {label_str}" - - def construct_requests(self, doc, ctx): - true_choice = self.format_answer(answer=doc["answer"], label=True) - false_choice = self.format_answer(answer=doc["answer"], label=False) - - ll_true_choice, _ = rf.loglikelihood(ctx, f' {true_choice}') - ll_false_choice, _ = rf.loglikelihood(ctx, f' {false_choice}') - - return ll_true_choice, ll_false_choice - def process_results(self, doc, results): ll_true_choice, ll_false_choice = results pred = ll_true_choice > ll_false_choice @@ -291,7 +207,7 @@ def aggregation(self): } -class ReCoRD(Task): +class ReCoRD(PromptSourceTask): VERSION = 0 DATASET_PATH = "super_glue" DATASET_NAME = "record" @@ -328,33 +244,13 @@ def _process_doc(cls, doc): "answers": sorted(list(set(doc["answers"]))), } - def doc_to_text(self, doc): - initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") - text = initial_text + "\n\n" - for highlight in highlights: - text += f" - {highlight}.\n" - return text - - @classmethod - def format_answer(cls, query, entity): - return f' - {query}'.replace("@placeholder", entity) - - def doc_to_target(self, doc): - # We only output the first correct entity in a doc - return self.format_answer(query=doc["query"], entity=doc["answers"][0]) - - def construct_requests(self, doc, ctx): - requests = [ - rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity)) - for entity in doc["entities"] - ] - return requests - def process_results(self, doc, results): # ReCoRD's evaluation is actually deceptively simple: # - Pick the maximum likelihood prediction entity # - Evaluate the accuracy and token F1 PER EXAMPLE # - Average over all examples + + # TODO (jon-tow): Look at result max_idx = np.argmax(np.array([result[0] for result in results])) prediction = doc["entities"][max_idx] @@ -380,7 +276,7 @@ def aggregation(self): } -class WordsInContext(Task): +class WordsInContext(PromptSourceTask): VERSION = 0 DATASET_PATH = "super_glue" DATASET_NAME = "wic" @@ -402,33 +298,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \ - " two sentences above?\nAnswer:".format( - doc["sentence1"], - doc["sentence2"], - doc["sentence1"][doc["start1"]:doc["end1"]], - ) - - def doc_to_target(self, doc): - return " {}".format({0: "no", 1: "yes"}[doc["label"]]) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, ' yes') - ll_no, _ = rf.loglikelihood(ctx, ' no') - - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - - acc = 1. if (ll_yes > ll_no) == gold else 0. - - return { - "acc": acc - } - def higher_is_better(self): return { "acc": True @@ -440,7 +309,7 @@ def aggregation(self): } -class SGWinogradSchemaChallenge(Task): +class SGWinogradSchemaChallenge(PromptSourceTask): VERSION = 0 # Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # binary version of the task. @@ -470,41 +339,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - raw_passage = doc["text"] - # NOTE: HuggingFace span indices are word-based not character-based. - pre = " ".join(raw_passage.split()[:doc["span2_index"]]) - post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:] - passage = general_detokenize(pre + " *{}*".format(doc['span2_text']) + post) - noun = doc["span1_text"] - pronoun = doc["span2_text"] - text = ( - f"Passage: {passage}\n" - + f"Question: In the passage above, does the pronoun \"*{pronoun}*\" refer to \"*{noun}*\"?\n" - + "Answer:" - ) - return text - - def doc_to_target(self, doc): - return " " + yesno(doc['label']) - - def construct_requests(self, doc, ctx): - - ll_yes, _ = rf.loglikelihood(ctx, ' yes') - ll_no, _ = rf.loglikelihood(ctx, ' no') - - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - - acc = 1. if (ll_yes > ll_no) == gold else 0. - - return { - "acc": acc - } - def higher_is_better(self): return { "acc": True From 6ec93da2497bda2402b35960b977f89943d1cb8f Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 20:01:05 -0400 Subject: [PATCH 07/40] Add `eos_token` property --- lm_eval/base.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index bc030675b7..9a7855c729 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -348,7 +348,8 @@ def _collate(x): if isinstance(until, str): until = [until] - (primary_until,) = self.tok_encode(until[0]) + # TODO: Come back to for generation `eos`. + primary_until = self.tok_encode(until[0])[0] context_enc = torch.tensor( [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] @@ -616,7 +617,6 @@ def fewshot_context( ) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) - # get rid of the doc that's the one we're evaluating, if it's in the fewshot fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] @@ -639,6 +639,9 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=Non super().__init__(data_dir, cache_dir, download_mode) self.prompt = prompt + def eos_token(self): + raise NotImplementedError() + def doc_to_target(self, doc): _, target = self.prompt.apply(doc) return f" {target}" @@ -659,7 +662,6 @@ def construct_requests(self, doc, ctx): part of the document for `doc`. """ _requests = [] - answer_choices_list = self.prompt.get_answer_choices_list(doc) if answer_choices_list: for answer_choice in answer_choices_list: @@ -667,8 +669,8 @@ def construct_requests(self, doc, ctx): _requests.append(ll_answer_choice) else: # TODO(Albert): What is the stop symbol? Is it model specific? - ll_greedy = rf.greedy_until(ctx, ["\nQ:"]) - _requests.append(ll_greedy) + cont_request = rf.greedy_until(ctx, [self.eos_token()]) + _requests.append(cont_request) return _requests @@ -694,6 +696,7 @@ def process_results(self, doc, results): } else: continuation = results + raise NotImplementedError() # Map metric name to HF metric. # TODO(Albert): What is Other? From 4ae2ab3747d2037d5ea87b9e180a2a297aa0cd7c Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 20:16:01 -0400 Subject: [PATCH 08/40] Add `higher_is_better` & `aggregation` defaults to `PromptSourceTask` --- lm_eval/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lm_eval/base.py b/lm_eval/base.py index 9a7855c729..3388743d7c 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -701,6 +701,16 @@ def process_results(self, doc, results): # Map metric name to HF metric. # TODO(Albert): What is Other? #metric_names = prompt.metadata.metrics + + def higher_is_better(self): + return { + "acc": True + } + + def aggregation(self): + return { + "acc": mean, + } class MultipleChoiceTask(Task): From c93093b63430ad9b75391504a68118d0c90dd437 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 20:23:56 -0400 Subject: [PATCH 09/40] Removed the default option for an acc task --- lm_eval/tasks/superglue.py | 88 ++++++++++++-------------------------- 1 file changed, 27 insertions(+), 61 deletions(-) diff --git a/lm_eval/tasks/superglue.py b/lm_eval/tasks/superglue.py index 1ea6edde6b..5f4a51a48d 100644 --- a/lm_eval/tasks/superglue.py +++ b/lm_eval/tasks/superglue.py @@ -54,16 +54,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def higher_is_better(self): - return { - "acc": True - } - - def aggregation(self): - return { - "acc": mean - } - class CommitmentBank(PromptSourceTask): VERSION = 1 @@ -90,18 +80,12 @@ def validation_docs(self): def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) - acc = 1. if pred == gold else 0. + acc = 1.0 if pred == gold else 0.0 + + return {"acc": acc, "f1": (pred, gold)} - return { - "acc": acc, - "f1": (pred, gold) - } - def higher_is_better(self): - return { - "acc": True, - "f1": True - } + return {"acc": True, "f1": True} @classmethod def cb_multi_fi(cls, items): @@ -113,7 +97,7 @@ def cb_multi_fi(cls, items): f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2) avg_f1 = mean([f11, f12, f13]) return avg_f1 - + def aggregation(self): return { "acc": mean, @@ -146,21 +130,15 @@ def validation_docs(self): def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) - acc = 1. if pred == gold else 0. + acc = 1.0 if pred == gold else 0.0 + + return {"acc": acc} - return { - "acc": acc - } - def higher_is_better(self): - return { - "acc": True - } - + return {"acc": True} + def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} @staticmethod def convert_choice(choice): @@ -192,19 +170,13 @@ def validation_docs(self): def process_results(self, doc, results): ll_true_choice, ll_false_choice = results pred = ll_true_choice > ll_false_choice - return { - "acc": (pred, doc) - } - + return {"acc": (pred, doc)} + def higher_is_better(self): - return { - "acc": True - } - + return {"acc": True} + def aggregation(self): - return { - "acc": acc_all - } + return {"acc": acc_all} class ReCoRD(PromptSourceTask): @@ -255,8 +227,12 @@ def process_results(self, doc, results): prediction = doc["entities"][max_idx] gold_label_set = doc["answers"] - f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set) - em = metric_max_over_ground_truths(squad_metrics.compute_exact, prediction, gold_label_set) + f1 = metric_max_over_ground_truths( + squad_metrics.compute_f1, prediction, gold_label_set + ) + em = metric_max_over_ground_truths( + squad_metrics.compute_exact, prediction, gold_label_set + ) return { "f1": f1, @@ -299,14 +275,10 @@ def validation_docs(self): return self.dataset["validation"] def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} class SGWinogradSchemaChallenge(PromptSourceTask): @@ -330,9 +302,7 @@ def training_docs(self): if self._training_docs is None: # GPT-3 Paper's format only uses positive examples for fewshot "training" self._training_docs = [ - doc for doc in - self.dataset["train"] - if doc["label"] + doc for doc in self.dataset["train"] if doc["label"] ] return self._training_docs @@ -340,11 +310,7 @@ def validation_docs(self): return self.dataset["validation"] def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} From ec2eb11736a364bf395d112241b0368b98091e2f Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 21:03:17 -0400 Subject: [PATCH 10/40] Add promptsource to setup.py --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 692d090872..00a9be64c5 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,9 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires=">=3.6", install_requires=[ + "promptsource", "black", "datasets==2.0.0", "click>=7.1", @@ -42,9 +43,9 @@ "openai==0.6.4", "jieba==0.42.1", "nagisa==0.2.7", - "bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt" + "bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", ], dependency_links=[ "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", - ] + ], ) From 31a019c2d60750f8c149ea1e9dcf357635fec7d0 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 22:02:24 -0400 Subject: [PATCH 11/40] Temp. commenting our e2e --- lm_eval/tasks/__init__.py | 66 +++++++++++++++++---------------------- setup.py | 1 + 2 files changed, 29 insertions(+), 38 deletions(-) diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index d07886b7df..5d0b646432 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -52,7 +52,8 @@ from . import asdiv from . import gsm8k from . import storycloze -from . import e2e_nlg_cleaned + +# from . import e2e_nlg_cleaned ######################################## # Translation tasks @@ -60,8 +61,8 @@ # 6 total gpt3_translation_benchmarks = { - "wmt14": ['en-fr', 'fr-en'], # French - "wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian + "wmt14": ["en-fr", "fr-en"], # French + "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian } @@ -69,7 +70,7 @@ selected_translation_benchmarks = { **gpt3_translation_benchmarks, "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), - "iwslt17": ['en-ar', 'ar-en'] # Arabic + "iwslt17": ["en-ar", "ar-en"], # Arabic } # 319 total @@ -93,7 +94,7 @@ "rte": glue.RTE, "qnli": glue.QNLI, "qqp": glue.QQP, - #"stsb": glue.STSB, # not implemented yet + # "stsb": glue.STSB, # not implemented yet "sst": glue.SST, "wnli": glue.WNLI, # SuperGLUE @@ -104,35 +105,27 @@ "record": superglue.ReCoRD, "wic": superglue.WordsInContext, "wsc": superglue.SGWinogradSchemaChallenge, - # Order by benchmark/genre? "coqa": coqa.CoQA, "drop": drop.DROP, "lambada": lambada.LAMBADA, "lambada_cloze": lambada_cloze.LAMBADA_cloze, - # multilingual lambada **lambada_multilingual.construct_tasks(), - "wikitext": wikitext.WikiText, # "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix - "piqa": piqa.PiQA, "prost": prost.PROST, "mc_taco": mc_taco.MCTACO, - # Science related - "pubmedqa" : pubmedqa.Pubmed_QA, - "sciq" : sciq.SciQ, - "e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned, - + "pubmedqa": pubmedqa.Pubmed_QA, + "sciq": sciq.SciQ, + # "e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned, "qasper": qasper.QASPER, - - "qa4mre_2011" : qa4mre.QA4MRE_2011, - "qa4mre_2012" : qa4mre.QA4MRE_2012, - "qa4mre_2013" : qa4mre.QA4MRE_2013, - + "qa4mre_2011": qa4mre.QA4MRE_2011, + "qa4mre_2012": qa4mre.QA4MRE_2012, + "qa4mre_2013": qa4mre.QA4MRE_2013, "triviaqa": triviaqa.TriviaQA, "arc_easy": arc.ARCEasy, "arc_challenge": arc.ARCChallenge, @@ -143,7 +136,7 @@ "squad2": squad.SQuAD2, "race": race.RACE, # "naturalqs": naturalqs.NaturalQs, # not implemented yet - "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es + "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa_es": headqa.HeadQAEs, "headqa_en": headqa.HeadQAEn, "mathqa": mathqa.MathQA, @@ -153,21 +146,17 @@ "anli_r1": anli.ANLIRound1, "anli_r2": anli.ANLIRound2, "anli_r3": anli.ANLIRound3, - "ethics_cm": hendrycks_ethics.EthicsCM, "ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_virtue": hendrycks_ethics.EthicsVirtue, - - "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, - "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, - + "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, + "truthfulqa_gen": truthfulqa.TruthfulQAGeneration, # dialogue "mutual": mutual.MuTual, "mutual_plus": mutual.MuTualPlus, - # math "math_algebra": hendrycks_math.MathAlgebra, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability, @@ -178,7 +167,6 @@ "math_precalc": hendrycks_math.MathPrecalculus, "math_asdiv": asdiv.Asdiv, "gsm8k": gsm8k.GradeSchoolMath8K, - # arithmetic "arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus, @@ -192,22 +180,18 @@ "arithmetic_1dc": arithmetic.Arithmetic1DComposite, # TODO Perhaps make these groups of tasks # e.g. anli, arithmetic, openai_translations, harness_translations - # hendrycksTest (57 tasks) **hendrycks_test.create_all_tasks(), - # e.g. wmt14-fr-en **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), # chef's selection, mostly wmt20 **translation.create_tasks_from_benchmarks(selected_translation_benchmarks), - # Word Scrambling and Manipulation Tasks "anagrams1": unscramble.Anagrams1, "anagrams2": unscramble.Anagrams2, "cycle_letters": unscramble.CycleLetters, "random_insertion": unscramble.RandomInsertion, "reversed_words": unscramble.ReversedWords, - # Pile "pile_arxiv": pile.PileArxiv, "pile_books3": pile.PileBooks3, @@ -231,7 +215,6 @@ "pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_wikipedia": pile.PileWikipedia, "pile_youtubesubtitles": pile.PileYoutubeSubtitles, - # BLiMP "blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, @@ -300,7 +283,6 @@ "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, - # Requires manual download of data. # "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2018": storycloze.StoryCloze2018, @@ -324,19 +306,25 @@ def get_task_name_from_object(task_object): for name, class_ in TASK_REGISTRY.items(): if class_ is task_object: return name - + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting - return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ + return ( + task_object.EVAL_HARNESS_NAME + if hasattr(task_object, "EVAL_HARNESS_NAME") + else type(task_object).__name__ + ) def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): task_name_dict = { task_name: get_task(task_name)() - for task_name in task_name_list if isinstance(task_name, str) + for task_name in task_name_list + if isinstance(task_name, str) } task_name_from_object_dict = { get_task_name_from_object(task_object): task_object - for task_object in task_name_list if not isinstance(task_object, str) + for task_object in task_name_list + if not isinstance(task_object, str) } assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) return {**task_name_dict, **task_name_from_object_dict} @@ -352,7 +340,9 @@ def get_task_dict_promptsource(task_name_list: List[str]): # Static version of the Task Use this to get HF dataset path / name. static_task_obj = get_task(task_name) # Create the proper task name arg for DatasetTemplates. - sub_task = f"/{static_task_obj.DATASET_NAME}" if static_task_obj.DATASET_NAME else "" + sub_task = ( + f"/{static_task_obj.DATASET_NAME}" if static_task_obj.DATASET_NAME else "" + ) ps_task_name = f"{static_task_obj.DATASET_PATH}{sub_task}" task_prompts = DatasetTemplates(ps_task_name) diff --git a/setup.py b/setup.py index 00a9be64c5..d9e3a87923 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ python_requires=">=3.6", install_requires=[ "promptsource", + "jinja2", "black", "datasets==2.0.0", "click>=7.1", From e49cf8da0912debaa0e5759e3434466d39598fd7 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 22:24:36 -0400 Subject: [PATCH 12/40] SST with PS integration. (It was already done.) --- lm_eval/tasks/glue.py | 108 ++++++++++++------------------------------ 1 file changed, 30 insertions(+), 78 deletions(-) diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index f640ae1759..26008510c5 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -67,7 +67,7 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def process_results(self, doc, results): + def process_results(self, doc, results): answer_choices_list = self.prompt.get_answer_choices_list(doc) pred = np.argmax(results) target = answer_choices_list.index(self.doc_to_target(doc).strip()) @@ -79,19 +79,13 @@ def process_results(self, doc, results): print(f"PRED: {pred}") print("*" * 80) - return { - "mcc": (target, pred) - } + return {"mcc": (target, pred)} def higher_is_better(self): - return { - "mcc": True - } + return {"mcc": True} def aggregation(self): - return { - "mcc": matthews_corrcoef - } + return {"mcc": matthews_corrcoef} class SST(PromptSourceTask): @@ -116,16 +110,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def higher_is_better(self): - return { - "acc": True - } - - def aggregation(self): - return { - "acc": mean - } - # Inference Tasks @@ -160,19 +144,13 @@ def test_docs(self): def process_results(self, doc, results): gold = doc["label"] pred = np.argmax(results) - return { - "acc": pred == gold - } + return {"acc": pred == gold} def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} class MNLIMismatched(MNLI): @@ -213,19 +191,13 @@ def process_results(self, doc, results): ll_yes, ll_no = results pred = ll_no > ll_yes gold = doc["label"] - return { - "acc": pred == gold - } + return {"acc": pred == gold} def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} class WNLI(PromptSourceTask): @@ -252,14 +224,10 @@ def validation_docs(self): return self.dataset["validation"] def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} class RTE(PromptSourceTask): @@ -285,14 +253,10 @@ def validation_docs(self): return self.dataset["validation"] def higher_is_better(self): - return { - "acc": True - } + return {"acc": True} def aggregation(self): - return { - "acc": mean - } + return {"acc": mean} # Similarity and Paraphrase Tasks @@ -330,16 +294,10 @@ def process_results(self, doc, results): } def higher_is_better(self): - return { - "acc": True, - "f1": True - } + return {"acc": True, "f1": True} def aggregation(self): - return { - "acc": mean, - "f1": f1_score - } + return {"acc": mean, "f1": f1_score} class QQP(Task): @@ -388,16 +346,10 @@ def process_results(self, doc, results): } def higher_is_better(self): - return { - "acc": True, - "f1": True - } + return {"acc": True, "f1": True} def aggregation(self): - return { - "acc": mean, - "f1": f1_score - } + return {"acc": mean, "f1": f1_score} class STSB(Task): @@ -435,22 +387,22 @@ def doc_to_target(self, doc): return " {}".format(doc["label"]) def construct_requests(self, doc, ctx): - """ Uses RequestFactory to construct Requests and returns an iterable of + """Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LM. :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param ctx: str - The context string, generated by fewshot_context. This includes the natural + The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question - part of the document for `doc`. + part of the document for `doc`. """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') - + raise NotImplementedError("Evaluation not implemented") + def process_results(self, doc, results): - """Take a single document and the LM results and evaluates, returning a - dict where keys are the names of submetrics and values are the values of + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of the metric for that one document :param doc: @@ -459,22 +411,22 @@ def process_results(self, doc, results): The results of the requests created in construct_requests. """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') + raise NotImplementedError("Evaluation not implemented") def aggregation(self): """ :returns: {str: [float] -> float} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metrics """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') + raise NotImplementedError("Evaluation not implemented") def higher_is_better(self): """ :returns: {str: bool} - A dictionary where keys are the names of submetrics and values are + A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ # TODO: implement evaluation. - raise NotImplementedError('Evaluation not implemented') + raise NotImplementedError("Evaluation not implemented") From 4c201b97dd4cf1dbfdde316afddbce7caea2bbe9 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 22:25:53 -0400 Subject: [PATCH 13/40] SST with PS integration. (It was already done.) --- lm_eval/tasks/glue.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index 26008510c5..7750c04065 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -71,14 +71,6 @@ def process_results(self, doc, results): answer_choices_list = self.prompt.get_answer_choices_list(doc) pred = np.argmax(results) target = answer_choices_list.index(self.doc_to_target(doc).strip()) - print("*" * 80) - print(f"DOC: {doc}") - print(f"TEXT: {self.doc_to_text(doc)}") - print(f"STRING TARGET: {self.doc_to_target(doc)} END TARGET") - print(f"TARGET: {target} END TARGET") - print(f"PRED: {pred}") - print("*" * 80) - return {"mcc": (target, pred)} def higher_is_better(self): @@ -141,17 +133,6 @@ def test_docs(self): if self.has_test_docs(): return self.dataset["test_matched"] - def process_results(self, doc, results): - gold = doc["label"] - pred = np.argmax(results) - return {"acc": pred == gold} - - def higher_is_better(self): - return {"acc": True} - - def aggregation(self): - return {"acc": mean} - class MNLIMismatched(MNLI): VERSION = 0 From b44aa554d7084de825c8c3660337881526c2f028 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 22:27:06 -0400 Subject: [PATCH 14/40] QNLI with PS integration. --- lm_eval/tasks/glue.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index 7750c04065..63ab25c78f 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -146,7 +146,7 @@ def test_docs(self): return self.dataset["test_mismatched"] -class QNLI(Task): +class QNLI(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "qnli" @@ -168,18 +168,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def process_results(self, doc, results): - ll_yes, ll_no = results - pred = ll_no > ll_yes - gold = doc["label"] - return {"acc": pred == gold} - - def higher_is_better(self): - return {"acc": True} - - def aggregation(self): - return {"acc": mean} - class WNLI(PromptSourceTask): VERSION = 1 @@ -196,20 +184,11 @@ def has_test_docs(self): return False def training_docs(self): - # if self._training_docs is None: - # self._training_docs = list() - # return self._training_docs return self.dataset["train"] def validation_docs(self): return self.dataset["validation"] - def higher_is_better(self): - return {"acc": True} - - def aggregation(self): - return {"acc": mean} - class RTE(PromptSourceTask): VERSION = 0 From 26e94211bc40c6b8d2a355bc381b72234a8eac25 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 22:28:17 -0400 Subject: [PATCH 15/40] MRPC with PS integration. --- lm_eval/tasks/glue.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index 63ab25c78f..c82192a8c6 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -222,7 +222,7 @@ def aggregation(self): # Similarity and Paraphrase Tasks -class MRPC(Task): +class MRPC(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "mrpc" @@ -244,21 +244,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - pred = ll_yes > ll_no - return { - "acc": pred == gold, - "f1": (gold, pred), - } - - def higher_is_better(self): - return {"acc": True, "f1": True} - - def aggregation(self): - return {"acc": mean, "f1": f1_score} - class QQP(Task): VERSION = 0 From b1a3c6e3a53d5dc6be23334aa77ba0e10f4ab2a0 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 22:31:00 -0400 Subject: [PATCH 16/40] QQP with PS integration. --- lm_eval/tasks/glue.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index c82192a8c6..2def8b0dca 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -245,7 +245,7 @@ def validation_docs(self): return self.dataset["validation"] -class QQP(Task): +class QQP(PromptSourceTask): VERSION = 0 DATASET_PATH = "glue" DATASET_NAME = "qqp" @@ -267,35 +267,6 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def doc_to_text(self, doc): - return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format( - doc["question1"], - doc["question2"], - ) - - def doc_to_target(self, doc): - return " {}".format(yesno(doc["label"])) - - def construct_requests(self, doc, ctx): - ll_yes, _ = rf.loglikelihood(ctx, " yes") - ll_no, _ = rf.loglikelihood(ctx, " no") - return ll_yes, ll_no - - def process_results(self, doc, results): - ll_yes, ll_no = results - gold = doc["label"] - pred = ll_yes > ll_no - return { - "acc": pred == gold, - "f1": (gold, pred), - } - - def higher_is_better(self): - return {"acc": True, "f1": True} - - def aggregation(self): - return {"acc": mean, "f1": f1_score} - class STSB(Task): VERSION = 0 From 0e578306dc8bd15e24a17aa848560d211de49b94 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 22:47:22 -0400 Subject: [PATCH 17/40] Add `HANS` --- lm_eval/tasks/__init__.py | 2 ++ lm_eval/tasks/hans.py | 61 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 lm_eval/tasks/hans.py diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index 5d0b646432..87c2a97af4 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -52,6 +52,7 @@ from . import asdiv from . import gsm8k from . import storycloze +from . import hans # from . import e2e_nlg_cleaned @@ -146,6 +147,7 @@ "anli_r1": anli.ANLIRound1, "anli_r2": anli.ANLIRound2, "anli_r3": anli.ANLIRound3, + "hans": hans.HANS, "ethics_cm": hendrycks_ethics.EthicsCM, "ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_justice": hendrycks_ethics.EthicsJustice, diff --git a/lm_eval/tasks/hans.py b/lm_eval/tasks/hans.py new file mode 100644 index 0000000000..d740763870 --- /dev/null +++ b/lm_eval/tasks/hans.py @@ -0,0 +1,61 @@ +""" +Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference +https://arxiv.org/abs/1902.01007 + +A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), +which contains many examples where the heuristics fail. + +Homepage: https://github.com/tommccoy1/hans +""" +from lm_eval.base import PromptSourceTask + + +_CITATION = """ +@inproceedings{mccoy-etal-2019-right, + title = "Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference", + author = "McCoy, Tom and + Pavlick, Ellie and + Linzen, Tal", + booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", + month = jul, + year = "2019", + address = "Florence, Italy", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/P19-1334", + doi = "10.18653/v1/P19-1334", + pages = "3428--3448", + abstract = "A machine learning system can score well on a given test set by relying on heuristics that are effective for frequent example types but break down in more challenging cases. We study this issue within natural language inference (NLI), the task of determining whether one sentence entails another. We hypothesize that statistical NLI models may adopt three fallible syntactic heuristics: the lexical overlap heuristic, the subsequence heuristic, and the constituent heuristic. To determine whether models have adopted these heuristics, we introduce a controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), which contains many examples where the heuristics fail. We find that models trained on MNLI, including BERT, a state-of-the-art model, perform very poorly on HANS, suggesting that they have indeed adopted these heuristics. We conclude that there is substantial room for improvement in NLI systems, and that the HANS dataset can motivate and measure progress in this area.", +} +""" + + +class HANS(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "hans" + DATASET_NAME = None + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return False + + def training_docs(self): + if self.has_training_docs(): + # We cache training documents in `self._training_docs` for faster + # few-shot processing. If the data is too large to fit in memory, + # return the training data as a generator instead of a list. + if self._training_docs is None: + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + return self.dataset["test"] From 0cdcc9891f5fbd182d2f30ce2f0c17bb0afc7330 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 22:47:22 -0400 Subject: [PATCH 18/40] Turned off generation tasks for now. Changed process to look at the metrics. Only accuracy implemented. --- lm_eval/base.py | 50 +++++++++++++++++++++++--------- lm_eval/evaluator.py | 14 ++++++--- lm_eval/tasks/__init__.py | 2 ++ lm_eval/tasks/glue.py | 18 ++++++------ lm_eval/tasks/hans.py | 61 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 118 insertions(+), 27 deletions(-) create mode 100644 lm_eval/tasks/hans.py diff --git a/lm_eval/base.py b/lm_eval/base.py index 3388743d7c..b45ef80143 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -1,7 +1,7 @@ import abc from typing import Iterable -import promptsource +import promptsource import numpy as np import random import re @@ -642,6 +642,12 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=Non def eos_token(self): raise NotImplementedError() + def is_generation_task(self): + return ( + "BLEU" in self.prompt.metadata.metrics + or "ROUGE" in self.prompt.metadata.metrics + ) + def doc_to_target(self, doc): _, target = self.prompt.apply(doc) return f" {target}" @@ -663,11 +669,19 @@ def construct_requests(self, doc, ctx): """ _requests = [] answer_choices_list = self.prompt.get_answer_choices_list(doc) + + # We take a present answer_choices list to mean that we should apply the supplied + # metrics (hardcoded or accuracy atm) to the ranked choices. Otherwise, assume generation. + # Above we do something similar, but rely on the metrics requested (BLEU, ROUGE indicating generation). if answer_choices_list: + assert ( + not self.is_generation_task() + ), f"We expect this to be a ranked choice task; double check please." for answer_choice in answer_choices_list: ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") _requests.append(ll_answer_choice) else: + assert False # TODO(Albert): What is the stop symbol? Is it model specific? cont_request = rf.greedy_until(ctx, [self.eos_token()]) _requests.append(cont_request) @@ -690,27 +704,35 @@ def process_results(self, doc, results): target = self.doc_to_target(doc).strip() answer_choices_list = self.prompt.get_answer_choices_list(doc) if answer_choices_list: + assert ( + not self.is_generation_task() + ), f"We expect this to be a ranked choice task; double check please." pred = answer_choices_list[np.argmax(results)] - return { - "acc": pred == target - } + out = {} + if "Accuracy" in self.prompt.metadata.metrics: + out["acc"] = pred == target + # TODO: Add metrics here. + return out else: - continuation = results - raise NotImplementedError() + raise NotImplementedError("Generation is not implemented yet.") # Map metric name to HF metric. # TODO(Albert): What is Other? - #metric_names = prompt.metadata.metrics - + # metric_names = prompt.metadata.metrics + def higher_is_better(self): - return { - "acc": True - } + out = {} + if "Accuracy" in self.prompt.metadata.metrics: + out["acc"] = True + + return out def aggregation(self): - return { - "acc": mean, - } + out = {} + if "Accuracy" in self.prompt.metadata.metrics: + out["acc"] = mean + + return out class MultipleChoiceTask(Task): diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 1a9fee8e97..362a7e63a1 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -170,8 +170,10 @@ def evaluate( # get lists of each type of request for task_prompt_name, task in task_dict_items: - print(f"TASK PROMPT NAME: {task_prompt_name}") - + if task.is_generation_task(): + print(f"WARNING: Skipping generation prompt {task.prompt.name}.") + continue + versions[task_prompt_name] = task.VERSION # default to test doc, fall back to val doc if validation unavailable # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point @@ -206,7 +208,9 @@ def evaluate( requests[req.request_type].append(req) # i: index in requests for a single task instance # doc_id: unique id that we can get back to a doc using `docs` - requests_origin[req.request_type].append((i, task_prompt_name, doc, doc_id)) + requests_origin[req.request_type].append( + (i, task_prompt_name, doc, doc_id) + ) # all responses for each (task, doc) process_res_queue = collections.defaultdict(list) @@ -224,7 +228,9 @@ def evaluate( x if req.index is None else x[req.index] for x, req in zip(resps, reqs) ] - for resp, (i, task_prompt_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): + for resp, (i, task_prompt_name, doc, doc_id) in zip( + resps, requests_origin[reqtype] + ): process_res_queue[(task_prompt_name, doc_id)].append((i, resp)) vals = collections.defaultdict(list) diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index 5d0b646432..87c2a97af4 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -52,6 +52,7 @@ from . import asdiv from . import gsm8k from . import storycloze +from . import hans # from . import e2e_nlg_cleaned @@ -146,6 +147,7 @@ "anli_r1": anli.ANLIRound1, "anli_r2": anli.ANLIRound2, "anli_r3": anli.ANLIRound3, + "hans": hans.HANS, "ethics_cm": hendrycks_ethics.EthicsCM, "ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_justice": hendrycks_ethics.EthicsJustice, diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index 2def8b0dca..a8b213e30a 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -67,17 +67,17 @@ def training_docs(self): def validation_docs(self): return self.dataset["validation"] - def process_results(self, doc, results): - answer_choices_list = self.prompt.get_answer_choices_list(doc) - pred = np.argmax(results) - target = answer_choices_list.index(self.doc_to_target(doc).strip()) - return {"mcc": (target, pred)} + # def process_results(self, doc, results): + # answer_choices_list = self.prompt.get_answer_choices_list(doc) + # pred = np.argmax(results) + # target = answer_choices_list.index(self.doc_to_target(doc).strip()) + # return {"mcc": (target, pred)} - def higher_is_better(self): - return {"mcc": True} + # def higher_is_better(self): + # return {"mcc": True} - def aggregation(self): - return {"mcc": matthews_corrcoef} + # def aggregation(self): + # return {"mcc": matthews_corrcoef} class SST(PromptSourceTask): diff --git a/lm_eval/tasks/hans.py b/lm_eval/tasks/hans.py new file mode 100644 index 0000000000..d740763870 --- /dev/null +++ b/lm_eval/tasks/hans.py @@ -0,0 +1,61 @@ +""" +Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference +https://arxiv.org/abs/1902.01007 + +A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), +which contains many examples where the heuristics fail. + +Homepage: https://github.com/tommccoy1/hans +""" +from lm_eval.base import PromptSourceTask + + +_CITATION = """ +@inproceedings{mccoy-etal-2019-right, + title = "Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference", + author = "McCoy, Tom and + Pavlick, Ellie and + Linzen, Tal", + booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", + month = jul, + year = "2019", + address = "Florence, Italy", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/P19-1334", + doi = "10.18653/v1/P19-1334", + pages = "3428--3448", + abstract = "A machine learning system can score well on a given test set by relying on heuristics that are effective for frequent example types but break down in more challenging cases. We study this issue within natural language inference (NLI), the task of determining whether one sentence entails another. We hypothesize that statistical NLI models may adopt three fallible syntactic heuristics: the lexical overlap heuristic, the subsequence heuristic, and the constituent heuristic. To determine whether models have adopted these heuristics, we introduce a controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), which contains many examples where the heuristics fail. We find that models trained on MNLI, including BERT, a state-of-the-art model, perform very poorly on HANS, suggesting that they have indeed adopted these heuristics. We conclude that there is substantial room for improvement in NLI systems, and that the HANS dataset can motivate and measure progress in this area.", +} +""" + + +class HANS(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "hans" + DATASET_NAME = None + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return False + + def training_docs(self): + if self.has_training_docs(): + # We cache training documents in `self._training_docs` for faster + # few-shot processing. If the data is too large to fit in memory, + # return the training data as a generator instead of a list. + if self._training_docs is None: + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + return self.dataset["test"] From 422380cc12e34284d50a78742ec878ea7d277ff9 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 23:07:39 -0400 Subject: [PATCH 19/40] Print table across task and prompt --- lm_eval/evaluator.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 362a7e63a1..cae8ff5972 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -275,8 +275,16 @@ def make_table(result_dict): md_writer = MarkdownTableWriter() latex_writer = LatexTableWriter() - md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] - latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] + md_writer.headers = ["Task", "Prompt", "Version", "Metric", "Value", "", "Stderr"] + latex_writer.headers = [ + "Task", + "Prompt", + "Version", + "Metric", + "Value", + "", + "Stderr", + ] values = [] for k, dic in result_dict["results"].items(): @@ -288,9 +296,29 @@ def make_table(result_dict): continue if m + "_stderr" in dic: se = dic[m + "_stderr"] - values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se]) + values.append( + [ + dic["task_name"], + dic["prompt_name"], + version, + m, + "%.4f" % v, + "±", + "%.4f" % se, + ] + ) else: - values.append([k, version, m, "%.4f" % v, "", ""]) + values.append( + [ + dic["task_name"], + dic["prompt_name"], + version, + m, + "%.4f" % v, + "", + "", + ] + ) k = "" version = "" md_writer.value_matrix = values From b988137d6a4e29ca28d2146192a87db6fc3b63b4 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Mon, 25 Apr 2022 23:23:15 -0400 Subject: [PATCH 20/40] Fix `wsc` subset name --- lm_eval/tasks/superglue.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/lm_eval/tasks/superglue.py b/lm_eval/tasks/superglue.py index 5f4a51a48d..667dc54271 100644 --- a/lm_eval/tasks/superglue.py +++ b/lm_eval/tasks/superglue.py @@ -199,22 +199,13 @@ def training_docs(self): if self._training_docs is None: self._training_docs = [] for doc in self.dataset["train"]: - self._training_docs.append(self._process_doc(doc)) + self._training_docs.append(doc) return self._training_docs def validation_docs(self): # See: training_docs for doc in self.dataset["validation"]: - yield self._process_doc(doc) - - @classmethod - def _process_doc(cls, doc): - return { - "passage": doc["passage"], - "query": doc["query"], - "entities": sorted(list(set(doc["entities"]))), - "answers": sorted(list(set(doc["answers"]))), - } + yield doc def process_results(self, doc, results): # ReCoRD's evaluation is actually deceptively simple: @@ -286,7 +277,7 @@ class SGWinogradSchemaChallenge(PromptSourceTask): # Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # binary version of the task. DATASET_PATH = "super_glue" - DATASET_NAME = "wsc" + DATASET_NAME = "wsc.fixed" def has_training_docs(self): return True From c26f1d4c9b0399f3d395038427a720792bcca065 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Mon, 25 Apr 2022 23:40:16 -0400 Subject: [PATCH 21/40] Update the token encode --- lm_eval/base.py | 7 ++--- lm_eval/evaluator.py | 6 ++--- lm_eval/models/gpt2.py | 59 +++++++++++++++++++++++++++++------------- lm_eval/tasks/coqa.py | 7 +++-- lm_eval/tasks/drop.py | 5 ++-- 5 files changed, 53 insertions(+), 31 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index b45ef80143..b50acd21e9 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -349,14 +349,16 @@ def _collate(x): until = [until] # TODO: Come back to for generation `eos`. - primary_until = self.tok_encode(until[0])[0] + primary_until = self.tok_encode(until[0]) context_enc = torch.tensor( [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] ).to(self.device) cont = self._model_generate( - context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until + context_enc, + context_enc.shape[1] + self.max_gen_toks, + torch.tensor(primary_until), ) s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) @@ -681,7 +683,6 @@ def construct_requests(self, doc, ctx): ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") _requests.append(ll_answer_choice) else: - assert False # TODO(Albert): What is the stop symbol? Is it model specific? cont_request = rf.greedy_until(ctx, [self.eos_token()]) _requests.append(cont_request) diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index cae8ff5972..60c123c451 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -170,9 +170,9 @@ def evaluate( # get lists of each type of request for task_prompt_name, task in task_dict_items: - if task.is_generation_task(): - print(f"WARNING: Skipping generation prompt {task.prompt.name}.") - continue + # if task.is_generation_task(): + # print(f"WARNING: Skipping generation prompt {task.prompt.name}.") + # continue versions[task_prompt_name] = task.VERSION # default to test doc, fall back to val doc if validation unavailable diff --git a/lm_eval/models/gpt2.py b/lm_eval/models/gpt2.py index a2214d39b1..2b16cda6ed 100644 --- a/lm_eval/models/gpt2.py +++ b/lm_eval/models/gpt2.py @@ -4,8 +4,15 @@ class HFLM(BaseLM): - - def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): + def __init__( + self, + device="cuda", + pretrained="gpt2", + revision="main", + subfolder=None, + tokenizer=None, + batch_size=1, + ): super().__init__() assert isinstance(device, str) @@ -15,28 +22,47 @@ def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder= if device: self._device = torch.device(device) else: - self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) # TODO: update this to be less of a hack once subfolder is fixed in HF self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( - pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") + pretrained, + revision=revision + ("/" + subfolder if subfolder is not None else ""), ).to(self.device) self.gpt2.eval() # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 self.tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) + pretrained if tokenizer is None else tokenizer, + revision=revision, + subfolder=subfolder, + ) - assert isinstance(self.tokenizer, ( - transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, - transformers.T5Tokenizer, transformers.T5TokenizerFast, - )), "this tokenizer has not been checked for compatibility yet!" + assert isinstance( + self.tokenizer, + ( + transformers.GPT2Tokenizer, + transformers.GPT2TokenizerFast, + transformers.T5Tokenizer, + transformers.T5TokenizerFast, + ), + ), "this tokenizer has not been checked for compatibility yet!" self.vocab_size = self.tokenizer.vocab_size - if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): - assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \ - self.tokenizer.encode('hello\n\nhello') + if isinstance( + self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast) + ): + assert self.tokenizer.encode("hello\n\nhello") == [ + 31373, + 198, + 198, + 31373, + ], self.tokenizer.encode("hello\n\nhello") # multithreading and batching self.batch_size_per_gpu = batch_size # todo: adaptive batch size @@ -75,7 +101,7 @@ def device(self): def tok_encode(self, string: str): return self.tokenizer.encode(string, add_special_tokens=False) - + def tok_decode(self, tokens): return self.tokenizer.decode(tokens) @@ -89,13 +115,10 @@ def _model_call(self, inps): """ with torch.no_grad(): return self.gpt2(inps)[0][:, :, :50257] - + def _model_generate(self, context, max_length, eos_token_id): return self.gpt2.generate( - context, - max_length=max_length, - eos_token_id=eos_token_id, - do_sample=False + context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False ) diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index ee4fa31e50..b0a993e456 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -67,7 +67,6 @@ def test_docs(self): # answers.append(additional_answer_for_turn) # return answers - @staticmethod def compute_scores(gold_list, pred): # tests for exact match and on the normalised answer (compute_exact) @@ -92,7 +91,7 @@ def compute_scores(gold_list, pred): } def eos_token(self): - return "\n" + return "\nQ:" # def construct_requests(self, doc, ctx): # """Uses RequestFactory to construct Requests and returns an iterable of @@ -121,10 +120,10 @@ def process_results(self, doc, results): pred = results[0].strip().split("\n")[0] print("*" * 80) print(f"DOC: {doc}") -# print(f"PS: {self.prompt.apply(doc)}") + # print(f"PS: {self.prompt.apply(doc)}") print(f"TEXT: {self.doc_to_text(doc)}") print(f"TARGET: {target} END TARGET") - print(pred) + print(f"PRED: {pred} END PRED") print("*" * 80) # turn_id = len(doc["questions"]["input_text"]) diff --git a/lm_eval/tasks/drop.py b/lm_eval/tasks/drop.py index 6e8ce20740..f7b301493b 100644 --- a/lm_eval/tasks/drop.py +++ b/lm_eval/tasks/drop.py @@ -39,7 +39,7 @@ class DROP(PromptSourceTask): VERSION = 1 - DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop) + DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop) DATASET_NAME = None def has_training_docs(self): @@ -114,10 +114,9 @@ def process_results(self, doc, results): print(f"PS: {self.prompt.apply(doc)}") print(f"TEXT: {self.doc_to_text(doc)}") print(f"TARGET: {target} END TARGET") - print(pred) + print(f"PRED: {pred} END PRED") print("*" * 80) - preds = [pred] golds = [target] From c27e29e156a0beed0b1cac742b2a8fa6f9ad9fc9 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Tue, 26 Apr 2022 00:05:52 -0400 Subject: [PATCH 22/40] Force no caching while testing --- lm_eval/evaluator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 60c123c451..59be1cd71f 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -70,6 +70,8 @@ def simple_evaluate( assert isinstance(model, lm_eval.base.LM) lm = model + # TODO: Hard-code turning off cache while testing. Remove once testing is completed. + no_cache = True if not no_cache: lm = lm_eval.base.CachingLM( lm, From f39c27c266fd366031ce392f0565a1e679cfd3a9 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 11:28:13 -0400 Subject: [PATCH 23/40] Rename task specific to --- lm_eval/base.py | 33 ++++++++++++++++++++++++++++++--- lm_eval/evaluator.py | 4 ++++ lm_eval/models/gpt2.py | 5 ++++- lm_eval/tasks/coqa.py | 2 +- lm_eval/tasks/drop.py | 2 +- lm_eval/tasks/glue.py | 3 +++ 6 files changed, 43 insertions(+), 6 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index b50acd21e9..0d90f3bc8e 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -641,8 +641,12 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=Non super().__init__(data_dir, cache_dir, download_mode) self.prompt = prompt - def eos_token(self): - raise NotImplementedError() + def end_of_generation_sequence(self): + """Denote where the generation should be split. + + For example, for coqa, this is '\nQ:' and for drop '.'. + """ + return None def is_generation_task(self): return ( @@ -650,6 +654,29 @@ def is_generation_task(self): or "ROUGE" in self.prompt.metadata.metrics ) + def invalid_doc_for_prompt(self, doc): + """Some prompts may not work for some documents. + + As of now, we skip particular prompts, s.t. we don't + overskip. If this turns out to be a problem for many prompts + we can instead make sure that apply returns 2 things. + + + """ + if ( + # generate_paraphrase for mrpc + ( + self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5" + or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f" + ) + and doc["label"] == 0 + ): + # This generation prompt assumes a positive example. We filter out the negative examples. + # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7 + # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88 + return True + return False + def doc_to_target(self, doc): _, target = self.prompt.apply(doc) return f" {target}" @@ -684,7 +711,7 @@ def construct_requests(self, doc, ctx): _requests.append(ll_answer_choice) else: # TODO(Albert): What is the stop symbol? Is it model specific? - cont_request = rf.greedy_until(ctx, [self.eos_token()]) + cont_request = rf.greedy_until(ctx, [self.end_of_generation_sequence()]) _requests.append(cont_request) return _requests diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 59be1cd71f..efeb5d2178 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -2,6 +2,7 @@ import itertools import pathlib import random + import lm_eval.metrics import lm_eval.models import lm_eval.tasks @@ -199,6 +200,9 @@ def evaluate( ) for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): + if task.invalid_doc_for_prompt(doc): + continue + docs[(task_prompt_name, doc_id)] = doc ctx = task.fewshot_context( doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description diff --git a/lm_eval/models/gpt2.py b/lm_eval/models/gpt2.py index 2b16cda6ed..a9ce172db9 100644 --- a/lm_eval/models/gpt2.py +++ b/lm_eval/models/gpt2.py @@ -118,7 +118,10 @@ def _model_call(self, inps): def _model_generate(self, context, max_length, eos_token_id): return self.gpt2.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False + context, + max_length=max_length, + eos_token_id=eos_token_id, + do_sample=False, ) diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index b0a993e456..c95b1891c8 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -90,7 +90,7 @@ def compute_scores(gold_list, pred): "f1": f1_sum / max(1, len(gold_list)), } - def eos_token(self): + def end_of_generation_sequence(self): return "\nQ:" # def construct_requests(self, doc, ctx): diff --git a/lm_eval/tasks/drop.py b/lm_eval/tasks/drop.py index f7b301493b..3c38159273 100644 --- a/lm_eval/tasks/drop.py +++ b/lm_eval/tasks/drop.py @@ -92,7 +92,7 @@ def parse_answer(cls, answer): # """ # conts = [rf.greedy_until(ctx, ["."])] # return conts - def eos_token(self): + def end_of_generation_sequence(self): return "." def process_results(self, doc, results): diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index a8b213e30a..5eeb5cd279 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -236,6 +236,9 @@ def has_validation_docs(self): def has_test_docs(self): return False + def end_of_generation_sequence(self): + return "\n" + def training_docs(self): if self._training_docs is None: self._training_docs = list(self.dataset["train"]) From d4c0009315197a9ba91adb3b4989f4c1c2325905 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 12:10:33 -0400 Subject: [PATCH 24/40] Added default behavior for bleu to the promtsourcetask class --- lm_eval/base.py | 60 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 0d90f3bc8e..fe38fd1108 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -14,6 +14,7 @@ import torch import torch.nn.functional as F +from lm_eval import metrics from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval import utils from abc import abstractmethod @@ -637,6 +638,16 @@ def fewshot_context( class PromptSourceTask(Task): + """These are the metrics from promptsource that we have + added default behavior for. If you want to add default behavior for a new metric, + update the functions below. If you want to use one of the following metrics, + *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`. + + WARNING: ROUGE is WIP. + """ + + CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"]) + def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None): super().__init__(data_dir, cache_dir, download_mode) self.prompt = prompt @@ -737,29 +748,60 @@ def process_results(self, doc, results): ), f"We expect this to be a ranked choice task; double check please." pred = answer_choices_list[np.argmax(results)] out = {} - if "Accuracy" in self.prompt.metadata.metrics: - out["acc"] = pred == target + + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "Accuracy": + out["acc"] = pred == target # TODO: Add metrics here. return out else: - raise NotImplementedError("Generation is not implemented yet.") + # NOTE: In the future, target may be a list, not a string. + pred = results[0].strip() + out = {} + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "BLEU": + out["bleu"] = (target, pred) + if metric == "ROUGE": + print("WARNING: Skipping Rouge.") + + return out # Map metric name to HF metric. # TODO(Albert): What is Other? # metric_names = prompt.metadata.metrics def higher_is_better(self): out = {} - if "Accuracy" in self.prompt.metadata.metrics: - out["acc"] = True - + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "Accuracy": + out["acc"] = True + if metric == "BLEU": + out["bleu"] = True + if metric == "ROUGE": + print("WARNING: Skipping Rouge.") return out def aggregation(self): out = {} - if "Accuracy" in self.prompt.metadata.metrics: - out["acc"] = mean - + for metric in self.prompt.metadata.metrics: + assert ( + metric in self.CONFIGURED_PS_METRICS + ), "Unexpected metric. Add it, or use a task-specific solution." + if metric == "Accuracy": + out["acc"] = mean + if metric == "BLEU": + out["bleu"] = metrics.bleu + if metric == "ROUGE": + print("WARNING: Skipping Rouge.") return out From 1dcca55ca4022b045bfdfec6624e99b5c2ed5dd2 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 12:25:46 -0400 Subject: [PATCH 25/40] Minor updates to documentation. --- lm_eval/base.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index fe38fd1108..cc231e898f 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -665,34 +665,28 @@ def is_generation_task(self): or "ROUGE" in self.prompt.metadata.metrics ) - def invalid_doc_for_prompt(self, doc): - """Some prompts may not work for some documents. - - As of now, we skip particular prompts, s.t. we don't - overskip. If this turns out to be a problem for many prompts - we can instead make sure that apply returns 2 things. - - - """ + def invalid_doc_for_prompt(self, doc) -> bool: + """Some prompts may not work for some documents.""" if ( # generate_paraphrase for mrpc + # This generation prompt assumes a positive example. We filter out the negative examples. + # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7 + # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88 ( self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5" or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f" ) and doc["label"] == 0 ): - # This generation prompt assumes a positive example. We filter out the negative examples. - # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7 - # https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88 return True return False - def doc_to_target(self, doc): + def doc_to_target(self, doc) -> str: + """NOTE: In the future, this may return Union[str, List[str]].""" _, target = self.prompt.apply(doc) return f" {target}" - def doc_to_text(self, doc): + def doc_to_text(self, doc) -> str: text, _ = self.prompt.apply(doc) return text @@ -737,9 +731,6 @@ def process_results(self, doc, results): :param results: The results of the requests created in construct_requests. """ - # raise NotImplementedError( - # "Implement process results using the `prompt.metadata.metrics`. See below." - # ) target = self.doc_to_target(doc).strip() answer_choices_list = self.prompt.get_answer_choices_list(doc) if answer_choices_list: @@ -772,9 +763,6 @@ def process_results(self, doc, results): print("WARNING: Skipping Rouge.") return out - # Map metric name to HF metric. - # TODO(Albert): What is Other? - # metric_names = prompt.metadata.metrics def higher_is_better(self): out = {} From b2838b8d119cec0765b233ba619168556d8ddc10 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 13:20:48 -0400 Subject: [PATCH 26/40] Rename task specific to --- lm_eval/base.py | 8 +++++--- lm_eval/tasks/coqa.py | 2 +- lm_eval/tasks/drop.py | 2 +- lm_eval/tasks/glue.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index cc231e898f..2a0cd04ed8 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -652,10 +652,12 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=Non super().__init__(data_dir, cache_dir, download_mode) self.prompt = prompt - def end_of_generation_sequence(self): - """Denote where the generation should be split. + def stopping_criteria(self): + """Denote where the generation should end. For example, for coqa, this is '\nQ:' and for drop '.'. + + By default, its None, meaning to generate up to max or EOT, whichever comes first. """ return None @@ -716,7 +718,7 @@ def construct_requests(self, doc, ctx): _requests.append(ll_answer_choice) else: # TODO(Albert): What is the stop symbol? Is it model specific? - cont_request = rf.greedy_until(ctx, [self.end_of_generation_sequence()]) + cont_request = rf.greedy_until(ctx, [self.stopping_criteria()]) _requests.append(cont_request) return _requests diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index c95b1891c8..5043c10594 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -90,7 +90,7 @@ def compute_scores(gold_list, pred): "f1": f1_sum / max(1, len(gold_list)), } - def end_of_generation_sequence(self): + def stopping_criteria(self): return "\nQ:" # def construct_requests(self, doc, ctx): diff --git a/lm_eval/tasks/drop.py b/lm_eval/tasks/drop.py index 3c38159273..ff78c76b59 100644 --- a/lm_eval/tasks/drop.py +++ b/lm_eval/tasks/drop.py @@ -92,7 +92,7 @@ def parse_answer(cls, answer): # """ # conts = [rf.greedy_until(ctx, ["."])] # return conts - def end_of_generation_sequence(self): + def stopping_criteria(self): return "." def process_results(self, doc, results): diff --git a/lm_eval/tasks/glue.py b/lm_eval/tasks/glue.py index 5eeb5cd279..8914db88dd 100644 --- a/lm_eval/tasks/glue.py +++ b/lm_eval/tasks/glue.py @@ -236,7 +236,7 @@ def has_validation_docs(self): def has_test_docs(self): return False - def end_of_generation_sequence(self): + def stopping_criteria(self): return "\n" def training_docs(self): From 96ea7ddce14daaabe3da8ec3a9ce520e72c9c535 Mon Sep 17 00:00:00 2001 From: Tian Yun Date: Tue, 26 Apr 2022 13:39:36 -0400 Subject: [PATCH 27/40] Added stoppping criteria for generation --- lm_eval/models/gpt2.py | 35 +++++++++++++++++++++++++++++++++-- tests/test_gpt2.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 tests/test_gpt2.py diff --git a/lm_eval/models/gpt2.py b/lm_eval/models/gpt2.py index 2b16cda6ed..8cb804cfa7 100644 --- a/lm_eval/models/gpt2.py +++ b/lm_eval/models/gpt2.py @@ -116,10 +116,41 @@ def _model_call(self, inps): with torch.no_grad(): return self.gpt2(inps)[0][:, :, :50257] - def _model_generate(self, context, max_length, eos_token_id): + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(stopping_criteria_ids) + ]) + + def _model_generate(self, context, max_length, stopping_criteria_ids): + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) return self.gpt2.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False + context, + max_length=max_length, + stopping_criteria=stopping_criteria, + do_sample=False, ) + # for backwards compatibility diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py new file mode 100644 index 0000000000..898d1b96ad --- /dev/null +++ b/tests/test_gpt2.py @@ -0,0 +1,33 @@ +import random +import lm_eval.models as models +import pytest +import torch +from transformers import StoppingCriteria + + +@pytest.mark.parametrize( + "eos_token,test_input,expected", + [ + ("not", "i like", "i like to say that I'm not"), + ("say that", "i like", "i like to say that"), + ("great", "big science is", "big science is a great"), + ("<|endoftext|>", "big science has", "big science has been done in the past, but it's not the same as the science of the") + ] +) +def test_stopping_criteria(eos_token, test_input, expected): + random.seed(42) + torch.random.manual_seed(42) + + device = "cuda" if torch.cuda.is_available() else "cpu" + gpt2 = models.get_model("gpt2")(device=device) + + context = torch.tensor([gpt2.tokenizer.encode(test_input)]) + stopping_criteria_ids = gpt2.tokenizer.encode(eos_token) + + generations = gpt2._model_generate( + context, + max_length=20, + stopping_criteria_ids=stopping_criteria_ids + ) + generations = gpt2.tokenizer.decode(generations[0]) + assert generations == expected From 61b5729293061d4d840b585a62b420e8b08790af Mon Sep 17 00:00:00 2001 From: Tian Yun Date: Tue, 26 Apr 2022 15:46:03 -0400 Subject: [PATCH 28/40] Modified stopping criteria for gpt2 --- lm_eval/models/gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_eval/models/gpt2.py b/lm_eval/models/gpt2.py index 8cb804cfa7..2e73adf3a7 100644 --- a/lm_eval/models/gpt2.py +++ b/lm_eval/models/gpt2.py @@ -139,7 +139,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return transformers.StoppingCriteriaList([ MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), - EOSCriteria(stopping_criteria_ids) + EOSCriteria(self.tokenizer.eos_token) ]) def _model_generate(self, context, max_length, stopping_criteria_ids): From 23a6ee7f174494ac2cab337bc9180547c126d6bf Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 16:00:17 -0400 Subject: [PATCH 29/40] Changing stopping criteria for coqa --- lm_eval/tasks/coqa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_eval/tasks/coqa.py b/lm_eval/tasks/coqa.py index 5043c10594..0fbd23112e 100644 --- a/lm_eval/tasks/coqa.py +++ b/lm_eval/tasks/coqa.py @@ -91,7 +91,7 @@ def compute_scores(gold_list, pred): } def stopping_criteria(self): - return "\nQ:" + return "\n\n" # def construct_requests(self, doc, ctx): # """Uses RequestFactory to construct Requests and returns an iterable of From 941fe268bf7f0a9252b460a6655a8e478c01c3b7 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Tue, 26 Apr 2022 16:44:12 -0400 Subject: [PATCH 30/40] Add `PromptSourceTask` template --- templates/new_task.py | 93 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 templates/new_task.py diff --git a/templates/new_task.py b/templates/new_task.py new file mode 100644 index 0000000000..10ba6eb513 --- /dev/null +++ b/templates/new_task.py @@ -0,0 +1,93 @@ +# TODO: Remove all TODO comments once the implementation is complete. +""" +TODO: Add the Paper Title on this line. +TODO: Add the paper's PDF URL (preferrably from arXiv) on this line. + +TODO: Write a Short Description of the task. + +Homepage: TODO: Add the URL to the task's Homepage here. +""" +from lm_eval.base import PromptSourceTask + + +# TODO: Add the BibTeX citation for the task. +_CITATION = """ +""" + + +# TODO: Replace `NewTask` with the name of your Task. +class NewTask(PromptSourceTask): + VERSION = 0 + # TODO: Add the `DATASET_PATH` string. This will be the name of the `Task` + # dataset as denoted in HuggingFace `datasets`. + DATASET_PATH = "" + # TODO: Add the `DATASET_NAME` string. This is the name of a subset within + # `DATASET_PATH`. If there aren't specific subsets you need, leave this as `None`. + DATASET_NAME = None + + def has_training_docs(self): + # TODO: Fill in the return with `True` if the Task has training data; else `False`. + return False + + def has_validation_docs(self): + # TODO: Fill in the return with `True` if the Task has validation data; else `False`. + return False + + def has_test_docs(self): + # TODO: Fill in the return with `True` if the Task has test data; else `False`. + return False + + def stopping_criteria(self): + # TODO: Denote the string where the generation should be split. + # For example, for `coqa`, this is '\nQ:' and for `drop` '.'. + # NOTE: You may delete this function if the task does not required generation. + return None + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or + test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # TODO: Construct your language model requests with the request factory, `rf`, + # and return them as an iterable. + return [] + + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + # TODO: For each (sub)metric in the task evaluation, add a key-value pair + # with the metric name as key and the corresponding metric result as value + # for the current `doc`. + return {} + + def aggregation(self): + """ + :returns: {str: [metric_score] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metric scores + """ + # TODO: For each (sub)metric in the task evaluation, add a key-value pair + # with the metric name as key and an aggregation function as value which + # determines how to combine results from each document in the dataset. + # Check `lm_eval.metrics` to find built-in aggregation functions. + return {} + + def higher_is_better(self): + # TODO: For each (sub)metric in the task evaluation, add a key-value pair + # with the metric name as key and a `bool` value determining whether or + # not higher values of that metric are deemed better. + return {} \ No newline at end of file From b8c203cd052a3808ed469fba924bb88264c76eb6 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 17:15:34 -0400 Subject: [PATCH 31/40] A dependency required this but it was not installed by default --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index d9e3a87923..94fb89cda2 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,8 @@ python_requires=">=3.6", install_requires=[ "promptsource", + "wrapt", + "nltk", "jinja2", "black", "datasets==2.0.0", From 9384ec917dd1e33ea2e406a682fcd411a3eab01b Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 17:30:03 -0400 Subject: [PATCH 32/40] A dependency required this but it was not installed by default --- lm_eval/base.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 2a0cd04ed8..57f0dcc71a 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -1,5 +1,5 @@ import abc -from typing import Iterable +from typing import Iterable, Optional import promptsource import numpy as np @@ -348,17 +348,25 @@ def _collate(x): for context, until in tqdm(reord.get_reordered()): if isinstance(until, str): until = [until] + max_length = None + elif isinstance(until, list) and len(until) == 2: + until, max_length = [until[0]], until[1] + elif isinstance(until, list): + max_length = None - # TODO: Come back to for generation `eos`. primary_until = self.tok_encode(until[0]) - context_enc = torch.tensor( [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] ).to(self.device) + if max_length is not None: + max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks) + else: + max_length = context_enc.shape[1] + self.max_gen_toks + cont = self._model_generate( context_enc, - context_enc.shape[1] + self.max_gen_toks, + max_length, torch.tensor(primary_until), ) @@ -652,7 +660,7 @@ def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=Non super().__init__(data_dir, cache_dir, download_mode) self.prompt = prompt - def stopping_criteria(self): + def stopping_criteria(self) -> Optional[str]: """Denote where the generation should end. For example, for coqa, this is '\nQ:' and for drop '.'. @@ -661,6 +669,10 @@ def stopping_criteria(self): """ return None + def max_generation_length(self) -> Optional[int]: + """Denote where the max length of the generation if it is obvious from the task.""" + return None + def is_generation_task(self): return ( "BLEU" in self.prompt.metadata.metrics @@ -718,7 +730,9 @@ def construct_requests(self, doc, ctx): _requests.append(ll_answer_choice) else: # TODO(Albert): What is the stop symbol? Is it model specific? - cont_request = rf.greedy_until(ctx, [self.stopping_criteria()]) + cont_request = rf.greedy_until( + ctx, [self.stopping_criteria(), self.max_generation_length()] + ) _requests.append(cont_request) return _requests From 94218002b139a7e8a4847cd52d95f86ccc92c4eb Mon Sep 17 00:00:00 2001 From: jon-tow Date: Tue, 26 Apr 2022 18:09:31 -0400 Subject: [PATCH 33/40] Add multi-reference ROUGE support --- lm_eval/metrics.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/lm_eval/metrics.py b/lm_eval/metrics.py index 05fad59ff3..38f60ab52b 100644 --- a/lm_eval/metrics.py +++ b/lm_eval/metrics.py @@ -1,8 +1,10 @@ +import typing import math from collections.abc import Iterable import numpy as np import sacrebleu +from rouge_score import rouge_scorer import sklearn.metrics import random @@ -184,6 +186,65 @@ def _sacreformat(refs, preds): return refs, preds + +def rouge( + refs: typing.List[str], + pred: str, + rouge_types: typing.List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"] +): + """ ROUGE with multi-reference support + + Implementation based on GEM-metrics: + https://github.com/GEM-benchmark/GEM-metrics/blob/431a8174bd6b3637e8d6118bfad2983e39e99733/gem_metrics/rouge.py + + :param refs: + A `list` of reference `str`s. + :param pred: + A single prediction `str`s. + """ + scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) + # ROUGE multi-ref jackknifing + if len(refs) > 1: + cur_scores = [scorer.score(ref, pred) for ref in refs] + + # get best score for all leave-one-out sets + best_scores = [] + for leave in range(len(refs)): + cur_scores_leave_one = [ + cur_scores[s] for s in range(len(refs)) if s != leave + ] + best_scores.append( + { + rouge_type: max( + [s[rouge_type] for s in cur_scores_leave_one], + key=lambda s: s.fmeasure, + ) + for rouge_type in rouge_types + } + ) + # average the leave-one-out bests to produce the final score + score = { + rouge_type: rouge_scorer.scoring.Score( + np.mean([b[rouge_type].precision for b in best_scores]), + np.mean([b[rouge_type].recall for b in best_scores]), + np.mean([b[rouge_type].fmeasure for b in best_scores]), + ) + for rouge_type in rouge_types + } + else: + score = scorer.score(refs[0], pred) + # convert the named tuples to plain nested dicts + score = { + rouge_type: { + "precision": score[rouge_type].precision, + "recall": score[rouge_type].recall, + "fmeasure": score[rouge_type].fmeasure, + } + for rouge_type in rouge_types + } + return score + + # stderr stuff class _bootstrap_internal: From 716c87d659ca08d2c52bd5b0a503a7ad8623321f Mon Sep 17 00:00:00 2001 From: Tian Yun Date: Tue, 26 Apr 2022 18:12:53 -0400 Subject: [PATCH 34/40] Added t0, t5, mt5 --- lm_eval/models/__init__.py | 5 ++ lm_eval/models/t0.py | 161 +++++++++++++++++++++++++++++++++++++ lm_eval/models/t5.py | 161 +++++++++++++++++++++++++++++++++++++ 3 files changed, 327 insertions(+) create mode 100644 lm_eval/models/t0.py create mode 100644 lm_eval/models/t5.py diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index a12f68a513..171db2e9e2 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -1,11 +1,16 @@ from . import gpt2 from . import gpt3 +from . import t5 +from . import t0 from . import dummy MODEL_REGISTRY = { "hf": gpt2.HFLM, "gpt2": gpt2.GPT2LM, "gpt3": gpt3.GPT3LM, + "t5": t5.T5LM, + "mt5": t5.T5LM, + "t0": t0.T0LM, "dummy": dummy.DummyLM, } diff --git a/lm_eval/models/t0.py b/lm_eval/models/t0.py new file mode 100644 index 0000000000..db7ba582a7 --- /dev/null +++ b/lm_eval/models/t0.py @@ -0,0 +1,161 @@ +import transformers +import torch +import torch.nn as nn +import torch.nn.functional as F +from lm_eval.base import LM +from lm_eval import utils +from tqdm import tqdm +import numpy as np +import math + +class T0LM(LM): + MAX_GEN_TOKS = 256 + MAX_INP_LENGTH = 512 + VOCAB_SIZE = 32100 + EOT_TOKEN_ID = 1 + + def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1): + super().__init__() + if device: + self.device = torch.device(device) + else: + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + print(pretrained) + self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained) + self.t0.eval() + + if parallelize == "True": + print(parallelize) + self.t0.parallelize() + self.device = torch.device('cuda:0') + else: + self.t0.to(self.device) + + self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained) + self.max_length = self.MAX_INP_LENGTH + + self.batch_size = int(batch_size) + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config={}): + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + def loglikelihood(self, requests): + res = [] + for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)): + + inputs, targets = zip(*chunk) + + inputs_tok = self.tokenizer( + list(inputs), + max_length=self.max_length, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in inputs_tok: + inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :] + + targets_tok = self.tokenizer( + list(targets), + max_length=self.MAX_GEN_TOKS, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in targets_tok: + targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :] + + with torch.no_grad(): + outputs = self.t0( + **inputs_tok, + labels=targets_tok["input_ids"] + ) + + log_softmaxes = F.log_softmax(outputs.logits, dim=-1) + + output_iterator = zip( + chunk, + log_softmaxes, + targets_tok["input_ids"], + targets_tok["attention_mask"], + ) + for cache_key, log_softmax, target_tok, target_mask in output_iterator: + length = target_mask.sum() + log_softmax = log_softmax[:length] + target_tok = target_tok[:length] + greedy_tokens = log_softmax.argmax(dim=-1) + max_equal = (greedy_tokens == target_tok).all() + target_logits = torch.gather( + log_softmax, 1, target_tok.unsqueeze(-1) + ).squeeze(-1) + answer = (float(target_logits.sum()), bool(max_equal)) + + if cache_key is not None: + self.cache_hook.add_partial("loglikelihood", cache_key, answer) + + res.append(answer) + + return res + + def loglikelihood_rolling(self, requests): + raise NotImplementedError + + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(self.tokenizer.eos_token) + ]) + + def greedy_until(self, requests): + + res = [] + + for context, until in tqdm(requests): + if isinstance(until, str): until = [until] + + context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids + + stopping_criteria_ids = self.tokenizer.encode(until[0]) + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) + + cont = self.t0.generate( + context_enc, + max_length=self.MAX_GEN_TOKS, + stopping_criteria=stopping_criteria, + do_sample=False + ) + + s = self.tokenizer.decode(cont[0].tolist()) + + self.cache_hook.add_partial("greedy_until", (context, until), s) + + res.append(s) + + return res \ No newline at end of file diff --git a/lm_eval/models/t5.py b/lm_eval/models/t5.py new file mode 100644 index 0000000000..e6d99f870f --- /dev/null +++ b/lm_eval/models/t5.py @@ -0,0 +1,161 @@ +import transformers +import torch +import torch.nn as nn +import torch.nn.functional as F +from lm_eval.base import LM +from lm_eval import utils +from tqdm import tqdm +import numpy as np +import math + +class T5LM(LM): + MAX_GEN_TOKS = 256 + MAX_INP_LENGTH = 512 + VOCAB_SIZE = 32128 + EOT_TOKEN_ID = 1 + + def __init__(self, device='cuda', parallelize=False, pretrained='t5', batch_size=1): + super().__init__() + if device: + self.device = torch.device(device) + else: + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + print(pretrained) + self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained) + self.t5.eval() + + if parallelize == "True": + print(parallelize) + self.t5.parallelize() + self.device = torch.device('cuda:0') + else: + self.t5.to(self.device) + + self.tokenizer = transformers.T5TokenizerFast.from_pretrained(pretrained) + self.max_length = self.MAX_INP_LENGTH + + self.batch_size = int(batch_size) + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config={}): + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + def loglikelihood(self, requests): + res = [] + for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)): + + inputs, targets = zip(*chunk) + + inputs_tok = self.tokenizer( + list(inputs), + max_length=self.max_length, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in inputs_tok: + inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :] + + targets_tok = self.tokenizer( + list(targets), + max_length=self.MAX_GEN_TOKS, + padding=True, + # truncation=True, + add_special_tokens=False, + return_tensors="pt" + ).to(self.device) + + for key in targets_tok: + targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :] + + with torch.no_grad(): + outputs = self.t5( + **inputs_tok, + labels=targets_tok["input_ids"] + ) + + log_softmaxes = F.log_softmax(outputs.logits, dim=-1) + + output_iterator = zip( + chunk, + log_softmaxes, + targets_tok["input_ids"], + targets_tok["attention_mask"], + ) + for cache_key, log_softmax, target_tok, target_mask in output_iterator: + length = target_mask.sum() + log_softmax = log_softmax[:length] + target_tok = target_tok[:length] + greedy_tokens = log_softmax.argmax(dim=-1) + max_equal = (greedy_tokens == target_tok).all() + target_logits = torch.gather( + log_softmax, 1, target_tok.unsqueeze(-1) + ).squeeze(-1) + answer = (float(target_logits.sum()), bool(max_equal)) + + if cache_key is not None: + self.cache_hook.add_partial("loglikelihood", cache_key, answer) + + res.append(answer) + + return res + + def loglikelihood_rolling(self, requests): + raise NotImplementedError + + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(self.tokenizer.eos_token) + ]) + + def greedy_until(self, requests): + + res = [] + + for context, until in tqdm(requests): + if isinstance(until, str): until = [until] + + context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids + + stopping_criteria_ids = self.tokenizer.encode(until[0]) + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) + + cont = self.t5.generate( + context_enc, + max_length=self.MAX_GEN_TOKS, + stopping_criteria=stopping_criteria, + do_sample=False + ) + + s = self.tokenizer.decode(cont[0].tolist()) + + self.cache_hook.add_partial("greedy_until", (context, until), s) + + res.append(s) + + return res \ No newline at end of file From 4f85bcf98d173af452791c128a7b50cfd5718b85 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 18:33:51 -0400 Subject: [PATCH 35/40] Updated doc: If the answer choices is empty, then it is generation; else ranked choice. This will be the canonical approach when using PS. --- lm_eval/base.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 57f0dcc71a..6ab40a3a6c 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -673,12 +673,6 @@ def max_generation_length(self) -> Optional[int]: """Denote where the max length of the generation if it is obvious from the task.""" return None - def is_generation_task(self): - return ( - "BLEU" in self.prompt.metadata.metrics - or "ROUGE" in self.prompt.metadata.metrics - ) - def invalid_doc_for_prompt(self, doc) -> bool: """Some prompts may not work for some documents.""" if ( @@ -718,18 +712,14 @@ def construct_requests(self, doc, ctx): _requests = [] answer_choices_list = self.prompt.get_answer_choices_list(doc) - # We take a present answer_choices list to mean that we should apply the supplied - # metrics (hardcoded or accuracy atm) to the ranked choices. Otherwise, assume generation. - # Above we do something similar, but rely on the metrics requested (BLEU, ROUGE indicating generation). if answer_choices_list: - assert ( - not self.is_generation_task() - ), f"We expect this to be a ranked choice task; double check please." + # If answer_choices_list, then this is a ranked choice prompt. for answer_choice in answer_choices_list: ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") _requests.append(ll_answer_choice) else: - # TODO(Albert): What is the stop symbol? Is it model specific? + # If not, then this is a generation prompt. + # NOTE: In the future, target will be a list of strings. cont_request = rf.greedy_until( ctx, [self.stopping_criteria(), self.max_generation_length()] ) @@ -750,9 +740,11 @@ def process_results(self, doc, results): target = self.doc_to_target(doc).strip() answer_choices_list = self.prompt.get_answer_choices_list(doc) if answer_choices_list: - assert ( - not self.is_generation_task() - ), f"We expect this to be a ranked choice task; double check please." + # If answer_choices_list, then this is a ranked choice prompt. + # NOTE: In the future, target will be a list of strings. + # For now, we can assume there will be only 1 target, but its possible + # that this not the case so we should check for that. + pred = answer_choices_list[np.argmax(results)] out = {} @@ -765,7 +757,8 @@ def process_results(self, doc, results): # TODO: Add metrics here. return out else: - # NOTE: In the future, target may be a list, not a string. + # If not, then this is a generation prompt. + # NOTE: In the future, target will be a list of strings. pred = results[0].strip() out = {} From 21d897db56e8bf72bb0a2bdb3ace62cc7cba0bee Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 26 Apr 2022 18:45:57 -0400 Subject: [PATCH 36/40] Updated the requests so that its easier to understand. --- lm_eval/base.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 6ab40a3a6c..16707d771b 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -345,25 +345,27 @@ def _collate(x): reord = utils.Reorderer(requests, _collate) - for context, until in tqdm(reord.get_reordered()): - if isinstance(until, str): - until = [until] - max_length = None - elif isinstance(until, list) and len(until) == 2: - until, max_length = [until[0]], until[1] - elif isinstance(until, list): - max_length = None + for context, request_args in tqdm(reord.get_reordered()): + stopping_criteria = request_args["stopping_criteria"] + max_generation_length = request_args["max_generation_length"] + assert isinstance(stopping_criteria, str) or stopping_criteria is None + assert ( + isinstance(max_generation_length, int) or max_generation_length is None + ) + + until = [stopping_criteria] primary_until = self.tok_encode(until[0]) context_enc = torch.tensor( [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] ).to(self.device) - if max_length is not None: - max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks) - else: + if max_generation_length is None: max_length = context_enc.shape[1] + self.max_gen_toks - + else: + max_length = min( + max_generation_length, context_enc.shape[1] + self.max_gen_toks + ) cont = self._model_generate( context_enc, max_length, @@ -720,9 +722,11 @@ def construct_requests(self, doc, ctx): else: # If not, then this is a generation prompt. # NOTE: In the future, target will be a list of strings. - cont_request = rf.greedy_until( - ctx, [self.stopping_criteria(), self.max_generation_length()] - ) + request_args = { + "stopping_criteria": self.stopping_criteria(), + "max_generation_length": self.max_generation_length(), + } + cont_request = rf.greedy_until(ctx, request_args) _requests.append(cont_request) return _requests From a3a9a7c2d162e7ec59fa1c6bfc435e6214e5f797 Mon Sep 17 00:00:00 2001 From: Tian Yun Date: Tue, 26 Apr 2022 21:43:01 -0400 Subject: [PATCH 37/40] Added GPT-J --- lm_eval/models/__init__.py | 2 + lm_eval/models/gptj.py | 119 +++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 lm_eval/models/gptj.py diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index 171db2e9e2..6b31a9e633 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -1,4 +1,5 @@ from . import gpt2 +from . import gptj from . import gpt3 from . import t5 from . import t0 @@ -7,6 +8,7 @@ MODEL_REGISTRY = { "hf": gpt2.HFLM, "gpt2": gpt2.GPT2LM, + "gptj": gptj.GPTJLM, "gpt3": gpt3.GPT3LM, "t5": t5.T5LM, "mt5": t5.T5LM, diff --git a/lm_eval/models/gptj.py b/lm_eval/models/gptj.py new file mode 100644 index 0000000000..398ae03053 --- /dev/null +++ b/lm_eval/models/gptj.py @@ -0,0 +1,119 @@ +import transformers +import torch +from lm_eval.base import BaseLM + + +class GPTJLM(BaseLM): + def __init__( + self, + device="cuda", + batch_size=1, + ): + super().__init__() + + assert isinstance(device, str) + assert isinstance(batch_size, int) + + if device: + self._device = torch.device(device) + else: + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + + pretrained = "EleutherAI/gpt-j-6B" + self.gptj = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device) + self.gptj.eval() + + # pretrained tokenizer for neo is broken for now so just hard-coding this to gptj + self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained) + self.vocab_size = self.tokenizer.vocab_size + + # multithreading and batching + self.batch_size_per_gpu = batch_size # todo: adaptive batch size + + # TODO: fix multi-gpu + # gpus = torch.cuda.device_count() + # if gpus > 1: + # self.gptj = nn.DataParallel(self.gptj) + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + try: + return self.gptj.config.n_ctx + except AttributeError: + # gptneoconfig doesn't have n_ctx apparently + return self.gptj.config.max_position_embeddings + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + # TODO: fix multi-gpu + return self.batch_size_per_gpu # * gpus + + @property + def device(self): + # TODO: fix multi-gpu + return self._device + + def tok_encode(self, string: str): + return self.tokenizer.encode(string, add_special_tokens=False) + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def _model_call(self, inps): + """ + inps: a torch tensor of shape [batch, sequence] + the size of sequence may vary from call to call + + returns: a torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model + """ + with torch.no_grad(): + return self.gptj(inps)[0][:, :, :50257] + + def _get_stopping_criteria(self, stopping_criteria_ids): + class MultitokenEOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_seq_id: torch.LongTensor, tokenizer): + self.eos_seq = tokenizer.decode(eos_seq_id) + self.eos_seq_id = eos_seq_id + self.eos_seq_len = len(eos_seq_id) + 1 + self.tokenizer = tokenizer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + last_token_id = input_ids[0, -self.eos_seq_len:] + last_tokens = self.tokenizer.decode(last_token_id) + is_stopped = self.eos_seq in last_tokens + return is_stopped + + class EOSCriteria(transformers.StoppingCriteria): + def __init__(self, eos_token_id: torch.LongTensor): + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids[0,-1] == self.eos_token_id + + return transformers.StoppingCriteriaList([ + MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer), + EOSCriteria(self.tokenizer.eos_token) + ]) + + def _model_generate(self, context, max_length, stopping_criteria_ids): + stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) + return self.gptj.generate( + context, + max_length=max_length, + stopping_criteria=stopping_criteria, + do_sample=False, + ) From ff89667f61aee65eec298b73ca7d7253476fb9b4 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Tue, 26 Apr 2022 22:45:03 -0400 Subject: [PATCH 38/40] Add ROUGE metric to `PromptSourceTask` --- lm_eval/base.py | 48 +++++++++++++++++++++++++++++++++++++++------- lm_eval/metrics.py | 9 +++++++++ lm_eval/utils.py | 13 +++++++++++++ 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 57f0dcc71a..667d77ec86 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -650,8 +650,6 @@ class PromptSourceTask(Task): added default behavior for. If you want to add default behavior for a new metric, update the functions below. If you want to use one of the following metrics, *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`. - - WARNING: ROUGE is WIP. """ CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"]) @@ -768,7 +766,6 @@ def process_results(self, doc, results): # NOTE: In the future, target may be a list, not a string. pred = results[0].strip() out = {} - for metric in self.prompt.metadata.metrics: assert ( metric in self.CONFIGURED_PS_METRICS @@ -776,8 +773,15 @@ def process_results(self, doc, results): if metric == "BLEU": out["bleu"] = (target, pred) if metric == "ROUGE": - print("WARNING: Skipping Rouge.") - + # TODO: This computes all rouge sub-metrics. Find a generic + # way to handle user specified rouge sub-metrics to avoid extra + # compute. + rouge_scores = metrics.rouge(target, pred) + # Flatten rouge score dict. + rouge_scores = utils.flatten(rouge_scores) + # Merge all the rouge-type scores into the `out` dict. + out = {**out, **rouge_scores} + print(out) return out def higher_is_better(self): @@ -791,7 +795,22 @@ def higher_is_better(self): if metric == "BLEU": out["bleu"] = True if metric == "ROUGE": - print("WARNING: Skipping Rouge.") + # TODO: Find a generic way to handle user specified rouge metrics. + out["rouge1_precision"] = True + out["rouge1_recall"] = True + out["rouge1_fmeasure"] = True + + out["rouge2_precision"] = True + out["rouge2_recall"] = True + out["rouge2_fmeasure"] = True + + out["rougeL_precision"] = True + out["rougeL_recall"] = True + out["rougeL_fmeasure"] = True + + out["rougeLsum_precision"] = True + out["rougeLsum_recall"] = True + out["rougeLsum_fmeasure"] = True return out def aggregation(self): @@ -805,7 +824,22 @@ def aggregation(self): if metric == "BLEU": out["bleu"] = metrics.bleu if metric == "ROUGE": - print("WARNING: Skipping Rouge.") + # TODO: Find a generic way to handle user specified rouge metrics. + out["rouge1_precision"] = mean + out["rouge1_recall"] = mean + out["rouge1_fmeasure"] = mean + + out["rouge2_precision"] = mean + out["rouge2_recall"] = mean + out["rouge2_fmeasure"] = mean + + out["rougeL_precision"] = mean + out["rougeL_recall"] = mean + out["rougeL_fmeasure"] = mean + + out["rougeLsum_precision"] = mean + out["rougeLsum_recall"] = mean + out["rougeLsum_fmeasure"] = mean return out diff --git a/lm_eval/metrics.py b/lm_eval/metrics.py index 38f60ab52b..ba91e0c2ee 100644 --- a/lm_eval/metrics.py +++ b/lm_eval/metrics.py @@ -202,6 +202,15 @@ def rouge( :param pred: A single prediction `str`s. """ + + # Add newlines between sentences to correctly compute `rougeLsum`. + if "rougeLsum" in rouge_types: + # TODO: Adapt this to handle languages that do not support sentence endings by `.`. + # See GEM-metrics implementation with lang specific `nltk` tokenizers to + # split sentences. + pred = pred.replace(".", ".\n") + refs = [ref.replace(".", ".\n") for ref in refs] + scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) # ROUGE multi-ref jackknifing if len(refs) > 1: diff --git a/lm_eval/utils.py b/lm_eval/utils.py index e331283866..b508db96f8 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -146,6 +146,19 @@ def get_original(self, newarr): return res + +def flatten(d, parent_key='', sep='_'): + # From: https://stackoverflow.com/a/6027615 + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + def positional_deprecated(fn): """ A decorator to nudge users into passing only keyword args (`kwargs`) to the From 62a706fcf0bfb937390355b342482558f9a01291 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Tue, 26 Apr 2022 23:27:02 -0400 Subject: [PATCH 39/40] Add doc getter methods to template --- templates/new_task.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/templates/new_task.py b/templates/new_task.py index 10ba6eb513..fb3a3c5090 100644 --- a/templates/new_task.py +++ b/templates/new_task.py @@ -37,6 +37,41 @@ def has_test_docs(self): # TODO: Fill in the return with `True` if the Task has test data; else `False`. return False + def training_docs(self): + if self.has_training_docs(): + # We cache training documents in `self._training_docs` for faster + # few-shot processing. If the data is too large to fit in memory, + # return the training data as a generator instead of a list. + if self._training_docs is None: + # TODO: Return the training document generator from `self.dataset`. + # If you need to process the data, `map` over the documents with + # the custom procesing function, `self._process_doc`. E.g. + # `map(self._process_doc, self.dataset["validation"])` + # In most case you can leave this as is unless the dataset split is + # named differently than the default `"train"`. + self._training_docs = list(self.dataset["train"]) + return self._training_docs + + def validation_docs(self): + if self.has_validation_docs(): + # TODO: Return the validation document generator from `self.dataset`. + # If you need to process the data, `map` over the documents with the + # custom procesing function, `self._process_doc`. E.g. + # `map(self._process_doc, self.dataset["validation"])` + # In most case you can leave this as is unless the dataset split is + # named differently than the default `"validation"`. + return self.dataset["validation"] + + def test_docs(self): + if self.has_test_docs(): + # TODO: Return the test document generator from `self.dataset`. + # If you need to process the data, `map` over the documents with the + # custom processing function, `self._process_doc`. E.g. + # `map(self._process_doc, self.dataset["test"])` + # In most case you can leave this as is unless the dataset split is + # named differently than the default `"test"`. + return self.dataset["test"] + def stopping_criteria(self): # TODO: Denote the string where the generation should be split. # For example, for `coqa`, this is '\nQ:' and for `drop` '.'. From 60a07f106bcfb8c10976661ad12c715ee59f4511 Mon Sep 17 00:00:00 2001 From: Pawan Sasanka Ammanamanchi Date: Wed, 27 Apr 2022 11:29:26 +0530 Subject: [PATCH 40/40] Use eval-hackathon branch for installing prompt-source --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 94fb89cda2..c33c62b1c8 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ ], python_requires=">=3.6", install_requires=[ - "promptsource", + "promptsource @ git+https://github.com/bigscience-workshop/promptsource@eval-hackathon", "wrapt", "nltk", "jinja2",