From b1e14b6c0bd174eb79bcbb1d84287f179820d759 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Wed, 30 Oct 2024 11:03:01 -0700 Subject: [PATCH] chore(weave): Revert Object/Feedback schema validation from hackweek (#2824) * init * init * init * init * init * init --- tests/trace/test_actions.py | 158 ----------------- .../clickhouse_trace_server_batched.py | 166 ++---------------- ...ternal_to_internal_trace_server_adapter.py | 6 - .../base_models/action_base_models.py | 22 --- .../base_models/base_model_registry.py | 20 --- .../feedback_base_model_registry.py | 11 -- weave/trace_server/sqlite_trace_server.py | 7 - weave/trace_server/trace_server_interface.py | 16 -- .../remote_http_trace_server.py | 12 +- 9 files changed, 16 insertions(+), 402 deletions(-) delete mode 100644 tests/trace/test_actions.py delete mode 100644 weave/trace_server/interface/base_models/action_base_models.py delete mode 100644 weave/trace_server/interface/base_models/base_model_registry.py delete mode 100644 weave/trace_server/interface/base_models/feedback_base_model_registry.py diff --git a/tests/trace/test_actions.py b/tests/trace/test_actions.py deleted file mode 100644 index 7d85151e280..00000000000 --- a/tests/trace/test_actions.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -from typing import Any - -import pytest -from pydantic import BaseModel - -import weave -from weave.trace.refs import ObjectRef -from weave.trace.weave_client import WeaveClient -from weave.trace_server.interface.base_models.action_base_models import ( - ConfiguredAction, - _BuiltinAction, -) -from weave.trace_server.sqlite_trace_server import SqliteTraceServer -from weave.trace_server.trace_server_interface import ( - ExecuteBatchActionReq, - FeedbackCreateReq, - ObjCreateReq, - ObjQueryReq, -) - - -def test_action_execute_workflow(client: WeaveClient): - is_sqlite = isinstance(client.server._internal_trace_server, SqliteTraceServer) - if is_sqlite: - # dont run this test for sqlite - return - - # part 1: create the action - class ExampleResponse(BaseModel): - score: int - reason: str - - digest = client.server.obj_create( - ObjCreateReq.model_validate( - { - "obj": { - "project_id": client._project_id(), - "object_id": "test_object", - "base_object_class": "ConfiguredAction", - "val": ConfiguredAction( - name="test_action", - action=_BuiltinAction(name="llm_judge"), - config={ - "system_prompt": "you are a judge", - "model": "gpt-4o-mini", - "response_format_schema": ExampleResponse.model_json_schema(), - }, - ).model_dump(), - } - } - ) - ).digest - - configured_actions = client.server.objs_query( - ObjQueryReq.model_validate( - { - "project_id": client._project_id(), - "filter": {"base_object_classes": ["ConfiguredAction"]}, - } - ) - ) - - assert len(configured_actions.objs) == 1 - assert configured_actions.objs[0].digest == digest - action_ref_uri = ObjectRef( - entity=client.entity, - project=client.project, - name="test_object", - _digest=digest, - ).uri() - - # part 2: manually create feedback - @weave.op - def example_op(input: str) -> str: - return input[::-1] - - _, call1 = example_op.call("hello") - with pytest.raises(Exception): - client.server.feedback_create( - FeedbackCreateReq.model_validate( - { - "project_id": client._project_id(), - "weave_ref": call1.ref.uri(), - "feedback_type": "ActionScore", - "payload": { - "output": { - "score": 1, - "reason": "because", - } - }, - } - ) - ) - - res = client.server.feedback_create( - FeedbackCreateReq.model_validate( - { - "project_id": client._project_id(), - "weave_ref": call1.ref.uri(), - "feedback_type": "ActionScore", - "payload": { - "configured_action_ref": action_ref_uri, - "output": { - "score": 1, - "reason": "because", - }, - }, - } - ) - ) - - feedbacks = list(call1.feedback) - assert len(feedbacks) == 1 - assert feedbacks[0].payload == { - "configured_action_ref": action_ref_uri, - "output": { - "score": 1, - "reason": "because", - }, - } - - # Step 3: execute the action - if os.environ.get("CI"): - # skip this test in CI for now - return - - _, call2 = example_op.call("hello") - - res = client.server.execute_batch_action( - ExecuteBatchActionReq.model_validate( - { - "project_id": client._project_id(), - "call_ids": [call2.id], - "configured_action_ref": action_ref_uri, - } - ) - ) - - feedbacks = list(call2.feedback) - assert len(feedbacks) == 1 - assert feedbacks[0].payload == { - "configured_action_ref": action_ref_uri, - "output": { - "score": MatchesAnyNumber(), - "reason": MatchesAnyStr(), - }, - } - - -class MatchesAnyStr: - def __eq__(self, other: Any) -> bool: - return isinstance(other, str) - - -class MatchesAnyNumber(BaseModel): - def __eq__(self, other: Any) -> bool: - return isinstance(other, (int, float)) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 9a1f1a1303f..bcdb4413493 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -30,7 +30,6 @@ import threading from collections import defaultdict from contextlib import contextmanager -from functools import partial from typing import ( Any, Dict, @@ -80,19 +79,6 @@ validate_feedback_purge_req, ) from weave.trace_server.ids import generate_id -from weave.trace_server.interface.base_models.action_base_models import ( - LLM_JUDGE_ACTION_NAME, - ConfiguredAction, -) -from weave.trace_server.interface.base_models.base_model_registry import ( - base_model_dump, - base_model_name, - base_models, -) -from weave.trace_server.interface.base_models.feedback_base_model_registry import ( - ActionScore, - feedback_base_models, -) from weave.trace_server.llm_completion import lite_llm_completion from weave.trace_server.model_providers.model_providers import ( MODEL_PROVIDERS_FILE, @@ -139,8 +125,6 @@ MAX_DELETE_CALLS_COUNT = 100 MAX_CALLS_STREAM_BATCH_SIZE = 500 -WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID = "WEAVE_ACTION_EXECUTOR" - class NotFoundError(Exception): pass @@ -592,28 +576,16 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: return tsi.OpQueryRes(op_objs=objs) def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - req_obj = req.obj - dict_val = req_obj.val - - if req.obj.base_object_class: - for base_model in base_models: - if base_model_name(base_model) == req.obj.base_object_class: - # 1. Validate the object against the base model & re-dump to a dict - dict_val = base_model_dump(base_model.model_validate(dict_val)) - break - else: - raise ValueError( - f"Unknown base object class: {req.obj.base_object_class}" - ) - - json_val = json.dumps(dict_val) + json_val = json.dumps(req.obj.val) digest = str_digest(json_val) + + req_obj = req.obj ch_obj = ObjCHInsertable( project_id=req_obj.project_id, object_id=req_obj.object_id, - kind=get_kind(dict_val), - base_object_class=get_base_object_class(dict_val), - refs=extract_refs_from_values(dict_val), + kind=get_kind(req.obj.val), + base_object_class=get_base_object_class(req.obj.val), + refs=extract_refs_from_values(req.obj.val), val_dump=json_val, digest=digest, ) @@ -1356,17 +1328,8 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: assert_non_null_wb_user_id(req) validate_feedback_create_req(req) - feedback_type = req.feedback_type - res_payload = req.payload - - for feedback_base_model in feedback_base_models: - if base_model_name(feedback_base_model) == feedback_type: - res_payload = base_model_dump( - feedback_base_model.model_validate(res_payload) - ) - break - # Augment emoji with alias. + res_payload = {} if req.feedback_type == "wandb.reaction.1": em = req.payload["emoji"] if emoji.emoji_count(em) != 1: @@ -1432,97 +1395,6 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: self.ch_client.query(prepared.sql, prepared.parameters) return tsi.FeedbackPurgeRes() - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - # WARNING: THIS IS NOT GOING TO WORK IN PRODUCTION - # UNTIL WE HAVE THE API KEY PIECE IN PLACE - configured_action_ref = req.configured_action_ref - - action_dict_res = self.refs_read_batch( - tsi.RefsReadBatchReq(refs=[configured_action_ref]) - ) - - action_dict = action_dict_res.vals[0] - action = ConfiguredAction.model_validate(action_dict) - - if action.action.action_type != "builtin": - raise InvalidRequest( - "Only builtin actions are supported for batch execution" - ) - - if action.action.name != LLM_JUDGE_ACTION_NAME: - raise InvalidRequest("Only llm_judge is supported for batch execution") - - # Step 1: Get all the calls in the batch - calls = self.calls_query_stream( - tsi.CallsQueryReq( - project_id=req.project_id, - filter=tsi.CallsFilter( - call_ids=req.call_ids, - ), - ) - ) - - # Normally we would dispatch here, but just hard coding for now - # We should do some validation here - config = action.config - model = config["model"] - - if model not in ["gpt-4o-mini", "gpt-4o"]: - raise InvalidRequest("Only gpt-4o-mini and gpt-4o are supported") - - system_prompt = config["system_prompt"] - response_format_schema = config["response_format_schema"] - response_format = { - "type": "json_schema", - "json_schema": { - "name": "response_format", - "schema": response_format_schema, - }, - } - - # mapping = mapping.input_mapping - - # Step 2: For Each call, execute the action: (this needs a lot of safety checks) - for call in calls: - args = { - "inputs": call.inputs, - "output": call.output, - } - from openai import OpenAI - - client = OpenAI() - # Silly hack to get around issue in tests: - create = client.chat.completions.create - if hasattr(create, "resolve_fn"): - create = partial(create.resolve_fn, self=client.chat.completions) - completion = create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": json.dumps(args)}, - ], - response_format=response_format, - ) - self.feedback_create( - tsi.FeedbackCreateReq( - project_id=req.project_id, - weave_ref=ri.InternalCallRef( - project_id=req.project_id, - id=call.id, - ).uri(), - feedback_type=base_model_name(ActionScore), - wb_user_id=WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID, # - THIS IS NOT GOOD! - payload=ActionScore( - configured_action_ref=configured_action_ref, - output=json.loads(completion.choices[0].message.content), - ).model_dump(), - ) - ) - - return tsi.ExecuteBatchActionRes() - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: @@ -2144,7 +2016,7 @@ def _process_parameters( def get_type(val: Any) -> str: - if val is None: + if val == None: return "none" elif isinstance(val, dict): if "_type" in val: @@ -2165,24 +2037,16 @@ def get_kind(val: Any) -> str: def get_base_object_class(val: Any) -> Optional[str]: - """ - Get the base object class of a value using: - 1. The last base class that is a subclass of BaseModel and not Object - 2. The _class_name attribute if it exists - 3. None if no base class is found - """ if isinstance(val, dict): if "_bases" in val: if isinstance(val["_bases"], list): - bases = val["_bases"] - if len(bases) > 0 and bases[-1] == "BaseModel": - bases = bases[:-1] - if len(bases) > 0 and bases[-1] == "Object": - bases = bases[:-1] - if len(bases) > 0: - return bases[-1] - elif "_class_name" in val: - return val["_class_name"] + if len(val["_bases"]) >= 2: + if val["_bases"][-1] == "BaseModel": + if val["_bases"][-2] == "Object": + if len(val["_bases"]) > 2: + return val["_bases"][-3] + elif "_class_name" in val: + return val["_class_name"] return None diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index ffc229a9270..7e085b8f75e 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -346,12 +346,6 @@ def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: cost["pricing_level_id"] = original_project_id return res - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - req.project_id = self._idc.ext_to_int_project_id(req.project_id) - return self._ref_apply(self._internal_trace_server.execute_batch_action, req) - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: diff --git a/weave/trace_server/interface/base_models/action_base_models.py b/weave/trace_server/interface/base_models/action_base_models.py deleted file mode 100644 index c3c91b9903d..00000000000 --- a/weave/trace_server/interface/base_models/action_base_models.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -LLM_JUDGE_ACTION_NAME = "llm_judge" - - -class _BuiltinAction(BaseModel): - action_type: Literal["builtin"] = "builtin" - name: str - - -class ConfiguredAction(BaseModel): - name: str - action: _BuiltinAction - config: dict - - -class ActionDispatchFilter(BaseModel): - op_name: str - sample_rate: float - configured_action_ref: str diff --git a/weave/trace_server/interface/base_models/base_model_registry.py b/weave/trace_server/interface/base_models/base_model_registry.py deleted file mode 100644 index ea582f119f6..00000000000 --- a/weave/trace_server/interface/base_models/base_model_registry.py +++ /dev/null @@ -1,20 +0,0 @@ -from pydantic import BaseModel - -from weave.trace_server.interface.base_models.action_base_models import ( - ActionDispatchFilter, - ConfiguredAction, -) - - -def base_model_name(base_model_class: type[BaseModel]) -> str: - return base_model_class.__name__ - - -def base_model_dump(base_model_obj: BaseModel) -> dict: - d = base_model_obj.model_dump() - d["_class_name"] = base_model_name(base_model_obj.__class__) - d["_bases"] = [base_model_name(b) for b in base_model_obj.__class__.mro()[1:-1]] - return d - - -base_models: list[type[BaseModel]] = [ConfiguredAction, ActionDispatchFilter] diff --git a/weave/trace_server/interface/base_models/feedback_base_model_registry.py b/weave/trace_server/interface/base_models/feedback_base_model_registry.py deleted file mode 100644 index 16f2033fec3..00000000000 --- a/weave/trace_server/interface/base_models/feedback_base_model_registry.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - -from pydantic import BaseModel - - -class ActionScore(BaseModel): - configured_action_ref: str - output: Any - - -feedback_base_models: list[type[BaseModel]] = [ActionScore] diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 6f5b5648f63..93a4f510090 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1081,13 +1081,6 @@ def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: print("COST PURGE is not implemented for local sqlite", req) return tsi.CostPurgeRes() - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - raise NotImplementedError( - "EXECUTE BATCH ACTION is not implemented for local sqlite" - ) - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 4f064759505..abdfeae38ac 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -189,7 +189,6 @@ class ObjSchema(BaseModel): class ObjSchemaForInsert(BaseModel): project_id: str object_id: str - base_object_class: Optional[str] = None val: Any @@ -837,16 +836,6 @@ class CostPurgeRes(BaseModel): pass -class ExecuteBatchActionReq(BaseModel): - project_id: str - call_ids: list[str] - configured_action_ref: str - - -class ExecuteBatchActionRes(BaseModel): - pass - - class TraceServerInterface(Protocol): def ensure_project_exists( self, entity: str, project: str @@ -888,10 +877,5 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... - # Action API - def execute_batch_action( - self, req: ExecuteBatchActionReq - ) -> ExecuteBatchActionRes: ... - # Execute LLM API def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index af13a7c856f..34b906a560c 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -265,7 +265,7 @@ def call_start( req_as_obj = tsi.CallStartReq.model_validate(req) else: req_as_obj = req - if req_as_obj.start.id is None or req_as_obj.start.trace_id is None: + if req_as_obj.start.id == None or req_as_obj.start.trace_id == None: raise ValueError( "CallStartReq must have id and trace_id when batching." ) @@ -549,16 +549,6 @@ def cost_purge( "/cost/purge", req, tsi.CostPurgeReq, tsi.CostPurgeRes ) - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - return self._generic_request( - "/execute/batch_action", - req, - tsi.ExecuteBatchActionReq, - tsi.ExecuteBatchActionRes, - ) - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: