Skip to content

Commit

Permalink
Merge pull request #3 from cjlovering/add-rouge
Browse files Browse the repository at this point in the history
Add `ROUGE` metric to `PromptSourceTask`
  • Loading branch information
cjlovering authored Apr 27, 2022
2 parents d40a7ce + ff89667 commit 448bcbf
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
48 changes: 41 additions & 7 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,6 @@ class PromptSourceTask(Task):
added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
WARNING: ROUGE is WIP.
"""

CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
Expand Down Expand Up @@ -765,16 +763,22 @@ def process_results(self, doc, results):
# NOTE: In the future, target will be a list of strings.
pred = results[0].strip()
out = {}

for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")

# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
rouge_scores = metrics.rouge(target, pred)
# Flatten rouge score dict.
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
print(out)
return out

def higher_is_better(self):
Expand All @@ -788,7 +792,22 @@ def higher_is_better(self):
if metric == "BLEU":
out["bleu"] = True
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
# TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = True
out["rouge1_recall"] = True
out["rouge1_fmeasure"] = True

out["rouge2_precision"] = True
out["rouge2_recall"] = True
out["rouge2_fmeasure"] = True

out["rougeL_precision"] = True
out["rougeL_recall"] = True
out["rougeL_fmeasure"] = True

out["rougeLsum_precision"] = True
out["rougeLsum_recall"] = True
out["rougeLsum_fmeasure"] = True
return out

def aggregation(self):
Expand All @@ -802,7 +821,22 @@ def aggregation(self):
if metric == "BLEU":
out["bleu"] = metrics.bleu
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
# TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = mean
out["rouge1_recall"] = mean
out["rouge1_fmeasure"] = mean

out["rouge2_precision"] = mean
out["rouge2_recall"] = mean
out["rouge2_fmeasure"] = mean

out["rougeL_precision"] = mean
out["rougeL_recall"] = mean
out["rougeL_fmeasure"] = mean

out["rougeLsum_precision"] = mean
out["rougeLsum_recall"] = mean
out["rougeLsum_fmeasure"] = mean
return out


Expand Down
9 changes: 9 additions & 0 deletions lm_eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@ def rouge(
:param pred:
A single prediction `str`s.
"""

# Add newlines between sentences to correctly compute `rougeLsum`.
if "rougeLsum" in rouge_types:
# TODO: Adapt this to handle languages that do not support sentence endings by `.`.
# See GEM-metrics implementation with lang specific `nltk` tokenizers to
# split sentences.
pred = pred.replace(".", ".\n")
refs = [ref.replace(".", ".\n") for ref in refs]

scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True)
# ROUGE multi-ref jackknifing
if len(refs) > 1:
Expand Down
13 changes: 13 additions & 0 deletions lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ def get_original(self, newarr):

return res


def flatten(d, parent_key='', sep='_'):
# From: https://stackoverflow.com/a/6027615
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)


def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
Expand Down

0 comments on commit 448bcbf

Please sign in to comment.