diff --git a/tests/trace/test_actions_e2e.py b/tests/trace/test_actions_e2e.py index 4f9e83d4822..548ec22ea94 100644 --- a/tests/trace/test_actions_e2e.py +++ b/tests/trace/test_actions_e2e.py @@ -97,6 +97,8 @@ def example_op(input: str) -> str: assert feedbacks[0].payload == { "runnable_ref": action_ref_uri, "value": {action_name: {digest: True}}, + "call_ref": None, + "trigger_ref": None, } # Step 3: test that we can in-place execute one action at a time. @@ -118,4 +120,6 @@ def example_op(input: str) -> str: assert feedbacks[0].payload == { "runnable_ref": action_ref_uri, "value": {action_name: {digest: False}}, + "call_ref": None, + "trigger_ref": None, } diff --git a/weave/actions_worker/tasks.py b/weave/actions_worker/tasks.py index 84bb9363fc4..47650a87857 100644 --- a/weave/actions_worker/tasks.py +++ b/weave/actions_worker/tasks.py @@ -1,7 +1,7 @@ import json import logging from functools import partial, wraps -from typing import Any, Callable, TypeVar +from typing import Any, Callable, Optional, TypeVar from weave.actions_worker.celery_app import app from weave.trace_server.action_executor import TaskCtx @@ -49,7 +49,10 @@ def ack_on_clickhouse(ctx: TaskCtx, succeeded: bool) -> None: def publish_results_as_feedback( - ctx: TaskCtx, result: Any, configured_action_ref: str + ctx: TaskCtx, + result: Any, + configured_action_ref: str, + trigger_ref: Optional[str] = None, ) -> None: project_id = ctx["project_id"] call_id = ctx["call_id"] @@ -68,6 +71,7 @@ def publish_results_as_feedback( payload=MachineScore( runnable_ref=configured_action_ref, value={action_name: {digest: result}}, + trigger_ref=trigger_ref, ).model_dump(), wb_user_id=WEAVE_ACTION_EXECUTOR_PACEHOLDER_ID, ) @@ -102,7 +106,7 @@ def resolve_call(ctx: TaskCtx) -> CallSchema: def action_task( func: Callable[[str, str, ActionConfigT], ActionResultT], -) -> Callable[[TaskCtx, str, str, str, ActionConfigT], ActionResultT]: +) -> Callable[[TaskCtx, str, str, str, ActionConfigT, Optional[str]], ActionResultT]: @wraps(func) def wrapper( ctx: TaskCtx, @@ -110,11 +114,12 @@ def wrapper( call_output: str, configured_action_ref: str, configured_action: ActionConfigT, + trigger_ref: Optional[str] = None, ) -> ActionResultT: success = True try: result = func(call_input, call_output, configured_action) - publish_results_as_feedback(ctx, result, configured_action_ref) + publish_results_as_feedback(ctx, result, configured_action_ref, trigger_ref) logging.info(f"Successfully ran {func.__name__}") logging.info(f"Result: {result}") except Exception as e: @@ -128,7 +133,9 @@ def wrapper( @app.task() -def do_task(ctx: TaskCtx, configured_action_ref: str) -> None: +def do_task( + ctx: TaskCtx, configured_action_ref: str, trigger_ref: Optional[str] = None +) -> None: action = resolve_action_ref(configured_action_ref) call = resolve_call(ctx) call_input = json.dumps(call.inputs) @@ -137,14 +144,40 @@ def do_task(ctx: TaskCtx, configured_action_ref: str) -> None: call_output = json.dumps(call_output) if action.config.action_type == "wordcount": - wordcount(ctx, call_input, call_output, configured_action_ref, action.config) + wordcount( + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, + ) elif action.config.action_type == "llm_judge": - llm_judge(ctx, call_input, call_output, configured_action_ref, action.config) + llm_judge( + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, + ) elif action.config.action_type == "noop": - noop(ctx, call_input, call_output, configured_action_ref, action.config) + noop( + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, + ) elif action.config.action_type == "contains_words": contains_words( - ctx, call_input, call_output, configured_action_ref, action.config + ctx, + call_input, + call_output, + configured_action_ref, + action.config, + trigger_ref, ) else: raise ValueError(f"Unknown action type: {action.config.action_type}") diff --git a/weave/trace_server/action_executor.py b/weave/trace_server/action_executor.py index a961f70f137..608ecc00d54 100644 --- a/weave/trace_server/action_executor.py +++ b/weave/trace_server/action_executor.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional from redis import Redis from typing_extensions import TypedDict @@ -20,11 +21,21 @@ def queue_from_addr(addr: str) -> "ActionExecutor": class ActionExecutor(ABC): @abstractmethod - def enqueue(self, ctx: TaskCtx, configured_action_ref: str) -> None: + def enqueue( + self, + ctx: TaskCtx, + configured_action_ref: str, + trigger_ref: Optional[str] = None, + ) -> None: pass @abstractmethod - def do_now(self, ctx: TaskCtx, configured_action_ref: str) -> None: + def do_now( + self, + ctx: TaskCtx, + configured_action_ref: str, + trigger_ref: Optional[str] = None, + ) -> None: pass @abstractmethod @@ -33,16 +44,26 @@ def _TESTONLY_clear_queue(self) -> None: class CeleryActionQueue(ActionExecutor): - def enqueue(self, ctx: TaskCtx, configured_action_ref: str) -> None: + def enqueue( + self, + ctx: TaskCtx, + configured_action_ref: str, + trigger_ref: Optional[str] = None, + ) -> None: # TODO We put this in here to break a circular import. Fix this later. import weave.actions_worker.tasks as tasks - tasks.do_task.delay(ctx, configured_action_ref) + tasks.do_task.delay(ctx, configured_action_ref, trigger_ref) - def do_now(self, ctx: TaskCtx, configured_action_ref: str) -> None: + def do_now( + self, + ctx: TaskCtx, + configured_action_ref: str, + trigger_ref: Optional[str] = None, + ) -> None: import weave.actions_worker.tasks as tasks - tasks.do_task(ctx, configured_action_ref) + tasks.do_task(ctx, configured_action_ref, trigger_ref) def _TESTONLY_clear_queue(self) -> None: redis = Redis.from_url(wf_env.wf_action_executor()) @@ -50,10 +71,20 @@ def _TESTONLY_clear_queue(self) -> None: class NoOpActionQueue(ActionExecutor): - def enqueue(self, ctx: TaskCtx, configured_action_ref: str) -> None: + def enqueue( + self, + ctx: TaskCtx, + configured_action_ref: str, + trigger_ref: Optional[str] = None, + ) -> None: pass - def do_now(self, ctx: TaskCtx, configured_action_ref: str) -> None: + def do_now( + self, + ctx: TaskCtx, + configured_action_ref: str, + trigger_ref: Optional[str] = None, + ) -> None: pass def _TESTONLY_clear_queue(self) -> None: diff --git a/weave/trace_server/actions.py b/weave/trace_server/actions.py index 921497f675a..0d19f91a605 100644 --- a/weave/trace_server/actions.py +++ b/weave/trace_server/actions.py @@ -9,7 +9,7 @@ Column("project_id", "string"), Column("call_id", "string"), Column("id", "string"), - Column("rule_matched", "string", nullable=True), + Column("trigger_ref", "string", nullable=True), Column("configured_action", "string", nullable=True), # Updated column name Column("created_at", "datetime", nullable=True), Column("finished_at", "datetime", nullable=True), @@ -23,7 +23,7 @@ Column("project_id", "string"), Column("call_id", "string"), Column("id", "string"), - Column("rule_matched", "string", nullable=True), + Column("trigger_ref", "string", nullable=True), Column("configured_action", "string", nullable=True), # Updated column name Column("created_at", "datetime", nullable=True), Column("finished_at", "datetime", nullable=True), @@ -56,7 +56,7 @@ def get_stale_actions(older_than: datetime.datetime) -> PreparedSelect: "project_id", "call_id", "id", - "any(rule_matched) as rule_matched", + "any(trigger_ref) as trigger_ref", "any(configured_action) as configured_action", "max(created_at) as created_at", "max(finished_at) as finished_at", diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index c6601cbd094..c4d0ea76091 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1554,6 +1554,7 @@ def _actions_requeue_stale(self) -> None: "id": row["id"], # type: ignore }, row["configured_action"], # type: ignore + row.get("trigger_ref"), # type: ignore ) except Exception as e: logger.error(f"Failed to requeue action: {row}. Error: {str(e)}") @@ -1864,7 +1865,7 @@ def _insert_call(self, ch_call: CallCHInsertable) -> None: def _get_matched_calls_for_filters( self, project_id: str, call_ids: list[str] - ) -> list[tuple[ActionDispatchFilter, list[tsi.CallSchema]]]: + ) -> list[tuple[ActionDispatchFilter, str, list[tsi.CallSchema]]]: """Helper function to get calls that match action filters. Returns a list of tuples containing (filter, matched_calls) pairs. @@ -1879,8 +1880,16 @@ def _get_matched_calls_for_filters( ), ) filter_res = self.objs_query(filter_req) - filters: list[ActionDispatchFilter] = [ - ActionDispatchFilter.model_validate(obj.val) for obj in filter_res.objs + filters: list[tuple[ActionDispatchFilter, str]] = [ + ( + ActionDispatchFilter.model_validate(obj.val), + ri.InternalObjectRef( + project_id=project_id, + name=obj.object_id, + version=obj.digest, + ).uri(), + ) + for obj in filter_res.objs ] if not filters: @@ -1929,7 +1938,7 @@ def _get_matched_calls_for_filters( # Match calls to filters matched_filters_and_calls = [] - for filter in filters: + for filter, filter_ref in filters: if filter.disabled: continue calls_with_refs = [ @@ -1951,7 +1960,7 @@ def _get_matched_calls_for_filters( ] if matched_calls: - matched_filters_and_calls.append((filter, matched_calls)) + matched_filters_and_calls.append((filter, filter_ref, matched_calls)) return matched_filters_and_calls @@ -1972,12 +1981,13 @@ def _flush_calls(self) -> None: self._call_batch.project_id, self._call_batch.call_ids ) - for filter, matched_calls in matched_filters_and_calls: + for filter, filter_ref, matched_calls in matched_filters_and_calls: self.actions_execute_batch( tsi.ActionsExecuteBatchReq( project_id=self._call_batch.project_id, call_ids=[call.id for call in matched_calls], configured_action_ref=filter.configured_action_ref, + trigger_ref=filter_ref, ) ) diff --git a/weave/trace_server/migrations/007_actions.up.sql b/weave/trace_server/migrations/007_actions.up.sql index fb9a0e6f5b5..ee18f6b4725 100644 --- a/weave/trace_server/migrations/007_actions.up.sql +++ b/weave/trace_server/migrations/007_actions.up.sql @@ -3,6 +3,7 @@ CREATE TABLE actions_parts ( call_id String, id String, configured_action Nullable(String), + trigger_ref Nullable(String), created_at Nullable(DateTime64(3)), finished_at Nullable(DateTime64(3)), failed_at Nullable(DateTime64(3)) @@ -15,6 +16,7 @@ CREATE TABLE actions_merged ( call_id String, id String, configured_action SimpleAggregateFunction(any, Nullable(String)), + trigger_ref SimpleAggregateFunction(any, Nullable(String)), created_at SimpleAggregateFunction(max, Nullable(DateTime64(3))), finished_at SimpleAggregateFunction(max, Nullable(DateTime64(3))), failed_at SimpleAggregateFunction(max, Nullable(DateTime64(3))), @@ -33,6 +35,7 @@ SELECT call_id, id, anySimpleState(configured_action) AS configured_action, + anySimpleState(trigger_ref) AS trigger_ref, maxSimpleState(created_at) AS created_at, maxSimpleState(finished_at) AS finished_at, maxSimpleState(failed_at) AS failed_at diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 50337220a8c..ac5a01fc061 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -801,6 +801,7 @@ class ActionsExecuteBatchReq(BaseModel): project_id: str call_ids: list[str] configured_action_ref: str + trigger_ref: Optional[str] = None # `id` is here so that clients can potentially guarantee idempotence. # Repeated calls with the same id will not result in duplicate actions. id: Optional[str] = None