From 7af240cbf88f99ba7ea9ebf653b3c951bf99b414 Mon Sep 17 00:00:00 2001 From: Thomas Capelle Date: Tue, 10 Dec 2024 12:19:01 -0800 Subject: [PATCH] remove torch --- weave/scorers/perplexity_scorer.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/weave/scorers/perplexity_scorer.py b/weave/scorers/perplexity_scorer.py index c097ee5922a..cc75943e2ce 100644 --- a/weave/scorers/perplexity_scorer.py +++ b/weave/scorers/perplexity_scorer.py @@ -1,8 +1,6 @@ from typing import Union import numpy as np -import torch -import torch.nn.functional as F from openai.types.chat import ChatCompletion import weave @@ -53,6 +51,20 @@ def score(self, output: Union[ChatCompletion, list]) -> dict: class HuggingFacePerplexityScorer(Scorer): """A scorer that computes perplexity for Hugging Face outputs using log probabilities.""" + def model_post_init(self, __context: Any) -> None: + """ + Initialize the model and tokenizer. Imports are performed here to ensure they're only + loaded when an instance of LlamaGuard is created. + """ + try: + import torch + except ImportError as e: + raise ImportError( + "The `transformers` and `torch` packages are required to use LlamaGuard. " + "Please install them by running `pip install transformers torch`." + ) from e + + @weave.op() def score(self, output: dict) -> dict: @@ -67,6 +79,9 @@ def score(self, output: dict) -> dict: Returns: dict: A dictionary containing the calculated perplexity. """ + import torch + import torch.nn.functional as F + logits = output["logits"] input_ids = output["input_ids"]