Skip to content

Commit

Permalink
Merge pull request #11 from yhwang/detect-unitxt-in-evaluate
Browse files Browse the repository at this point in the history
use local unitxt if it exits for evaluation
  • Loading branch information
ruivieira authored Nov 27, 2024
2 parents 18058a9 + a240b7d commit 509ef85
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions lm_eval/tasks/unitxt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
Addressing this need, we present Unitxt, an innovative library for customizable textual data preparation and evaluation tailored to generative language models. Unitxt natively integrates with common libraries like HuggingFace and LM-eval-harness and deconstructs processing flows into modular components, enabling easy customization and sharing between practitioners. These components encompass model-specific formats, task prompts, and many other comprehensive dataset processing definitions. The Unitxt-Catalog centralizes these components, fostering collaboration and exploration in modern textual data workflows. Beyond being a tool, Unitxt is a community-driven platform, empowering users to build, share, and advance their pipelines collaboratively.
"""

import importlib.util
from functools import partial
from typing import Optional
from typing import Any, Dict, Optional

import evaluate
import datasets

from lm_eval.api.instance import Instance
from lm_eval.api.task import ConfigurableTask
Expand All @@ -24,13 +25,19 @@
}
"""

def is_unitxt_installed() -> bool:
return importlib.util.find_spec("unitxt") is not None

def score(items, metric):
predictions, references = zip(*items)
evaluator = evaluate.load("unitxt/metric")
for reference in references:
reference["metrics"] = [metric]
results = evaluator.compute(predictions=predictions, references=references)
if is_unitxt_installed():
from unitxt import evaluate
for reference in references:
reference["metrics"] = [metric]
results = evaluate(predictions,references)
else:
raise Exception("Please install unitxt via 'pip install unitxt'. For more information see: https://www.unitxt.ai/")

return results[0]["score"]["global"]["score"]


Expand All @@ -41,17 +48,26 @@ def __init__(
self,
config: Optional[dict] = None,
) -> None:
if config is None:
config = {}
assert "recipe" in config, "Unitxt task must have a 'recipe' string."
super().__init__(
config={
"metadata": {"version": self.VERSION},
"dataset_kwargs": {"trust_remote_code": True},
"dataset_name": config["recipe"],
"dataset_path": "unitxt/data",
}
)
self.image_decoder = datasets.Image()
self.metrics = self.dataset["test"][0]["metrics"]

def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
if is_unitxt_installed():
from unitxt import load_dataset

self.dataset = load_dataset(self.DATASET_NAME)
else:
raise Exception("Please install unitxt via 'pip install unitxt'. For more information see: https://www.unitxt.ai/")

def has_training_docs(self):
return "train" in self.dataset

Expand Down

0 comments on commit 509ef85

Please sign in to comment.