Skip to content

Commit

Permalink
relevance scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 10, 2024
1 parent 2d908a7 commit 30f5f7f
Showing 1 changed file with 151 additions and 12 deletions.
163 changes: 151 additions & 12 deletions weave/scorers/relevance_scorer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
import json
import os
from typing import Any, Optional

import numpy as np
from pydantic import PrivateAttr

import weave
from weave.scorers.base_scorer import Scorer
from weave.scorers.llm_utils import download_model, scorer_model_paths, set_device

try:
import torch
from transformers import pipeline
except ImportError:
import_failed = True
print(
"The `transformers` package is required to use the RelevanceScorer, please run `pip install transformers`"
)

RELEVANCE_INSTRUCTIONS = """You are an expert evaluator assessing the relevance of LLM-generated outputs relative to their input context.
Your goal is to provide a single relevance score and classification based on comprehensive analysis.
Relevance measures how effectively a generated output addresses its input context across three core dimensions:
Expand Down Expand Up @@ -71,7 +62,7 @@
"""


class RelevanceScorer(Scorer):
class OldRelevanceScorer(Scorer):
"""
Use wandb/relevance_scorer to check if the model output is relevant.
Expand All @@ -80,15 +71,22 @@ class RelevanceScorer(Scorer):
device: The device to use for inference. Defaults to `None`, which will use `cuda` if available.
"""

device: str = None
model_name_or_path: str = None
base_url: Optional[str] = None
device: str = None
_classifier: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
_id2label: dict[int, str] = PrivateAttr()
_system_prompt: str = PrivateAttr()

def model_post_init(self, __context: Any) -> None:
try:
import torch
from transformers import pipeline
except ImportError:
print(
"The `transformers` package is required to use the RelevanceScorer, please run `pip install transformers`"
)
if self.base_url:
print(f"Using external API at {self.base_url} for scoring.")
return # Skip local model loading if base_url is provided
Expand Down Expand Up @@ -218,3 +216,144 @@ def score(
prompt=input, completion=output, context=context, chat_history=chat_history
)
return self.score_messages(messages)

class RelevanceScorer(Scorer):
"""
A scorer that evaluates the relevance of model outputs relative to input queries and context.
This scorer uses a fine-tuned model to analyze whether outputs are semantically relevant to their
input queries and context. It processes text in chunks and returns both binary relevance flags
and detailed span-level scores.
Args:
model_name_or_path (str): Path or name of model weights to load
base_url (Optional[str]): Optional URL for external API scoring instead of local model
device (str): Device to run model on, defaults to "cpu"
threshold (float): Threshold for relevance classification, defaults to 0.7
return_all_scores (bool): Whether to return detailed span-level scores, defaults to False
debug (bool): Enable debug logging, defaults to False
Returns:
dict: A dictionary containing:
- flagged (bool): Whether the output was flagged as irrelevant
- extras (dict): Contains:
- score (float): Overall relevance score
- all_spans (list, optional): If return_all_scores=True, includes list of relevant
text spans and their scores
Example:
>>> scorer = RelevanceScorer(model_name_or_path="path/to/model")
>>> result = scorer.score(
... query="What is the capital of France?",
... documents=["Paris is the capital of France."]
... )
>>> print(result)
{
'flagged': False,
'extras': {
'score': 0.92,
'all_spans': [ # Only included if return_all_scores=True
{'text': 'Paris is the capital of France', 'scores': 0.92}
]
}
}
"""
model_name_or_path: str = None
base_url: Optional[str] = None
device: str = "cpu"
threshold: float = 0.7
return_all_scores: bool = False
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()

def model_post_init(self, __context: Any) -> None:
try:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
except ImportError:
print(
"The `transformers` and `torch` packages are required to use the RelevanceScorer, please run `pip install transformers torch`"
)
"""Initialize the model, tokenizer and device after pydantic initialization."""
self._model = AutoModelForTokenClassification.from_pretrained(
self.model_name_or_path, device_map=self.device
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self._model.eval()
self.device = set_device(self.device)

def _score_document(
self,
query: str,
document: str,
threshold: float) -> tuple[list[dict[str, Any]], int, int]:
"""Score a single document."""
import torch
with torch.no_grad():
input_text = query + f" {self._tokenizer.sep_token} " + document
model_inputs = self._tokenizer(
input_text,
truncation=True,
padding=False,
return_tensors="pt",
return_special_tokens_mask=True
)

model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}

special_tokens_mask = model_inputs.pop("special_tokens_mask")
combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool())

results = self._model(**model_inputs)

logits = results.logits[0].detach()
probabilities = torch.nn.functional.softmax(logits, dim=-1).detach()

pred_mask = (probabilities[:,1] > threshold).cpu().numpy().astype(int)
label_mask = (pred_mask & combined_mask.cpu().numpy()).flatten()

positive_probs = probabilities[:, 1].cpu().numpy()
transitions = np.diff(np.concatenate([[0], label_mask, [0]]))
starts = np.where(transitions == 1)[0]
ends = np.where(transitions == -1)[0]

spans_with_probs = []
token_ids = model_inputs["input_ids"].cpu().numpy()[0]

for start, end in zip(starts, ends):
span_text = self._tokenizer.decode(token_ids[start:end])
span_prob = positive_probs[start:end].mean()
spans_with_probs.append({
"text": span_text,
"scores": float(span_prob)
})

return spans_with_probs, int(label_mask.sum()), int(len(label_mask))

@weave.op
def score(
self,
query: str,
documents: list[str]) -> tuple[list[dict[str, Any]], float]:
"""Score multiple documents and compute weighted average relevance."""
all_spans = []
total_weighted_score = 0.0
total_length = 0

for doc in documents:
spans, relevant_tokens, total_tokens = self._score_document(query, doc, self.threshold)

all_spans.extend(spans)

if total_tokens > 0:
doc_score = relevant_tokens / total_tokens
doc_weight = total_tokens
total_weighted_score += doc_score * doc_weight
total_length += total_tokens

final_score = total_weighted_score / total_length if total_length > 0 else 0.0
output = {"flagged": final_score > self.threshold}
output['extras'] = {'score': final_score}
if self.return_all_scores:
output['extras']['all_spans'] = all_spans
return output

0 comments on commit 30f5f7f

Please sign in to comment.