diff --git a/lm_eval/tasks/opengptx/all_tasks_registry.py b/lm_eval/tasks/opengptx/all_tasks_registry.py index b05b47e5ec..d1289199a6 100644 --- a/lm_eval/tasks/opengptx/all_tasks_registry.py +++ b/lm_eval/tasks/opengptx/all_tasks_registry.py @@ -1,4 +1,5 @@ # OpenGPT-X tasks +from . import flores200 from . import german_europarl_ppl from . import german_ler_ppl from . import germanquad @@ -17,6 +18,9 @@ from . import xquad from . import xnli +######################################## +# Translation tasks +######################################## TASK_REGISTRY_TMP = { # OpenGPT-X tasks @@ -40,6 +44,8 @@ "xstance_fr": x_stance.XStanceFR, **xquad.construct_tasks(), **xnli.construct_tasks(), + **flores200.construct_lang_tasks(), + **flores200.construct_trans_tasks(), } # add a prefix to tasks implemented by OpenGPT-X diff --git a/lm_eval/tasks/opengptx/flores200.py b/lm_eval/tasks/opengptx/flores200.py new file mode 100644 index 0000000000..cc09e02779 --- /dev/null +++ b/lm_eval/tasks/opengptx/flores200.py @@ -0,0 +1,429 @@ +""" +NOTE: This file implements the Flores200 translation task, see +https://github.com/facebookresearch/flores/tree/main/flores200. +""" +import pycountry +import datasets +from itertools import permutations +from pprint import pprint +from sacrebleu import sacrebleu +from lm_eval import metrics +from lm_eval.base import Task, rf +from lm_eval.tasks.translation import code_to_language +from typing import List + + +class FloresBase(Task): + VERSION = 0 + DATASET_PATH = "facebook/flores" + DATASET_NAME = None + + def __init__(self, lang: str = None): + self.DATASET_NAME = self.language = lang + super().__init__() + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def has_training_docs(self): + return False + + def test_docs(self): + return self.dataset["devtest"] + + def validation_docs(self): + return self.dataset["dev"] + + def doc_to_text(self, doc): + return doc[f"sentence"] + + def doc_to_target(self, doc): + return None + + def should_decontaminate(self): + return True + + def doc_to_decontamination_query(self, doc): + return doc[f"sentence"] + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + return [rf.loglikelihood_rolling(ctx)] + + def process_results(self, doc, results): + ll = results[0] + return { + "ppl": ll, + } + + def aggregation(self): + """ + :returns: {str: [float] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metrics + """ + return { + "ppl": metrics.perplexity, + } + + def higher_is_better(self): + """ + :returns: {str: bool} + A dictionary where keys are the names of submetrics and values are + whether a higher value of the submetric is better + """ + return { + "ppl": False, + } + + def __str__(self): + return f"Flores200 Perplexity Task for {self.language}" + + +def create_ppl_task(lang, version=0): + class PPLTask(FloresBase): + VERSION = version + + def __init__(self): + super().__init__(lang) + + return PPLTask + + +def construct_lang_tasks(): + return {f"flores200-lang-{lang}": create_ppl_task(lang) for lang in _LANGUAGES} + + +class FloresTranslationTask(Task): + DATASET_PATH = "facebook/flores" + + def __init__(self, language_pair: str = None): + self.DATASET_NAME = self.language_pair = language_pair + super().__init__() + self.src_code, self.tgt_code = language_pair.split("-") + self.src_lang = code_to_language(self.src_code[:3]) + self.tgt_lang = code_to_language(self.tgt_code[:3]) + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def has_training_docs(self): + return False + + def test_docs(self): + return self.dataset["devtest"] + + def validation_docs(self): + return self.dataset["dev"] + + def doc_to_text(self, doc): + return ( + f"{self.src_lang} phrase: " + + doc[f"sentence_{self.src_code}"] + + f"\n{self.tgt_lang} phrase:" + ) + + def should_decontaminate(self): + return True + + def doc_to_decontamination_query(self, doc): + return doc[f"sentence_{self.src_code}"] + + def doc_to_target(self, doc): + return doc[f"sentence_{self.tgt_code}"] + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + return [ + rf.greedy_until(ctx, {"until": ["\n"]}), + ] + + def process_results(self, doc, results): + # These metrics are corpus-level not sentence level, so we'll hide the + # translation results in this dict and compute the corpus score in the + # aggregate method. + + pred = results + ref_pred = (self.doc_to_target(doc), pred) + return { + "bleu": ref_pred, + "chrf": ref_pred, + "ter": ref_pred, + } + + def aggregation(self): + """ + :returns: {str: [float] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metrics + """ + return { + "bleu": metrics.bleu, + "chrf": metrics.chrf, + "ter": metrics.ter, + } + + def higher_is_better(self): + """ + :returns: {str: bool} + A dictionary where keys are the names of submetrics and values are + whether a higher value of the submetric is better + """ + return { + "bleu": True, + "chrf": True, + "ter": False, + } + + def __str__(self): + return f"Flores200 Translation Task from {self.src_lang} to {self.tgt_lang}" + + +def create_translation_task(language_pair, version=0): + class TranslationTask(FloresTranslationTask): + VERSION = version + + def __init__(self): + super().__init__(language_pair) + + return TranslationTask + + +def construct_trans_tasks(): + """Symmetrically create all flores200 translation tasks from a list of languages.""" + return { + f"flores200-trans-{src}-{tgt}": create_translation_task(f"{src}-{tgt}") + for src, tgt in permutations(_LANGUAGES, 2) + } + + +_LANGUAGES = [ + # "ace_Arab", + # "ace_Latn", + # "acm_Arab", + # "acq_Arab", + # "aeb_Arab", + # "afr_Latn", + # "ajp_Arab", + # "aka_Latn", + # "als_Latn", + # "amh_Ethi", + # "apc_Arab", + # "arb_Arab", + # "arb_Latn", + # "ars_Arab", + # "ary_Arab", + # "arz_Arab", + # "asm_Beng", + # "ast_Latn", + # "awa_Deva", + # "ayr_Latn", + # "azb_Arab", + # "azj_Latn", + # "bak_Cyrl", + # "bam_Latn", + # "ban_Latn", + # "bel_Cyrl", + # "bem_Latn", + # "ben_Beng", + # "bho_Deva", + # "bjn_Arab", + # "bjn_Latn", + # "bod_Tibt", + # "bos_Latn", + # "bug_Latn", + # "bul_Cyrl", + # "cat_Latn", + # "ceb_Latn", + # "ces_Latn", + # "cjk_Latn", + # "ckb_Arab", + # "crh_Latn", + # "cym_Latn", + # "dan_Latn", + "deu_Latn", + # "dik_Latn", + # "dyu_Latn", + # "dzo_Tibt", + # "ell_Grek", + "eng_Latn", + # "epo_Latn", + # "est_Latn", + # "eus_Latn", + # "ewe_Latn", + # "fao_Latn", + # "fij_Latn", + # "fin_Latn", + # "fon_Latn", + "fra_Latn", + # "fur_Latn", + # "fuv_Latn", + # "gaz_Latn", + # "gla_Latn", + # "gle_Latn", + # "glg_Latn", + # "grn_Latn", + # "guj_Gujr", + # "hat_Latn", + # "hau_Latn", + # "heb_Hebr", + # "hin_Deva", + # "hne_Deva", + # "hrv_Latn", + # "hun_Latn", + # "hye_Armn", + # "ibo_Latn", + # "ilo_Latn", + # "ind_Latn", + # "isl_Latn", + "ita_Latn", + # "jav_Latn", + # "jpn_Jpan", + # "kab_Latn", + # "kac_Latn", + # "kam_Latn", + # "kan_Knda", + # "kas_Arab", + # "kas_Deva", + # "kat_Geor", + # "kaz_Cyrl", + # "kbp_Latn", + # "kea_Latn", + # "khk_Cyrl", + # "khm_Khmr", + # "kik_Latn", + # "kin_Latn", + # "kir_Cyrl", + # "kmb_Latn", + # "kmr_Latn", + # "knc_Arab", + # "knc_Latn", + # "kon_Latn", + # "kor_Hang", + # "lao_Laoo", + # "lij_Latn", + # "lim_Latn", + # "lin_Latn", + # "lit_Latn", + # "lmo_Latn", + # "ltg_Latn", + # "ltz_Latn", + # "lua_Latn", + # "lug_Latn", + # "luo_Latn", + # "lus_Latn", + # "lvs_Latn", + # "mag_Deva", + # "mai_Deva", + # "mal_Mlym", + # "mar_Deva", + # "min_Arab", + # "min_Latn", + # "mkd_Cyrl", + # "mlt_Latn", + # "mni_Beng", + # "mos_Latn", + # "mri_Latn", + # "mya_Mymr", + # "nld_Latn", + # "nno_Latn", + # "nob_Latn", + # "npi_Deva", + # "nso_Latn", + # "nus_Latn", + # "nya_Latn", + # "oci_Latn", + # "ory_Orya", + # "pag_Latn", + # "pan_Guru", + # "pap_Latn", + # "pbt_Arab", + # "pes_Arab", + # "plt_Latn", + # "pol_Latn", + # "por_Latn", + # "prs_Arab", + # "quy_Latn", + # "ron_Latn", + # "run_Latn", + # "rus_Cyrl", + # "sag_Latn", + # "san_Deva", + # "sat_Olck", + # "scn_Latn", + # "shn_Mymr", + # "sin_Sinh", + # "slk_Latn", + # "slv_Latn", + # "smo_Latn", + # "sna_Latn", + # "snd_Arab", + # "som_Latn", + # "sot_Latn", + "spa_Latn", + # "srd_Latn", + # "srp_Cyrl", + # "ssw_Latn", + # "sun_Latn", + # "swe_Latn", + # "swh_Latn", + # "szl_Latn", + # "tam_Taml", + # "taq_Latn", + # "taq_Tfng", + # "tat_Cyrl", + # "tel_Telu", + # "tgk_Cyrl", + # "tgl_Latn", + # "tha_Thai", + # "tir_Ethi", + # "tpi_Latn", + # "tsn_Latn", + # "tso_Latn", + # "tuk_Latn", + # "tum_Latn" + # "tur_Latn", + # "twi_Latn", + # "tzm_Tfng", + # "uig_Arab", + # "ukr_Cyrl", + # "umb_Latn", + # "urd_Arab", + # "uzn_Latn", + # "vec_Latn", + # "vie_Latn", + # "war_Latn", + # "wol_Latn", + # "xho_Latn", + # "ydd_Hebr", + # "yor_Latn", + # "yue_Hant", + # "zho_Hans", + # "zho_Hant", + # "zsm_Latn", + # "zul_Latn", +]