diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index ae70d71b277..06e87774fa2 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -152,3 +152,79 @@ def score(self, target, output): "mean": pytest.approx(0, abs=1), }, } + + +@pytest.mark.asyncio +async def test_basic_evaluation_with_mixed_scorer_styles(client): + @weave.op + def fn_scorer_with_old_style(col_a, col_b, model_output, target): + return col_a + col_b == model_output == target + + @weave.op + def fn_scorer_with_new_style(col_a, col_b, output, target): + return col_a + col_b == output == target + + class ClassScorerWithOldStyle(weave.Scorer): + @weave.op + def score(self, col_a, col_b, model_output, target): + return col_a + col_b == model_output == target + + class ClassScorerWithNewStyle(weave.Scorer): + @weave.op + def score(self, a, b, output, c): + return a + b == output == c + + dataset = [ + {"col_a": 1, "col_b": 2, "target": 3}, + {"col_a": 1, "col_b": 2, "target": 3}, + {"col_a": 1, "col_b": 2, "target": 3}, + ] + evaluation = Evaluation( + dataset=dataset, + scorers=[ + fn_scorer_with_old_style, + fn_scorer_with_new_style, + ClassScorerWithOldStyle(), + ClassScorerWithNewStyle( + column_map={ + "a": "col_a", + "b": "col_b", + "c": "target", + } + ), + ], + ) + + @weave.op + def model(col_a, col_b): + return col_a + col_b + + result = await evaluation.evaluate(model) + assert result.pop("model_latency").get("mean") == pytest.approx(0, abs=1) + assert result == { + # Should now be `output` even if there are scorers that use `model_output` + "output": {"mean": 3.0}, + "fn_scorer_with_old_style": {"true_count": 3, "true_fraction": 1.0}, + "fn_scorer_with_new_style": {"true_count": 3, "true_fraction": 1.0}, + "ClassScorerWithOldStyle": {"true_count": 3, "true_fraction": 1.0}, + "ClassScorerWithNewStyle": {"true_count": 3, "true_fraction": 1.0}, + } + + predict_and_score_calls = list(evaluation.predict_and_score.calls()) + assert len(predict_and_score_calls) == 3 + outputs = [c.output for c in predict_and_score_calls] + assert all(o.pop("model_latency") == pytest.approx(0, abs=1) for o in outputs) + assert all( + o + == { + # Should now be `output` even if there are scorers that use `model_output` + "output": 3.0, + "scores": { + "fn_scorer_with_old_style": True, + "fn_scorer_with_new_style": True, + "ClassScorerWithOldStyle": True, + "ClassScorerWithNewStyle": True, + }, + } + for o in outputs + )