From 84cdf7c276198aae1f41b91ffb7929dd749e0bfa Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 11 Oct 2024 16:53:10 -0700 Subject: [PATCH] test scores --- tests/trace/test_scores.py | 41 +++++++++++++++++++++++++++++++++++++ weave/flow/eval.py | 2 +- weave/trace/weave_client.py | 3 ++- 3 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 tests/trace/test_scores.py diff --git a/tests/trace/test_scores.py b/tests/trace/test_scores.py new file mode 100644 index 00000000000..d98d069c362 --- /dev/null +++ b/tests/trace/test_scores.py @@ -0,0 +1,41 @@ +from concurrent.futures import Future + +import weave +from weave.trace.weave_client import get_ref +from weave.trace_server import trace_server_interface as tsi + + +def test_send_score_call(client): + @weave.op + def my_op(x: int) -> int: + return x + 1 + + @weave.op + def my_score(input_x: int, model_output: int) -> int: + return {"in_range": input_x < model_output} + + call_res, call = my_op.call(1) + assert call_res == 2 + score_res, score_call = my_score.call(1, call_res) + assert score_res == {"in_range": True} + res_fut = client._send_score_call(call, score_call) + assert isinstance(res_fut, Future) + res = res_fut.result() + assert isinstance(res, str) + + query_res = client.server.calls_query( + tsi.CallsQueryReq( + project_id=client._project_id(), + include_feedback=True, + ) + ) + calls = query_res.calls + + assert len(calls) == 2 + feedback = calls[0].summary["weave"]["feedback"][0] + assert feedback["feedback_type"] == "wandb.score.1" + assert feedback["weave_ref"] == get_ref(call).uri() + assert feedback["payload"]["name"] == "my_score" + assert feedback["payload"]["op_ref"] == get_ref(my_score).uri() + assert feedback["payload"]["call_ref"] == get_ref(score_call).uri() + assert feedback["payload"]["results"] == score_res diff --git a/weave/flow/eval.py b/weave/flow/eval.py index f2a4010b6fb..47380b77767 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -266,7 +266,7 @@ async def predict_and_score( result, score_call = await async_call_op(score_fn, **score_args) wc = get_weave_client() if wc: - wc._link_score_call, model_call, score_call + wc._send_score_call, model_call, score_call else: # I would not expect this path to be hit, but keeping it for diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index a925a8e9fac..47b04e8a284 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -1059,7 +1059,7 @@ def query_costs( return res.results @trace_sentry.global_trace_sentry.watch() - def _link_score_call(self, predict_call: Call, score_call: Call) -> Future[str]: + def _send_score_call(self, predict_call: Call, score_call: Call) -> Future[str]: """(Private) Adds a score to a call. This is particularly useful for adding evaluation metrics to a call. """ @@ -1087,6 +1087,7 @@ def _link_score_call(self, predict_call: Call, score_call: Call) -> Future[str]: scorer_op_ref_uri=scorer_op_ref_uri, ) + @trace_sentry.global_trace_sentry.watch() def _add_score( self, *,