diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 9a1f1a1303f..fbef064d42b 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -30,7 +30,6 @@ import threading from collections import defaultdict from contextlib import contextmanager -from functools import partial from typing import ( Any, Dict, @@ -80,19 +79,6 @@ validate_feedback_purge_req, ) from weave.trace_server.ids import generate_id -from weave.trace_server.interface.base_models.action_base_models import ( - LLM_JUDGE_ACTION_NAME, - ConfiguredAction, -) -from weave.trace_server.interface.base_models.base_model_registry import ( - base_model_dump, - base_model_name, - base_models, -) -from weave.trace_server.interface.base_models.feedback_base_model_registry import ( - ActionScore, - feedback_base_models, -) from weave.trace_server.llm_completion import lite_llm_completion from weave.trace_server.model_providers.model_providers import ( MODEL_PROVIDERS_FILE, @@ -139,8 +125,6 @@ MAX_DELETE_CALLS_COUNT = 100 MAX_CALLS_STREAM_BATCH_SIZE = 500 -WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID = "WEAVE_ACTION_EXECUTOR" - class NotFoundError(Exception): pass @@ -592,28 +576,17 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: return tsi.OpQueryRes(op_objs=objs) def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: + json_val = json.dumps(req.obj.val) + digest = str_digest(json_val) + req_obj = req.obj - dict_val = req_obj.val - - if req.obj.base_object_class: - for base_model in base_models: - if base_model_name(base_model) == req.obj.base_object_class: - # 1. Validate the object against the base model & re-dump to a dict - dict_val = base_model_dump(base_model.model_validate(dict_val)) - break - else: - raise ValueError( - f"Unknown base object class: {req.obj.base_object_class}" - ) - json_val = json.dumps(dict_val) - digest = str_digest(json_val) ch_obj = ObjCHInsertable( project_id=req_obj.project_id, object_id=req_obj.object_id, - kind=get_kind(dict_val), - base_object_class=get_base_object_class(dict_val), - refs=extract_refs_from_values(dict_val), + kind=get_kind(req.obj.val), + base_object_class=get_base_object_class(req.obj.val), + refs=extract_refs_from_values(req.obj.val), val_dump=json_val, digest=digest, ) @@ -1356,16 +1329,8 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: assert_non_null_wb_user_id(req) validate_feedback_create_req(req) - feedback_type = req.feedback_type res_payload = req.payload - for feedback_base_model in feedback_base_models: - if base_model_name(feedback_base_model) == feedback_type: - res_payload = base_model_dump( - feedback_base_model.model_validate(res_payload) - ) - break - # Augment emoji with alias. if req.feedback_type == "wandb.reaction.1": em = req.payload["emoji"] @@ -1432,97 +1397,6 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: self.ch_client.query(prepared.sql, prepared.parameters) return tsi.FeedbackPurgeRes() - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - # WARNING: THIS IS NOT GOING TO WORK IN PRODUCTION - # UNTIL WE HAVE THE API KEY PIECE IN PLACE - configured_action_ref = req.configured_action_ref - - action_dict_res = self.refs_read_batch( - tsi.RefsReadBatchReq(refs=[configured_action_ref]) - ) - - action_dict = action_dict_res.vals[0] - action = ConfiguredAction.model_validate(action_dict) - - if action.action.action_type != "builtin": - raise InvalidRequest( - "Only builtin actions are supported for batch execution" - ) - - if action.action.name != LLM_JUDGE_ACTION_NAME: - raise InvalidRequest("Only llm_judge is supported for batch execution") - - # Step 1: Get all the calls in the batch - calls = self.calls_query_stream( - tsi.CallsQueryReq( - project_id=req.project_id, - filter=tsi.CallsFilter( - call_ids=req.call_ids, - ), - ) - ) - - # Normally we would dispatch here, but just hard coding for now - # We should do some validation here - config = action.config - model = config["model"] - - if model not in ["gpt-4o-mini", "gpt-4o"]: - raise InvalidRequest("Only gpt-4o-mini and gpt-4o are supported") - - system_prompt = config["system_prompt"] - response_format_schema = config["response_format_schema"] - response_format = { - "type": "json_schema", - "json_schema": { - "name": "response_format", - "schema": response_format_schema, - }, - } - - # mapping = mapping.input_mapping - - # Step 2: For Each call, execute the action: (this needs a lot of safety checks) - for call in calls: - args = { - "inputs": call.inputs, - "output": call.output, - } - from openai import OpenAI - - client = OpenAI() - # Silly hack to get around issue in tests: - create = client.chat.completions.create - if hasattr(create, "resolve_fn"): - create = partial(create.resolve_fn, self=client.chat.completions) - completion = create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": json.dumps(args)}, - ], - response_format=response_format, - ) - self.feedback_create( - tsi.FeedbackCreateReq( - project_id=req.project_id, - weave_ref=ri.InternalCallRef( - project_id=req.project_id, - id=call.id, - ).uri(), - feedback_type=base_model_name(ActionScore), - wb_user_id=WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID, # - THIS IS NOT GOOD! - payload=ActionScore( - configured_action_ref=configured_action_ref, - output=json.loads(completion.choices[0].message.content), - ).model_dump(), - ) - ) - - return tsi.ExecuteBatchActionRes() - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: @@ -2144,7 +2018,7 @@ def _process_parameters( def get_type(val: Any) -> str: - if val is None: + if val == None: return "none" elif isinstance(val, dict): if "_type" in val: @@ -2165,24 +2039,16 @@ def get_kind(val: Any) -> str: def get_base_object_class(val: Any) -> Optional[str]: - """ - Get the base object class of a value using: - 1. The last base class that is a subclass of BaseModel and not Object - 2. The _class_name attribute if it exists - 3. None if no base class is found - """ if isinstance(val, dict): if "_bases" in val: if isinstance(val["_bases"], list): - bases = val["_bases"] - if len(bases) > 0 and bases[-1] == "BaseModel": - bases = bases[:-1] - if len(bases) > 0 and bases[-1] == "Object": - bases = bases[:-1] - if len(bases) > 0: - return bases[-1] - elif "_class_name" in val: - return val["_class_name"] + if len(val["_bases"]) >= 2: + if val["_bases"][-1] == "BaseModel": + if val["_bases"][-2] == "Object": + if len(val["_bases"]) > 2: + return val["_bases"][-3] + elif "_class_name" in val: + return val["_class_name"] return None diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index ffc229a9270..7e085b8f75e 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -346,12 +346,6 @@ def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: cost["pricing_level_id"] = original_project_id return res - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - req.project_id = self._idc.ext_to_int_project_id(req.project_id) - return self._ref_apply(self._internal_trace_server.execute_batch_action, req) - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: diff --git a/weave/trace_server/interface/base_models/action_base_models.py b/weave/trace_server/interface/base_models/action_base_models.py deleted file mode 100644 index c3c91b9903d..00000000000 --- a/weave/trace_server/interface/base_models/action_base_models.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -LLM_JUDGE_ACTION_NAME = "llm_judge" - - -class _BuiltinAction(BaseModel): - action_type: Literal["builtin"] = "builtin" - name: str - - -class ConfiguredAction(BaseModel): - name: str - action: _BuiltinAction - config: dict - - -class ActionDispatchFilter(BaseModel): - op_name: str - sample_rate: float - configured_action_ref: str diff --git a/weave/trace_server/interface/base_models/base_model_registry.py b/weave/trace_server/interface/base_models/base_model_registry.py deleted file mode 100644 index ea582f119f6..00000000000 --- a/weave/trace_server/interface/base_models/base_model_registry.py +++ /dev/null @@ -1,20 +0,0 @@ -from pydantic import BaseModel - -from weave.trace_server.interface.base_models.action_base_models import ( - ActionDispatchFilter, - ConfiguredAction, -) - - -def base_model_name(base_model_class: type[BaseModel]) -> str: - return base_model_class.__name__ - - -def base_model_dump(base_model_obj: BaseModel) -> dict: - d = base_model_obj.model_dump() - d["_class_name"] = base_model_name(base_model_obj.__class__) - d["_bases"] = [base_model_name(b) for b in base_model_obj.__class__.mro()[1:-1]] - return d - - -base_models: list[type[BaseModel]] = [ConfiguredAction, ActionDispatchFilter] diff --git a/weave/trace_server/interface/base_models/feedback_base_model_registry.py b/weave/trace_server/interface/base_models/feedback_base_model_registry.py deleted file mode 100644 index 16f2033fec3..00000000000 --- a/weave/trace_server/interface/base_models/feedback_base_model_registry.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - -from pydantic import BaseModel - - -class ActionScore(BaseModel): - configured_action_ref: str - output: Any - - -feedback_base_models: list[type[BaseModel]] = [ActionScore] diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 6f5b5648f63..93a4f510090 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1081,13 +1081,6 @@ def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: print("COST PURGE is not implemented for local sqlite", req) return tsi.CostPurgeRes() - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - raise NotImplementedError( - "EXECUTE BATCH ACTION is not implemented for local sqlite" - ) - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 4f064759505..23be8b2c3b3 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -189,7 +189,6 @@ class ObjSchema(BaseModel): class ObjSchemaForInsert(BaseModel): project_id: str object_id: str - base_object_class: Optional[str] = None val: Any @@ -888,10 +887,6 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... - # Action API - def execute_batch_action( - self, req: ExecuteBatchActionReq - ) -> ExecuteBatchActionRes: ... # Execute LLM API def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index af13a7c856f..34b906a560c 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -265,7 +265,7 @@ def call_start( req_as_obj = tsi.CallStartReq.model_validate(req) else: req_as_obj = req - if req_as_obj.start.id is None or req_as_obj.start.trace_id is None: + if req_as_obj.start.id == None or req_as_obj.start.trace_id == None: raise ValueError( "CallStartReq must have id and trace_id when batching." ) @@ -549,16 +549,6 @@ def cost_purge( "/cost/purge", req, tsi.CostPurgeReq, tsi.CostPurgeRes ) - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - return self._generic_request( - "/execute/batch_action", - req, - tsi.ExecuteBatchActionReq, - tsi.ExecuteBatchActionRes, - ) - def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: