From 5cdc40ef6e4f6409612e4f12a954dbfd7b15c3fd Mon Sep 17 00:00:00 2001 From: Yihong Wang Date: Fri, 22 Nov 2024 10:51:27 -0800 Subject: [PATCH] use local unitxt if it exits for evaluation detect local installed unitxt before fetching `unitxt/metric` from HF Signed-off-by: Yihong Wang --- lm_eval/tasks/unitxt/task.py | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/lm_eval/tasks/unitxt/task.py b/lm_eval/tasks/unitxt/task.py index 339a3076c5..3385980da4 100644 --- a/lm_eval/tasks/unitxt/task.py +++ b/lm_eval/tasks/unitxt/task.py @@ -4,10 +4,11 @@ Addressing this need, we present Unitxt, an innovative library for customizable textual data preparation and evaluation tailored to generative language models. Unitxt natively integrates with common libraries like HuggingFace and LM-eval-harness and deconstructs processing flows into modular components, enabling easy customization and sharing between practitioners. These components encompass model-specific formats, task prompts, and many other comprehensive dataset processing definitions. The Unitxt-Catalog centralizes these components, fostering collaboration and exploration in modern textual data workflows. Beyond being a tool, Unitxt is a community-driven platform, empowering users to build, share, and advance their pipelines collaboratively. """ +import importlib.util from functools import partial -from typing import Optional +from typing import Any, Dict, Optional -import evaluate +import datasets from lm_eval.api.instance import Instance from lm_eval.api.task import ConfigurableTask @@ -24,13 +25,23 @@ } """ +def is_unitxt_installed() -> bool: + return importlib.util.find_spec("unitxt") is not None def score(items, metric): predictions, references = zip(*items) - evaluator = evaluate.load("unitxt/metric") - for reference in references: - reference["metrics"] = [metric] - results = evaluator.compute(predictions=predictions, references=references) + if is_unitxt_installed(): + from unitxt import evaluate + for reference in references: + reference["metrics"] = [metric] + results = evaluate(predictions,references) + else: + import evaluate + evaluator = evaluate.load("unitxt/metric") + for reference in references: + reference["metrics"] = [metric] + results = evaluator.compute(predictions=predictions, references=references) + return results[0]["score"]["global"]["score"] @@ -41,17 +52,30 @@ def __init__( self, config: Optional[dict] = None, ) -> None: + if config is None: + config = {} assert "recipe" in config, "Unitxt task must have a 'recipe' string." super().__init__( config={ "metadata": {"version": self.VERSION}, - "dataset_kwargs": {"trust_remote_code": True}, "dataset_name": config["recipe"], - "dataset_path": "unitxt/data", } ) + self.image_decoder = datasets.Image() self.metrics = self.dataset["test"][0]["metrics"] + def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None: + if is_unitxt_installed(): + from unitxt import load_dataset + + self.dataset = load_dataset(self.DATASET_NAME) + else: + self.dataset = datasets.load_dataset( + name=self.DATASET_NAME, + path="unitxt/data", + trust_remote_code=True, + ) + def has_training_docs(self): return "train" in self.dataset