Skip to content

Commit

Permalink
test scores
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Oct 11, 2024
1 parent ad0998c commit 84cdf7c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
41 changes: 41 additions & 0 deletions tests/trace/test_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from concurrent.futures import Future

import weave
from weave.trace.weave_client import get_ref
from weave.trace_server import trace_server_interface as tsi


def test_send_score_call(client):
@weave.op
def my_op(x: int) -> int:
return x + 1

@weave.op
def my_score(input_x: int, model_output: int) -> int:
return {"in_range": input_x < model_output}

call_res, call = my_op.call(1)
assert call_res == 2
score_res, score_call = my_score.call(1, call_res)
assert score_res == {"in_range": True}
res_fut = client._send_score_call(call, score_call)
assert isinstance(res_fut, Future)
res = res_fut.result()
assert isinstance(res, str)

query_res = client.server.calls_query(
tsi.CallsQueryReq(
project_id=client._project_id(),
include_feedback=True,
)
)
calls = query_res.calls

assert len(calls) == 2
feedback = calls[0].summary["weave"]["feedback"][0]
assert feedback["feedback_type"] == "wandb.score.1"
assert feedback["weave_ref"] == get_ref(call).uri()
assert feedback["payload"]["name"] == "my_score"
assert feedback["payload"]["op_ref"] == get_ref(my_score).uri()
assert feedback["payload"]["call_ref"] == get_ref(score_call).uri()
assert feedback["payload"]["results"] == score_res
2 changes: 1 addition & 1 deletion weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ async def predict_and_score(
result, score_call = await async_call_op(score_fn, **score_args)
wc = get_weave_client()
if wc:
wc._link_score_call, model_call, score_call
wc._send_score_call, model_call, score_call

else:
# I would not expect this path to be hit, but keeping it for
Expand Down
3 changes: 2 additions & 1 deletion weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ def query_costs(
return res.results

@trace_sentry.global_trace_sentry.watch()
def _link_score_call(self, predict_call: Call, score_call: Call) -> Future[str]:
def _send_score_call(self, predict_call: Call, score_call: Call) -> Future[str]:
"""(Private) Adds a score to a call. This is particularly useful
for adding evaluation metrics to a call.
"""
Expand Down Expand Up @@ -1087,6 +1087,7 @@ def _link_score_call(self, predict_call: Call, score_call: Call) -> Future[str]:
scorer_op_ref_uri=scorer_op_ref_uri,
)

@trace_sentry.global_trace_sentry.watch()
def _add_score(
self,
*,
Expand Down

0 comments on commit 84cdf7c

Please sign in to comment.