diff --git a/dev_docs/BaseObjectClasses.md b/dev_docs/BaseObjectClasses.md
index a571af49755..20a104e2e90 100644
--- a/dev_docs/BaseObjectClasses.md
+++ b/dev_docs/BaseObjectClasses.md
@@ -116,7 +116,7 @@ curl -X POST 'https://trace.wandb.ai/obj/create' \
"project_id": "user/project",
"object_id": "my_config",
"val": {...},
- "set_base_object_class": "MyConfig"
+ "set_leaf_object_class": "MyConfig"
}
}'
@@ -162,7 +162,7 @@ Run `make synchronize-base-object-schemas` to ensure the frontend TypeScript typ
4. Now, each use case uses different parts:
1. `Python Writing`. Users can directly import these classes and use them as normal Pydantic models, which get published with `weave.publish`. The python client correct builds the requisite payload.
2. `Python Reading`. Users can `weave.ref().get()` and the weave python SDK will return the instance with the correct type. Note: we do some special handling such that the returned object is not a WeaveObject, but literally the exact pydantic class.
- 3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish base objects by setting the `set_base_object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object.
+ 3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish objects by setting the `set_leaf_object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object.
4. `HTTP Reading`. When querying for objects, the server will return the object with the correct type if the `base_object_class` metadata field is set.
5. `Frontend`. The frontend will read the zod schema from `weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts` and use that to provide compile time type safety when using `useBaseObjectInstances` and runtime type safety when using `useCreateBaseObjectInstance`.
* Note: it is critical that all techniques produce the same digest for the same data - which is tested in the tests. This way versions are not thrashed by different clients/users.
@@ -185,7 +185,7 @@ graph TD
subgraph "Trace Server"
subgraph "HTTP API"
- R --> |validates using| HW["POST obj/create
set_base_object_class"]
+ R --> |validates using| HW["POST obj/create
set_leaf_object_class"]
HW --> DB[(Weave Object Store)]
HR["POST objs/query
base_object_classes"] --> |Filters base_object_class| DB
end
@@ -203,7 +203,7 @@ graph TD
Z --> |import| UBI["useBaseObjectInstances"]
Z --> |import| UCI["useCreateBaseObjectInstance"]
UBI --> |Filters base_object_class| HR
- UCI --> |set_base_object_class| HW
+ UCI --> |set_leaf_object_class| HW
UI[React UI] --> UBI
UI --> UCI
end
diff --git a/pyproject.toml b/pyproject.toml
index f34757b315e..447368bbd0d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -106,6 +106,9 @@ test = [
"pillow",
"filelock",
"httpx",
+
+ # Bultins
+ "litellm>=1.49.1",
]
[project.scripts]
diff --git a/tests/conftest.py b/tests/conftest.py
index 85e9b53c36b..febf5305f4b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -330,6 +330,14 @@ def actions_execute_batch(
req.wb_user_id = self._user_id
return super().actions_execute_batch(req)
+ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
+ req.wb_user_id = self._user_id
+ return super().call_method(req)
+
+ def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
+ req.wb_user_id = self._user_id
+ return super().score_call(req)
+
# https://docs.pytest.org/en/7.1.x/example/simple.html#pytest-current-test-environment-variable
def get_test_name():
diff --git a/tests/trace/builtin_objects/backend_models.ipynb b/tests/trace/builtin_objects/backend_models.ipynb
new file mode 100644
index 00000000000..3247b5a8978
--- /dev/null
+++ b/tests/trace/builtin_objects/backend_models.ipynb
@@ -0,0 +1,351 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"WF_TRACE_SERVER_URL\"] = \"http://127.0.01:6345\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import weave"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Logged in as Weights & Biases user: timssweeney.\n",
+ "View Weave data at https://wandb.ai/timssweeney/remote_model_demo_4/weave\n"
+ ]
+ }
+ ],
+ "source": [
+ "client = weave.init(\"remote_model_demo_4\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "🍩 https://wandb.ai/timssweeney/remote_model_demo_4/r/call/0193ba4f-de91-79a2-9028-67c3c500703e\n",
+ "weave:///timssweeney/remote_model_demo_4/object/LiteLLMCompletionModel:KBsfUswVpEHFYmZuJjmhM2YH4EttkRZJSoH0Z0ZaNRY\n",
+ "{'name': 'Fred', 'age': 30}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Demonstrates creating a model in python\n",
+ "\n",
+ "from weave.builtin_objects.models.CompletionModel import LiteLLMCompletionModel\n",
+ "\n",
+ "model = LiteLLMCompletionModel(\n",
+ " model=\"gpt-4o\",\n",
+ " messages_template=[\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": \"Please extract the name and age from the following text\",\n",
+ " },\n",
+ " {\"role\": \"user\", \"content\": \"{user_input}\"},\n",
+ " ],\n",
+ " response_format={\n",
+ " \"type\": \"json_schema\",\n",
+ " \"json_schema\": {\n",
+ " \"name\": \"Person\",\n",
+ " \"schema\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"age\": {\"type\": \"integer\"},\n",
+ " \"name\": {\"type\": \"string\"},\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "res = model.predict(user_input=\"Hello, my name is Fred and I am 30 years old.\")\n",
+ "\n",
+ "print(model.ref.uri())\n",
+ "\n",
+ "print(res)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CallMethodRes(call_id='0193ba4f-fc13-79c2-b217-03e6fdd7e7c4', output={'name': 'Charles', 'age': 40})"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Demonstrates calling a model created in python\n",
+ "\n",
+ "from weave.trace_server.trace_server_interface import CallMethodReq\n",
+ "\n",
+ "call_res = client.server.call_method(\n",
+ " CallMethodReq(\n",
+ " project_id=client._project_id(),\n",
+ " object_ref=model.ref.uri(),\n",
+ " method_name=\"predict\",\n",
+ " args={\"user_input\": \"Hello, my name is Charles and I am 40 years old.\"},\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "call_res"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ObjCreateRes(digest='k85wXnWLVxpHujpohAqNBIirXZSM6XRSOSk84n1XR84')"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Demonstrates creating a model in the UI - notice the digest match\n",
+ "\n",
+ "from weave.trace_server.trace_server_interface import ObjCreateReq\n",
+ "\n",
+ "obj_res = client.server.obj_create(\n",
+ " ObjCreateReq.model_validate(\n",
+ " {\n",
+ " \"obj\": {\n",
+ " \"project_id\": client._project_id(),\n",
+ " \"object_id\": \"LiteLLMCompletionModel\",\n",
+ " \"val\": {\n",
+ " \"model\": \"gpt-4o\",\n",
+ " \"messages_template\": [\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": \"Please extract the name and age from the following text!\",\n",
+ " },\n",
+ " {\"role\": \"user\", \"content\": \"{user_input}\"},\n",
+ " ],\n",
+ " \"response_format\": {\n",
+ " \"type\": \"json_schema\",\n",
+ " \"json_schema\": {\n",
+ " \"name\": \"Person\",\n",
+ " \"schema\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"age\": {\"type\": \"integer\"},\n",
+ " \"name\": {\"type\": \"string\"},\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " \"set_leaf_object_class\": \"LiteLLMCompletionModel\",\n",
+ " }\n",
+ " }\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "obj_res"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "🍩 https://wandb.ai/timssweeney/remote_model_demo_4/r/call/0193ba50-1be0-70d2-84cd-dcf8fac3ff09\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'name': 'Fred', 'age': 30}"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Demonstrates fetching a model in python that was created in the UI\n",
+ "\n",
+ "from weave.trace.refs import ObjectRef\n",
+ "\n",
+ "gotten_model = weave.ref(\n",
+ " ObjectRef(\n",
+ " entity=client.entity,\n",
+ " project=client.project,\n",
+ " name=\"LiteLLMCompletionModel\",\n",
+ " _digest=obj_res.digest,\n",
+ " ).uri()\n",
+ ").get()\n",
+ "\n",
+ "res = gotten_model.predict(user_input=\"Hello, my name is Fred and I am 30 years old.\")\n",
+ "res"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "LiteLLMCompletionModel(name=None, description=None, model='gpt-4o', messages_template=WeaveList([{'role': 'system', 'content': 'Please extract the name and age from the following text!'}, {'role': 'user', 'content': '{user_input}'}]), response_format=WeaveDict({'type': 'json_schema', 'json_schema': {'name': 'Person', 'schema': {'type': 'object', 'properties': {'age': {'type': 'integer'}, 'name': {'type': 'string'}}}}}))"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gotten_model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Part 2: Scoring:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ScoreCallRes(feedback_id='0193ba55-d466-74a3-a4de-da0a456b08a7', score_call=CallSchema(id='0193ba55-cb43-7c61-a712-f7249e6dfe4f', project_id='UHJvamVjdEludGVybmFsSWQ6NDA1NzYyOTQ=', op_name='weave:///timssweeney/remote_model_demo_4/op/LLMJudgeScorer.score:LSxb3VBdL8YmPr9vqYhxsMe74D8C04dJL1IKQ61Ke7M', display_name=None, trace_id='0193ba55-cb43-7c61-a712-f71512a66d3b', parent_id=None, started_at=datetime.datetime(2024, 12, 12, 10, 6, 45, 59887, tzinfo=TzInfo(UTC)), attributes={'weave': {'client_version': '0.51.25-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}, inputs={'self': 'weave:///timssweeney/remote_model_demo_4/object/LLMJudgeScorer:uCL086uULzE1HKLFn8YIezCG98HiqayaAp3d1R9ktA0', 'inputs': {'kwargs': {'user_input': 'Hello, my name is Charles and I am 40 years old.'}}, 'output': {'name': 'Charles', 'age': 40}}, ended_at=datetime.datetime(2024, 12, 12, 10, 6, 47, 368348, tzinfo=TzInfo(UTC)), exception=None, output={'is_correct': True}, summary={'usage': {'gpt-4o-2024-08-06': {'prompt_tokens': 91, 'completion_tokens': 6, 'requests': 1, 'total_tokens': 97, 'completion_tokens_details': {'audio_tokens': 0, 'reasoning_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}}, 'weave': {'status': , 'trace_name': 'LLMJudgeScorer.score', 'latency_ms': 2308}}, wb_user_id='VXNlcjo2Mzg4Nw==', wb_run_id=None, deleted_at=None))"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import weave\n",
+ "from weave.trace.refs import CallRef\n",
+ "from weave.trace_server import trace_server_interface as tsi\n",
+ "\n",
+ "obj_create_res = client.server.obj_create(\n",
+ " tsi.ObjCreateReq.model_validate(\n",
+ " {\n",
+ " \"obj\": {\n",
+ " \"project_id\": client._project_id(),\n",
+ " \"object_id\": \"CorrectnessJudge\",\n",
+ " \"val\": {\n",
+ " \"model\": \"gpt-4o\",\n",
+ " \"system_prompt\": \"You are a judge that scores the correctness of a response.\",\n",
+ " \"response_format\": {\n",
+ " \"type\": \"json_schema\",\n",
+ " \"json_schema\": {\n",
+ " \"name\": \"Correctness\",\n",
+ " \"schema\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"is_correct\": {\"type\": \"boolean\"},\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " \"set_leaf_object_class\": \"LLMJudgeScorer\",\n",
+ " }\n",
+ " }\n",
+ " )\n",
+ ")\n",
+ "client._flush()\n",
+ "scorer_ref = weave.ObjectRef(\n",
+ " entity=client._project_id().split(\"/\")[0],\n",
+ " project=client._project_id().split(\"/\")[1],\n",
+ " name=\"CorrectnessJudge\",\n",
+ " _digest=obj_create_res.digest,\n",
+ ")\n",
+ "\n",
+ "call_ref = CallRef(\n",
+ " entity=client._project_id().split(\"/\")[0],\n",
+ " project=client._project_id().split(\"/\")[1],\n",
+ " id=call_res.call_id,\n",
+ ")\n",
+ "\n",
+ "score_res = client.server.score_call(\n",
+ " tsi.ScoreCallReq.model_validate(\n",
+ " {\n",
+ " \"project_id\": client._project_id(),\n",
+ " \"call_ref\": call_ref.uri(),\n",
+ " \"scorer_ref\": scorer_ref.uri(),\n",
+ " }\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "score_res"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "wandb-weave",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tests/trace/builtin_objects/test_builtin_model.py b/tests/trace/builtin_objects/test_builtin_model.py
new file mode 100644
index 00000000000..fa28ae0755f
--- /dev/null
+++ b/tests/trace/builtin_objects/test_builtin_model.py
@@ -0,0 +1,149 @@
+import weave
+from weave.builtin_objects.models.CompletionModel import LiteLLMCompletionModel
+from weave.trace.refs import ObjectRef
+from weave.trace.weave_client import WeaveClient
+from weave.trace_server import trace_server_interface as tsi
+
+model_args = {
+ "model": "gpt-4o",
+ "messages_template": [{"role": "user", "content": "{input}"}],
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {
+ "name": "Person",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "age": {"type": "integer"},
+ "name": {"type": "string"},
+ },
+ },
+ },
+ },
+}
+
+input_text = "My name is Carlos and I am 42 years old."
+
+expected_result = {"age": 42, "name": "Carlos"}
+
+
+def test_model_publishing_alignment(client: WeaveClient):
+ model = LiteLLMCompletionModel(**model_args)
+ publish_ref = weave.publish(model)
+
+ obj_create_res = client.server.obj_create(
+ tsi.ObjCreateReq.model_validate(
+ {
+ "obj": {
+ "project_id": client._project_id(),
+ "object_id": "LiteLLMCompletionModel",
+ "val": model_args,
+ "set_leaf_object_class": "LiteLLMCompletionModel",
+ }
+ }
+ )
+ )
+
+ assert obj_create_res.digest == publish_ref.digest
+
+ gotten_model = weave.ref(publish_ref.uri()).get()
+ assert isinstance(gotten_model, LiteLLMCompletionModel)
+
+
+def test_model_local_create_local_use(client: WeaveClient):
+ model = LiteLLMCompletionModel(**model_args)
+ predict_result = model.predict(input=input_text)
+ assert predict_result == expected_result
+
+
+def test_model_local_create_remote_use(client: WeaveClient):
+ model = LiteLLMCompletionModel(**model_args)
+ publish_ref = weave.publish(model)
+ remote_call_res = client.server.call_method(
+ tsi.CallMethodReq.model_validate(
+ {
+ "project_id": client._project_id(),
+ "object_ref": publish_ref.uri(),
+ "method_name": "predict",
+ "args": {"input": input_text},
+ }
+ )
+ )
+ assert remote_call_res.output == expected_result
+
+ remote_call_read = client.server.call_read(
+ tsi.CallReadReq.model_validate(
+ {
+ "project_id": client._project_id(),
+ "id": remote_call_res.call_id,
+ }
+ )
+ )
+ assert remote_call_read.call.output == expected_result
+
+
+def test_model_remote_create_local_use(client: WeaveClient):
+ obj_create_res = client.server.obj_create(
+ tsi.ObjCreateReq.model_validate(
+ {
+ "obj": {
+ "project_id": client._project_id(),
+ "object_id": "LiteLLMCompletionModel",
+ "val": model_args,
+ "set_leaf_object_class": "LiteLLMCompletionModel",
+ }
+ }
+ )
+ )
+ obj_ref = ObjectRef(
+ entity=client._project_id().split("/")[0],
+ project=client._project_id().split("/")[1],
+ name="LiteLLMCompletionModel",
+ _digest=obj_create_res.digest,
+ )
+ fetched = obj_ref.get()
+ assert isinstance(fetched, LiteLLMCompletionModel)
+ predict_res = fetched.predict(input=input_text)
+ assert predict_res == expected_result
+
+
+def test_model_remote_create_remote_use(client: WeaveClient):
+ obj_create_res = client.server.obj_create(
+ tsi.ObjCreateReq.model_validate(
+ {
+ "obj": {
+ "project_id": client._project_id(),
+ "object_id": "LiteLLMCompletionModel",
+ "val": model_args,
+ "set_leaf_object_class": "LiteLLMCompletionModel",
+ }
+ }
+ )
+ )
+ obj_ref = ObjectRef(
+ entity=client._project_id().split("/")[0],
+ project=client._project_id().split("/")[1],
+ name="LiteLLMCompletionModel",
+ _digest=obj_create_res.digest,
+ )
+ obj_call_res = client.server.call_method(
+ tsi.CallMethodReq.model_validate(
+ {
+ "project_id": client._project_id(),
+ "object_ref": obj_ref.uri(),
+ "method_name": "predict",
+ "args": {"input": input_text},
+ }
+ )
+ )
+ assert obj_call_res.output == expected_result
+
+ remote_call_read = client.server.call_read(
+ tsi.CallReadReq.model_validate(
+ {
+ "project_id": client._project_id(),
+ "id": obj_call_res.call_id,
+ }
+ )
+ )
+ assert remote_call_read.call.output == expected_result
diff --git a/tests/trace/builtin_objects/test_builtin_scorer.py b/tests/trace/builtin_objects/test_builtin_scorer.py
new file mode 100644
index 00000000000..d41bfc3fbd8
--- /dev/null
+++ b/tests/trace/builtin_objects/test_builtin_scorer.py
@@ -0,0 +1,152 @@
+# Tests:
+# 1. Publishing alignment & class alignment
+# 2. Local Create, Local Direct Score
+# 3. Local Create, Remote Direct Score
+# 4. Remote Create, Local Direct Score
+# 5. Remote Create, Remote Direct Score
+from __future__ import annotations
+
+import weave
+from weave.builtin_objects.scorers.LLMJudgeScorer import LLMJudgeScorer
+from weave.trace.weave_client import ApplyScorerResult, Call, WeaveClient
+from weave.trace_server import trace_server_interface as tsi
+
+scorer_args = {
+ "model": "gpt-4o",
+ "system_prompt": "You are a judge that scores the correctness of a response.",
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {
+ "name": "Correctness",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "is_correct": {"type": "boolean"},
+ },
+ },
+ },
+ },
+}
+
+score_input = {"inputs": {"question": "What color is the sky?"}, "output": "blue"}
+
+expected_score = {"is_correct": True}
+
+
+def test_scorer_publishing_alignment(client: WeaveClient):
+ model = LLMJudgeScorer(**scorer_args)
+ publish_ref = weave.publish(model, name="CorrectnessJudge")
+
+ obj_create_res = client.server.obj_create(
+ tsi.ObjCreateReq.model_validate(
+ {
+ "obj": {
+ "project_id": client._project_id(),
+ "object_id": "CorrectnessJudge",
+ "val": scorer_args,
+ "set_leaf_object_class": "LLMJudgeScorer",
+ }
+ }
+ )
+ )
+
+ assert obj_create_res.digest == publish_ref.digest
+
+ gotten_model = weave.ref(publish_ref.uri()).get()
+ assert isinstance(gotten_model, LLMJudgeScorer)
+
+
+def make_simple_call():
+ @weave.op
+ def simple_op(question: str) -> str:
+ return "blue"
+
+ res, call = simple_op.call("What color is the sky?")
+ return res, call
+
+
+def assert_expected_outcome(
+ target_call: Call, scorer_res: ApplyScorerResult | tsi.ScoreCallRes
+):
+ scorer_output = None
+ feedback_id = None
+ if isinstance(scorer_res, tsi.ScoreCallRes):
+ scorer_output = scorer_res.score_call.output
+ feedback_id = scorer_res.feedback_id
+ else:
+ scorer_output = scorer_res["score_call"].output
+ feedback_id = scorer_res["feedback_id"]
+
+ assert scorer_output == expected_score
+ feedbacks = list(target_call.feedback)
+ assert len(feedbacks) == 1
+ assert feedbacks[0].payload["output"] == expected_score
+ assert feedbacks[0].id == feedback_id
+
+
+def do_remote_score(
+ client: WeaveClient, target_call: Call, scorer_ref: weave.ObjectRef
+):
+ return client.server.score_call(
+ tsi.ScoreCallReq.model_validate(
+ {
+ "project_id": client._project_id(),
+ "call_ref": target_call.ref.uri(),
+ "scorer_ref": scorer_ref.uri(),
+ }
+ )
+ )
+
+
+def make_remote_scorer(client: WeaveClient):
+ obj_create_res = client.server.obj_create(
+ tsi.ObjCreateReq.model_validate(
+ {
+ "obj": {
+ "project_id": client._project_id(),
+ "object_id": "CorrectnessJudge",
+ "val": scorer_args,
+ "set_leaf_object_class": "LLMJudgeScorer",
+ }
+ }
+ )
+ )
+ client._flush()
+ obj_ref = weave.ObjectRef(
+ entity=client._project_id().split("/")[0],
+ project=client._project_id().split("/")[1],
+ name="CorrectnessJudge",
+ _digest=obj_create_res.digest,
+ )
+ return obj_ref
+
+
+def test_scorer_local_create_local_use(client: WeaveClient):
+ scorer = LLMJudgeScorer(**scorer_args)
+ res, call = make_simple_call()
+ apply_scorer_res = call._apply_scorer(scorer)
+ assert_expected_outcome(call, apply_scorer_res)
+
+
+def test_scorer_local_create_remote_use(client: WeaveClient):
+ scorer = LLMJudgeScorer(**scorer_args)
+ res, call = make_simple_call()
+ publish_ref = weave.publish(scorer)
+ remote_score_res = do_remote_score(client, call, publish_ref)
+ assert_expected_outcome(call, remote_score_res)
+
+
+def test_scorer_remote_create_local_use(client: WeaveClient):
+ obj_ref = make_remote_scorer(client)
+ fetched = weave.ref(obj_ref.uri()).get()
+ assert isinstance(fetched, LLMJudgeScorer)
+ res, call = make_simple_call()
+ apply_scorer_res = call._apply_scorer(fetched)
+ assert_expected_outcome(call, apply_scorer_res)
+
+
+def test_scorer_remote_create_remote_use(client: WeaveClient):
+ obj_ref = make_remote_scorer(client)
+ res, call = make_simple_call()
+ remote_score_res = do_remote_score(client, call, obj_ref)
+ assert_expected_outcome(call, remote_score_res)
diff --git a/tests/trace/test_base_object_classes.py b/tests/trace/test_base_object_classes.py
index a264941f7b0..98b5dd8458e 100644
--- a/tests/trace/test_base_object_classes.py
+++ b/tests/trace/test_base_object_classes.py
@@ -139,7 +139,7 @@ def test_interface_creation(client):
"project_id": client._project_id(),
"object_id": nested_obj_id,
"val": nested_obj.model_dump(),
- "set_base_object_class": "TestOnlyNestedBaseObject",
+ "set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
@@ -164,7 +164,7 @@ def test_interface_creation(client):
"project_id": client._project_id(),
"object_id": top_level_obj_id,
"val": top_obj.model_dump(),
- "set_base_object_class": "TestOnlyExample",
+ "set_leaf_object_class": "TestOnlyExample",
}
}
)
@@ -271,7 +271,7 @@ def test_digest_equality(client):
"project_id": client._project_id(),
"object_id": nested_obj_id,
"val": nested_obj.model_dump(),
- "set_base_object_class": "TestOnlyNestedBaseObject",
+ "set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
@@ -300,7 +300,7 @@ def test_digest_equality(client):
"project_id": client._project_id(),
"object_id": top_level_obj_id,
"val": top_obj.model_dump(),
- "set_base_object_class": "TestOnlyExample",
+ "set_leaf_object_class": "TestOnlyExample",
}
}
)
@@ -322,7 +322,7 @@ def test_schema_validation(client):
"object_id": "nested_obj",
# Incorrect schema, should raise!
"val": {"a": 2},
- "set_base_object_class": "TestOnlyNestedBaseObject",
+ "set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
@@ -340,7 +340,7 @@ def test_schema_validation(client):
"_class_name": "TestOnlyNestedBaseObject",
"_bases": ["BaseObject", "BaseModel"],
},
- "set_base_object_class": "TestOnlyNestedBaseObject",
+ "set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
@@ -359,7 +359,7 @@ def test_schema_validation(client):
"_class_name": "TestOnlyNestedBaseObject",
"_bases": ["BaseObject", "BaseModel"],
},
- "set_base_object_class": "TestOnlyExample",
+ "set_leaf_object_class": "TestOnlyExample",
}
}
)
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
index a3315266f65..058475ebef3 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
@@ -70,7 +70,7 @@ export const CallPage: FC<{
};
export const useShowRunnableUI = () => {
- return false;
+ return true;
// Uncomment to re-enable
// const viewerInfo = useViewerInfo();
// return viewerInfo.loading ? false : viewerInfo.userInfo?.admin;
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx
index 0b4e9374a26..e27f8d06117 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallScoresViewer.tsx
@@ -13,7 +13,7 @@ import {NotApplicable} from '../../../Browse2/NotApplicable';
import {SmallRef} from '../../../Browse2/SmallRef';
import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component
import {
- TraceObjSchemaForBaseObjectClass,
+ TraceObjSchemaForObjectClass,
useBaseObjectInstances,
} from '../wfReactInterface/baseObjectClassQuery';
import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants';
@@ -61,7 +61,7 @@ const useRunnableFeedbacksForCall = (call: CallSchema) => {
const useRunnableFeedbackTypeToLatestActionRef = (
call: CallSchema,
- actionSpecs: Array>
+ actionSpecs: Array>
): Record => {
return useMemo(() => {
return _.fromPairs(
@@ -92,7 +92,7 @@ type GroupedRowType = {
};
const useTableRowsForRunnableFeedbacks = (
- actionSpecs: Array>,
+ actionSpecs: Array>,
runnableFeedbacks: Feedback[],
runnableFeedbackTypeToLatestActionRef: Record
): GroupedRowType[] => {
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx
index 52d6795806c..4d3d37caa57 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardListingPage.tsx
@@ -12,7 +12,7 @@ import {SimplePageLayout} from '../common/SimplePageLayout';
import {ObjectVersionsTable} from '../ObjectVersionsPage';
import {
useBaseObjectInstances,
- useCreateBaseObjectInstance,
+ useCreateLeafObjectInstance,
} from '../wfReactInterface/baseObjectClassQuery';
import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient';
import {
@@ -162,7 +162,7 @@ const generateLeaderboardId = () => {
};
const useCreateLeaderboard = (entity: string, project: string) => {
- const createLeaderboardInstance = useCreateBaseObjectInstance('Leaderboard');
+ const createLeaderboardInstance = useCreateLeafObjectInstance('Leaderboard');
const createLeaderboard = async () => {
const objectId = sanitizeObjectId(generateLeaderboardId());
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx
index 6fac8eaa599..fa1881462a0 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/LeaderboardPage/LeaderboardPage.tsx
@@ -26,7 +26,7 @@ import {LeaderboardObjectVal} from '../../views/Leaderboard/types/leaderboardCon
import {SimplePageLayout} from '../common/SimplePageLayout';
import {
useBaseObjectInstances,
- useCreateBaseObjectInstance,
+ useCreateLeafObjectInstance,
} from '../wfReactInterface/baseObjectClassQuery';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
import {LeaderboardConfigEditor} from './LeaderboardConfigEditor';
@@ -131,7 +131,7 @@ const useUpdateLeaderboard = (
project: string,
objectId: string
) => {
- const createLeaderboard = useCreateBaseObjectInstance('Leaderboard');
+ const createLeaderboard = useCreateLeafObjectInstance('Leaderboard');
const updateLeaderboard = async (leaderboardVal: LeaderboardObjectVal) => {
return await createLeaderboard({
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx
index a478437facb..3f84196c6c7 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx
@@ -2,7 +2,7 @@ import {Box} from '@material-ui/core';
import React, {FC, useCallback, useEffect, useState} from 'react';
import {z} from 'zod';
-import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
+import {createLeafObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
@@ -86,7 +86,7 @@ export const onAnnotationScorerSave = async (
) => {
const jsonSchemaType = convertTypeToJsonSchemaType(data.Type.type);
const typeExtras = convertTypeExtrasToJsonSchema(data);
- return createBaseObjectInstance(client, 'AnnotationSpec', {
+ return createLeafObjectInstance(client, 'AnnotationSpec', {
obj: {
project_id: projectIdFromParts({entity, project}),
object_id: sanitizeObjectId(data.Name),
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx
index 64823e9d551..31195a5cb10 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/LLMJudgeScorerForm.tsx
@@ -11,7 +11,7 @@ import React, {FC, useCallback, useState} from 'react';
import {z} from 'zod';
import {LlmJudgeActionSpecSchema} from '../wfReactInterface/baseObjectClasses.zod';
-import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
+import {createLeafObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {ActionSpecSchema} from '../wfReactInterface/generatedBaseObjectClasses.zod';
import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
@@ -185,7 +185,7 @@ export const onLLMJudgeScorerSave = async (
config: judgeAction,
});
- return createBaseObjectInstance(client, 'ActionSpec', {
+ return createLeafObjectInstance(client, 'ActionSpec', {
obj: {
project_id: projectIdFromParts({entity, project}),
object_id: objectId,
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts
index 9918ae7f285..f61f56afa7d 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.test.ts
@@ -2,7 +2,7 @@ import {expectType} from 'tsd';
import {
useBaseObjectInstances,
- useCreateBaseObjectInstance,
+ useCreateLeafObjectInstance,
} from './baseObjectClassQuery';
import {
TestOnlyExample,
@@ -74,7 +74,7 @@ describe('Type Tests', () => {
it('useCreateCollectionObject return type matches expected structure', () => {
type CreateCollectionObjectReturn = ReturnType<
- typeof useCreateBaseObjectInstance<'TestOnlyExample'>
+ typeof useCreateLeafObjectInstance<'TestOnlyExample'>
>;
// Define the expected type structure
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts
index 6ceb39daa70..9125217f614 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/baseObjectClassQuery.ts
@@ -13,22 +13,22 @@ import {
} from './traceServerClientTypes';
import {Loadable} from './wfDataModelHooksInterface';
-type BaseObjectClassRegistry = typeof baseObjectClassRegistry;
-type BaseObjectClassRegistryKeys = keyof BaseObjectClassRegistry;
-type BaseObjectClassType = z.infer<
- BaseObjectClassRegistry[C]
+type ObjectClassRegistry = typeof baseObjectClassRegistry; // TODO: Add more here - not just bases!
+type ObjectClassRegistryKeys = keyof ObjectClassRegistry;
+type ObjectClassType = z.infer<
+ ObjectClassRegistry[C]
>;
-export type TraceObjSchemaForBaseObjectClass<
- C extends BaseObjectClassRegistryKeys
-> = TraceObjSchema, C>;
+export type TraceObjSchemaForObjectClass<
+ C extends ObjectClassRegistryKeys
+> = TraceObjSchema, C>;
-export const useBaseObjectInstances = (
+export const useBaseObjectInstances = (
baseObjectClassName: C,
req: TraceObjQueryReq
-): Loadable>> => {
+): Loadable>> => {
const [objects, setObjects] = useState<
- Array>
+ Array>
>([]);
const getTsClient = useGetTraceServerClientContext();
const client = getTsClient();
@@ -56,11 +56,11 @@ export const useBaseObjectInstances = (
return {result: objects, loading};
};
-const getBaseObjectInstances = async (
+const getBaseObjectInstances = async (
client: TraceServerClient,
baseObjectClassName: C,
req: TraceObjQueryReq
-): Promise, C>>> => {
+): Promise, C>>> => {
const knownObjectClass = baseObjectClassRegistry[baseObjectClassName];
if (!knownObjectClass) {
console.warn(`Unknown object class: ${baseObjectClassName}`);
@@ -86,47 +86,47 @@ const getBaseObjectInstances = async (
.map(
({obj, parsed}) =>
({...obj, val: parsed.data} as TraceObjSchema<
- BaseObjectClassType,
+ ObjectClassType,
C
>)
);
};
-export const useCreateBaseObjectInstance = <
- C extends BaseObjectClassRegistryKeys,
- T = BaseObjectClassType
+export const useCreateLeafObjectInstance = <
+ C extends ObjectClassRegistryKeys,
+ T = ObjectClassType
>(
- baseObjectClassName: C
+ leafObjectClassName: C
): ((req: TraceObjCreateReq) => Promise) => {
const getTsClient = useGetTraceServerClientContext();
const client = getTsClient();
return (req: TraceObjCreateReq) =>
- createBaseObjectInstance(client, baseObjectClassName, req);
+ createLeafObjectInstance(client, leafObjectClassName, req);
};
-export const createBaseObjectInstance = async <
- C extends BaseObjectClassRegistryKeys,
- T = BaseObjectClassType
+export const createLeafObjectInstance = async <
+ C extends ObjectClassRegistryKeys,
+ T = ObjectClassType
>(
client: TraceServerClient,
- baseObjectClassName: C,
+ leafObjectClassName: C,
req: TraceObjCreateReq
): Promise => {
if (
- req.obj.set_base_object_class != null &&
- req.obj.set_base_object_class !== baseObjectClassName
+ req.obj.set_leaf_object_class != null &&
+ req.obj.set_leaf_object_class !== leafObjectClassName
) {
throw new Error(
- `set_base_object_class must match baseObjectClassName: ${baseObjectClassName}`
+ `set_leaf_object_class must match leafObjectClassName: ${leafObjectClassName}`
);
}
- const knownBaseObjectClass = baseObjectClassRegistry[baseObjectClassName];
- if (!knownBaseObjectClass) {
- throw new Error(`Unknown object class: ${baseObjectClassName}`);
+ const knownObjectClass = baseObjectClassRegistry[leafObjectClassName];
+ if (!knownObjectClass) {
+ throw new Error(`Unknown object class: ${leafObjectClassName}`);
}
- const verifiedObject = knownBaseObjectClass.safeParse(req.obj.val);
+ const verifiedObject = knownObjectClass.safeParse(req.obj.val);
if (!verifiedObject.success) {
throw new Error(
@@ -134,13 +134,13 @@ export const createBaseObjectInstance = async <
);
}
- const reqWithBaseObjectClass: TraceObjCreateReq = {
+ const reqWithLeafObjectClass: TraceObjCreateReq = {
...req,
obj: {
...req.obj,
- set_base_object_class: baseObjectClassName,
+ set_leaf_object_class: leafObjectClassName,
},
};
- return client.objCreate(reqWithBaseObjectClass);
+ return client.objCreate(reqWithLeafObjectClass);
};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts
index c396962f0fb..520926532c1 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts
@@ -243,7 +243,7 @@ export type TraceObjCreateReq = {
project_id: string;
object_id: string;
val: T;
- set_base_object_class?: string;
+ set_leaf_object_class?: string;
};
};
diff --git a/weave/builtin_objects/builtin_registry.py b/weave/builtin_objects/builtin_registry.py
new file mode 100644
index 00000000000..41f7fb82c91
--- /dev/null
+++ b/weave/builtin_objects/builtin_registry.py
@@ -0,0 +1,23 @@
+import weave
+from weave.builtin_objects.models.CompletionModel import LiteLLMCompletionModel
+from weave.builtin_objects.scorers.LLMJudgeScorer import LLMJudgeScorer
+
+_BUILTIN_REGISTRY: dict[str, type[weave.Object]] = {}
+
+
+def register_builtin(cls: type[weave.Object]) -> None:
+ if not issubclass(cls, weave.Object):
+ raise TypeError(f"Object {cls} is not a subclass of weave.Object")
+
+ if cls.__name__ in _BUILTIN_REGISTRY:
+ raise ValueError(f"Object {cls} already registered")
+
+ _BUILTIN_REGISTRY[cls.__name__] = cls
+
+
+def get_builtin(name: str) -> type[weave.Object]:
+ return _BUILTIN_REGISTRY[name]
+
+
+register_builtin(LiteLLMCompletionModel)
+register_builtin(LLMJudgeScorer)
diff --git a/weave/builtin_objects/models/CompletionModel.py b/weave/builtin_objects/models/CompletionModel.py
new file mode 100644
index 00000000000..808a764ca69
--- /dev/null
+++ b/weave/builtin_objects/models/CompletionModel.py
@@ -0,0 +1,27 @@
+import json
+from typing import Any, Optional
+
+import litellm
+
+import weave
+
+
+class LiteLLMCompletionModel(weave.Model):
+ model: str
+ messages_template: list[dict[str, str]]
+ response_format: Optional[dict] = None
+
+ @weave.op()
+ def predict(self, **kwargs: Any) -> str:
+ messages: list[dict] = [
+ {**m, "content": m["content"].format(**kwargs)}
+ for m in self.messages_template
+ ]
+
+ res = litellm.completion(
+ model=self.model,
+ messages=messages,
+ response_format=self.response_format,
+ )
+
+ return json.loads(res.choices[0].message.content)
diff --git a/weave/builtin_objects/scorers/LLMJudgeScorer.py b/weave/builtin_objects/scorers/LLMJudgeScorer.py
new file mode 100644
index 00000000000..480ef7a47f7
--- /dev/null
+++ b/weave/builtin_objects/scorers/LLMJudgeScorer.py
@@ -0,0 +1,37 @@
+import json
+from typing import Any
+
+import litellm
+
+import weave
+
+# TODO: Questions
+# Should "reasoning" be built into the system itself?
+
+
+class LLMJudgeScorer(weave.Scorer):
+ model: str
+ system_prompt: str
+ response_format: dict
+
+ @weave.op()
+ def score(self, inputs: dict, output: Any) -> str:
+ user_prompt = json.dumps(
+ {
+ "inputs": inputs,
+ "output": output,
+ }
+ )
+
+ messages = [
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
+
+ res = litellm.completion(
+ model=self.model,
+ messages=messages,
+ response_format=self.response_format,
+ )
+
+ return json.loads(res.choices[0].message.content)
diff --git a/weave/trace/serialize.py b/weave/trace/serialize.py
index 63210f9abee..9b0b6923464 100644
--- a/weave/trace/serialize.py
+++ b/weave/trace/serialize.py
@@ -245,6 +245,8 @@ def _load_custom_obj_files(
def from_json(obj: Any, project_id: str, server: TraceServerInterface) -> Any:
+ from weave.builtin_objects.builtin_registry import get_builtin
+
if isinstance(obj, list):
return [from_json(v, project_id, server) for v in obj]
elif isinstance(obj, dict):
@@ -265,6 +267,15 @@ def from_json(obj: Any, project_id: str, server: TraceServerInterface) -> Any:
and (baseObject := BASE_OBJECT_REGISTRY.get(val_type))
):
return baseObject.model_validate(obj)
+ elif (
+ isinstance(val_type, str)
+ and obj.get("_class_name") == val_type
+ and (baseObject := get_builtin(val_type))
+ ):
+ valid_keys = baseObject.model_fields.keys()
+ return baseObject.model_validate(
+ {k: v for k, v in obj.items() if k in valid_keys}
+ )
else:
return ObjectRecord(
{k: from_json(v, project_id, server) for k, v in obj.items()}
diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py
index 1d5d54b9b23..d933d29dbc1 100644
--- a/weave/trace/weave_client.py
+++ b/weave/trace/weave_client.py
@@ -10,7 +10,17 @@
from collections.abc import Iterator, Sequence
from concurrent.futures import Future
from functools import lru_cache
-from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generic,
+ Protocol,
+ TypedDict,
+ TypeVar,
+ cast,
+ overload,
+)
import pydantic
from requests import HTTPError
@@ -83,6 +93,9 @@
)
from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer
+if TYPE_CHECKING:
+ from weave.scorers.base_scorer import Scorer
+
# Controls if objects can have refs to projects not the WeaveClient project.
# If False, object refs with with mismatching projects will be recreated.
# If True, use existing ref to object in other project.
@@ -302,6 +315,11 @@ def map_to_refs(obj: Any) -> Any:
return obj
+class ApplyScorerResult(TypedDict):
+ feedback_id: str
+ score_call: Call
+
+
@dataclasses.dataclass
class Call:
"""A Call represents a single operation that was executed as part of a trace."""
@@ -445,7 +463,7 @@ def set_display_name(self, name: str | None) -> None:
def remove_display_name(self) -> None:
self.set_display_name(None)
- def _apply_scorer(self, scorer_op: Op) -> None:
+ def _apply_scorer(self, scorer_op: Op | Scorer) -> ApplyScorerResult:
"""
This is a private method that applies a scorer to a call and records the feedback.
In the near future, this will be made public, but for now it is only used internally
@@ -455,18 +473,32 @@ def _apply_scorer(self, scorer_op: Op) -> None:
inside `eval.py` uses this method inside the scorer block.
Current limitations:
- - only works for ops (not Scorer class)
- no async support
- no context yet (ie. ground truth)
"""
+ from weave.scorers.base_scorer import Scorer
+
+ self_arg = None
+ orig_scorer_op = scorer_op
+ if isinstance(scorer_op, Scorer):
+ self_arg = scorer_op
+ scorer_op = scorer_op.score
+
client = weave_client_context.require_weave_client()
scorer_signature = inspect.signature(scorer_op)
scorer_arg_names = list(scorer_signature.parameters.keys())
- score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names}
+ if "inputs" in scorer_arg_names:
+ score_args = {
+ "inputs": {k: v for k, v in self.inputs.items() if k != "self"}
+ }
+ else:
+ score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names}
+ if self_arg is not None:
+ score_args["self"] = self_arg
if "output" in scorer_arg_names:
score_args["output"] = self.output
_, score_call = scorer_op.call(**score_args)
- scorer_op_ref = get_ref(scorer_op)
+ scorer_op_ref = get_ref(orig_scorer_op)
if scorer_op_ref is None:
raise ValueError("Scorer op has no ref")
self_ref = get_ref(self)
@@ -476,13 +508,19 @@ def _apply_scorer(self, scorer_op: Op) -> None:
score_call_ref = get_ref(score_call)
if score_call_ref is None:
raise ValueError("Score call has no ref")
- client._add_runnable_feedback(
+ feedback_id = client._add_runnable_feedback(
weave_ref_uri=self_ref.uri(),
output=score_results,
call_ref_uri=score_call_ref.uri(),
runnable_ref_uri=scorer_op_ref.uri(),
)
+ # TODO: Make this a class
+ return ApplyScorerResult(
+ feedback_id=feedback_id,
+ score_call=score_call,
+ )
+
def make_client_call(
entity: str, project: str, server_call: CallSchema, server: TraceServerInterface
@@ -1522,6 +1560,8 @@ def _ref_is_own(self, ref: Ref) -> bool:
return isinstance(ref, Ref)
def _project_id(self) -> str:
+ # if self.entity == "":
+ # return self.project
return f"{self.entity}/{self.project}"
@trace_sentry.global_trace_sentry.watch()
diff --git a/weave/trace_server/base_object_class_util.py b/weave/trace_server/base_object_class_util.py
index 1c52f766c0c..55595583093 100644
--- a/weave/trace_server/base_object_class_util.py
+++ b/weave/trace_server/base_object_class_util.py
@@ -29,65 +29,76 @@ def get_base_object_class(val: Any) -> Optional[str]:
return None
+def get_leaf_object_class(val: Any) -> Optional[str]:
+ if isinstance(val, dict):
+ if "_bases" in val:
+ if isinstance(val["_bases"], list):
+ if len(val["_bases"]) >= 2:
+ if val["_bases"][-1] == "BaseModel":
+ if val["_bases"][-2] in base_object_class_names:
+ if "_class_name" in val:
+ return val["_class_name"]
+ return None
+
+
def process_incoming_object_val(
- val: Any, req_base_object_class: Optional[str] = None
+ val: Any, req_leaf_object_class: Optional[str] = None
) -> tuple[dict, Optional[str]]:
"""
This method is responsible for accepting an incoming object from the user, validating it
- against the base object class, and returning the object with the base object class
+ against the leaf object class, and returning the object with the base object class
set. It does not mutate the original object, but returns a new object with values set if needed.
Specifically,:
1. If the object is not a dict, it is returned as is, and the base object class is set to None.
- 2. There are 2 ways to specify the base object class:
- a. The `req_base_object_class` argument.
+ 2. There are 2 ways to specify the leaf object class:
+ a. The `req_leaf_object_class` argument.
* used by non-pythonic writers of weave objects
b. The `_bases` & `_class_name` attributes of the object, which is a list of base class names.
* used by pythonic weave object writers (legacy)
- 3. If the object has a base object class that does not match the requested base object class,
+ 3. If the object has a leaf object class that does not match the requested leaf object class,
an error is thrown.
- 4. if the object contains a base object class inside the payload, then we simply validate
- the object against the base object class (if a match is found in BASE_OBJECT_REGISTRY)
- 5. If the object does not have a base object class and a requested base object class is
+ 4. if the object contains a leaf object class inside the payload, then we simply validate
+ the object against the leaf object class (if a match is found in BASE_OBJECT_REGISTRY)
+ 5. If the object does not have a leaf object class and a requested leaf object class is
provided, we require a match in BASE_OBJECT_REGISTRY and validate the object against
- the requested base object class. Finally, we set the correct feilds.
+ the requested leaf object class. Finally, we set the correct fields.
"""
if not isinstance(val, dict):
- if req_base_object_class is not None:
+ if req_leaf_object_class is not None:
raise ValueError(
- "set_base_object_class cannot be provided for non-dict objects"
+ "set_leaf_object_class cannot be provided for non-dict objects"
)
return val, None
dict_val = val.copy()
- val_base_object_class = get_base_object_class(dict_val)
+ val_leaf_object_class = get_leaf_object_class(dict_val)
if (
- val_base_object_class != None
- and req_base_object_class != None
- and val_base_object_class != req_base_object_class
+ val_leaf_object_class != None
+ and req_leaf_object_class != None
+ and val_leaf_object_class != req_leaf_object_class
):
raise ValueError(
- f"set_base_object_class must match base_object_class: {val_base_object_class} != {req_base_object_class}"
+ f"set_leaf_object_class must match found leaf class: {req_leaf_object_class} != {val_leaf_object_class}"
)
- if val_base_object_class is not None:
+ if val_leaf_object_class is not None:
# In this case, we simply validate if the match is found
- if base_object_class_type := BASE_OBJECT_REGISTRY.get(val_base_object_class):
- base_object_class_type.model_validate(dict_val)
- elif req_base_object_class is not None:
+ if object_class_type := BASE_OBJECT_REGISTRY.get(val_leaf_object_class):
+ object_class_type.model_validate(dict_val)
+ elif req_leaf_object_class is not None:
# In this case, we require that the base object class is registered
- if base_object_class_type := BASE_OBJECT_REGISTRY.get(req_base_object_class):
- dict_val = dump_base_object(base_object_class_type.model_validate(dict_val))
+ if object_class_type := BASE_OBJECT_REGISTRY.get(req_leaf_object_class):
+ dict_val = dump_base_object(object_class_type.model_validate(dict_val))
else:
- raise ValueError(f"Unknown base object class: {req_base_object_class}")
-
- base_object_class = val_base_object_class or req_base_object_class
+ raise ValueError(f"Unknown leaf object class: {req_leaf_object_class}")
- return dict_val, base_object_class
+ return dict_val, get_base_object_class(dict_val)
+# BIG TODO: Replace this with a true object serialization step using some synthetic weave client!
# Server-side version of `pydantic_object_record`
def dump_base_object(val: BaseModel) -> dict:
cls = val.__class__
diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py
index d40d7bcc2a3..f7cf8e808d8 100644
--- a/weave/trace_server/clickhouse_trace_server_batched.py
+++ b/weave/trace_server/clickhouse_trace_server_batched.py
@@ -196,6 +196,16 @@ def __init__(
self._use_async_insert = use_async_insert
self._model_to_provider_info_map = read_model_to_provider_info_map()
+ def model_dump(self) -> dict[str, Any]:
+ return {
+ "host": self._host,
+ "port": self._port,
+ "user": self._user,
+ "password": self._password,
+ "database": self._database,
+ "use_async_insert": self._use_async_insert,
+ }
+
@classmethod
def from_env(cls, use_async_insert: bool = False) -> "ClickHouseTraceServer":
# Explicitly calling `RemoteHTTPTraceServer` constructor here to ensure
@@ -563,8 +573,22 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes:
return tsi.OpQueryRes(op_objs=objs)
def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
+ from weave.builtin_objects.builtin_registry import get_builtin
+
+ if req.obj.set_leaf_object_class is not None:
+ from weave.trace_server.server_side_object_saver import RunAsUser
+
+ object_class_type = get_builtin(req.obj.set_leaf_object_class)
+
+ new_obj = object_class_type.model_validate(req.obj.val)
+ runner = RunAsUser(ch_server_dump=self.model_dump())
+ digest = runner.run_save_object(
+ new_obj, req.obj.project_id, req.obj.object_id, None
+ )
+ return tsi.ObjCreateRes(digest=digest)
+
val, base_object_class = process_incoming_object_val(
- req.obj.val, req.obj.set_base_object_class
+ req.obj.val, req.obj.set_leaf_object_class
)
json_val = json.dumps(val)
@@ -1503,6 +1527,32 @@ def completions_create(
response=res.response, weave_call_id=start_call.id
)
+ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
+ from weave.trace_server.server_side_object_saver import RunAsUser
+
+ if req.wb_user_id is None:
+ raise ValueError("User ID is required")
+
+ runner = RunAsUser(ch_server_dump=self.model_dump())
+ # TODO: handle errors here
+ res = runner.run_call_method(
+ req.object_ref, req.project_id, req.wb_user_id, req.method_name, req.args
+ )
+ return tsi.CallMethodRes.model_validate(res)
+
+ def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
+ from weave.trace_server.server_side_object_saver import RunAsUser
+
+ runner = RunAsUser(ch_server_dump=self.model_dump())
+ res = runner.run_score_call(req)
+
+ return tsi.ScoreCallRes(
+ feedback_id=res["feedback_id"],
+ score_call=self.call_read(
+ tsi.CallReadReq(project_id=req.project_id, id=res["scorer_call_id"])
+ ).call,
+ )
+
# Private Methods
@property
def ch_client(self) -> CHClient:
diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py
index 1df739adbcd..fe38a2b1d91 100644
--- a/weave/trace_server/external_to_internal_trace_server_adapter.py
+++ b/weave/trace_server/external_to_internal_trace_server_adapter.py
@@ -376,3 +376,19 @@ def completions_create(
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
res = self._ref_apply(self._internal_trace_server.completions_create, req)
return res
+
+ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
+ req.project_id = self._idc.ext_to_int_project_id(req.project_id)
+ original_user_id = req.wb_user_id
+ if original_user_id is None:
+ raise ValueError("wb_user_id cannot be None")
+ req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id)
+ return self._ref_apply(self._internal_trace_server.call_method, req)
+
+ def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
+ req.project_id = self._idc.ext_to_int_project_id(req.project_id)
+ original_user_id = req.wb_user_id
+ if original_user_id is None:
+ raise ValueError("wb_user_id cannot be None")
+ req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id)
+ return self._ref_apply(self._internal_trace_server.score_call, req)
diff --git a/weave/trace_server/interface/base_object_classes/base_object_registry.py b/weave/trace_server/interface/base_object_classes/base_object_registry.py
index 19eb865daea..bc1bcc6d7a7 100644
--- a/weave/trace_server/interface/base_object_classes/base_object_registry.py
+++ b/weave/trace_server/interface/base_object_classes/base_object_registry.py
@@ -9,6 +9,7 @@
TestOnlyNestedBaseObject,
)
+# TODO: Migrate everything to normal Objects!
BASE_OBJECT_REGISTRY: dict[str, type[BaseObject]] = {}
diff --git a/weave/trace_server/migrations/002_add_deleted_at.up.sql b/weave/trace_server/migrations/002_add_deleted_at.up.sql
index 77538f75495..9d16534eb27 100644
--- a/weave/trace_server/migrations/002_add_deleted_at.up.sql
+++ b/weave/trace_server/migrations/002_add_deleted_at.up.sql
@@ -8,48 +8,6 @@ This migration adds:
ALTER TABLE object_versions
ADD COLUMN deleted_at Nullable(DateTime64(3)) DEFAULT NULL;
-CREATE OR REPLACE VIEW object_versions_deduped as
- SELECT project_id,
- object_id,
- created_at,
- deleted_at, -- **** Add deleted_at to the view ****
- kind,
- base_object_class,
- refs,
- val_dump,
- digest,
- if (kind = 'op', 1, 0) AS is_op,
- row_number() OVER (
- PARTITION BY project_id,
- kind,
- object_id
- ORDER BY created_at ASC
- ) AS _version_index_plus_1,
- _version_index_plus_1 - 1 AS version_index,
- count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count,
- if(_version_index_plus_1 = version_count, 1, 0) AS is_latest
- FROM (
- SELECT *,
- row_number() OVER (
- PARTITION BY project_id,
- kind,
- object_id,
- digest
- ORDER BY created_at ASC
- ) AS rn
- FROM object_versions
- )
- WHERE rn = 1 WINDOW w AS (
- PARTITION BY project_id,
- kind,
- object_id
- ORDER BY created_at ASC
- )
- ORDER BY project_id,
- kind,
- object_id,
- created_at;
-
ALTER TABLE call_parts
ADD COLUMN deleted_at Nullable(DateTime64(3)) DEFAULT NULL;
diff --git a/weave/trace_server/server_side_object_saver.py b/weave/trace_server/server_side_object_saver.py
new file mode 100644
index 00000000000..14d85b61f27
--- /dev/null
+++ b/weave/trace_server/server_side_object_saver.py
@@ -0,0 +1,397 @@
+from __future__ import annotations
+
+import multiprocessing
+from typing import Any, Callable, TypedDict
+
+import weave
+from weave.trace import autopatch
+from weave.trace.refs import ObjectRef
+from weave.trace.weave_client import WeaveClient
+from weave.trace.weave_init import InitializedClient
+from weave.trace_server import external_to_internal_trace_server_adapter
+from weave.trace_server import trace_server_interface as tsi
+from weave.trace_server.refs_internal import (
+ InternalCallRef,
+ InternalObjectRef,
+ parse_internal_uri,
+)
+
+
+class ScoreCallResult(TypedDict):
+ feedback_id: str
+ scorer_call_id: str
+
+
+class RunSaveObjectException(Exception):
+ pass
+
+
+class RunCallMethodException(Exception):
+ pass
+
+
+class RunScoreCallException(Exception):
+ pass
+
+
+class RunAsUser:
+ """Executes a function in a separate process for memory isolation.
+
+ This class provides a way to run functions in an isolated memory space using
+ multiprocessing. The function and its arguments are executed in a new Process,
+ ensuring complete memory isolation from the parent process.
+ """
+
+ def __init__(self, ch_server_dump: dict[str, Any]):
+ self.ch_server_dump = ch_server_dump
+
+ @staticmethod
+ def _process_runner(
+ func: Callable[..., Any],
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ result_queue: multiprocessing.Queue,
+ ) -> None:
+ """Execute the function and put its result in the queue.
+
+ Args:
+ func: The function to execute
+ args: Positional arguments for the function
+ kwargs: Keyword arguments for the function
+ result_queue: Queue to store the function's result
+ """
+ try:
+ result = func(*args, **kwargs)
+ result_queue.put(("success", result))
+ except Exception as e:
+ result_queue.put(("error", str(e)))
+
+ def run_save_object(
+ self,
+ new_obj: Any,
+ project_id: str,
+ object_name: str | None,
+ user_id: str | None,
+ ) -> str:
+ """Run the save_object operation in a separate process.
+
+ Args:
+ new_obj: The object to save
+ project_id: The project identifier
+ user_id: The user identifier
+
+ Returns:
+ str: The digest of the saved object
+
+ Raises:
+ Exception: If the save operation fails in the child process
+ """
+ result_queue: multiprocessing.Queue[tuple[str, str]] = multiprocessing.Queue()
+
+ process = multiprocessing.Process(
+ target=self._save_object,
+ args=(
+ new_obj,
+ project_id,
+ object_name,
+ user_id,
+ result_queue,
+ ), # Pass result_queue here
+ )
+
+ process.start()
+ status, result = result_queue.get()
+ process.join()
+
+ if status == "error":
+ raise RunSaveObjectException(f"Process execution failed: {result}")
+
+ return result
+
+ def _save_object(
+ self,
+ new_obj: Any,
+ project_id: str,
+ object_name: str | None,
+ user_id: str | None,
+ result_queue: multiprocessing.Queue,
+ ) -> None:
+ """Save an object in a separate process.
+
+ Args:
+ new_obj: The object to save
+ project_id: The project identifier
+ object_name: The name of the object
+ user_id: The user identifier
+ result_queue: Queue to store the operation's result
+ """
+ try:
+ from weave.trace_server.clickhouse_trace_server_batched import (
+ ClickHouseTraceServer,
+ )
+
+ client = WeaveClient(
+ "_SERVER_",
+ project_id,
+ UserInjectingExternalTraceServer(
+ ClickHouseTraceServer(**self.ch_server_dump),
+ id_converter=IdConverter(),
+ user_id=user_id,
+ ),
+ False,
+ )
+
+ ic = InitializedClient(client)
+ autopatch.autopatch()
+
+ res = weave.publish(new_obj, name=object_name).digest
+ autopatch.reset_autopatch()
+ client._flush()
+ ic.reset()
+ result_queue.put(("success", res)) # Put the result in the queue
+ except Exception as e:
+ result_queue.put(("error", str(e))) # Put any errors in the queue
+
+ def run_call_method(
+ self,
+ obj_ref: str,
+ project_id: str,
+ user_id: str,
+ method_name: str,
+ args: dict[str, Any],
+ ) -> str:
+ result_queue: multiprocessing.Queue[tuple[str, Any]] = multiprocessing.Queue()
+
+ process = multiprocessing.Process(
+ target=self._call_method,
+ args=(obj_ref, project_id, user_id, method_name, args, result_queue),
+ )
+
+ process.start()
+ status, result = result_queue.get()
+ process.join()
+
+ if status == "error":
+ raise RunCallMethodException(f"Process execution failed: {result}")
+
+ return result
+
+ def _call_method(
+ self,
+ obj_ref: str,
+ project_id: str,
+ user_id: str,
+ method_name: str,
+ args: dict[str, Any],
+ result_queue: multiprocessing.Queue,
+ ) -> None:
+ try:
+ from weave.trace_server.clickhouse_trace_server_batched import (
+ ClickHouseTraceServer,
+ )
+
+ client = WeaveClient(
+ "_SERVER_",
+ project_id,
+ UserInjectingExternalTraceServer(
+ ClickHouseTraceServer(**self.ch_server_dump),
+ id_converter=IdConverter(),
+ user_id=user_id,
+ ),
+ False,
+ )
+
+ ic = InitializedClient(client)
+ autopatch.autopatch()
+
+ # TODO: validate project alignment?
+ int_ref = parse_internal_uri(obj_ref)
+ assert isinstance(int_ref, InternalObjectRef)
+ ref = ObjectRef(
+ entity="_SERVER_",
+ project=int_ref.project_id,
+ name=int_ref.name,
+ _digest=int_ref.version,
+ )
+ obj = client.get(ref)
+ method = getattr(obj, method_name)
+ # TODO: Self might be wrong
+ res, call = method.call(self=obj, **args)
+ autopatch.reset_autopatch()
+ client._flush()
+ ic.reset()
+ result_queue.put(
+ ("success", {"output": res, "call_id": call.id})
+ ) # Put the result in the queue
+ except Exception as e:
+ result_queue.put(("error", str(e))) # Put any errors in the queue
+
+ def run_score_call(self, req: tsi.ScoreCallReq) -> ScoreCallResult:
+ result_queue: multiprocessing.Queue[tuple[str, ScoreCallResult | str]] = (
+ multiprocessing.Queue()
+ )
+
+ process = multiprocessing.Process(
+ target=self._score_call,
+ args=(req, result_queue),
+ )
+
+ process.start()
+ status, result = result_queue.get()
+ process.join()
+
+ if status == "error":
+ raise RunScoreCallException(f"Process execution failed: {result}")
+
+ if isinstance(result, dict):
+ return result
+ else:
+ raise RunScoreCallException(f"Unexpected result: {result}")
+
+ def _score_call(
+ self,
+ req: tsi.ScoreCallReq,
+ result_queue: multiprocessing.Queue[tuple[str, ScoreCallResult | str]],
+ ) -> None:
+ try:
+ from weave.trace.weave_client import Call
+ from weave.trace_server.clickhouse_trace_server_batched import (
+ ClickHouseTraceServer,
+ )
+
+ client = WeaveClient(
+ "_SERVER_",
+ req.project_id,
+ UserInjectingExternalTraceServer(
+ ClickHouseTraceServer(**self.ch_server_dump),
+ id_converter=IdConverter(),
+ user_id=req.wb_user_id,
+ ),
+ False,
+ )
+
+ ic = InitializedClient(client)
+ autopatch.autopatch()
+
+ target_call_ref = parse_internal_uri(req.call_ref)
+ if not isinstance(target_call_ref, InternalCallRef):
+ raise TypeError("Invalid call reference")
+ target_call = client.get_call(target_call_ref.id)._val
+ if not isinstance(target_call, Call):
+ raise TypeError("Invalid call reference")
+ scorer_ref = parse_internal_uri(req.scorer_ref)
+ if not isinstance(scorer_ref, InternalObjectRef):
+ raise TypeError("Invalid scorer reference")
+ scorer = weave.ref(
+ ObjectRef(
+ entity="_SERVER_",
+ project=scorer_ref.project_id,
+ name=scorer_ref.name,
+ _digest=scorer_ref.version,
+ ).uri()
+ ).get()
+ if not isinstance(scorer, weave.Scorer):
+ raise TypeError("Invalid scorer reference")
+ apply_scorer_res = target_call._apply_scorer(scorer)
+
+ autopatch.reset_autopatch()
+ client._flush()
+ ic.reset()
+ scorer_call_id = apply_scorer_res["score_call"].id
+ if not scorer_call_id:
+ raise ValueError("Scorer call ID is required")
+ result_queue.put(
+ (
+ "success",
+ ScoreCallResult(
+ feedback_id=apply_scorer_res["feedback_id"],
+ scorer_call_id=scorer_call_id,
+ ),
+ )
+ ) # Put the result in the queue
+ except Exception as e:
+ result_queue.put(("error", str(e))) # Put any errors in the queue
+
+
+class IdConverter(external_to_internal_trace_server_adapter.IdConverter):
+ def ext_to_int_project_id(self, project_id: str) -> str:
+ assert project_id.startswith("_SERVER_/")
+ return project_id[len("_SERVER_/") :]
+
+ def int_to_ext_project_id(self, project_id: str) -> str | None:
+ return "_SERVER_/" + project_id
+
+ def ext_to_int_run_id(self, run_id: str) -> str:
+ return run_id
+
+ def int_to_ext_run_id(self, run_id: str) -> str:
+ return run_id
+
+ def ext_to_int_user_id(self, user_id: str) -> str:
+ return user_id
+
+ def int_to_ext_user_id(self, user_id: str) -> str:
+ return user_id
+
+
+class UserInjectingExternalTraceServer(
+ external_to_internal_trace_server_adapter.ExternalTraceServer
+):
+ def __init__(
+ self,
+ internal_trace_server: tsi.TraceServerInterface,
+ id_converter: external_to_internal_trace_server_adapter.IdConverter,
+ user_id: str | None,
+ ):
+ super().__init__(internal_trace_server, id_converter)
+ self._user_id = user_id
+
+ def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.start.wb_user_id = self._user_id
+ return super().call_start(req)
+
+ def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.wb_user_id = self._user_id
+ return super().calls_delete(req)
+
+ def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.wb_user_id = self._user_id
+ return super().call_update(req)
+
+ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.wb_user_id = self._user_id
+ return super().feedback_create(req)
+
+ def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.wb_user_id = self._user_id
+ return super().cost_create(req)
+
+ def actions_execute_batch(
+ self, req: tsi.ActionsExecuteBatchReq
+ ) -> tsi.ActionsExecuteBatchRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.wb_user_id = self._user_id
+ return super().actions_execute_batch(req)
+
+ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.wb_user_id = self._user_id
+ return super().call_method(req)
+
+ def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
+ if self._user_id is None:
+ raise ValueError("User ID is required")
+ req.wb_user_id = self._user_id
+ return super().score_call(req)
diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py
index e14cef8041e..0cd4f7db5ca 100644
--- a/weave/trace_server/sqlite_trace_server.py
+++ b/weave/trace_server/sqlite_trace_server.py
@@ -612,7 +612,7 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
conn, cursor = get_conn_cursor(self.db_path)
val, base_object_class = process_incoming_object_val(
- req.obj.val, req.obj.set_base_object_class
+ req.obj.val, req.obj.set_leaf_object_class
)
json_val = json.dumps(val)
digest = str_digest(json_val)
@@ -1221,6 +1221,12 @@ def table_query_stream(
results = self.table_query(req)
yield from results.rows
+ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
+ raise NotImplementedError("call_method is not implemented for local sqlite")
+
+ def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
+ raise NotImplementedError("score_call is not implemented for local sqlite")
+
def get_type(val: Any) -> str:
if val == None:
diff --git a/weave/trace_server/todo.md b/weave/trace_server/todo.md
new file mode 100644
index 00000000000..77ce36d0d04
--- /dev/null
+++ b/weave/trace_server/todo.md
@@ -0,0 +1,35 @@
+* [x] Finish Porting over weave client
+* [ ] Add end-to-end test with new lifecycle:
+ * [x] Create Dummy Model via API -> See in object explorer
+ * [x] Requires finishing the leaf object registry
+ * [x] Requires sub process construction & saving?
+ * [x] Invoke Dummy Model via API -> See in trace explorer
+ * [x] Requires execution in sub process
+ * New TODOs:
+ * API Keys are not correctly setup in the sub runner
+ * A bunch of naming issues all around (leaf, saver, etc...)
+ * Not very good error handling
+ * Server seems to get stuck when it crashes
+ * Accidentally checked in parallelization disabling
+ * Create a Dummy Scorer via API -> See in object explorer
+ * [x] Should be pretty straight forward at this point
+ * Invoke the Scorer against the previous call via API -> see in traces AND in feedback
+ * [x] Should be mostly straight forward (the Scorer API itself is a bit wonky)
+ * Important Proof of system:
+ * [x] create the same dummy model locally & invoke -> notice no version change
+ * [x] Run locally against the call -> notice that there are no extra objects
+ * [ ]Should de-ref inputs if they contain refs
+* [ ] Refactor the entire "base model" system to conform to this new way of doing things (leaf models)
+ * [ ] Might get hairy with nested refs - consider implications
+* [ ] Figure out how to refactor scorers that use LLMs
+ * [ ] a new process with correct env setup (from secret fetcher?)
+ * [ ] scorers should have a client-spec, not a specific client
+ * [ ] How to model a scorers's stub (input, output, context, reference(s), etc...)
+ * [ ] How to handle output types from scorers (boolean, number, reason, etc...)
+ * [ ]Investigate why the tests are running so slowly
+* [ ] Consider a rule: objectId must be different than the name of the class when creating these objects
+
+
+---- Decomposition PRs ----
+1. Change/Add the set_object_class instead of base_object_class
+2. Add new methods to the server
\ No newline at end of file
diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py
index 77ef6198ecf..02edeb4b909 100644
--- a/weave/trace_server/trace_server_interface.py
+++ b/weave/trace_server/trace_server_interface.py
@@ -191,7 +191,7 @@ class ObjSchemaForInsert(BaseModel):
project_id: str
object_id: str
val: Any
- set_base_object_class: Optional[str] = None
+ set_leaf_object_class: Optional[str] = None
class TableSchemaForInsert(BaseModel):
@@ -867,6 +867,31 @@ class ActionsExecuteBatchRes(BaseModel):
pass
+class CallMethodReq(BaseModel):
+ project_id: str
+ object_ref: str
+ method_name: str
+ args: dict[str, Any]
+ wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
+
+
+class CallMethodRes(BaseModel):
+ call_id: str
+ output: Any
+
+
+class ScoreCallReq(BaseModel):
+ project_id: str
+ call_ref: str
+ scorer_ref: str
+ wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
+
+
+class ScoreCallRes(BaseModel):
+ feedback_id: str
+ score_call: CallSchema
+
+
class TraceServerInterface(Protocol):
def ensure_project_exists(
self, entity: str, project: str
@@ -910,6 +935,10 @@ def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ...
def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ...
def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ...
+ # Execute API
+ def call_method(self, req: CallMethodReq) -> CallMethodRes: ...
+ def score_call(self, req: ScoreCallReq) -> ScoreCallRes: ...
+
# Action API
def actions_execute_batch(
self, req: ActionsExecuteBatchReq
diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py
index b749af60d24..f5446bbfb71 100644
--- a/weave/trace_server_bindings/remote_http_trace_server.py
+++ b/weave/trace_server_bindings/remote_http_trace_server.py
@@ -569,6 +569,16 @@ def completions_create(
tsi.CompletionsCreateRes,
)
+ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
+ return self._generic_request(
+ "/execute/method", req, tsi.CallMethodReq, tsi.CallMethodRes
+ )
+
+ def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
+ return self._generic_request(
+ "/execute/score_call", req, tsi.ScoreCallReq, tsi.ScoreCallRes
+ )
+
__docspec__ = [
RemoteHTTPTraceServer,