Skip to content

Commit

Permalink
lint pass 3
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Dec 12, 2024
1 parent f8beb08 commit af69e4a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 68 deletions.
24 changes: 0 additions & 24 deletions tests/trace/builtin_objects/backend_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -325,30 +325,6 @@
"\n",
"score_res"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feedback_id='0193ba50-631e-7192-aa0c-d9332c474c17' score_call=CallSchema(id='0193ba50-62e7-7f21-bc8c-e8f46fedbdb4', project_id='UHJvamVjdEludGVybmFsSWQ6NDA1NzYyOTQ=', op_name='weave:///timssweeney/remote_model_demo_4/op/LLMJudgeScorer.score:LSxb3VBdL8YmPr9vqYhxsMe74D8C04dJL1IKQ61Ke7M', display_name=None, trace_id='0193ba50-62e7-7f21-bc8c-e8eb057791c0', parent_id=None, started_at=datetime.datetime(2024, 12, 12, 10, 0, 50, 663992, tzinfo=TzInfo(UTC)), attributes={'weave': {'client_version': '0.51.25-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}, inputs={'self': 'weave:///timssweeney/remote_model_demo_4/object/LLMJudgeScorer:uCL086uULzE1HKLFn8YIezCG98HiqayaAp3d1R9ktA0', 'inputs': {'self': 'weave:///timssweeney/remote_model_demo_4/object/LiteLLMCompletionModel:KBsfUswVpEHFYmZuJjmhM2YH4EttkRZJSoH0Z0ZaNRY', 'kwargs': {'user_input': 'Hello, my name is Charles and I am 40 years old.'}}, 'output': {'name': 'Charles', 'age': 40}}, ended_at=datetime.datetime(2024, 12, 12, 10, 0, 50, 689418, tzinfo=TzInfo(UTC)), exception='{\"type\": \"TypeError\", \"message\": \"Object of type ObjectRef is not JSON serializable\", \"traceback\": [{\"filename\": \"<string>\", \"line_number\": 1, \"function_name\": \"<module>\", \"text\": \"\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/multiprocessing/spawn.py\", \"line_number\": 116, \"function_name\": \"spawn_main\", \"text\": \"exitcode = _main(fd, parent_sentinel)\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/multiprocessing/spawn.py\", \"line_number\": 129, \"function_name\": \"_main\", \"text\": \"return self._bootstrap(parent_sentinel)\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/multiprocessing/process.py\", \"line_number\": 314, \"function_name\": \"_bootstrap\", \"text\": \"self.run()\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/multiprocessing/process.py\", \"line_number\": 108, \"function_name\": \"run\", \"text\": \"self._target(*self._args, **self._kwargs)\"}, {\"filename\": \"/Users/timothysweeney/Workspace/github/wandb/core/services/weave-python/weave-public/weave/trace_server/server_side_object_saver.py\", \"line_number\": 301, \"function_name\": \"_score_call\", \"text\": \"apply_scorer_res = target_call._apply_scorer(scorer)\"}, {\"filename\": \"/Users/timothysweeney/Workspace/github/wandb/core/services/weave-python/weave-public/weave/trace/weave_client.py\", \"line_number\": 497, \"function_name\": \"_apply_scorer\", \"text\": \"_, score_call = scorer_op.call(**score_args)\"}, {\"filename\": \"/Users/timothysweeney/Workspace/github/wandb/core/services/weave-python/weave-public/weave/trace/op.py\", \"line_number\": 372, \"function_name\": \"call\", \"text\": \"return _do_call(\"}, {\"filename\": \"/Users/timothysweeney/Workspace/github/wandb/core/services/weave-python/weave-public/weave/trace/op.py\", \"line_number\": 432, \"function_name\": \"_do_call\", \"text\": \"execute_result = _execute_op(\"}, {\"filename\": \"/Users/timothysweeney/Workspace/github/wandb/core/services/weave-python/weave-public/weave/trace/op.py\", \"line_number\": 331, \"function_name\": \"_execute_op\", \"text\": \"res = func(*args, **kwargs)\"}, {\"filename\": \"/Users/timothysweeney/Workspace/github/wandb/core/services/weave-python/weave-public/weave/builtin_objects/scorers/LLMJudgeScorer.py\", \"line_number\": 18, \"function_name\": \"score\", \"text\": \"user_prompt = json.dumps(\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/json/__init__.py\", \"line_number\": 231, \"function_name\": \"dumps\", \"text\": \"return _default_encoder.encode(obj)\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/json/encoder.py\", \"line_number\": 199, \"function_name\": \"encode\", \"text\": \"chunks = self.iterencode(o, _one_shot=True)\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/json/encoder.py\", \"line_number\": 257, \"function_name\": \"iterencode\", \"text\": \"return _iterencode(o, 0)\"}, {\"filename\": \"/Users/timothysweeney/.pyenv/versions/3.10.8/lib/python3.10/json/encoder.py\", \"line_number\": 179, \"function_name\": \"default\", \"text\": \"raise TypeError(f\\'Object of type {o.__class__.__name__} \\'\"}]}', output=None, summary={'weave': {'status': <TraceStatus.ERROR: 'error'>, 'trace_name': 'LLMJudgeScorer.score', 'latency_ms': 25}}, wb_user_id='VXNlcjo2Mzg4Nw==', wb_run_id=None, deleted_at=None)\n"
]
}
],
"source": [
"print(score_res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion weave/builtin_objects/builtin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def register_builtin(cls: type[weave.Object]) -> None:
if not issubclass(cls, weave.Object):
raise ValueError(f"Object {cls} is not a subclass of weave.Object")
raise TypeError(f"Object {cls} is not a subclass of weave.Object")

if cls.__name__ in _BUILTIN_REGISTRY:
raise ValueError(f"Object {cls} already registered")
Expand Down
69 changes: 26 additions & 43 deletions weave/trace_server/server_side_object_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing
import typing
from typing import Any, Callable, Tuple, TypedDict
from typing import Any, Callable, TypedDict

import weave
from weave.trace import autopatch
Expand All @@ -21,6 +21,18 @@ class ScoreCallResult(TypedDict):
scorer_call_id: str


class RunSaveObjectException(Exception):
pass


class RunCallMethodException(Exception):
pass


class RunScoreCallException(Exception):
pass


class RunAsUser:
"""Executes a function in a separate process for memory isolation.
Expand All @@ -35,7 +47,7 @@ def __init__(self, ch_server_dump: dict[str, Any]):
@staticmethod
def _process_runner(
func: Callable[..., Any],
args: Tuple[Any, ...],
args: tuple[Any, ...],
kwargs: dict[str, Any],
result_queue: multiprocessing.Queue,
) -> None:
Expand All @@ -53,35 +65,6 @@ def _process_runner(
except Exception as e:
result_queue.put(("error", str(e)))

def run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Run the provided function in a separate process.
Args:
func: The function to execute
*args: Positional arguments to pass to the function
**kwargs: Keyword arguments to pass to the function
Returns:
The result of the function execution
Raises:
Exception: If the function execution fails in the child process
"""
result_queue: multiprocessing.Queue[Tuple[str, Any]] = multiprocessing.Queue()

process = multiprocessing.Process(
target=self._process_runner, args=(func, args, kwargs, result_queue)
)

process.start()
status, result = result_queue.get()
process.join()

if status == "error":
raise Exception(f"Process execution failed: {result}")

return result

def run_save_object(
self,
new_obj: Any,
Expand All @@ -102,7 +85,7 @@ def run_save_object(
Raises:
Exception: If the save operation fails in the child process
"""
result_queue: multiprocessing.Queue[Tuple[str, str]] = multiprocessing.Queue()
result_queue: multiprocessing.Queue[tuple[str, str]] = multiprocessing.Queue()

process = multiprocessing.Process(
target=self._save_object,
Expand All @@ -120,7 +103,7 @@ def run_save_object(
process.join()

if status == "error":
raise Exception(f"Process execution failed: {result}")
raise RunSaveObjectException(f"Process execution failed: {result}")

return result

Expand Down Expand Up @@ -176,7 +159,7 @@ def run_call_method(
method_name: str,
args: dict[str, Any],
) -> str:
result_queue: multiprocessing.Queue[Tuple[str, Any]] = multiprocessing.Queue()
result_queue: multiprocessing.Queue[tuple[str, Any]] = multiprocessing.Queue()

process = multiprocessing.Process(
target=self._call_method,
Expand All @@ -188,7 +171,7 @@ def run_call_method(
process.join()

if status == "error":
raise Exception(f"Process execution failed: {result}")
raise RunCallMethodException(f"Process execution failed: {result}")

return result

Expand Down Expand Up @@ -243,7 +226,7 @@ def _call_method(
result_queue.put(("error", str(e))) # Put any errors in the queue

def run_score_call(self, req: tsi.ScoreCallReq) -> ScoreCallResult:
result_queue: multiprocessing.Queue[Tuple[str, ScoreCallResult | str]] = (
result_queue: multiprocessing.Queue[tuple[str, ScoreCallResult | str]] = (
multiprocessing.Queue()
)

Expand All @@ -257,17 +240,17 @@ def run_score_call(self, req: tsi.ScoreCallReq) -> ScoreCallResult:
process.join()

if status == "error":
raise Exception(f"Process execution failed: {result}")
raise RunScoreCallException(f"Process execution failed: {result}")

if isinstance(result, dict):
return result
else:
raise Exception(f"Unexpected result: {result}")
raise RunScoreCallException(f"Unexpected result: {result}")

def _score_call(
self,
req: tsi.ScoreCallReq,
result_queue: multiprocessing.Queue[Tuple[str, ScoreCallResult | str]],
result_queue: multiprocessing.Queue[tuple[str, ScoreCallResult | str]],
) -> None:
try:
from weave.trace.weave_client import Call
Expand All @@ -291,13 +274,13 @@ def _score_call(

target_call_ref = parse_internal_uri(req.call_ref)
if not isinstance(target_call_ref, InternalCallRef):
raise ValueError("Invalid call reference")
raise TypeError("Invalid call reference")
target_call = client.get_call(target_call_ref.id)._val
if not isinstance(target_call, Call):
raise ValueError("Invalid call reference")
raise TypeError("Invalid call reference")
scorer_ref = parse_internal_uri(req.scorer_ref)
if not isinstance(scorer_ref, InternalObjectRef):
raise ValueError("Invalid scorer reference")
raise TypeError("Invalid scorer reference")
scorer = weave.ref(
ObjectRef(
entity="_SERVER_",
Expand All @@ -307,7 +290,7 @@ def _score_call(
).uri()
).get()
if not isinstance(scorer, weave.Scorer):
raise ValueError("Invalid scorer reference")
raise TypeError("Invalid scorer reference")
apply_scorer_res = target_call._apply_scorer(scorer)

autopatch.reset_autopatch()
Expand Down

0 comments on commit af69e4a

Please sign in to comment.