Skip to content

Commit

Permalink
use local unitxt if it exits for evaluation
Browse files Browse the repository at this point in the history
detect local installed unitxt before fetching
`unitxt/metric` from HF

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang committed Nov 22, 2024
1 parent dcce372 commit 5cdc40e
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions lm_eval/tasks/unitxt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
Addressing this need, we present Unitxt, an innovative library for customizable textual data preparation and evaluation tailored to generative language models. Unitxt natively integrates with common libraries like HuggingFace and LM-eval-harness and deconstructs processing flows into modular components, enabling easy customization and sharing between practitioners. These components encompass model-specific formats, task prompts, and many other comprehensive dataset processing definitions. The Unitxt-Catalog centralizes these components, fostering collaboration and exploration in modern textual data workflows. Beyond being a tool, Unitxt is a community-driven platform, empowering users to build, share, and advance their pipelines collaboratively.
"""

import importlib.util
from functools import partial
from typing import Optional
from typing import Any, Dict, Optional

import evaluate
import datasets

from lm_eval.api.instance import Instance
from lm_eval.api.task import ConfigurableTask
Expand All @@ -24,13 +25,23 @@
}
"""

def is_unitxt_installed() -> bool:
return importlib.util.find_spec("unitxt") is not None

def score(items, metric):
predictions, references = zip(*items)
evaluator = evaluate.load("unitxt/metric")
for reference in references:
reference["metrics"] = [metric]
results = evaluator.compute(predictions=predictions, references=references)
if is_unitxt_installed():
from unitxt import evaluate
for reference in references:
reference["metrics"] = [metric]
results = evaluate(predictions,references)
else:
import evaluate
evaluator = evaluate.load("unitxt/metric")
for reference in references:
reference["metrics"] = [metric]
results = evaluator.compute(predictions=predictions, references=references)

return results[0]["score"]["global"]["score"]


Expand All @@ -41,17 +52,30 @@ def __init__(
self,
config: Optional[dict] = None,
) -> None:
if config is None:
config = {}
assert "recipe" in config, "Unitxt task must have a 'recipe' string."
super().__init__(
config={
"metadata": {"version": self.VERSION},
"dataset_kwargs": {"trust_remote_code": True},
"dataset_name": config["recipe"],
"dataset_path": "unitxt/data",
}
)
self.image_decoder = datasets.Image()
self.metrics = self.dataset["test"][0]["metrics"]

def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
if is_unitxt_installed():
from unitxt import load_dataset

self.dataset = load_dataset(self.DATASET_NAME)
else:
self.dataset = datasets.load_dataset(
name=self.DATASET_NAME,
path="unitxt/data",
trust_remote_code=True,
)

def has_training_docs(self):
return "train" in self.dataset

Expand Down

0 comments on commit 5cdc40e

Please sign in to comment.