diff --git a/lm_eval/base.py b/lm_eval/base.py index 16707d771b..63eb38efc8 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -652,8 +652,6 @@ class PromptSourceTask(Task): added default behavior for. If you want to add default behavior for a new metric, update the functions below. If you want to use one of the following metrics, *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`. - - WARNING: ROUGE is WIP. """ CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"]) @@ -765,7 +763,6 @@ def process_results(self, doc, results): # NOTE: In the future, target will be a list of strings. pred = results[0].strip() out = {} - for metric in self.prompt.metadata.metrics: assert ( metric in self.CONFIGURED_PS_METRICS @@ -773,8 +770,15 @@ def process_results(self, doc, results): if metric == "BLEU": out["bleu"] = (target, pred) if metric == "ROUGE": - print("WARNING: Skipping Rouge.") - + # TODO: This computes all rouge sub-metrics. Find a generic + # way to handle user specified rouge sub-metrics to avoid extra + # compute. + rouge_scores = metrics.rouge(target, pred) + # Flatten rouge score dict. + rouge_scores = utils.flatten(rouge_scores) + # Merge all the rouge-type scores into the `out` dict. + out = {**out, **rouge_scores} + print(out) return out def higher_is_better(self): @@ -788,7 +792,22 @@ def higher_is_better(self): if metric == "BLEU": out["bleu"] = True if metric == "ROUGE": - print("WARNING: Skipping Rouge.") + # TODO: Find a generic way to handle user specified rouge metrics. + out["rouge1_precision"] = True + out["rouge1_recall"] = True + out["rouge1_fmeasure"] = True + + out["rouge2_precision"] = True + out["rouge2_recall"] = True + out["rouge2_fmeasure"] = True + + out["rougeL_precision"] = True + out["rougeL_recall"] = True + out["rougeL_fmeasure"] = True + + out["rougeLsum_precision"] = True + out["rougeLsum_recall"] = True + out["rougeLsum_fmeasure"] = True return out def aggregation(self): @@ -802,7 +821,22 @@ def aggregation(self): if metric == "BLEU": out["bleu"] = metrics.bleu if metric == "ROUGE": - print("WARNING: Skipping Rouge.") + # TODO: Find a generic way to handle user specified rouge metrics. + out["rouge1_precision"] = mean + out["rouge1_recall"] = mean + out["rouge1_fmeasure"] = mean + + out["rouge2_precision"] = mean + out["rouge2_recall"] = mean + out["rouge2_fmeasure"] = mean + + out["rougeL_precision"] = mean + out["rougeL_recall"] = mean + out["rougeL_fmeasure"] = mean + + out["rougeLsum_precision"] = mean + out["rougeLsum_recall"] = mean + out["rougeLsum_fmeasure"] = mean return out diff --git a/lm_eval/metrics.py b/lm_eval/metrics.py index 38f60ab52b..ba91e0c2ee 100644 --- a/lm_eval/metrics.py +++ b/lm_eval/metrics.py @@ -202,6 +202,15 @@ def rouge( :param pred: A single prediction `str`s. """ + + # Add newlines between sentences to correctly compute `rougeLsum`. + if "rougeLsum" in rouge_types: + # TODO: Adapt this to handle languages that do not support sentence endings by `.`. + # See GEM-metrics implementation with lang specific `nltk` tokenizers to + # split sentences. + pred = pred.replace(".", ".\n") + refs = [ref.replace(".", ".\n") for ref in refs] + scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) # ROUGE multi-ref jackknifing if len(refs) > 1: diff --git a/lm_eval/utils.py b/lm_eval/utils.py index e331283866..b508db96f8 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -146,6 +146,19 @@ def get_original(self, newarr): return res + +def flatten(d, parent_key='', sep='_'): + # From: https://stackoverflow.com/a/6027615 + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + def positional_deprecated(fn): """ A decorator to nudge users into passing only keyword args (`kwargs`) to the