Skip to content

Commit

Permalink
finished scorer basics
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Dec 12, 2024
1 parent d64802a commit 9612a9d
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 11 deletions.
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
req.wb_user_id = self._user_id
return super().call_method(req)

def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
req.wb_user_id = self._user_id
return super().score_call(req)


# https://docs.pytest.org/en/7.1.x/example/simple.html#pytest-current-test-environment-variable
def get_test_name():
Expand Down
23 changes: 17 additions & 6 deletions tests/trace/builtin_objects/test_builtin_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# 5. Remote Create, Remote Direct Score
import weave
from weave.builtin_objects.scorers.LLMJudgeScorer import LLMJudgeScorer
from weave.trace.weave_client import Call, WeaveClient
from weave.trace.weave_client import ApplyScorerResult, Call, WeaveClient
from weave.trace_server import trace_server_interface as tsi

scorer_args = {
Expand Down Expand Up @@ -63,12 +63,23 @@ def simple_op(question: str) -> str:
return res, call


def assert_expected_outcome(target_call: Call, scorer_res_as_dict: dict):
assert scorer_res_as_dict["score_call"].output == expected_score
def assert_expected_outcome(
target_call: Call, scorer_res: ApplyScorerResult | tsi.ScoreCallRes
):
scorer_output = None
feedback_id = None
if isinstance(scorer_res, tsi.ScoreCallRes):
scorer_output = scorer_res.score_call.output
feedback_id = scorer_res.feedback_id
else:
scorer_output = scorer_res["score_call"].output
feedback_id = scorer_res["feedback_id"]

assert scorer_output == expected_score
feedbacks = list(target_call.feedback)
assert len(feedbacks) == 1
assert feedbacks[0].payload["output"] == expected_score
assert feedbacks[0].id == scorer_res_as_dict["feedback_id"]
assert feedbacks[0].id == feedback_id


def do_remote_score(
Expand Down Expand Up @@ -120,7 +131,7 @@ def test_scorer_local_create_remote_use(client: WeaveClient):
res, call = make_simple_call()
publish_ref = weave.publish(scorer)
remote_score_res = do_remote_score(client, call, publish_ref)
assert_expected_outcome(call, remote_score_res.model_dump())
assert_expected_outcome(call, remote_score_res)


def test_scorer_remote_create_local_use(client: WeaveClient):
Expand All @@ -136,4 +147,4 @@ def test_scorer_remote_create_remote_use(client: WeaveClient):
obj_ref = make_remote_scorer(client)
res, call = make_simple_call()
remote_score_res = do_remote_score(client, call, obj_ref)
assert_expected_outcome(call, remote_score_res.model_dump())
assert_expected_outcome(call, remote_score_res)
13 changes: 13 additions & 0 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,19 @@ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
)
return tsi.CallMethodRes.model_validate(res)

def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
from weave.trace_server.server_side_object_saver import RunAsUser

runner = RunAsUser(ch_server_dump=self.model_dump())
res = runner.run_score_call(req)

return tsi.ScoreCallRes(
feedback_id=res["feedback_id"],
score_call=self.call_read(
tsi.CallReadReq(project_id=req.project_id, id=res["scorer_call_id"])
).call,
)

# Private Methods
@property
def ch_client(self) -> CHClient:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,11 @@ def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
raise ValueError("wb_user_id cannot be None")
req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id)
return self._ref_apply(self._internal_trace_server.call_method, req)

def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
original_user_id = req.wb_user_id
if original_user_id is None:
raise ValueError("wb_user_id cannot be None")
req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id)
return self._ref_apply(self._internal_trace_server.score_call, req)
88 changes: 87 additions & 1 deletion weave/trace_server/server_side_object_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from weave.trace.weave_init import InitializedClient
from weave.trace_server import external_to_internal_trace_server_adapter
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.refs_internal import InternalObjectRef, parse_internal_uri
from weave.trace_server.refs_internal import (
InternalCallRef,
InternalObjectRef,
parse_internal_uri,
)


class RunAsUser:
Expand Down Expand Up @@ -233,6 +237,84 @@ def _call_method(
except Exception as e:
result_queue.put(("error", str(e))) # Put any errors in the queue

def run_score_call(self, req: tsi.ScoreCallReq) -> str:
result_queue = multiprocessing.Queue()

process = multiprocessing.Process(
target=self._score_call,
args=(req, result_queue),
)

process.start()
status, result = result_queue.get()
process.join()

if status == "error":
raise Exception(f"Process execution failed: {result}")

return result

def _score_call(
self,
req: tsi.ScoreCallReq,
result_queue: multiprocessing.Queue,
) -> None:
try:
from weave.trace.weave_client import Call
from weave.trace_server.clickhouse_trace_server_batched import (
ClickHouseTraceServer,
)

client = WeaveClient(
"_SERVER_",
req.project_id,
UserInjectingExternalTraceServer(
ClickHouseTraceServer(**self.ch_server_dump),
id_converter=IdConverter(),
user_id=req.wb_user_id,
),
False,
)

ic = InitializedClient(client)
autopatch.autopatch()

target_call_ref = parse_internal_uri(req.call_ref)
if not isinstance(target_call_ref, InternalCallRef):
raise ValueError("Invalid call reference")
target_call = client.get_call(target_call_ref.id)._val
if not isinstance(target_call, Call):
raise ValueError("Invalid call reference")
scorer_ref = parse_internal_uri(req.scorer_ref)
if not isinstance(scorer_ref, InternalObjectRef):
raise ValueError("Invalid scorer reference")
scorer = weave.ref(
ObjectRef(
entity="_SERVER_",
project=scorer_ref.project_id,
name=scorer_ref.name,
_digest=scorer_ref.version,
).uri()
).get()
if not isinstance(scorer, weave.Scorer):
raise ValueError("Invalid scorer reference")
apply_scorer_res = target_call._apply_scorer(scorer)

autopatch.reset_autopatch()
client._flush()
ic.reset()
result_queue.put(
(
"success",
{
"feedback_id": apply_scorer_res["feedback_id"],
"scorer_call_id": apply_scorer_res["score_call"].id,
},
)
) # Put the result in the queue
except Exception as e:
result_queue.put(("error", str(e))) # Put any errors in the queue


class IdConverter(external_to_internal_trace_server_adapter.IdConverter):
def ext_to_int_project_id(self, project_id: str) -> str:
Expand Down Expand Up @@ -296,3 +378,7 @@ def actions_execute_batch(
def call_method(self, req: tsi.CallMethodReq) -> tsi.CallMethodRes:
req.wb_user_id = self._user_id
return super().call_method(req)

def score_call(self, req: tsi.ScoreCallReq) -> tsi.ScoreCallRes:
req.wb_user_id = self._user_id
return super().score_call(req)
8 changes: 4 additions & 4 deletions weave/trace_server/todo.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
* Server seems to get stuck when it crashes
* Accidentally checked in parallelization disabling
* Create a Dummy Scorer via API -> See in object explorer
* Should be pretty straight forward at this point
* [x] Should be pretty straight forward at this point
* Invoke the Scorer against the previous call via API -> see in traces AND in feedback
* Should be mostly straight forward (the Scorer API itself is a bit wonky)
* [x] Should be mostly straight forward (the Scorer API itself is a bit wonky)
* Important Proof of system:
* [x] create the same dummy model locally & invoke -> notice no version change
* Run locally against the call -> notice that there are no extra objects
* [x] Run locally against the call -> notice that there are no extra objects
* [ ]Should de-ref inputs if they contain refs
* [ ] Refactor the entire "base model" system to conform to this new way of doing things (leaf models)
* [ ] Might get hairy with nested refs - consider implications
Expand All @@ -26,4 +26,4 @@
* [ ] scorers should have a client-spec, not a specific client
* [ ] How to model a scorers's stub (input, output, context, reference(s), etc...)
* [ ] How to handle output types from scorers (boolean, number, reason, etc...)

* [ ]Investigate why the tests are running so slowly
13 changes: 13 additions & 0 deletions weave/trace_server/trace_server_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,18 @@ class CallMethodRes(BaseModel):
output: Any


class ScoreCallReq(BaseModel):
project_id: str
call_ref: str
scorer_ref: str
wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)


class ScoreCallRes(BaseModel):
feedback_id: str
score_call: CallSchema


class TraceServerInterface(Protocol):
def ensure_project_exists(
self, entity: str, project: str
Expand Down Expand Up @@ -925,6 +937,7 @@ def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ...

# Execute API
def call_method(self, req: CallMethodReq) -> CallMethodRes: ...
def score_call(self, req: ScoreCallReq) -> ScoreCallRes: ...

# Action API
def actions_execute_batch(
Expand Down

0 comments on commit 9612a9d

Please sign in to comment.