+
3:
+ if len(dataset_column_names) > 10:
dataset_column_names_str += ", ..."
required_arg_names = [
param.name
for param in score_signature.parameters.values()
if param.default == inspect.Parameter.empty
]
- required_arg_names.remove("model_output")
+ required_arg_names.remove("output")
message = textwrap.dedent(
f"""
Call error: {e}
+ If using the `Scorer` weave class, you can set the `scorer.column_map`
+ attribute to map scorer argument names to dataset columns.
+
+ For example, if the `score` expects "output", "input" and "ground_truth" and we have a dataset
+ with columns "question" and "answer", `column_map` can be used to map the non-output parameter like so:
+ {{"input": "question", "ground_truth": "answer"}}
+
+ scorer argument names: {score_arg_names}
+ dataset keys: {example.keys()}
+ scorer.column_map: {getattr(scorer, 'column_map', '{}')}
+
Options for resolving:
- a. change {scorer_name} argument names to match a subset of dataset column names ({dataset_column_names_str})
- b. change dataset column names to match expected {scorer_name} argument names: {required_arg_names}
+ a. if using the `Scorer` weave class, you can set the `scorer.column_map` attribute to map scorer argument names to dataset column names or
+ b. change the argument names the in the scoring function of {scorer_name} to match a subset of dataset column names: ({dataset_column_names_str}) or
+ c. change dataset column names to match expected {scorer_name} argument names: {required_arg_names}
"""
)
raise OpCallError(message)
scores[scorer_name] = result
return {
- "model_output": model_output,
+ "output": model_output,
"scores": scores,
"model_latency": model_latency,
}
@@ -341,7 +441,7 @@ async def eval_example(example: dict) -> dict:
except Exception as e:
print("Predict and score failed")
traceback.print_exc()
- return {"model_output": None, "scores": {}}
+ return {"output": None, "scores": {}}
return eval_row
n_complete = 0
@@ -358,7 +458,7 @@ async def eval_example(example: dict) -> dict:
# f"Evaluating... {duration:.2f}s [{n_complete} / {len(self.dataset.rows)} complete]" # type:ignore
# )
if eval_row is None:
- eval_row = {"model_output": None, "scores": {}}
+ eval_row = {"output": None, "scores": {}}
else:
eval_row["scores"] = eval_row.get("scores", {})
for scorer in self.scorers or []:
diff --git a/weave/flow/scorer.py b/weave/flow/scorer.py
index e69f3afeb3f..86df3d6a055 100644
--- a/weave/flow/scorer.py
+++ b/weave/flow/scorer.py
@@ -1,158 +1,12 @@
-from collections import defaultdict
-from numbers import Number
-from typing import Any, Callable, Optional, Sequence, Tuple, Union
-
-import numpy as np
-from pydantic import BaseModel
-
-import weave
-from weave.flow.obj import Object
-from weave.trace.isinstance import weave_isinstance
-from weave.trace.op import Op, as_op, is_op
-
-
-class Scorer(Object):
- def score(self, target: Any, model_output: Any) -> Any:
- raise NotImplementedError
-
- @weave.op()
- def summarize(self, score_rows: list) -> Optional[dict]:
- return auto_summarize(score_rows)
-
-
-def stderr(data: Sequence[Union[int, float]]) -> float:
- if len(data) > 1:
- sample_variance = np.var(data, ddof=1)
- return float(np.sqrt(sample_variance / len(data)))
- else:
- return 0
-
-
-def auto_summarize(data: list) -> Optional[dict[str, Any]]:
- """Automatically summarize a list of (potentially nested) dicts.
-
- Computes:
- - avg for numeric cols
- - count and fraction for boolean cols
- - other col types are ignored
-
- If col is all None, result is None
-
- Returns:
- dict of summary stats, with structure matching input dict structure.
- """
- if not data:
- return {}
- data = [x for x in data if x is not None]
-
- if not data:
- return None
-
- val = data[0]
-
- if isinstance(val, bool):
- return {
- "true_count": (true_count := sum(1 for x in data if x)),
- "true_fraction": true_count / len(data),
- }
- elif isinstance(val, Number):
- return {"mean": np.mean(data).item()}
- elif isinstance(val, dict):
- result = {}
- all_keys = set().union(*[x.keys() for x in data if isinstance(x, dict)])
- for k in all_keys:
- if (
- summary := auto_summarize(
- [x.get(k) for x in data if isinstance(x, dict)]
- )
- ) is not None:
- if k in summary:
- result.update(summary)
- else:
- result[k] = summary
- if not result:
- return None
- return result
- elif isinstance(val, BaseModel):
- return auto_summarize([x.model_dump() for x in data])
- return None
-
-
-def get_scorer_attributes(
- scorer: Union[Callable, Op, Scorer],
-) -> Tuple[str, Callable, Callable]:
- if weave_isinstance(scorer, Scorer):
- scorer_name = scorer.name
- if scorer_name is None:
- scorer_name = scorer.__class__.__name__
- try:
- score_fn = scorer.score
- summarize_fn = scorer.summarize # type: ignore
- except AttributeError:
- raise ValueError(
- f"Scorer {scorer_name} must implement score and summarize methods. Did you forget to wrap with @weave.op()?"
- )
- elif callable(scorer):
- if is_op(scorer):
- scorer = as_op(scorer)
- scorer_name = scorer.name
- else:
- scorer_name = scorer.__name__
- score_fn = scorer
- summarize_fn = auto_summarize # type: ignore
- else:
- raise ValueError(f"Unknown scorer type: {scorer}")
- return (scorer_name, score_fn, summarize_fn) # type: ignore
-
-
-def p_r_f1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]:
- # if any denom is zero, then zero. could use NaN instead...
- precision: float = 0
- if tp or fp:
- precision = tp / (tp + fp)
- recall: float = 0
- if tp or fn:
- recall = tp / (tp + fn)
- f1: float = 0
- if precision or recall:
- f1 = 2 * (precision * recall) / (precision + recall)
- return precision, recall, f1
-
-
-class MultiTaskBinaryClassificationF1(Scorer):
- class_names: list[str]
-
- @weave.op()
- def summarize(self, score_rows: list) -> Optional[dict]:
- result = {}
- cols = transpose(score_rows)
-
- for class_name in self.class_names:
- col = cols[class_name]
- tp = sum(r["correct"] and not r["negative"] for r in col)
- fp = sum(not r["correct"] and not r["negative"] for r in col)
- fn = sum(not r["correct"] and r["negative"] for r in col)
- precision, recall, f1 = p_r_f1(tp, fp, fn)
- result[class_name] = {"f1": f1, "precision": precision, "recall": recall}
-
- return result
-
- @weave.op()
- def score(self, target: dict, model_output: Optional[dict]) -> dict:
- result = {}
- for class_name in self.class_names:
- class_label = target.get(class_name)
- class_model_output = model_output.get(class_name) if model_output else None
- result[class_name] = {
- "correct": class_label == class_model_output,
- "negative": not class_model_output,
- }
- return result
-
-
-def transpose(rows: list[dict]) -> dict[str, list]:
- cols = defaultdict(list)
- for row in rows:
- for k, v in row.items():
- cols[k].append(v)
- return dict(cols)
+# Keeping this file for now to avoid breaking changes.
+# In future, users should import all scoring functionality from weave.scorers
+import warnings
+
+from weave.scorers import *
+
+warnings.warn(
+ "Importing from weave.flow.scorer is deprecated. "
+ "Please import from weave.scorers in the future.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/weave/scorers/__init__.py b/weave/scorers/__init__.py
new file mode 100644
index 00000000000..941f48e7b13
--- /dev/null
+++ b/weave/scorers/__init__.py
@@ -0,0 +1,55 @@
+from weave.scorers.base_scorer import (
+ Scorer,
+ auto_summarize,
+ get_scorer_attributes,
+)
+from weave.scorers.classification_scorer import (
+ MultiTaskBinaryClassificationF1,
+ transpose,
+)
+from weave.scorers.hallucination_scorer import HallucinationFreeScorer
+from weave.scorers.json_scorer import ValidJSONScorer
+from weave.scorers.llm_scorer import (
+ InstructorLLMScorer,
+ LLMScorer,
+)
+from weave.scorers.llm_utils import (
+ create,
+ embed,
+)
+from weave.scorers.moderation_scorer import OpenAIModerationScorer
+from weave.scorers.pydantic_scorer import PydanticScorer
+from weave.scorers.ragas_scorer import (
+ ContextEntityRecallScorer,
+ ContextRelevancyScorer,
+)
+from weave.scorers.similarity_scorer import EmbeddingSimilarityScorer
+from weave.scorers.string_scorer import (
+ LevenshteinScorer,
+ StringMatchScorer,
+)
+from weave.scorers.summarization_scorer import SummarizationScorer
+from weave.scorers.xml_scorer import ValidXMLScorer
+
+__all__ = [
+ "auto_summarize",
+ "create",
+ "embed",
+ "ContextEntityRecallScorer",
+ "ContextRelevancyScorer",
+ "EmbeddingSimilarityScorer",
+ "get_scorer_attributes",
+ "HallucinationFreeScorer",
+ "InstructorLLMScorer",
+ "ValidJSONScorer",
+ "LevenshteinScorer",
+ "LLMScorer",
+ "MultiTaskBinaryClassificationF1",
+ "OpenAIModerationScorer",
+ "PydanticScorer",
+ "Scorer",
+ "StringMatchScorer",
+ "SummarizationScorer",
+ "transpose",
+ "ValidXMLScorer",
+]
diff --git a/weave/scorers/base_scorer.py b/weave/scorers/base_scorer.py
new file mode 100644
index 00000000000..a0eec1ac09c
--- /dev/null
+++ b/weave/scorers/base_scorer.py
@@ -0,0 +1,109 @@
+from numbers import Number
+from typing import Any, Callable, Optional, Sequence, Tuple, Union
+
+import numpy as np
+from pydantic import BaseModel, Field
+
+import weave
+from weave.flow.obj import Object
+from weave.trace.isinstance import weave_isinstance
+from weave.trace.op import Op, as_op, is_op
+
+
+class Scorer(Object):
+ column_map: Optional[dict[str, str]] = Field(
+ default=None,
+ description="A mapping from column names in the dataset to the names expected by the scorer",
+ )
+
+ def score(self, input: Any, target: Any, output: Any) -> Any:
+ raise NotImplementedError
+
+ @weave.op()
+ def summarize(self, score_rows: list) -> Optional[dict]:
+ return auto_summarize(score_rows)
+
+
+def stderr(data: Sequence[Union[int, float]]) -> float:
+ if len(data) > 1:
+ sample_variance = np.var(data, ddof=1)
+ return float(np.sqrt(sample_variance / len(data)))
+ else:
+ return 0
+
+
+def auto_summarize(data: list) -> Optional[dict[str, Any]]:
+ """Automatically summarize a list of (potentially nested) dicts.
+
+ Computes:
+ - avg for numeric cols
+ - count and fraction for boolean cols
+ - other col types are ignored
+
+ If col is all None, result is None
+
+ Returns:
+ dict of summary stats, with structure matching input dict structure.
+ """
+ if not data:
+ return {}
+ data = [x for x in data if x is not None]
+
+ if not data:
+ return None
+
+ val = data[0]
+
+ if isinstance(val, bool):
+ return {
+ "true_count": (true_count := sum(1 for x in data if x)),
+ "true_fraction": true_count / len(data),
+ }
+ elif isinstance(val, Number):
+ return {"mean": np.mean(data).item()}
+ elif isinstance(val, dict):
+ result = {}
+ all_keys = set().union(*[x.keys() for x in data if isinstance(x, dict)])
+ for k in all_keys:
+ if (
+ summary := auto_summarize(
+ [x.get(k) for x in data if isinstance(x, dict)]
+ )
+ ) is not None:
+ if k in summary:
+ result.update(summary)
+ else:
+ result[k] = summary
+ if not result:
+ return None
+ return result
+ elif isinstance(val, BaseModel):
+ return auto_summarize([x.model_dump() for x in data])
+ return None
+
+
+def get_scorer_attributes(
+ scorer: Union[Callable, Op, Scorer],
+) -> Tuple[str, Callable, Callable]:
+ if weave_isinstance(scorer, Scorer):
+ scorer_name = scorer.name
+ if scorer_name is None:
+ scorer_name = scorer.__class__.__name__
+ try:
+ score_fn = scorer.score
+ summarize_fn = scorer.summarize # type: ignore
+ except AttributeError:
+ raise ValueError(
+ f"Scorer {scorer_name} must implement score and summarize methods. Did you forget to wrap with @weave.op()?"
+ )
+ elif callable(scorer):
+ if is_op(scorer):
+ scorer = as_op(scorer)
+ scorer_name = scorer.name
+ else:
+ scorer_name = scorer.__name__
+ score_fn = scorer
+ summarize_fn = auto_summarize # type: ignore
+ else:
+ raise ValueError(f"Unknown scorer type: {scorer}")
+ return (scorer_name, score_fn, summarize_fn) # type: ignore
diff --git a/weave/scorers/classification_scorer.py b/weave/scorers/classification_scorer.py
new file mode 100644
index 00000000000..7c6cb1207c3
--- /dev/null
+++ b/weave/scorers/classification_scorer.py
@@ -0,0 +1,58 @@
+from collections import defaultdict
+from typing import Optional, Tuple
+
+import weave
+from weave.scorers.base_scorer import Scorer
+
+
+def p_r_f1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]:
+ # if any denom is zero, then zero. could use NaN instead...
+ precision: float = 0
+ if tp or fp:
+ precision = tp / (tp + fp)
+ recall: float = 0
+ if tp or fn:
+ recall = tp / (tp + fn)
+ f1: float = 0
+ if precision or recall:
+ f1 = 2 * (precision * recall) / (precision + recall)
+ return precision, recall, f1
+
+
+class MultiTaskBinaryClassificationF1(Scorer):
+ class_names: list[str]
+
+ @weave.op()
+ def summarize(self, score_rows: list) -> Optional[dict]:
+ result = {}
+ cols = transpose(score_rows)
+
+ for class_name in self.class_names:
+ col = cols[class_name]
+ tp = sum(r["correct"] and not r["negative"] for r in col)
+ fp = sum(not r["correct"] and not r["negative"] for r in col)
+ fn = sum(not r["correct"] and r["negative"] for r in col)
+ precision, recall, f1 = p_r_f1(tp, fp, fn)
+ result[class_name] = {"f1": f1, "precision": precision, "recall": recall}
+
+ return result
+
+ @weave.op()
+ def score(self, target: dict, output: Optional[dict]) -> dict:
+ result = {}
+ for class_name in self.class_names:
+ class_label = target.get(class_name)
+ class_output = output.get(class_name) if output else None
+ result[class_name] = {
+ "correct": class_label == class_output,
+ "negative": not class_output,
+ }
+ return result
+
+
+def transpose(rows: list[dict]) -> dict[str, list]:
+ cols = defaultdict(list)
+ for row in rows:
+ for k, v in row.items():
+ cols[k].append(v)
+ return dict(cols)
diff --git a/weave/scorers/hallucination_scorer.py b/weave/scorers/hallucination_scorer.py
new file mode 100644
index 00000000000..1aee2012134
--- /dev/null
+++ b/weave/scorers/hallucination_scorer.py
@@ -0,0 +1,160 @@
+from typing import List
+
+from pydantic import BaseModel, Field
+
+import weave
+from weave.scorers.llm_scorer import InstructorLLMScorer
+from weave.scorers.llm_utils import OPENAI_DEFAULT_MODEL, create
+from weave.scorers.utils import stringify
+
+DEFAULT_HALLUCINATION_SYSTEM_PROMPT = """
+Given some from a user and an