diff --git a/dev_docs/BaseObjectClasses.md b/dev_docs/BaseObjectClasses.md index a571af49755..20a104e2e90 100644 --- a/dev_docs/BaseObjectClasses.md +++ b/dev_docs/BaseObjectClasses.md @@ -116,7 +116,7 @@ curl -X POST 'https://trace.wandb.ai/obj/create' \ "project_id": "user/project", "object_id": "my_config", "val": {...}, - "set_base_object_class": "MyConfig" + "set_leaf_object_class": "MyConfig" } }' @@ -162,7 +162,7 @@ Run `make synchronize-base-object-schemas` to ensure the frontend TypeScript typ 4. Now, each use case uses different parts: 1. `Python Writing`. Users can directly import these classes and use them as normal Pydantic models, which get published with `weave.publish`. The python client correct builds the requisite payload. 2. `Python Reading`. Users can `weave.ref().get()` and the weave python SDK will return the instance with the correct type. Note: we do some special handling such that the returned object is not a WeaveObject, but literally the exact pydantic class. - 3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish base objects by setting the `set_base_object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object. + 3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish objects by setting the `set_leaf_object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object. 4. `HTTP Reading`. When querying for objects, the server will return the object with the correct type if the `base_object_class` metadata field is set. 5. `Frontend`. The frontend will read the zod schema from `weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts` and use that to provide compile time type safety when using `useBaseObjectInstances` and runtime type safety when using `useCreateBaseObjectInstance`. * Note: it is critical that all techniques produce the same digest for the same data - which is tested in the tests. This way versions are not thrashed by different clients/users. @@ -185,7 +185,7 @@ graph TD subgraph "Trace Server" subgraph "HTTP API" - R --> |validates using| HW["POST obj/create
set_base_object_class"] + R --> |validates using| HW["POST obj/create
set_leaf_object_class"] HW --> DB[(Weave Object Store)] HR["POST objs/query
base_object_classes"] --> |Filters base_object_class| DB end @@ -203,7 +203,7 @@ graph TD Z --> |import| UBI["useBaseObjectInstances"] Z --> |import| UCI["useCreateBaseObjectInstance"] UBI --> |Filters base_object_class| HR - UCI --> |set_base_object_class| HW + UCI --> |set_leaf_object_class| HW UI[React UI] --> UBI UI --> UCI end diff --git a/pyproject.toml b/pyproject.toml index f34757b315e..447368bbd0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,9 @@ test = [ "pillow", "filelock", "httpx", + + # Bultins + "litellm>=1.49.1", ] [project.scripts] diff --git a/tests/conftest.py b/tests/conftest.py index 85e9b53c36b..febf5305f4b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -330,6 +330,14 @@ def actions_execute_batch( req.wb_user_id = self._user_id return super().actions_execute_batch(req) + def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + req.wb_user_id = self._user_id + return super().call_method(req) + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + req.wb_user_id = self._user_id + return super().score_call(req) + # https://docs.pytest.org/en/7.1.x/example/simple.html#pytest-current-test-environment-variable def get_test_name(): diff --git a/tests/trace/builtin_objects/backend_models.ipynb b/tests/trace/builtin_objects/backend_models.ipynb new file mode 100644 index 00000000000..3247b5a8978 --- /dev/null +++ b/tests/trace/builtin_objects/backend_models.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"WF_TRACE_SERVER_URL\"] = \"http://127.0.01:6345\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import weave" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logged in as Weights & Biases user: timssweeney.\n", + "View Weave data at https://wandb.ai/timssweeney/remote_model_demo_4/weave\n" + ] + } + ], + "source": [ + "client = weave.init(\"remote_model_demo_4\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🍩 https://wandb.ai/timssweeney/remote_model_demo_4/r/call/0193ba4f-de91-79a2-9028-67c3c500703e\n", + "weave:///timssweeney/remote_model_demo_4/object/LiteLLMCompletionModel:KBsfUswVpEHFYmZuJjmhM2YH4EttkRZJSoH0Z0ZaNRY\n", + "{'name': 'Fred', 'age': 30}\n" + ] + } + ], + "source": [ + "# Demonstrates creating a model in python\n", + "\n", + "from weave.builtin_objects.models.CompletionModel import LiteLLMCompletionModel\n", + "\n", + "model = LiteLLMCompletionModel(\n", + " model=\"gpt-4o\",\n", + " messages_template=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"Please extract the name and age from the following text\",\n", + " },\n", + " {\"role\": \"user\", \"content\": \"{user_input}\"},\n", + " ],\n", + " response_format={\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": {\n", + " \"name\": \"Person\",\n", + " \"schema\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"age\": {\"type\": \"integer\"},\n", + " \"name\": {\"type\": \"string\"},\n", + " },\n", + " },\n", + " },\n", + " },\n", + ")\n", + "\n", + "res = model.predict(user_input=\"Hello, my name is Fred and I am 30 years old.\")\n", + "\n", + "print(model.ref.uri())\n", + "\n", + "print(res)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CallMethodRes(call_id='0193ba4f-fc13-79c2-b217-03e6fdd7e7c4', output={'name': 'Charles', 'age': 40})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Demonstrates calling a model created in python\n", + "\n", + "from weave.trace_server.trace_server_interface import CallMethodReq\n", + "\n", + "call_res = client.server.call_method(\n", + " CallMethodReq(\n", + " project_id=client._project_id(),\n", + " object_ref=model.ref.uri(),\n", + " method_name=\"predict\",\n", + " args={\"user_input\": \"Hello, my name is Charles and I am 40 years old.\"},\n", + " )\n", + ")\n", + "\n", + "call_res" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ObjCreateRes(digest='k85wXnWLVxpHujpohAqNBIirXZSM6XRSOSk84n1XR84')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Demonstrates creating a model in the UI - notice the digest match\n", + "\n", + "from weave.trace_server.trace_server_interface import ObjCreateReq\n", + "\n", + "obj_res = client.server.obj_create(\n", + " ObjCreateReq.model_validate(\n", + " {\n", + " \"obj\": {\n", + " \"project_id\": client._project_id(),\n", + " \"object_id\": \"LiteLLMCompletionModel\",\n", + " \"val\": {\n", + " \"model\": \"gpt-4o\",\n", + " \"messages_template\": [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"Please extract the name and age from the following text!\",\n", + " },\n", + " {\"role\": \"user\", \"content\": \"{user_input}\"},\n", + " ],\n", + " \"response_format\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": {\n", + " \"name\": \"Person\",\n", + " \"schema\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"age\": {\"type\": \"integer\"},\n", + " \"name\": {\"type\": \"string\"},\n", + " },\n", + " },\n", + " },\n", + " },\n", + " },\n", + " \"set_leaf_object_class\": \"LiteLLMCompletionModel\",\n", + " }\n", + " }\n", + " )\n", + ")\n", + "\n", + "obj_res" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🍩 https://wandb.ai/timssweeney/remote_model_demo_4/r/call/0193ba50-1be0-70d2-84cd-dcf8fac3ff09\n" + ] + }, + { + "data": { + "text/plain": [ + "{'name': 'Fred', 'age': 30}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Demonstrates fetching a model in python that was created in the UI\n", + "\n", + "from weave.trace.refs import ObjectRef\n", + "\n", + "gotten_model = weave.ref(\n", + " ObjectRef(\n", + " entity=client.entity,\n", + " project=client.project,\n", + " name=\"LiteLLMCompletionModel\",\n", + " _digest=obj_res.digest,\n", + " ).uri()\n", + ").get()\n", + "\n", + "res = gotten_model.predict(user_input=\"Hello, my name is Fred and I am 30 years old.\")\n", + "res" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LiteLLMCompletionModel(name=None, description=None, model='gpt-4o', messages_template=WeaveList([{'role': 'system', 'content': 'Please extract the name and age from the following text!'}, {'role': 'user', 'content': '{user_input}'}]), response_format=WeaveDict({'type': 'json_schema', 'json_schema': {'name': 'Person', 'schema': {'type': 'object', 'properties': {'age': {'type': 'integer'}, 'name': {'type': 'string'}}}}}))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gotten_model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Part 2: Scoring:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ScoreCallRes(feedback_id='0193ba55-d466-74a3-a4de-da0a456b08a7', score_call=CallSchema(id='0193ba55-cb43-7c61-a712-f7249e6dfe4f', project_id='UHJvamVjdEludGVybmFsSWQ6NDA1NzYyOTQ=', op_name='weave:///timssweeney/remote_model_demo_4/op/LLMJudgeScorer.score:LSxb3VBdL8YmPr9vqYhxsMe74D8C04dJL1IKQ61Ke7M', display_name=None, trace_id='0193ba55-cb43-7c61-a712-f71512a66d3b', parent_id=None, started_at=datetime.datetime(2024, 12, 12, 10, 6, 45, 59887, tzinfo=TzInfo(UTC)), attributes={'weave': {'client_version': '0.51.25-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}, inputs={'self': 'weave:///timssweeney/remote_model_demo_4/object/LLMJudgeScorer:uCL086uULzE1HKLFn8YIezCG98HiqayaAp3d1R9ktA0', 'inputs': {'kwargs': {'user_input': 'Hello, my name is Charles and I am 40 years old.'}}, 'output': {'name': 'Charles', 'age': 40}}, ended_at=datetime.datetime(2024, 12, 12, 10, 6, 47, 368348, tzinfo=TzInfo(UTC)), exception=None, output={'is_correct': True}, summary={'usage': {'gpt-4o-2024-08-06': {'prompt_tokens': 91, 'completion_tokens': 6, 'requests': 1, 'total_tokens': 97, 'completion_tokens_details': {'audio_tokens': 0, 'reasoning_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}}, 'weave': {'status': , 'trace_name': 'LLMJudgeScorer.score', 'latency_ms': 2308}}, wb_user_id='VXNlcjo2Mzg4Nw==', wb_run_id=None, deleted_at=None))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import weave\n", + "from weave.trace.refs import CallRef\n", + "from weave.trace_server import trace_server_interface as tsi\n", + "\n", + "obj_create_res = client.server.obj_create(\n", + " tsi.ObjCreateReq.model_validate(\n", + " {\n", + " \"obj\": {\n", + " \"project_id\": client._project_id(),\n", + " \"object_id\": \"CorrectnessJudge\",\n", + " \"val\": {\n", + " \"model\": \"gpt-4o\",\n", + " \"system_prompt\": \"You are a judge that scores the correctness of a response.\",\n", + " \"response_format\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": {\n", + " \"name\": \"Correctness\",\n", + " \"schema\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"is_correct\": {\"type\": \"boolean\"},\n", + " },\n", + " },\n", + " },\n", + " },\n", + " },\n", + " \"set_leaf_object_class\": \"LLMJudgeScorer\",\n", + " }\n", + " }\n", + " )\n", + ")\n", + "client._flush()\n", + "scorer_ref = weave.ObjectRef(\n", + " entity=client._project_id().split(\"/\")[0],\n", + " project=client._project_id().split(\"/\")[1],\n", + " name=\"CorrectnessJudge\",\n", + " _digest=obj_create_res.digest,\n", + ")\n", + "\n", + "call_ref = CallRef(\n", + " entity=client._project_id().split(\"/\")[0],\n", + " project=client._project_id().split(\"/\")[1],\n", + " id=call_res.call_id,\n", + ")\n", + "\n", + "score_res = client.server.score_call(\n", + " tsi.ScoreCallReq.model_validate(\n", + " {\n", + " \"project_id\": client._project_id(),\n", + " \"call_ref\": call_ref.uri(),\n", + " \"scorer_ref\": scorer_ref.uri(),\n", + " }\n", + " )\n", + ")\n", + "\n", + "score_res" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "wandb-weave", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/trace/builtin_objects/test_builtin_model.py b/tests/trace/builtin_objects/test_builtin_model.py new file mode 100644 index 00000000000..fa28ae0755f --- /dev/null +++ b/tests/trace/builtin_objects/test_builtin_model.py @@ -0,0 +1,149 @@ +import weave +from weave.builtin_objects.models.CompletionModel import LiteLLMCompletionModel +from weave.trace.refs import ObjectRef +from weave.trace.weave_client import WeaveClient +from weave.trace_server import trace_server_interface as tsi + +model_args = { + "model": "gpt-4o", + "messages_template": [{"role": "user", "content": "{input}"}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Person", + "schema": { + "type": "object", + "properties": { + "age": {"type": "integer"}, + "name": {"type": "string"}, + }, + }, + }, + }, +} + +input_text = "My name is Carlos and I am 42 years old." + +expected_result = {"age": 42, "name": "Carlos"} + + +def test_model_publishing_alignment(client: WeaveClient): + model = LiteLLMCompletionModel(**model_args) + publish_ref = weave.publish(model) + + obj_create_res = client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": "LiteLLMCompletionModel", + "val": model_args, + "set_leaf_object_class": "LiteLLMCompletionModel", + } + } + ) + ) + + assert obj_create_res.digest == publish_ref.digest + + gotten_model = weave.ref(publish_ref.uri()).get() + assert isinstance(gotten_model, LiteLLMCompletionModel) + + +def test_model_local_create_local_use(client: WeaveClient): + model = LiteLLMCompletionModel(**model_args) + predict_result = model.predict(input=input_text) + assert predict_result == expected_result + + +def test_model_local_create_remote_use(client: WeaveClient): + model = LiteLLMCompletionModel(**model_args) + publish_ref = weave.publish(model) + remote_call_res = client.server.call_method( + tsi.CallMethodReq.model_validate( + { + "project_id": client._project_id(), + "object_ref": publish_ref.uri(), + "method_name": "predict", + "args": {"input": input_text}, + } + ) + ) + assert remote_call_res.output == expected_result + + remote_call_read = client.server.call_read( + tsi.CallReadReq.model_validate( + { + "project_id": client._project_id(), + "id": remote_call_res.call_id, + } + ) + ) + assert remote_call_read.call.output == expected_result + + +def test_model_remote_create_local_use(client: WeaveClient): + obj_create_res = client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": "LiteLLMCompletionModel", + "val": model_args, + "set_leaf_object_class": "LiteLLMCompletionModel", + } + } + ) + ) + obj_ref = ObjectRef( + entity=client._project_id().split("/")[0], + project=client._project_id().split("/")[1], + name="LiteLLMCompletionModel", + _digest=obj_create_res.digest, + ) + fetched = obj_ref.get() + assert isinstance(fetched, LiteLLMCompletionModel) + predict_res = fetched.predict(input=input_text) + assert predict_res == expected_result + + +def test_model_remote_create_remote_use(client: WeaveClient): + obj_create_res = client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": "LiteLLMCompletionModel", + "val": model_args, + "set_leaf_object_class": "LiteLLMCompletionModel", + } + } + ) + ) + obj_ref = ObjectRef( + entity=client._project_id().split("/")[0], + project=client._project_id().split("/")[1], + name="LiteLLMCompletionModel", + _digest=obj_create_res.digest, + ) + obj_call_res = client.server.call_method( + tsi.CallMethodReq.model_validate( + { + "project_id": client._project_id(), + "object_ref": obj_ref.uri(), + "method_name": "predict", + "args": {"input": input_text}, + } + ) + ) + assert obj_call_res.output == expected_result + + remote_call_read = client.server.call_read( + tsi.CallReadReq.model_validate( + { + "project_id": client._project_id(), + "id": obj_call_res.call_id, + } + ) + ) + assert remote_call_read.call.output == expected_result diff --git a/tests/trace/builtin_objects/test_builtin_scorer.py b/tests/trace/builtin_objects/test_builtin_scorer.py new file mode 100644 index 00000000000..d41bfc3fbd8 --- /dev/null +++ b/tests/trace/builtin_objects/test_builtin_scorer.py @@ -0,0 +1,152 @@ +# Tests: +# 1. Publishing alignment & class alignment +# 2. Local Create, Local Direct Score +# 3. Local Create, Remote Direct Score +# 4. Remote Create, Local Direct Score +# 5. Remote Create, Remote Direct Score +from __future__ import annotations + +import weave +from weave.builtin_objects.scorers.LLMJudgeScorer import LLMJudgeScorer +from weave.trace.weave_client import ApplyScorerResult, Call, WeaveClient +from weave.trace_server import trace_server_interface as tsi + +scorer_args = { + "model": "gpt-4o", + "system_prompt": "You are a judge that scores the correctness of a response.", + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Correctness", + "schema": { + "type": "object", + "properties": { + "is_correct": {"type": "boolean"}, + }, + }, + }, + }, +} + +score_input = {"inputs": {"question": "What color is the sky?"}, "output": "blue"} + +expected_score = {"is_correct": True} + + +def test_scorer_publishing_alignment(client: WeaveClient): + model = LLMJudgeScorer(**scorer_args) + publish_ref = weave.publish(model, name="CorrectnessJudge") + + obj_create_res = client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": "CorrectnessJudge", + "val": scorer_args, + "set_leaf_object_class": "LLMJudgeScorer", + } + } + ) + ) + + assert obj_create_res.digest == publish_ref.digest + + gotten_model = weave.ref(publish_ref.uri()).get() + assert isinstance(gotten_model, LLMJudgeScorer) + + +def make_simple_call(): + @weave.op + def simple_op(question: str) -> str: + return "blue" + + res, call = simple_op.call("What color is the sky?") + return res, call + + +def assert_expected_outcome( + target_call: Call, scorer_res: ApplyScorerResult | tsi.ScoreCallRes +): + scorer_output = None + feedback_id = None + if isinstance(scorer_res, tsi.ScoreCallRes): + scorer_output = scorer_res.score_call.output + feedback_id = scorer_res.feedback_id + else: + scorer_output = scorer_res["score_call"].output + feedback_id = scorer_res["feedback_id"] + + assert scorer_output == expected_score + feedbacks = list(target_call.feedback) + assert len(feedbacks) == 1 + assert feedbacks[0].payload["output"] == expected_score + assert feedbacks[0].id == feedback_id + + +def do_remote_score( + client: WeaveClient, target_call: Call, scorer_ref: weave.ObjectRef +): + return client.server.score_call( + tsi.ScoreCallReq.model_validate( + { + "project_id": client._project_id(), + "call_ref": target_call.ref.uri(), + "scorer_ref": scorer_ref.uri(), + } + ) + ) + + +def make_remote_scorer(client: WeaveClient): + obj_create_res = client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client._project_id(), + "object_id": "CorrectnessJudge", + "val": scorer_args, + "set_leaf_object_class": "LLMJudgeScorer", + } + } + ) + ) + client._flush() + obj_ref = weave.ObjectRef( + entity=client._project_id().split("/")[0], + project=client._project_id().split("/")[1], + name="CorrectnessJudge", + _digest=obj_create_res.digest, + ) + return obj_ref + + +def test_scorer_local_create_local_use(client: WeaveClient): + scorer = LLMJudgeScorer(**scorer_args) + res, call = make_simple_call() + apply_scorer_res = call._apply_scorer(scorer) + assert_expected_outcome(call, apply_scorer_res) + + +def test_scorer_local_create_remote_use(client: WeaveClient): + scorer = LLMJudgeScorer(**scorer_args) + res, call = make_simple_call() + publish_ref = weave.publish(scorer) + remote_score_res = do_remote_score(client, call, publish_ref) + assert_expected_outcome(call, remote_score_res) + + +def test_scorer_remote_create_local_use(client: WeaveClient): + obj_ref = make_remote_scorer(client) + fetched = weave.ref(obj_ref.uri()).get() + assert isinstance(fetched, LLMJudgeScorer) + res, call = make_simple_call() + apply_scorer_res = call._apply_scorer(fetched) + assert_expected_outcome(call, apply_scorer_res) + + +def test_scorer_remote_create_remote_use(client: WeaveClient): + obj_ref = make_remote_scorer(client) + res, call = make_simple_call() + remote_score_res = do_remote_score(client, call, obj_ref) + assert_expected_outcome(call, remote_score_res) diff --git a/tests/trace/test_base_object_classes.py b/tests/trace/test_base_object_classes.py index a264941f7b0..98b5dd8458e 100644 --- a/tests/trace/test_base_object_classes.py +++ b/tests/trace/test_base_object_classes.py @@ -139,7 +139,7 @@ def test_interface_creation(client): "project_id": client._project_id(), "object_id": nested_obj_id, "val": nested_obj.model_dump(), - "set_base_object_class": "TestOnlyNestedBaseObject", + "set_leaf_object_class": "TestOnlyNestedBaseObject", } } ) @@ -164,7 +164,7 @@ def test_interface_creation(client): "project_id": client._project_id(), "object_id": top_level_obj_id, "val": top_obj.model_dump(), - "set_base_object_class": "TestOnlyExample", + "set_leaf_object_class": "TestOnlyExample", } } ) @@ -271,7 +271,7 @@ def test_digest_equality(client): "project_id": client._project_id(), "object_id": nested_obj_id, "val": nested_obj.model_dump(), - "set_base_object_class": "TestOnlyNestedBaseObject", + "set_leaf_object_class": "TestOnlyNestedBaseObject", } } ) @@ -300,7 +300,7 @@ def test_digest_equality(client): "project_id": client._project_id(), "object_id": top_level_obj_id, "val": top_obj.model_dump(), - "set_base_object_class": "TestOnlyExample", + "set_leaf_object_class": "TestOnlyExample", } } ) @@ -322,7 +322,7 @@ def test_schema_validation(client): "object_id": "nested_obj", # Incorrect schema, should raise! "val": {"a": 2}, - "set_base_object_class": "TestOnlyNestedBaseObject", + "set_leaf_object_class": "TestOnlyNestedBaseObject", } } ) @@ -340,7 +340,7 @@ def test_schema_validation(client): "_class_name": "TestOnlyNestedBaseObject", "_bases": ["BaseObject", "BaseModel"], }, - "set_base_object_class": "TestOnlyNestedBaseObject", + "set_leaf_object_class": "TestOnlyNestedBaseObject", } } ) @@ -359,7 +359,7 @@ def test_schema_validation(client): "_class_name": "TestOnlyNestedBaseObject", "_bases": ["BaseObject", "BaseModel"], }, - "set_base_object_class": "TestOnlyExample", + "set_leaf_object_class": "TestOnlyExample", } } ) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index a3315266f65..058475ebef3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -70,7 +70,7 @@ export const CallPage: FC<{ }; export const useShowRunnableUI = () => { - return false; + return true; // Uncomment to re-enable // const viewerInfo = useViewerInfo(); // return viewerInfo.loading ? false : viewerInfo.userInfo?.admin; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx index 0b4e9374a26..e27f8d06117 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx @@ -13,7 +13,7 @@ import {NotApplicable} from '../../../Browse2/NotApplicable'; import {SmallRef} from '../../../Browse2/SmallRef'; import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component import { - TraceObjSchemaForBaseObjectClass, + TraceObjSchemaForObjectClass, useBaseObjectInstances, } from '../wfReactInterface/baseObjectClassQuery'; import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants'; @@ -61,7 +61,7 @@ const useRunnableFeedbacksForCall = (call: CallSchema) => { const useRunnableFeedbackTypeToLatestActionRef = ( call: CallSchema, - actionSpecs: Array> + actionSpecs: Array> ): Record => { return useMemo(() => { return _.fromPairs( @@ -92,7 +92,7 @@ type GroupedRowType = { }; const useTableRowsForRunnableFeedbacks = ( - actionSpecs: Array>, + actionSpecs: Array>, runnableFeedbacks: Feedback[], runnableFeedbackTypeToLatestActionRef: Record ): GroupedRowType[] => { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx index 52d6795806c..4d3d37caa57 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx @@ -12,7 +12,7 @@ import {SimplePageLayout} from '../common/SimplePageLayout'; import {ObjectVersionsTable} from '../ObjectVersionsPage'; import { useBaseObjectInstances, - useCreateBaseObjectInstance, + useCreateLeafObjectInstance, } from '../wfReactInterface/baseObjectClassQuery'; import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient'; import { @@ -162,7 +162,7 @@ const generateLeaderboardId = () => { }; const useCreateLeaderboard = (entity: string, project: string) => { - const createLeaderboardInstance = useCreateBaseObjectInstance('Leaderboard'); + const createLeaderboardInstance = useCreateLeafObjectInstance('Leaderboard'); const createLeaderboard = async () => { const objectId = sanitizeObjectId(generateLeaderboardId()); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx index 6fac8eaa599..fa1881462a0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx @@ -26,7 +26,7 @@ import {LeaderboardObjectVal} from '../../views/Leaderboard/types/leaderboardCon import {SimplePageLayout} from '../common/SimplePageLayout'; import { useBaseObjectInstances, - useCreateBaseObjectInstance, + useCreateLeafObjectInstance, } from '../wfReactInterface/baseObjectClassQuery'; import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; import {LeaderboardConfigEditor} from './LeaderboardConfigEditor'; @@ -131,7 +131,7 @@ const useUpdateLeaderboard = ( project: string, objectId: string ) => { - const createLeaderboard = useCreateBaseObjectInstance('Leaderboard'); + const createLeaderboard = useCreateLeafObjectInstance('Leaderboard'); const updateLeaderboard = async (leaderboardVal: LeaderboardObjectVal) => { return await createLeaderboard({ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx index a478437facb..3f84196c6c7 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx @@ -2,7 +2,7 @@ import {Box} from '@material-ui/core'; import React, {FC, useCallback, useEffect, useState} from 'react'; import {z} from 'zod'; -import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; +import {createLeafObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; import {TraceServerClient} from '../wfReactInterface/traceServerClient'; import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient'; import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; @@ -86,7 +86,7 @@ export const onAnnotationScorerSave = async ( ) => { const jsonSchemaType = convertTypeToJsonSchemaType(data.Type.type); const typeExtras = convertTypeExtrasToJsonSchema(data); - return createBaseObjectInstance(client, 'AnnotationSpec', { + return createLeafObjectInstance(client, 'AnnotationSpec', { obj: { project_id: projectIdFromParts({entity, project}), object_id: sanitizeObjectId(data.Name), diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx index 64823e9d551..31195a5cb10 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx @@ -11,7 +11,7 @@ import React, {FC, useCallback, useState} from 'react'; import {z} from 'zod'; import {LlmJudgeActionSpecSchema} from '../wfReactInterface/baseObjectClasses.zod'; -import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; +import {createLeafObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; import {ActionSpecSchema} from '../wfReactInterface/generatedBaseObjectClasses.zod'; import {TraceServerClient} from '../wfReactInterface/traceServerClient'; import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks'; @@ -185,7 +185,7 @@ export const onLLMJudgeScorerSave = async ( config: judgeAction, }); - return createBaseObjectInstance(client, 'ActionSpec', { + return createLeafObjectInstance(client, 'ActionSpec', { obj: { project_id: projectIdFromParts({entity, project}), object_id: objectId, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts index 9918ae7f285..f61f56afa7d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts @@ -2,7 +2,7 @@ import {expectType} from 'tsd'; import { useBaseObjectInstances, - useCreateBaseObjectInstance, + useCreateLeafObjectInstance, } from './baseObjectClassQuery'; import { TestOnlyExample, @@ -74,7 +74,7 @@ describe('Type Tests', () => { it('useCreateCollectionObject return type matches expected structure', () => { type CreateCollectionObjectReturn = ReturnType< - typeof useCreateBaseObjectInstance<'TestOnlyExample'> + typeof useCreateLeafObjectInstance<'TestOnlyExample'> >; // Define the expected type structure diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts index 6ceb39daa70..9125217f614 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts @@ -13,22 +13,22 @@ import { } from './traceServerClientTypes'; import {Loadable} from './wfDataModelHooksInterface'; -type BaseObjectClassRegistry = typeof baseObjectClassRegistry; -type BaseObjectClassRegistryKeys = keyof BaseObjectClassRegistry; -type BaseObjectClassType = z.infer< - BaseObjectClassRegistry[C] +type ObjectClassRegistry = typeof baseObjectClassRegistry; // TODO: Add more here - not just bases! +type ObjectClassRegistryKeys = keyof ObjectClassRegistry; +type ObjectClassType = z.infer< + ObjectClassRegistry[C] >; -export type TraceObjSchemaForBaseObjectClass< - C extends BaseObjectClassRegistryKeys -> = TraceObjSchema, C>; +export type TraceObjSchemaForObjectClass< + C extends ObjectClassRegistryKeys +> = TraceObjSchema, C>; -export const useBaseObjectInstances = ( +export const useBaseObjectInstances = ( baseObjectClassName: C, req: TraceObjQueryReq -): Loadable>> => { +): Loadable>> => { const [objects, setObjects] = useState< - Array> + Array> >([]); const getTsClient = useGetTraceServerClientContext(); const client = getTsClient(); @@ -56,11 +56,11 @@ export const useBaseObjectInstances = ( return {result: objects, loading}; }; -const getBaseObjectInstances = async ( +const getBaseObjectInstances = async ( client: TraceServerClient, baseObjectClassName: C, req: TraceObjQueryReq -): Promise, C>>> => { +): Promise, C>>> => { const knownObjectClass = baseObjectClassRegistry[baseObjectClassName]; if (!knownObjectClass) { console.warn(`Unknown object class: ${baseObjectClassName}`); @@ -86,47 +86,47 @@ const getBaseObjectInstances = async ( .map( ({obj, parsed}) => ({...obj, val: parsed.data} as TraceObjSchema< - BaseObjectClassType, + ObjectClassType, C >) ); }; -export const useCreateBaseObjectInstance = < - C extends BaseObjectClassRegistryKeys, - T = BaseObjectClassType +export const useCreateLeafObjectInstance = < + C extends ObjectClassRegistryKeys, + T = ObjectClassType >( - baseObjectClassName: C + leafObjectClassName: C ): ((req: TraceObjCreateReq) => Promise) => { const getTsClient = useGetTraceServerClientContext(); const client = getTsClient(); return (req: TraceObjCreateReq) => - createBaseObjectInstance(client, baseObjectClassName, req); + createLeafObjectInstance(client, leafObjectClassName, req); }; -export const createBaseObjectInstance = async < - C extends BaseObjectClassRegistryKeys, - T = BaseObjectClassType +export const createLeafObjectInstance = async < + C extends ObjectClassRegistryKeys, + T = ObjectClassType >( client: TraceServerClient, - baseObjectClassName: C, + leafObjectClassName: C, req: TraceObjCreateReq ): Promise => { if ( - req.obj.set_base_object_class != null && - req.obj.set_base_object_class !== baseObjectClassName + req.obj.set_leaf_object_class != null && + req.obj.set_leaf_object_class !== leafObjectClassName ) { throw new Error( - `set_base_object_class must match baseObjectClassName: ${baseObjectClassName}` + `set_leaf_object_class must match leafObjectClassName: ${leafObjectClassName}` ); } - const knownBaseObjectClass = baseObjectClassRegistry[baseObjectClassName]; - if (!knownBaseObjectClass) { - throw new Error(`Unknown object class: ${baseObjectClassName}`); + const knownObjectClass = baseObjectClassRegistry[leafObjectClassName]; + if (!knownObjectClass) { + throw new Error(`Unknown object class: ${leafObjectClassName}`); } - const verifiedObject = knownBaseObjectClass.safeParse(req.obj.val); + const verifiedObject = knownObjectClass.safeParse(req.obj.val); if (!verifiedObject.success) { throw new Error( @@ -134,13 +134,13 @@ export const createBaseObjectInstance = async < ); } - const reqWithBaseObjectClass: TraceObjCreateReq = { + const reqWithLeafObjectClass: TraceObjCreateReq = { ...req, obj: { ...req.obj, - set_base_object_class: baseObjectClassName, + set_leaf_object_class: leafObjectClassName, }, }; - return client.objCreate(reqWithBaseObjectClass); + return client.objCreate(reqWithLeafObjectClass); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index c396962f0fb..520926532c1 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -243,7 +243,7 @@ export type TraceObjCreateReq = { project_id: string; object_id: string; val: T; - set_base_object_class?: string; + set_leaf_object_class?: string; }; }; diff --git a/weave/builtin_objects/builtin_registry.py b/weave/builtin_objects/builtin_registry.py new file mode 100644 index 00000000000..41f7fb82c91 --- /dev/null +++ b/weave/builtin_objects/builtin_registry.py @@ -0,0 +1,23 @@ +import weave +from weave.builtin_objects.models.CompletionModel import LiteLLMCompletionModel +from weave.builtin_objects.scorers.LLMJudgeScorer import LLMJudgeScorer + +_BUILTIN_REGISTRY: dict[str, type[weave.Object]] = {} + + +def register_builtin(cls: type[weave.Object]) -> None: + if not issubclass(cls, weave.Object): + raise TypeError(f"Object {cls} is not a subclass of weave.Object") + + if cls.__name__ in _BUILTIN_REGISTRY: + raise ValueError(f"Object {cls} already registered") + + _BUILTIN_REGISTRY[cls.__name__] = cls + + +def get_builtin(name: str) -> type[weave.Object]: + return _BUILTIN_REGISTRY[name] + + +register_builtin(LiteLLMCompletionModel) +register_builtin(LLMJudgeScorer) diff --git a/weave/builtin_objects/models/CompletionModel.py b/weave/builtin_objects/models/CompletionModel.py new file mode 100644 index 00000000000..808a764ca69 --- /dev/null +++ b/weave/builtin_objects/models/CompletionModel.py @@ -0,0 +1,27 @@ +import json +from typing import Any, Optional + +import litellm + +import weave + + +class LiteLLMCompletionModel(weave.Model): + model: str + messages_template: list[dict[str, str]] + response_format: Optional[dict] = None + + @weave.op() + def predict(self, **kwargs: Any) -> str: + messages: list[dict] = [ + {**m, "content": m["content"].format(**kwargs)} + for m in self.messages_template + ] + + res = litellm.completion( + model=self.model, + messages=messages, + response_format=self.response_format, + ) + + return json.loads(res.choices[0].message.content) diff --git a/weave/builtin_objects/scorers/LLMJudgeScorer.py b/weave/builtin_objects/scorers/LLMJudgeScorer.py new file mode 100644 index 00000000000..480ef7a47f7 --- /dev/null +++ b/weave/builtin_objects/scorers/LLMJudgeScorer.py @@ -0,0 +1,37 @@ +import json +from typing import Any + +import litellm + +import weave + +# TODO: Questions +# Should "reasoning" be built into the system itself? + + +class LLMJudgeScorer(weave.Scorer): + model: str + system_prompt: str + response_format: dict + + @weave.op() + def score(self, inputs: dict, output: Any) -> str: + user_prompt = json.dumps( + { + "inputs": inputs, + "output": output, + } + ) + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + res = litellm.completion( + model=self.model, + messages=messages, + response_format=self.response_format, + ) + + return json.loads(res.choices[0].message.content) diff --git a/weave/trace/serialize.py b/weave/trace/serialize.py index 63210f9abee..9b0b6923464 100644 --- a/weave/trace/serialize.py +++ b/weave/trace/serialize.py @@ -245,6 +245,8 @@ def _load_custom_obj_files( def from_json(obj: Any, project_id: str, server: TraceServerInterface) -> Any: + from weave.builtin_objects.builtin_registry import get_builtin + if isinstance(obj, list): return [from_json(v, project_id, server) for v in obj] elif isinstance(obj, dict): @@ -265,6 +267,15 @@ def from_json(obj: Any, project_id: str, server: TraceServerInterface) -> Any: and (baseObject := BASE_OBJECT_REGISTRY.get(val_type)) ): return baseObject.model_validate(obj) + elif ( + isinstance(val_type, str) + and obj.get("_class_name") == val_type + and (baseObject := get_builtin(val_type)) + ): + valid_keys = baseObject.model_fields.keys() + return baseObject.model_validate( + {k: v for k, v in obj.items() if k in valid_keys} + ) else: return ObjectRecord( {k: from_json(v, project_id, server) for k, v in obj.items()} diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 1d5d54b9b23..d933d29dbc1 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -10,7 +10,17 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Protocol, + TypedDict, + TypeVar, + cast, + overload, +) import pydantic from requests import HTTPError @@ -83,6 +93,9 @@ ) from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer +if TYPE_CHECKING: + from weave.scorers.base_scorer import Scorer + # Controls if objects can have refs to projects not the WeaveClient project. # If False, object refs with with mismatching projects will be recreated. # If True, use existing ref to object in other project. @@ -302,6 +315,11 @@ def map_to_refs(obj: Any) -> Any: return obj +class ApplyScorerResult(TypedDict): + feedback_id: str + score_call: Call + + @dataclasses.dataclass class Call: """A Call represents a single operation that was executed as part of a trace.""" @@ -445,7 +463,7 @@ def set_display_name(self, name: str | None) -> None: def remove_display_name(self) -> None: self.set_display_name(None) - def _apply_scorer(self, scorer_op: Op) -> None: + def _apply_scorer(self, scorer_op: Op | Scorer) -> ApplyScorerResult: """ This is a private method that applies a scorer to a call and records the feedback. In the near future, this will be made public, but for now it is only used internally @@ -455,18 +473,32 @@ def _apply_scorer(self, scorer_op: Op) -> None: inside `eval.py` uses this method inside the scorer block. Current limitations: - - only works for ops (not Scorer class) - no async support - no context yet (ie. ground truth) """ + from weave.scorers.base_scorer import Scorer + + self_arg = None + orig_scorer_op = scorer_op + if isinstance(scorer_op, Scorer): + self_arg = scorer_op + scorer_op = scorer_op.score + client = weave_client_context.require_weave_client() scorer_signature = inspect.signature(scorer_op) scorer_arg_names = list(scorer_signature.parameters.keys()) - score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names} + if "inputs" in scorer_arg_names: + score_args = { + "inputs": {k: v for k, v in self.inputs.items() if k != "self"} + } + else: + score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names} + if self_arg is not None: + score_args["self"] = self_arg if "output" in scorer_arg_names: score_args["output"] = self.output _, score_call = scorer_op.call(**score_args) - scorer_op_ref = get_ref(scorer_op) + scorer_op_ref = get_ref(orig_scorer_op) if scorer_op_ref is None: raise ValueError("Scorer op has no ref") self_ref = get_ref(self) @@ -476,13 +508,19 @@ def _apply_scorer(self, scorer_op: Op) -> None: score_call_ref = get_ref(score_call) if score_call_ref is None: raise ValueError("Score call has no ref") - client._add_runnable_feedback( + feedback_id = client._add_runnable_feedback( weave_ref_uri=self_ref.uri(), output=score_results, call_ref_uri=score_call_ref.uri(), runnable_ref_uri=scorer_op_ref.uri(), ) + # TODO: Make this a class + return ApplyScorerResult( + feedback_id=feedback_id, + score_call=score_call, + ) + def make_client_call( entity: str, project: str, server_call: CallSchema, server: TraceServerInterface @@ -1522,6 +1560,8 @@ def _ref_is_own(self, ref: Ref) -> bool: return isinstance(ref, Ref) def _project_id(self) -> str: + # if self.entity == "": + # return self.project return f"{self.entity}/{self.project}" @trace_sentry.global_trace_sentry.watch() diff --git a/weave/trace_server/base_object_class_util.py b/weave/trace_server/base_object_class_util.py index 1c52f766c0c..55595583093 100644 --- a/weave/trace_server/base_object_class_util.py +++ b/weave/trace_server/base_object_class_util.py @@ -29,65 +29,76 @@ def get_base_object_class(val: Any) -> Optional[str]: return None +def get_leaf_object_class(val: Any) -> Optional[str]: + if isinstance(val, dict): + if "_bases" in val: + if isinstance(val["_bases"], list): + if len(val["_bases"]) >= 2: + if val["_bases"][-1] == "BaseModel": + if val["_bases"][-2] in base_object_class_names: + if "_class_name" in val: + return val["_class_name"] + return None + + def process_incoming_object_val( - val: Any, req_base_object_class: Optional[str] = None + val: Any, req_leaf_object_class: Optional[str] = None ) -> tuple[dict, Optional[str]]: """ This method is responsible for accepting an incoming object from the user, validating it - against the base object class, and returning the object with the base object class + against the leaf object class, and returning the object with the base object class set. It does not mutate the original object, but returns a new object with values set if needed. Specifically,: 1. If the object is not a dict, it is returned as is, and the base object class is set to None. - 2. There are 2 ways to specify the base object class: - a. The `req_base_object_class` argument. + 2. There are 2 ways to specify the leaf object class: + a. The `req_leaf_object_class` argument. * used by non-pythonic writers of weave objects b. The `_bases` & `_class_name` attributes of the object, which is a list of base class names. * used by pythonic weave object writers (legacy) - 3. If the object has a base object class that does not match the requested base object class, + 3. If the object has a leaf object class that does not match the requested leaf object class, an error is thrown. - 4. if the object contains a base object class inside the payload, then we simply validate - the object against the base object class (if a match is found in BASE_OBJECT_REGISTRY) - 5. If the object does not have a base object class and a requested base object class is + 4. if the object contains a leaf object class inside the payload, then we simply validate + the object against the leaf object class (if a match is found in BASE_OBJECT_REGISTRY) + 5. If the object does not have a leaf object class and a requested leaf object class is provided, we require a match in BASE_OBJECT_REGISTRY and validate the object against - the requested base object class. Finally, we set the correct feilds. + the requested leaf object class. Finally, we set the correct fields. """ if not isinstance(val, dict): - if req_base_object_class is not None: + if req_leaf_object_class is not None: raise ValueError( - "set_base_object_class cannot be provided for non-dict objects" + "set_leaf_object_class cannot be provided for non-dict objects" ) return val, None dict_val = val.copy() - val_base_object_class = get_base_object_class(dict_val) + val_leaf_object_class = get_leaf_object_class(dict_val) if ( - val_base_object_class != None - and req_base_object_class != None - and val_base_object_class != req_base_object_class + val_leaf_object_class != None + and req_leaf_object_class != None + and val_leaf_object_class != req_leaf_object_class ): raise ValueError( - f"set_base_object_class must match base_object_class: {val_base_object_class} != {req_base_object_class}" + f"set_leaf_object_class must match found leaf class: {req_leaf_object_class} != {val_leaf_object_class}" ) - if val_base_object_class is not None: + if val_leaf_object_class is not None: # In this case, we simply validate if the match is found - if base_object_class_type := BASE_OBJECT_REGISTRY.get(val_base_object_class): - base_object_class_type.model_validate(dict_val) - elif req_base_object_class is not None: + if object_class_type := BASE_OBJECT_REGISTRY.get(val_leaf_object_class): + object_class_type.model_validate(dict_val) + elif req_leaf_object_class is not None: # In this case, we require that the base object class is registered - if base_object_class_type := BASE_OBJECT_REGISTRY.get(req_base_object_class): - dict_val = dump_base_object(base_object_class_type.model_validate(dict_val)) + if object_class_type := BASE_OBJECT_REGISTRY.get(req_leaf_object_class): + dict_val = dump_base_object(object_class_type.model_validate(dict_val)) else: - raise ValueError(f"Unknown base object class: {req_base_object_class}") - - base_object_class = val_base_object_class or req_base_object_class + raise ValueError(f"Unknown leaf object class: {req_leaf_object_class}") - return dict_val, base_object_class + return dict_val, get_base_object_class(dict_val) +# BIG TODO: Replace this with a true object serialization step using some synthetic weave client! # Server-side version of `pydantic_object_record` def dump_base_object(val: BaseModel) -> dict: cls = val.__class__ diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index d40d7bcc2a3..f7cf8e808d8 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -196,6 +196,16 @@ def __init__( self._use_async_insert = use_async_insert self._model_to_provider_info_map = read_model_to_provider_info_map() + def model_dump(self) -> dict[str, Any]: + return { + "host": self._host, + "port": self._port, + "user": self._user, + "password": self._password, + "database": self._database, + "use_async_insert": self._use_async_insert, + } + @classmethod def from_env(cls, use_async_insert: bool = False) -> "ClickHouseTraceServer": # Explicitly calling `RemoteHTTPTraceServer` constructor here to ensure @@ -563,8 +573,22 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: return tsi.OpQueryRes(op_objs=objs) def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: + from weave.builtin_objects.builtin_registry import get_builtin + + if req.obj.set_leaf_object_class is not None: + from weave.trace_server.server_side_object_saver import RunAsUser + + object_class_type = get_builtin(req.obj.set_leaf_object_class) + + new_obj = object_class_type.model_validate(req.obj.val) + runner = RunAsUser(ch_server_dump=self.model_dump()) + digest = runner.run_save_object( + new_obj, req.obj.project_id, req.obj.object_id, None + ) + return tsi.ObjCreateRes(digest=digest) + val, base_object_class = process_incoming_object_val( - req.obj.val, req.obj.set_base_object_class + req.obj.val, req.obj.set_leaf_object_class ) json_val = json.dumps(val) @@ -1503,6 +1527,32 @@ def completions_create( response=res.response, weave_call_id=start_call.id ) + def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + from weave.trace_server.server_side_object_saver import RunAsUser + + if req.wb_user_id is None: + raise ValueError("User ID is required") + + runner = RunAsUser(ch_server_dump=self.model_dump()) + # TODO: handle errors here + res = runner.run_call_method( + req.object_ref, req.project_id, req.wb_user_id, req.method_name, req.args + ) + return tsi.CallMethodRes.model_validate(res) + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + from weave.trace_server.server_side_object_saver import RunAsUser + + runner = RunAsUser(ch_server_dump=self.model_dump()) + res = runner.run_score_call(req) + + return tsi.ScoreCallRes( + feedback_id=res["feedback_id"], + score_call=self.call_read( + tsi.CallReadReq(project_id=req.project_id, id=res["scorer_call_id"]) + ).call, + ) + # Private Methods @property def ch_client(self) -> CHClient: diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 1df739adbcd..fe38a2b1d91 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -376,3 +376,19 @@ def completions_create( req.project_id = self._idc.ext_to_int_project_id(req.project_id) res = self._ref_apply(self._internal_trace_server.completions_create, req) return res + + def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + original_user_id = req.wb_user_id + if original_user_id is None: + raise ValueError("wb_user_id cannot be None") + req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id) + return self._ref_apply(self._internal_trace_server.call_method, req) + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + original_user_id = req.wb_user_id + if original_user_id is None: + raise ValueError("wb_user_id cannot be None") + req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id) + return self._ref_apply(self._internal_trace_server.score_call, req) diff --git a/weave/trace_server/interface/base_object_classes/base_object_registry.py b/weave/trace_server/interface/base_object_classes/base_object_registry.py index 19eb865daea..bc1bcc6d7a7 100644 --- a/weave/trace_server/interface/base_object_classes/base_object_registry.py +++ b/weave/trace_server/interface/base_object_classes/base_object_registry.py @@ -9,6 +9,7 @@ TestOnlyNestedBaseObject, ) +# TODO: Migrate everything to normal Objects! BASE_OBJECT_REGISTRY: dict[str, type[BaseObject]] = {} diff --git a/weave/trace_server/migrations/002_add_deleted_at.up.sql b/weave/trace_server/migrations/002_add_deleted_at.up.sql index 77538f75495..9d16534eb27 100644 --- a/weave/trace_server/migrations/002_add_deleted_at.up.sql +++ b/weave/trace_server/migrations/002_add_deleted_at.up.sql @@ -8,48 +8,6 @@ This migration adds: ALTER TABLE object_versions ADD COLUMN deleted_at Nullable(DateTime64(3)) DEFAULT NULL; -CREATE OR REPLACE VIEW object_versions_deduped as - SELECT project_id, - object_id, - created_at, - deleted_at, -- **** Add deleted_at to the view **** - kind, - base_object_class, - refs, - val_dump, - digest, - if (kind = 'op', 1, 0) AS is_op, - row_number() OVER ( - PARTITION BY project_id, - kind, - object_id - ORDER BY created_at ASC - ) AS _version_index_plus_1, - _version_index_plus_1 - 1 AS version_index, - count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, - if(_version_index_plus_1 = version_count, 1, 0) AS is_latest - FROM ( - SELECT *, - row_number() OVER ( - PARTITION BY project_id, - kind, - object_id, - digest - ORDER BY created_at ASC - ) AS rn - FROM object_versions - ) - WHERE rn = 1 WINDOW w AS ( - PARTITION BY project_id, - kind, - object_id - ORDER BY created_at ASC - ) - ORDER BY project_id, - kind, - object_id, - created_at; - ALTER TABLE call_parts ADD COLUMN deleted_at Nullable(DateTime64(3)) DEFAULT NULL; diff --git a/weave/trace_server/server_side_object_saver.py b/weave/trace_server/server_side_object_saver.py new file mode 100644 index 00000000000..14d85b61f27 --- /dev/null +++ b/weave/trace_server/server_side_object_saver.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import multiprocessing +from typing import Any, Callable, TypedDict + +import weave +from weave.trace import autopatch +from weave.trace.refs import ObjectRef +from weave.trace.weave_client import WeaveClient +from weave.trace.weave_init import InitializedClient +from weave.trace_server import external_to_internal_trace_server_adapter +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.refs_internal import ( + InternalCallRef, + InternalObjectRef, + parse_internal_uri, +) + + +class ScoreCallResult(TypedDict): + feedback_id: str + scorer_call_id: str + + +class RunSaveObjectException(Exception): + pass + + +class RunCallMethodException(Exception): + pass + + +class RunScoreCallException(Exception): + pass + + +class RunAsUser: + """Executes a function in a separate process for memory isolation. + + This class provides a way to run functions in an isolated memory space using + multiprocessing. The function and its arguments are executed in a new Process, + ensuring complete memory isolation from the parent process. + """ + + def __init__(self, ch_server_dump: dict[str, Any]): + self.ch_server_dump = ch_server_dump + + @staticmethod + def _process_runner( + func: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + result_queue: multiprocessing.Queue, + ) -> None: + """Execute the function and put its result in the queue. + + Args: + func: The function to execute + args: Positional arguments for the function + kwargs: Keyword arguments for the function + result_queue: Queue to store the function's result + """ + try: + result = func(*args, **kwargs) + result_queue.put(("success", result)) + except Exception as e: + result_queue.put(("error", str(e))) + + def run_save_object( + self, + new_obj: Any, + project_id: str, + object_name: str | None, + user_id: str | None, + ) -> str: + """Run the save_object operation in a separate process. + + Args: + new_obj: The object to save + project_id: The project identifier + user_id: The user identifier + + Returns: + str: The digest of the saved object + + Raises: + Exception: If the save operation fails in the child process + """ + result_queue: multiprocessing.Queue[tuple[str, str]] = multiprocessing.Queue() + + process = multiprocessing.Process( + target=self._save_object, + args=( + new_obj, + project_id, + object_name, + user_id, + result_queue, + ), # Pass result_queue here + ) + + process.start() + status, result = result_queue.get() + process.join() + + if status == "error": + raise RunSaveObjectException(f"Process execution failed: {result}") + + return result + + def _save_object( + self, + new_obj: Any, + project_id: str, + object_name: str | None, + user_id: str | None, + result_queue: multiprocessing.Queue, + ) -> None: + """Save an object in a separate process. + + Args: + new_obj: The object to save + project_id: The project identifier + object_name: The name of the object + user_id: The user identifier + result_queue: Queue to store the operation's result + """ + try: + from weave.trace_server.clickhouse_trace_server_batched import ( + ClickHouseTraceServer, + ) + + client = WeaveClient( + "_SERVER_", + project_id, + UserInjectingExternalTraceServer( + ClickHouseTraceServer(**self.ch_server_dump), + id_converter=IdConverter(), + user_id=user_id, + ), + False, + ) + + ic = InitializedClient(client) + autopatch.autopatch() + + res = weave.publish(new_obj, name=object_name).digest + autopatch.reset_autopatch() + client._flush() + ic.reset() + result_queue.put(("success", res)) # Put the result in the queue + except Exception as e: + result_queue.put(("error", str(e))) # Put any errors in the queue + + def run_call_method( + self, + obj_ref: str, + project_id: str, + user_id: str, + method_name: str, + args: dict[str, Any], + ) -> str: + result_queue: multiprocessing.Queue[tuple[str, Any]] = multiprocessing.Queue() + + process = multiprocessing.Process( + target=self._call_method, + args=(obj_ref, project_id, user_id, method_name, args, result_queue), + ) + + process.start() + status, result = result_queue.get() + process.join() + + if status == "error": + raise RunCallMethodException(f"Process execution failed: {result}") + + return result + + def _call_method( + self, + obj_ref: str, + project_id: str, + user_id: str, + method_name: str, + args: dict[str, Any], + result_queue: multiprocessing.Queue, + ) -> None: + try: + from weave.trace_server.clickhouse_trace_server_batched import ( + ClickHouseTraceServer, + ) + + client = WeaveClient( + "_SERVER_", + project_id, + UserInjectingExternalTraceServer( + ClickHouseTraceServer(**self.ch_server_dump), + id_converter=IdConverter(), + user_id=user_id, + ), + False, + ) + + ic = InitializedClient(client) + autopatch.autopatch() + + # TODO: validate project alignment? + int_ref = parse_internal_uri(obj_ref) + assert isinstance(int_ref, InternalObjectRef) + ref = ObjectRef( + entity="_SERVER_", + project=int_ref.project_id, + name=int_ref.name, + _digest=int_ref.version, + ) + obj = client.get(ref) + method = getattr(obj, method_name) + # TODO: Self might be wrong + res, call = method.call(self=obj, **args) + autopatch.reset_autopatch() + client._flush() + ic.reset() + result_queue.put( + ("success", {"output": res, "call_id": call.id}) + ) # Put the result in the queue + except Exception as e: + result_queue.put(("error", str(e))) # Put any errors in the queue + + def run_score_call(self, req: tsi.ScoreCallReq) -> ScoreCallResult: + result_queue: multiprocessing.Queue[tuple[str, ScoreCallResult | str]] = ( + multiprocessing.Queue() + ) + + process = multiprocessing.Process( + target=self._score_call, + args=(req, result_queue), + ) + + process.start() + status, result = result_queue.get() + process.join() + + if status == "error": + raise RunScoreCallException(f"Process execution failed: {result}") + + if isinstance(result, dict): + return result + else: + raise RunScoreCallException(f"Unexpected result: {result}") + + def _score_call( + self, + req: tsi.ScoreCallReq, + result_queue: multiprocessing.Queue[tuple[str, ScoreCallResult | str]], + ) -> None: + try: + from weave.trace.weave_client import Call + from weave.trace_server.clickhouse_trace_server_batched import ( + ClickHouseTraceServer, + ) + + client = WeaveClient( + "_SERVER_", + req.project_id, + UserInjectingExternalTraceServer( + ClickHouseTraceServer(**self.ch_server_dump), + id_converter=IdConverter(), + user_id=req.wb_user_id, + ), + False, + ) + + ic = InitializedClient(client) + autopatch.autopatch() + + target_call_ref = parse_internal_uri(req.call_ref) + if not isinstance(target_call_ref, InternalCallRef): + raise TypeError("Invalid call reference") + target_call = client.get_call(target_call_ref.id)._val + if not isinstance(target_call, Call): + raise TypeError("Invalid call reference") + scorer_ref = parse_internal_uri(req.scorer_ref) + if not isinstance(scorer_ref, InternalObjectRef): + raise TypeError("Invalid scorer reference") + scorer = weave.ref( + ObjectRef( + entity="_SERVER_", + project=scorer_ref.project_id, + name=scorer_ref.name, + _digest=scorer_ref.version, + ).uri() + ).get() + if not isinstance(scorer, weave.Scorer): + raise TypeError("Invalid scorer reference") + apply_scorer_res = target_call._apply_scorer(scorer) + + autopatch.reset_autopatch() + client._flush() + ic.reset() + scorer_call_id = apply_scorer_res["score_call"].id + if not scorer_call_id: + raise ValueError("Scorer call ID is required") + result_queue.put( + ( + "success", + ScoreCallResult( + feedback_id=apply_scorer_res["feedback_id"], + scorer_call_id=scorer_call_id, + ), + ) + ) # Put the result in the queue + except Exception as e: + result_queue.put(("error", str(e))) # Put any errors in the queue + + +class IdConverter(external_to_internal_trace_server_adapter.IdConverter): + def ext_to_int_project_id(self, project_id: str) -> str: + assert project_id.startswith("_SERVER_/") + return project_id[len("_SERVER_/") :] + + def int_to_ext_project_id(self, project_id: str) -> str | None: + return "_SERVER_/" + project_id + + def ext_to_int_run_id(self, run_id: str) -> str: + return run_id + + def int_to_ext_run_id(self, run_id: str) -> str: + return run_id + + def ext_to_int_user_id(self, user_id: str) -> str: + return user_id + + def int_to_ext_user_id(self, user_id: str) -> str: + return user_id + + +class UserInjectingExternalTraceServer( + external_to_internal_trace_server_adapter.ExternalTraceServer +): + def __init__( + self, + internal_trace_server: tsi.TraceServerInterface, + id_converter: external_to_internal_trace_server_adapter.IdConverter, + user_id: str | None, + ): + super().__init__(internal_trace_server, id_converter) + self._user_id = user_id + + def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.start.wb_user_id = self._user_id + return super().call_start(req) + + def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.wb_user_id = self._user_id + return super().calls_delete(req) + + def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.wb_user_id = self._user_id + return super().call_update(req) + + def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.wb_user_id = self._user_id + return super().feedback_create(req) + + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.wb_user_id = self._user_id + return super().cost_create(req) + + def actions_execute_batch( + self, req: tsi.ActionsExecuteBatchReq + ) -> tsi.ActionsExecuteBatchRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.wb_user_id = self._user_id + return super().actions_execute_batch(req) + + def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.wb_user_id = self._user_id + return super().call_method(req) + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + if self._user_id is None: + raise ValueError("User ID is required") + req.wb_user_id = self._user_id + return super().score_call(req) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index e14cef8041e..0cd4f7db5ca 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -612,7 +612,7 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: conn, cursor = get_conn_cursor(self.db_path) val, base_object_class = process_incoming_object_val( - req.obj.val, req.obj.set_base_object_class + req.obj.val, req.obj.set_leaf_object_class ) json_val = json.dumps(val) digest = str_digest(json_val) @@ -1221,6 +1221,12 @@ def table_query_stream( results = self.table_query(req) yield from results.rows + def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + raise NotImplementedError("call_method is not implemented for local sqlite") + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + raise NotImplementedError("score_call is not implemented for local sqlite") + def get_type(val: Any) -> str: if val == None: diff --git a/weave/trace_server/todo.md b/weave/trace_server/todo.md new file mode 100644 index 00000000000..77ce36d0d04 --- /dev/null +++ b/weave/trace_server/todo.md @@ -0,0 +1,35 @@ +* [x] Finish Porting over weave client +* [ ] Add end-to-end test with new lifecycle: + * [x] Create Dummy Model via API -> See in object explorer + * [x] Requires finishing the leaf object registry + * [x] Requires sub process construction & saving? + * [x] Invoke Dummy Model via API -> See in trace explorer + * [x] Requires execution in sub process + * New TODOs: + * API Keys are not correctly setup in the sub runner + * A bunch of naming issues all around (leaf, saver, etc...) + * Not very good error handling + * Server seems to get stuck when it crashes + * Accidentally checked in parallelization disabling + * Create a Dummy Scorer via API -> See in object explorer + * [x] Should be pretty straight forward at this point + * Invoke the Scorer against the previous call via API -> see in traces AND in feedback + * [x] Should be mostly straight forward (the Scorer API itself is a bit wonky) + * Important Proof of system: + * [x] create the same dummy model locally & invoke -> notice no version change + * [x] Run locally against the call -> notice that there are no extra objects + * [ ]Should de-ref inputs if they contain refs +* [ ] Refactor the entire "base model" system to conform to this new way of doing things (leaf models) + * [ ] Might get hairy with nested refs - consider implications +* [ ] Figure out how to refactor scorers that use LLMs + * [ ] a new process with correct env setup (from secret fetcher?) + * [ ] scorers should have a client-spec, not a specific client + * [ ] How to model a scorers's stub (input, output, context, reference(s), etc...) + * [ ] How to handle output types from scorers (boolean, number, reason, etc...) + * [ ]Investigate why the tests are running so slowly +* [ ] Consider a rule: objectId must be different than the name of the class when creating these objects + + +---- Decomposition PRs ---- +1. Change/Add the set_object_class instead of base_object_class +2. Add new methods to the server \ No newline at end of file diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 77ef6198ecf..02edeb4b909 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -191,7 +191,7 @@ class ObjSchemaForInsert(BaseModel): project_id: str object_id: str val: Any - set_base_object_class: Optional[str] = None + set_leaf_object_class: Optional[str] = None class TableSchemaForInsert(BaseModel): @@ -867,6 +867,31 @@ class ActionsExecuteBatchRes(BaseModel): pass +class CallMethodReq(BaseModel): + project_id: str + object_ref: str + method_name: str + args: dict[str, Any] + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + + +class CallMethodRes(BaseModel): + call_id: str + output: Any + + +class ScoreCallReq(BaseModel): + project_id: str + call_ref: str + scorer_ref: str + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + + +class ScoreCallRes(BaseModel): + feedback_id: str + score_call: CallSchema + + class TraceServerInterface(Protocol): def ensure_project_exists( self, entity: str, project: str @@ -910,6 +935,10 @@ def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ... + # Execute API + def call_method(self, req: CallMethodReq) -> CallMethodRes: ... + def score_call(self, req: ScoreCallReq) -> ScoreCallRes: ... + # Action API def actions_execute_batch( self, req: ActionsExecuteBatchReq diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index b749af60d24..f5446bbfb71 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -569,6 +569,16 @@ def completions_create( tsi.CompletionsCreateRes, ) + def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes: + return self._generic_request( + "/execute/method", req, tsi.CallMethodReq, tsi.CallMethodRes + ) + + def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes: + return self._generic_request( + "/execute/score_call", req, tsi.ScoreCallReq, tsi.ScoreCallRes + ) + __docspec__ = [ RemoteHTTPTraceServer,