From 51949926881214499a10540e1a18d374e95af80f Mon Sep 17 00:00:00 2001 From: Thomas Capelle Date: Tue, 10 Dec 2024 11:17:55 -0800 Subject: [PATCH] return_all_scores --- weave/scorers/relevance_scorer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/weave/scorers/relevance_scorer.py b/weave/scorers/relevance_scorer.py index ee37e852283..18b370377e4 100644 --- a/weave/scorers/relevance_scorer.py +++ b/weave/scorers/relevance_scorer.py @@ -230,7 +230,6 @@ class RelevanceScorer(Scorer): base_url (Optional[str]): Optional URL for external API scoring instead of local model device (str): Device to run model on, defaults to "cpu" threshold (float): Threshold for relevance classification, defaults to 0.7 - return_all_scores (bool): Whether to return detailed span-level scores, defaults to False debug (bool): Enable debug logging, defaults to False Returns: @@ -262,7 +261,6 @@ class RelevanceScorer(Scorer): base_url: Optional[str] = None device: str = "cpu" threshold: float = 0.7 - return_all_scores: bool = False _model: Any = PrivateAttr() _tokenizer: Any = PrivateAttr() @@ -325,7 +323,7 @@ def _score_document( span_prob = positive_probs[start:end].mean() spans_with_probs.append({ "text": span_text, - "scores": float(span_prob) + "score": float(span_prob) }) return spans_with_probs, int(label_mask.sum()), int(len(label_mask)) @@ -334,7 +332,9 @@ def _score_document( def score( self, query: str, - documents: list[str]) -> tuple[list[dict[str, Any]], float]: + documents: list[str], + return_all_scores: bool = False + ) -> tuple[list[dict[str, Any]], float]: """Score multiple documents and compute weighted average relevance.""" all_spans = [] total_weighted_score = 0.0 @@ -354,6 +354,6 @@ def score( final_score = total_weighted_score / total_length if total_length > 0 else 0.0 output = {"flagged": final_score > self.threshold} output['extras'] = {'score': final_score} - if self.return_all_scores: + if return_all_scores: output['extras']['all_spans'] = all_spans return output