Skip to content

Commit

Permalink
lint pass 2
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Dec 12, 2024
1 parent fab271b commit f8beb08
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 20 deletions.
2 changes: 1 addition & 1 deletion weave/builtin_objects/builtin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
_BUILTIN_REGISTRY: dict[str, type[weave.Object]] = {}


def register_builtin(cls: type[weave.Object]):
def register_builtin(cls: type[weave.Object]) -> None:
if not issubclass(cls, weave.Object):
raise ValueError(f"Object {cls} is not a subclass of weave.Object")

Expand Down
6 changes: 3 additions & 3 deletions weave/builtin_objects/models/CompletionModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Optional

import litellm

Expand All @@ -8,11 +8,11 @@

class LiteLLMCompletionModel(weave.Model):
model: str
messages_template: list[dict[str, str]] = None
messages_template: list[dict[str, str]]
response_format: Optional[dict] = None

@weave.op()
def predict(self, **kwargs) -> str:
def predict(self, **kwargs: Any) -> str:
messages: list[dict] = [
{**m, "content": m["content"].format(**kwargs)}
for m in self.messages_template
Expand Down
7 changes: 4 additions & 3 deletions weave/builtin_objects/scorers/LLMJudgeScorer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Any

import litellm

Expand All @@ -10,11 +11,11 @@

class LLMJudgeScorer(weave.Scorer):
model: str
system_prompt: str = None
response_format: dict = None
system_prompt: str
response_format: dict

@weave.op()
def score(self, inputs, output) -> str:
def score(self, inputs: dict, output: Any) -> str:
user_prompt = json.dumps(
{
"inputs": inputs,
Expand Down
3 changes: 3 additions & 0 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,9 @@ def completions_create(
def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
from weave.trace_server.server_side_object_saver import RunAsUser

if req.wb_user_id is None:
raise ValueError("User ID is required")

runner = RunAsUser(ch_server_dump=self.model_dump())
# TODO: handle errors here
res = runner.run_call_method(
Expand Down
55 changes: 42 additions & 13 deletions weave/trace_server/server_side_object_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing
import typing
from typing import Any, Callable, Tuple
from typing import Any, Callable, Tuple, TypedDict

import weave
from weave.trace import autopatch
Expand All @@ -16,6 +16,11 @@
)


class ScoreCallResult(TypedDict):
feedback_id: str
scorer_call_id: str


class RunAsUser:
"""Executes a function in a separate process for memory isolation.
Expand Down Expand Up @@ -62,7 +67,7 @@ def run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
Raises:
Exception: If the function execution fails in the child process
"""
result_queue = multiprocessing.Queue()
result_queue: multiprocessing.Queue[Tuple[str, Any]] = multiprocessing.Queue()

process = multiprocessing.Process(
target=self._process_runner, args=(func, args, kwargs, result_queue)
Expand Down Expand Up @@ -97,7 +102,7 @@ def run_save_object(
Raises:
Exception: If the save operation fails in the child process
"""
result_queue = multiprocessing.Queue()
result_queue: multiprocessing.Queue[Tuple[str, str]] = multiprocessing.Queue()

process = multiprocessing.Process(
target=self._save_object,
Expand Down Expand Up @@ -171,7 +176,7 @@ def run_call_method(
method_name: str,
args: dict[str, Any],
) -> str:
result_queue = multiprocessing.Queue()
result_queue: multiprocessing.Queue[Tuple[str, Any]] = multiprocessing.Queue()

process = multiprocessing.Process(
target=self._call_method,
Expand Down Expand Up @@ -237,8 +242,10 @@ def _call_method(
except Exception as e:
result_queue.put(("error", str(e))) # Put any errors in the queue

def run_score_call(self, req: tsi.ScoreCallReq) -> str:
result_queue = multiprocessing.Queue()
def run_score_call(self, req: tsi.ScoreCallReq) -> ScoreCallResult:
result_queue: multiprocessing.Queue[Tuple[str, ScoreCallResult | str]] = (
multiprocessing.Queue()
)

process = multiprocessing.Process(
target=self._score_call,
Expand All @@ -252,12 +259,15 @@ def run_score_call(self, req: tsi.ScoreCallReq) -> str:
if status == "error":
raise Exception(f"Process execution failed: {result}")

return result
if isinstance(result, dict):
return result
else:
raise Exception(f"Unexpected result: {result}")

def _score_call(
self,
req: tsi.ScoreCallReq,
result_queue: multiprocessing.Queue,
result_queue: multiprocessing.Queue[Tuple[str, ScoreCallResult | str]],
) -> None:
try:
from weave.trace.weave_client import Call
Expand Down Expand Up @@ -303,13 +313,16 @@ def _score_call(
autopatch.reset_autopatch()
client._flush()
ic.reset()
scorer_call_id = apply_scorer_res["score_call"].id
if not scorer_call_id:
raise ValueError("Scorer call ID is required")
result_queue.put(
(
"success",
{
"feedback_id": apply_scorer_res["feedback_id"],
"scorer_call_id": apply_scorer_res["score_call"].id,
},
ScoreCallResult(
feedback_id=apply_scorer_res["feedback_id"],
scorer_call_id=scorer_call_id,
),
)
) # Put the result in the queue
except Exception as e:
Expand Down Expand Up @@ -344,41 +357,57 @@ def __init__(
self,
internal_trace_server: tsi.TraceServerInterface,
id_converter: external_to_internal_trace_server_adapter.IdConverter,
user_id: str,
user_id: str | None,
):
super().__init__(internal_trace_server, id_converter)
self._user_id = user_id

def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.start.wb_user_id = self._user_id
return super().call_start(req)

def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.wb_user_id = self._user_id
return super().calls_delete(req)

def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.wb_user_id = self._user_id
return super().call_update(req)

def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.wb_user_id = self._user_id
return super().feedback_create(req)

def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.wb_user_id = self._user_id
return super().cost_create(req)

def actions_execute_batch(
self, req: tsi.ActionsExecuteBatchReq
) -> tsi.ActionsExecuteBatchRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.wb_user_id = self._user_id
return super().actions_execute_batch(req)

def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.wb_user_id = self._user_id
return super().call_method(req)

def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
if self._user_id is None:
raise ValueError("User ID is required")
req.wb_user_id = self._user_id
return super().score_call(req)
6 changes: 6 additions & 0 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,12 @@ def table_query_stream(
results = self.table_query(req)
yield from results.rows

def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
raise NotImplementedError("call_method is not implemented for local sqlite")

def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
raise NotImplementedError("score_call is not implemented for local sqlite")


def get_type(val: Any) -> str:
if val == None:
Expand Down

0 comments on commit f8beb08

Please sign in to comment.