diff --git a/weave/scorers/relevance_scorer.py b/weave/scorers/relevance_scorer.py index 50aaa2b8650..ee37e852283 100644 --- a/weave/scorers/relevance_scorer.py +++ b/weave/scorers/relevance_scorer.py @@ -1,22 +1,13 @@ import json import os from typing import Any, Optional - +import numpy as np from pydantic import PrivateAttr import weave from weave.scorers.base_scorer import Scorer from weave.scorers.llm_utils import download_model, scorer_model_paths, set_device -try: - import torch - from transformers import pipeline -except ImportError: - import_failed = True - print( - "The `transformers` package is required to use the RelevanceScorer, please run `pip install transformers`" - ) - RELEVANCE_INSTRUCTIONS = """You are an expert evaluator assessing the relevance of LLM-generated outputs relative to their input context. Your goal is to provide a single relevance score and classification based on comprehensive analysis. Relevance measures how effectively a generated output addresses its input context across three core dimensions: @@ -71,7 +62,7 @@ """ -class RelevanceScorer(Scorer): +class OldRelevanceScorer(Scorer): """ Use wandb/relevance_scorer to check if the model output is relevant. @@ -80,15 +71,22 @@ class RelevanceScorer(Scorer): device: The device to use for inference. Defaults to `None`, which will use `cuda` if available. """ - device: str = None model_name_or_path: str = None base_url: Optional[str] = None + device: str = None _classifier: Any = PrivateAttr() _tokenizer: Any = PrivateAttr() _id2label: dict[int, str] = PrivateAttr() _system_prompt: str = PrivateAttr() def model_post_init(self, __context: Any) -> None: + try: + import torch + from transformers import pipeline + except ImportError: + print( + "The `transformers` package is required to use the RelevanceScorer, please run `pip install transformers`" + ) if self.base_url: print(f"Using external API at {self.base_url} for scoring.") return # Skip local model loading if base_url is provided @@ -218,3 +216,144 @@ def score( prompt=input, completion=output, context=context, chat_history=chat_history ) return self.score_messages(messages) + +class RelevanceScorer(Scorer): + """ + A scorer that evaluates the relevance of model outputs relative to input queries and context. + + This scorer uses a fine-tuned model to analyze whether outputs are semantically relevant to their + input queries and context. It processes text in chunks and returns both binary relevance flags + and detailed span-level scores. + + Args: + model_name_or_path (str): Path or name of model weights to load + base_url (Optional[str]): Optional URL for external API scoring instead of local model + device (str): Device to run model on, defaults to "cpu" + threshold (float): Threshold for relevance classification, defaults to 0.7 + return_all_scores (bool): Whether to return detailed span-level scores, defaults to False + debug (bool): Enable debug logging, defaults to False + + Returns: + dict: A dictionary containing: + - flagged (bool): Whether the output was flagged as irrelevant + - extras (dict): Contains: + - score (float): Overall relevance score + - all_spans (list, optional): If return_all_scores=True, includes list of relevant + text spans and their scores + + Example: + >>> scorer = RelevanceScorer(model_name_or_path="path/to/model") + >>> result = scorer.score( + ... query="What is the capital of France?", + ... documents=["Paris is the capital of France."] + ... ) + >>> print(result) + { + 'flagged': False, + 'extras': { + 'score': 0.92, + 'all_spans': [ # Only included if return_all_scores=True + {'text': 'Paris is the capital of France', 'scores': 0.92} + ] + } + } + """ + model_name_or_path: str = None + base_url: Optional[str] = None + device: str = "cpu" + threshold: float = 0.7 + return_all_scores: bool = False + _model: Any = PrivateAttr() + _tokenizer: Any = PrivateAttr() + + def model_post_init(self, __context: Any) -> None: + try: + import torch + from transformers import AutoModelForTokenClassification, AutoTokenizer + except ImportError: + print( + "The `transformers` and `torch` packages are required to use the RelevanceScorer, please run `pip install transformers torch`" + ) + """Initialize the model, tokenizer and device after pydantic initialization.""" + self._model = AutoModelForTokenClassification.from_pretrained( + self.model_name_or_path, device_map=self.device + ) + self._tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) + self._model.eval() + self.device = set_device(self.device) + + def _score_document( + self, + query: str, + document: str, + threshold: float) -> tuple[list[dict[str, Any]], int, int]: + """Score a single document.""" + import torch + with torch.no_grad(): + input_text = query + f" {self._tokenizer.sep_token} " + document + model_inputs = self._tokenizer( + input_text, + truncation=True, + padding=False, + return_tensors="pt", + return_special_tokens_mask=True + ) + + model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()} + + special_tokens_mask = model_inputs.pop("special_tokens_mask") + combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool()) + + results = self._model(**model_inputs) + + logits = results.logits[0].detach() + probabilities = torch.nn.functional.softmax(logits, dim=-1).detach() + + pred_mask = (probabilities[:,1] > threshold).cpu().numpy().astype(int) + label_mask = (pred_mask & combined_mask.cpu().numpy()).flatten() + + positive_probs = probabilities[:, 1].cpu().numpy() + transitions = np.diff(np.concatenate([[0], label_mask, [0]])) + starts = np.where(transitions == 1)[0] + ends = np.where(transitions == -1)[0] + + spans_with_probs = [] + token_ids = model_inputs["input_ids"].cpu().numpy()[0] + + for start, end in zip(starts, ends): + span_text = self._tokenizer.decode(token_ids[start:end]) + span_prob = positive_probs[start:end].mean() + spans_with_probs.append({ + "text": span_text, + "scores": float(span_prob) + }) + + return spans_with_probs, int(label_mask.sum()), int(len(label_mask)) + + @weave.op + def score( + self, + query: str, + documents: list[str]) -> tuple[list[dict[str, Any]], float]: + """Score multiple documents and compute weighted average relevance.""" + all_spans = [] + total_weighted_score = 0.0 + total_length = 0 + + for doc in documents: + spans, relevant_tokens, total_tokens = self._score_document(query, doc, self.threshold) + + all_spans.extend(spans) + + if total_tokens > 0: + doc_score = relevant_tokens / total_tokens + doc_weight = total_tokens + total_weighted_score += doc_score * doc_weight + total_length += total_tokens + + final_score = total_weighted_score / total_length if total_length > 0 else 0.0 + output = {"flagged": final_score > self.threshold} + output['extras'] = {'score': final_score} + if self.return_all_scores: + output['extras']['all_spans'] = all_spans + return output