Skip to content

Commit

Permalink
Merge pull request #2 from cjlovering/master
Browse files Browse the repository at this point in the history
Pulling eval harness updates
  • Loading branch information
StellaAthena authored Apr 27, 2022
2 parents 6caa0af + 18af502 commit 5499919
Show file tree
Hide file tree
Showing 22 changed files with 1,667 additions and 1,019 deletions.
470 changes: 369 additions & 101 deletions lm_eval/base.py

Large diffs are not rendered by default.

181 changes: 128 additions & 53 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,38 @@
import itertools
import pathlib
import random

import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import promptsource
import numpy as np

from promptsource.templates import DatasetTemplates
from lm_eval.utils import positional_deprecated, run_task_tests


@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, check_integrity=False):
def simple_evaluate(
model,
model_args=None,
tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
Expand All @@ -37,7 +50,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return
Expand All @@ -49,20 +62,28 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified"

if isinstance(model, str):
if model_args is None: model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.base.LM)
lm = model

# TODO: Hard-code turning off cache while testing. Remove once testing is completed.
no_cache = True
if not no_cache:
lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks)

task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks)

if check_integrity:
run_task_tests(task_list=tasks)
Expand All @@ -72,7 +93,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict,
num_fewshot=num_fewshot,
limit=limit,
description_dict=description_dict
description_dict=description_dict,
)

# add info about the model and few shot config
Expand All @@ -85,14 +106,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict
"description_dict": description_dict,
}

return results


@positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None):
def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
Expand All @@ -108,7 +137,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:return
Dictionary of results
"""
Expand All @@ -118,12 +147,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented.
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)

task_dict_items = [
(name, task)
for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs())
if (task.has_validation_docs() or task.has_test_docs())
]

results = collections.defaultdict(dict)
Expand All @@ -141,8 +172,12 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
docs = {}

# get lists of each type of request
for task_name, task in task_dict_items:
versions[task_name] = task.VERSION
for task_prompt_name, task in task_dict_items:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue

versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs():
Expand All @@ -158,15 +193,19 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42)
rnd.shuffle(task_docs)

description = description_dict[task_name] if description_dict and task_name in description_dict else ""
description = (
description_dict[task_prompt_name]
if description_dict and task_prompt_name in description_dict
else ""
)

for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc
if task.invalid_doc_for_prompt(doc):
continue

docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
Expand All @@ -175,7 +214,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
requests[req.request_type].append(req)
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id)
)

# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
Expand All @@ -189,43 +230,49 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,

print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]

for resp, (i, task_prompt_name, doc, doc_id) in zip(
resps, requests_origin[reqtype]
):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp))

for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))

vals = collections.defaultdict(list)

# unpack results and sort back in order and return control to Task
for (task_name, doc_id), requests in process_res_queue.items():
for (task_prompt_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]

task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)]

metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
vals[(task_prompt_name, metric)].append(value)

# aggregate results
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items)

# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)

return {
"results": dict(results),
"versions": dict(versions)
}
results[task_prompt_name][metric + "_stderr"] = stderr(items)

return {"results": dict(results), "versions": dict(versions)}


def make_table(result_dict):
Expand All @@ -234,22 +281,50 @@ def make_table(result_dict):

md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
md_writer.headers = ["Task", "Prompt", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = [
"Task",
"Prompt",
"Version",
"Metric",
"Value",
"",
"Stderr",
]

values = []

for k, dic in result_dict["results"].items():
version = result_dict["versions"][k]
for m, v in dic.items():
if m.endswith("_stderr"):
continue

if "_name" in m:
continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
values.append(
[
dic["task_name"],
dic["prompt_name"],
version,
m,
"%.4f" % v,
"±",
"%.4f" % se,
]
)
else:
values.append([k, version, m, '%.4f' % v, '', ''])
values.append(
[
dic["task_name"],
dic["prompt_name"],
version,
m,
"%.4f" % v,
"",
"",
]
)
k = ""
version = ""
md_writer.value_matrix = values
Expand Down
Loading

0 comments on commit 5499919

Please sign in to comment.