Skip to content

Commit

Permalink
Merge branch 'master' into griffin/objs-delete
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Oct 30, 2024
2 parents d12d667 + b1e14b6 commit 667ccf5
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 402 deletions.
158 changes: 0 additions & 158 deletions tests/trace/test_actions.py

This file was deleted.

166 changes: 15 additions & 151 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import threading
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -87,19 +86,6 @@
validate_feedback_purge_req,
)
from weave.trace_server.ids import generate_id
from weave.trace_server.interface.base_models.action_base_models import (
LLM_JUDGE_ACTION_NAME,
ConfiguredAction,
)
from weave.trace_server.interface.base_models.base_model_registry import (
base_model_dump,
base_model_name,
base_models,
)
from weave.trace_server.interface.base_models.feedback_base_model_registry import (
ActionScore,
feedback_base_models,
)
from weave.trace_server.llm_completion import lite_llm_completion
from weave.trace_server.model_providers.model_providers import (
MODEL_PROVIDERS_FILE,
Expand Down Expand Up @@ -146,8 +132,6 @@
MAX_DELETE_CALLS_COUNT = 100
MAX_CALLS_STREAM_BATCH_SIZE = 500

WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID = "WEAVE_ACTION_EXECUTOR"


CallCHInsertable = Union[
CallStartCHInsertable,
Expand Down Expand Up @@ -595,28 +579,16 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes:
return tsi.OpQueryRes(op_objs=objs)

def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
req_obj = req.obj
dict_val = req_obj.val

if req.obj.base_object_class:
for base_model in base_models:
if base_model_name(base_model) == req.obj.base_object_class:
# 1. Validate the object against the base model & re-dump to a dict
dict_val = base_model_dump(base_model.model_validate(dict_val))
break
else:
raise ValueError(
f"Unknown base object class: {req.obj.base_object_class}"
)

json_val = json.dumps(dict_val)
json_val = json.dumps(req.obj.val)
digest = str_digest(json_val)

req_obj = req.obj
ch_obj = ObjCHInsertable(
project_id=req_obj.project_id,
object_id=req_obj.object_id,
kind=get_kind(dict_val),
base_object_class=get_base_object_class(dict_val),
refs=extract_refs_from_values(dict_val),
kind=get_kind(req.obj.val),
base_object_class=get_base_object_class(req.obj.val),
refs=extract_refs_from_values(req.obj.val),
val_dump=json_val,
digest=digest,
)
Expand Down Expand Up @@ -1441,17 +1413,8 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
assert_non_null_wb_user_id(req)
validate_feedback_create_req(req)

feedback_type = req.feedback_type
res_payload = req.payload

for feedback_base_model in feedback_base_models:
if base_model_name(feedback_base_model) == feedback_type:
res_payload = base_model_dump(
feedback_base_model.model_validate(res_payload)
)
break

# Augment emoji with alias.
res_payload = {}
if req.feedback_type == "wandb.reaction.1":
em = req.payload["emoji"]
if emoji.emoji_count(em) != 1:
Expand Down Expand Up @@ -1517,97 +1480,6 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes:
self.ch_client.query(prepared.sql, prepared.parameters)
return tsi.FeedbackPurgeRes()

def execute_batch_action(
self, req: tsi.ExecuteBatchActionReq
) -> tsi.ExecuteBatchActionRes:
# WARNING: THIS IS NOT GOING TO WORK IN PRODUCTION
# UNTIL WE HAVE THE API KEY PIECE IN PLACE
configured_action_ref = req.configured_action_ref

action_dict_res = self.refs_read_batch(
tsi.RefsReadBatchReq(refs=[configured_action_ref])
)

action_dict = action_dict_res.vals[0]
action = ConfiguredAction.model_validate(action_dict)

if action.action.action_type != "builtin":
raise InvalidRequest(
"Only builtin actions are supported for batch execution"
)

if action.action.name != LLM_JUDGE_ACTION_NAME:
raise InvalidRequest("Only llm_judge is supported for batch execution")

# Step 1: Get all the calls in the batch
calls = self.calls_query_stream(
tsi.CallsQueryReq(
project_id=req.project_id,
filter=tsi.CallsFilter(
call_ids=req.call_ids,
),
)
)

# Normally we would dispatch here, but just hard coding for now
# We should do some validation here
config = action.config
model = config["model"]

if model not in ["gpt-4o-mini", "gpt-4o"]:
raise InvalidRequest("Only gpt-4o-mini and gpt-4o are supported")

system_prompt = config["system_prompt"]
response_format_schema = config["response_format_schema"]
response_format = {
"type": "json_schema",
"json_schema": {
"name": "response_format",
"schema": response_format_schema,
},
}

# mapping = mapping.input_mapping

# Step 2: For Each call, execute the action: (this needs a lot of safety checks)
for call in calls:
args = {
"inputs": call.inputs,
"output": call.output,
}
from openai import OpenAI

client = OpenAI()
# Silly hack to get around issue in tests:
create = client.chat.completions.create
if hasattr(create, "resolve_fn"):
create = partial(create.resolve_fn, self=client.chat.completions)
completion = create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": json.dumps(args)},
],
response_format=response_format,
)
self.feedback_create(
tsi.FeedbackCreateReq(
project_id=req.project_id,
weave_ref=ri.InternalCallRef(
project_id=req.project_id,
id=call.id,
).uri(),
feedback_type=base_model_name(ActionScore),
wb_user_id=WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID, # - THIS IS NOT GOOD!
payload=ActionScore(
configured_action_ref=configured_action_ref,
output=json.loads(completion.choices[0].message.content),
).model_dump(),
)
)

return tsi.ExecuteBatchActionRes()

def completions_create(
self, req: tsi.CompletionsCreateReq
) -> tsi.CompletionsCreateRes:
Expand Down Expand Up @@ -2244,7 +2116,7 @@ def _process_parameters(


def get_type(val: Any) -> str:
if val is None:
if val == None:
return "none"
elif isinstance(val, dict):
if "_type" in val:
Expand All @@ -2265,24 +2137,16 @@ def get_kind(val: Any) -> str:


def get_base_object_class(val: Any) -> Optional[str]:
"""
Get the base object class of a value using:
1. The last base class that is a subclass of BaseModel and not Object
2. The _class_name attribute if it exists
3. None if no base class is found
"""
if isinstance(val, dict):
if "_bases" in val:
if isinstance(val["_bases"], list):
bases = val["_bases"]
if len(bases) > 0 and bases[-1] == "BaseModel":
bases = bases[:-1]
if len(bases) > 0 and bases[-1] == "Object":
bases = bases[:-1]
if len(bases) > 0:
return bases[-1]
elif "_class_name" in val:
return val["_class_name"]
if len(val["_bases"]) >= 2:
if val["_bases"][-1] == "BaseModel":
if val["_bases"][-2] == "Object":
if len(val["_bases"]) > 2:
return val["_bases"][-3]
elif "_class_name" in val:
return val["_class_name"]
return None


Expand Down
Loading

0 comments on commit 667ccf5

Please sign in to comment.