diff --git a/tests/conftest.py b/tests/conftest.py index 2f6732bb843..451de14e17e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -334,6 +334,10 @@ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: req.wb_user_id = self._user_id return super().call_method(req) + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + req.wb_user_id = self._user_id + return super().score_call(req) + # https://docs.pytest.org/en/7.1.x/example/simple.html#pytest-current-test-environment-variable def get_test_name(): diff --git a/tests/trace/builtin_objects/test_builtin_scorer.py b/tests/trace/builtin_objects/test_builtin_scorer.py index 9e06cdfaddb..69ebbd12295 100644 --- a/tests/trace/builtin_objects/test_builtin_scorer.py +++ b/tests/trace/builtin_objects/test_builtin_scorer.py @@ -6,7 +6,7 @@ # 5. Remote Create, Remote Direct Score import weave from weave.builtin_objects.scorers.LLMJudgeScorer import LLMJudgeScorer -from weave.trace.weave_client import Call, WeaveClient +from weave.trace.weave_client import ApplyScorerResult, Call, WeaveClient from weave.trace_server import trace_server_interface as tsi scorer_args = { @@ -63,12 +63,23 @@ def simple_op(question: str) -> str: return res, call -def assert_expected_outcome(target_call: Call, scorer_res_as_dict: dict): - assert scorer_res_as_dict["score_call"].output == expected_score +def assert_expected_outcome( + target_call: Call, scorer_res: ApplyScorerResult | tsi.ScoreCallRes +): + scorer_output = None + feedback_id = None + if isinstance(scorer_res, tsi.ScoreCallRes): + scorer_output = scorer_res.score_call.output + feedback_id = scorer_res.feedback_id + else: + scorer_output = scorer_res["score_call"].output + feedback_id = scorer_res["feedback_id"] + + assert scorer_output == expected_score feedbacks = list(target_call.feedback) assert len(feedbacks) == 1 assert feedbacks[0].payload["output"] == expected_score - assert feedbacks[0].id == scorer_res_as_dict["feedback_id"] + assert feedbacks[0].id == feedback_id def do_remote_score( @@ -120,7 +131,7 @@ def test_scorer_local_create_remote_use(client: WeaveClient): res, call = make_simple_call() publish_ref = weave.publish(scorer) remote_score_res = do_remote_score(client, call, publish_ref) - assert_expected_outcome(call, remote_score_res.model_dump()) + assert_expected_outcome(call, remote_score_res) def test_scorer_remote_create_local_use(client: WeaveClient): @@ -136,4 +147,4 @@ def test_scorer_remote_create_remote_use(client: WeaveClient): obj_ref = make_remote_scorer(client) res, call = make_simple_call() remote_score_res = do_remote_score(client, call, obj_ref) - assert_expected_outcome(call, remote_score_res.model_dump()) + assert_expected_outcome(call, remote_score_res) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 295eed67377..fc2ff26c3b3 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1537,6 +1537,19 @@ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: ) return tsi.CallMethodRes.model_validate(res) + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + from weave.trace_server.server_side_object_saver import RunAsUser + + runner = RunAsUser(ch_server_dump=self.model_dump()) + res = runner.run_score_call(req) + + return tsi.ScoreCallRes( + feedback_id=res["feedback_id"], + score_call=self.call_read( + tsi.CallReadReq(project_id=req.project_id, id=res["scorer_call_id"]) + ).call, + ) + # Private Methods @property def ch_client(self) -> CHClient: diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 61e10a6fff2..fe38a2b1d91 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -384,3 +384,11 @@ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: raise ValueError("wb_user_id cannot be None") req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id) return self._ref_apply(self._internal_trace_server.call_method, req) + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + original_user_id = req.wb_user_id + if original_user_id is None: + raise ValueError("wb_user_id cannot be None") + req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id) + return self._ref_apply(self._internal_trace_server.score_call, req) diff --git a/weave/trace_server/server_side_object_saver.py b/weave/trace_server/server_side_object_saver.py index d8fa2706034..1d80f98b302 100644 --- a/weave/trace_server/server_side_object_saver.py +++ b/weave/trace_server/server_side_object_saver.py @@ -9,7 +9,11 @@ from weave.trace.weave_init import InitializedClient from weave.trace_server import external_to_internal_trace_server_adapter from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.refs_internal import InternalObjectRef, parse_internal_uri +from weave.trace_server.refs_internal import ( + InternalCallRef, + InternalObjectRef, + parse_internal_uri, +) class RunAsUser: @@ -233,6 +237,84 @@ def _call_method( except Exception as e: result_queue.put(("error", str(e))) # Put any errors in the queue + def run_score_call(self, req: tsi.ScoreCallReq) -> str: + result_queue = multiprocessing.Queue() + + process = multiprocessing.Process( + target=self._score_call, + args=(req, result_queue), + ) + + process.start() + status, result = result_queue.get() + process.join() + + if status == "error": + raise Exception(f"Process execution failed: {result}") + + return result + + def _score_call( + self, + req: tsi.ScoreCallReq, + result_queue: multiprocessing.Queue, + ) -> None: + try: + from weave.trace.weave_client import Call + from weave.trace_server.clickhouse_trace_server_batched import ( + ClickHouseTraceServer, + ) + + client = WeaveClient( + "_SERVER_", + req.project_id, + UserInjectingExternalTraceServer( + ClickHouseTraceServer(**self.ch_server_dump), + id_converter=IdConverter(), + user_id=req.wb_user_id, + ), + False, + ) + + ic = InitializedClient(client) + autopatch.autopatch() + + target_call_ref = parse_internal_uri(req.call_ref) + if not isinstance(target_call_ref, InternalCallRef): + raise ValueError("Invalid call reference") + target_call = client.get_call(target_call_ref.id)._val + if not isinstance(target_call, Call): + raise ValueError("Invalid call reference") + scorer_ref = parse_internal_uri(req.scorer_ref) + if not isinstance(scorer_ref, InternalObjectRef): + raise ValueError("Invalid scorer reference") + scorer = weave.ref( + ObjectRef( + entity="_SERVER_", + project=scorer_ref.project_id, + name=scorer_ref.name, + _digest=scorer_ref.version, + ).uri() + ).get() + if not isinstance(scorer, weave.Scorer): + raise ValueError("Invalid scorer reference") + apply_scorer_res = target_call._apply_scorer(scorer) + + autopatch.reset_autopatch() + client._flush() + ic.reset() + result_queue.put( + ( + "success", + { + "feedback_id": apply_scorer_res["feedback_id"], + "scorer_call_id": apply_scorer_res["score_call"].id, + }, + ) + ) # Put the result in the queue + except Exception as e: + result_queue.put(("error", str(e))) # Put any errors in the queue + class IdConverter(external_to_internal_trace_server_adapter.IdConverter): def ext_to_int_project_id(self, project_id: str) -> str: @@ -296,3 +378,7 @@ def actions_execute_batch( def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: req.wb_user_id = self._user_id return super().call_method(req) + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + req.wb_user_id = self._user_id + return super().score_call(req) diff --git a/weave/trace_server/todo.md b/weave/trace_server/todo.md index 2e3682833ad..a75b3e75179 100644 --- a/weave/trace_server/todo.md +++ b/weave/trace_server/todo.md @@ -12,12 +12,12 @@ * Server seems to get stuck when it crashes * Accidentally checked in parallelization disabling * Create a Dummy Scorer via API -> See in object explorer - * Should be pretty straight forward at this point + * [x] Should be pretty straight forward at this point * Invoke the Scorer against the previous call via API -> see in traces AND in feedback - * Should be mostly straight forward (the Scorer API itself is a bit wonky) + * [x] Should be mostly straight forward (the Scorer API itself is a bit wonky) * Important Proof of system: * [x] create the same dummy model locally & invoke -> notice no version change - * Run locally against the call -> notice that there are no extra objects + * [x] Run locally against the call -> notice that there are no extra objects * [ ]Should de-ref inputs if they contain refs * [ ] Refactor the entire "base model" system to conform to this new way of doing things (leaf models) * [ ] Might get hairy with nested refs - consider implications @@ -26,4 +26,4 @@ * [ ] scorers should have a client-spec, not a specific client * [ ] How to model a scorers's stub (input, output, context, reference(s), etc...) * [ ] How to handle output types from scorers (boolean, number, reason, etc...) - \ No newline at end of file + * [ ]Investigate why the tests are running so slowly \ No newline at end of file diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 649aa3c429c..02edeb4b909 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -880,6 +880,18 @@ class CallMethodRes(BaseModel): output: Any +class ScoreCallReq(BaseModel): + project_id: str + call_ref: str + scorer_ref: str + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + + +class ScoreCallRes(BaseModel): + feedback_id: str + score_call: CallSchema + + class TraceServerInterface(Protocol): def ensure_project_exists( self, entity: str, project: str @@ -925,6 +937,7 @@ def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ... # Execute API def call_method(self, req: CallMethodReq) -> CallMethodRes: ... + def score_call(self, req: ScoreCallReq) -> ScoreCallRes: ... # Action API def actions_execute_batch(