Skip to content

Commit

Permalink
remove torch
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 10, 2024
1 parent 3b954ff commit 7af240c
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions weave/scorers/perplexity_scorer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Union

import numpy as np
import torch
import torch.nn.functional as F
from openai.types.chat import ChatCompletion

import weave
Expand Down Expand Up @@ -53,6 +51,20 @@ def score(self, output: Union[ChatCompletion, list]) -> dict:

class HuggingFacePerplexityScorer(Scorer):
"""A scorer that computes perplexity for Hugging Face outputs using log probabilities."""
def model_post_init(self, __context: Any) -> None:
"""
Initialize the model and tokenizer. Imports are performed here to ensure they're only
loaded when an instance of LlamaGuard is created.
"""
try:
import torch
except ImportError as e:
raise ImportError(
"The `transformers` and `torch` packages are required to use LlamaGuard. "
"Please install them by running `pip install transformers torch`."
) from e



@weave.op()
def score(self, output: dict) -> dict:
Expand All @@ -67,6 +79,9 @@ def score(self, output: dict) -> dict:
Returns:
dict: A dictionary containing the calculated perplexity.
"""
import torch
import torch.nn.functional as F

logits = output["logits"]
input_ids = output["input_ids"]

Expand Down

0 comments on commit 7af240c

Please sign in to comment.