diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 002ed34fee3..990bc921fd6 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -107,6 +107,23 @@ async def infer(self, input) -> str: assert result == expected_eval_result +def test_evaluate_model_with__call__(client): + class EvalModel(Model): + @weave.op() + async def infer(self, input) -> str: + return eval(input) + + def __call__(self, *args, **kwargs): + return self.infer(*args, **kwargs) + + evaluation = Evaluation( + dataset=dataset_rows, + scorers=[score], + ) + result = asyncio.run(evaluation.evaluate(EvalModel())) + assert result == expected_eval_result + + def test_score_as_class(client): class MyScorer(weave.Scorer): @weave.op() diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 5f4a961f904..193459a9e11 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -5,6 +5,7 @@ import time import traceback from collections.abc import Coroutine +from inspect import isfunction from typing import Any, Callable, Literal, Optional, Union, cast from pydantic import PrivateAttr @@ -165,9 +166,12 @@ async def predict_and_score( model_self = None model_predict: Union[Callable, Model] - if callable(model): + if callable(model) and isfunction(model): model_predict = model else: + if not isinstance(model, Model): + raise ValueError(INVALID_MODEL_ERROR) + model_self = model model_predict = get_infer_method(model)