From f8beb084bd8b0b54de1e4336a37a39987746eb63 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 12 Dec 2024 14:47:16 -0800 Subject: [PATCH] lint pass 2 --- weave/builtin_objects/builtin_registry.py | 2 +- .../builtin_objects/models/CompletionModel.py | 6 +- .../builtin_objects/scorers/LLMJudgeScorer.py | 7 ++- .../clickhouse_trace_server_batched.py | 3 + .../trace_server/server_side_object_saver.py | 55 ++++++++++++++----- weave/trace_server/sqlite_trace_server.py | 6 ++ 6 files changed, 59 insertions(+), 20 deletions(-) diff --git a/weave/builtin_objects/builtin_registry.py b/weave/builtin_objects/builtin_registry.py index 2464d0699fc..15809626dc2 100644 --- a/weave/builtin_objects/builtin_registry.py +++ b/weave/builtin_objects/builtin_registry.py @@ -5,7 +5,7 @@ _BUILTIN_REGISTRY: dict[str, type[weave.Object]] = {} -def register_builtin(cls: type[weave.Object]): +def register_builtin(cls: type[weave.Object]) -> None: if not issubclass(cls, weave.Object): raise ValueError(f"Object {cls} is not a subclass of weave.Object") diff --git a/weave/builtin_objects/models/CompletionModel.py b/weave/builtin_objects/models/CompletionModel.py index 6a3ee8f0e94..808a764ca69 100644 --- a/weave/builtin_objects/models/CompletionModel.py +++ b/weave/builtin_objects/models/CompletionModel.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional import litellm @@ -8,11 +8,11 @@ class LiteLLMCompletionModel(weave.Model): model: str - messages_template: list[dict[str, str]] = None + messages_template: list[dict[str, str]] response_format: Optional[dict] = None @weave.op() - def predict(self, **kwargs) -> str: + def predict(self, **kwargs: Any) -> str: messages: list[dict] = [ {**m, "content": m["content"].format(**kwargs)} for m in self.messages_template diff --git a/weave/builtin_objects/scorers/LLMJudgeScorer.py b/weave/builtin_objects/scorers/LLMJudgeScorer.py index 3514df9a7be..480ef7a47f7 100644 --- a/weave/builtin_objects/scorers/LLMJudgeScorer.py +++ b/weave/builtin_objects/scorers/LLMJudgeScorer.py @@ -1,4 +1,5 @@ import json +from typing import Any import litellm @@ -10,11 +11,11 @@ class LLMJudgeScorer(weave.Scorer): model: str - system_prompt: str = None - response_format: dict = None + system_prompt: str + response_format: dict @weave.op() - def score(self, inputs, output) -> str: + def score(self, inputs: dict, output: Any) -> str: user_prompt = json.dumps( { "inputs": inputs, diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index fc2ff26c3b3..f7cf8e808d8 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1530,6 +1530,9 @@ def completions_create( def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: from weave.trace_server.server_side_object_saver import RunAsUser + if req.wb_user_id is None: + raise ValueError("User ID is required") + runner = RunAsUser(ch_server_dump=self.model_dump()) # TODO: handle errors here res = runner.run_call_method( diff --git a/weave/trace_server/server_side_object_saver.py b/weave/trace_server/server_side_object_saver.py index 1d80f98b302..969b54868e3 100644 --- a/weave/trace_server/server_side_object_saver.py +++ b/weave/trace_server/server_side_object_saver.py @@ -1,6 +1,6 @@ import multiprocessing import typing -from typing import Any, Callable, Tuple +from typing import Any, Callable, Tuple, TypedDict import weave from weave.trace import autopatch @@ -16,6 +16,11 @@ ) +class ScoreCallResult(TypedDict): + feedback_id: str + scorer_call_id: str + + class RunAsUser: """Executes a function in a separate process for memory isolation. @@ -62,7 +67,7 @@ def run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: Raises: Exception: If the function execution fails in the child process """ - result_queue = multiprocessing.Queue() + result_queue: multiprocessing.Queue[Tuple[str, Any]] = multiprocessing.Queue() process = multiprocessing.Process( target=self._process_runner, args=(func, args, kwargs, result_queue) @@ -97,7 +102,7 @@ def run_save_object( Raises: Exception: If the save operation fails in the child process """ - result_queue = multiprocessing.Queue() + result_queue: multiprocessing.Queue[Tuple[str, str]] = multiprocessing.Queue() process = multiprocessing.Process( target=self._save_object, @@ -171,7 +176,7 @@ def run_call_method( method_name: str, args: dict[str, Any], ) -> str: - result_queue = multiprocessing.Queue() + result_queue: multiprocessing.Queue[Tuple[str, Any]] = multiprocessing.Queue() process = multiprocessing.Process( target=self._call_method, @@ -237,8 +242,10 @@ def _call_method( except Exception as e: result_queue.put(("error", str(e))) # Put any errors in the queue - def run_score_call(self, req: tsi.ScoreCallReq) -> str: - result_queue = multiprocessing.Queue() + def run_score_call(self, req: tsi.ScoreCallReq) -> ScoreCallResult: + result_queue: multiprocessing.Queue[Tuple[str, ScoreCallResult | str]] = ( + multiprocessing.Queue() + ) process = multiprocessing.Process( target=self._score_call, @@ -252,12 +259,15 @@ def run_score_call(self, req: tsi.ScoreCallReq) -> str: if status == "error": raise Exception(f"Process execution failed: {result}") - return result + if isinstance(result, dict): + return result + else: + raise Exception(f"Unexpected result: {result}") def _score_call( self, req: tsi.ScoreCallReq, - result_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue[Tuple[str, ScoreCallResult | str]], ) -> None: try: from weave.trace.weave_client import Call @@ -303,13 +313,16 @@ def _score_call( autopatch.reset_autopatch() client._flush() ic.reset() + scorer_call_id = apply_scorer_res["score_call"].id + if not scorer_call_id: + raise ValueError("Scorer call ID is required") result_queue.put( ( "success", - { - "feedback_id": apply_scorer_res["feedback_id"], - "scorer_call_id": apply_scorer_res["score_call"].id, - }, + ScoreCallResult( + feedback_id=apply_scorer_res["feedback_id"], + scorer_call_id=scorer_call_id, + ), ) ) # Put the result in the queue except Exception as e: @@ -344,41 +357,57 @@ def __init__( self, internal_trace_server: tsi.TraceServerInterface, id_converter: external_to_internal_trace_server_adapter.IdConverter, - user_id: str, + user_id: str | None, ): super().__init__(internal_trace_server, id_converter) self._user_id = user_id def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: + if self._user_id is None: + raise ValueError("User ID is required") req.start.wb_user_id = self._user_id return super().call_start(req) def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: + if self._user_id is None: + raise ValueError("User ID is required") req.wb_user_id = self._user_id return super().calls_delete(req) def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: + if self._user_id is None: + raise ValueError("User ID is required") req.wb_user_id = self._user_id return super().call_update(req) def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: + if self._user_id is None: + raise ValueError("User ID is required") req.wb_user_id = self._user_id return super().feedback_create(req) def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: + if self._user_id is None: + raise ValueError("User ID is required") req.wb_user_id = self._user_id return super().cost_create(req) def actions_execute_batch( self, req: tsi.ActionsExecuteBatchReq ) -> tsi.ActionsExecuteBatchRes: + if self._user_id is None: + raise ValueError("User ID is required") req.wb_user_id = self._user_id return super().actions_execute_batch(req) def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + if self._user_id is None: + raise ValueError("User ID is required") req.wb_user_id = self._user_id return super().call_method(req) def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + if self._user_id is None: + raise ValueError("User ID is required") req.wb_user_id = self._user_id return super().score_call(req) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index b68af1ac238..0cd4f7db5ca 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1221,6 +1221,12 @@ def table_query_stream( results = self.table_query(req) yield from results.rows + def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + raise NotImplementedError("call_method is not implemented for local sqlite") + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + raise NotImplementedError("score_call is not implemented for local sqlite") + def get_type(val: Any) -> str: if val == None: