diff --git a/tests/trace/demo.ipynb b/tests/trace/demo.ipynb index 5eef9295938f..b1177dd0b8ba 100644 --- a/tests/trace/demo.ipynb +++ b/tests/trace/demo.ipynb @@ -24,20 +24,6 @@ "client = weave.init(\"action_test_6\")" ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# from weave.trace import autopatch, weave_init\n", - "# from weave.trace_server import clickhouse_trace_server_batched\n", - "\n", - "# ch_server = clickhouse_trace_server_batched.ClickHouseTraceServer.from_env()\n", - "# inited_client = weave_init.InitializedClient(client)\n", - "# autopatch.autopatch()\n" - ] - }, { "cell_type": "code", "execution_count": 3, @@ -89,89 +75,6 @@ "\n", "print(calls)" ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from weave.collection_objects import action_objects\n", - "from weave.trace.weave_client import get_ref\n", - "from weave.trace_server.interface.collections import action_collection\n", - "\n", - "action = action_objects.ActionWithConfig(\n", - " name=\"is_name_extracted\",\n", - " action=action_collection._BuiltinAction(\n", - " name=\"openai_completion\",\n", - " ),\n", - " config={\n", - " \"model\": \"gpt-4o-mini\",\n", - " \"system_prompt\": \"Given the following prompt and response, determine if the name was extracted correctly.\",\n", - " \"response_format\": {\n", - " \"type\": \"json_schema\",\n", - " \"json_schema\": {\n", - " \"name\": \"is_name_extracted\",\n", - " \"schema\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\"is_extracted\": {\"type\": \"boolean\"}},\n", - " \"required\": [\"is_extracted\"],\n", - " \"additionalProperties\": False,\n", - " },\n", - " \"strict\": True,\n", - " },\n", - " },\n", - " },\n", - ")\n", - "mapping = action_objects.ActionOpMapping(\n", - " name=\"extract_name-is_name_extracted5\",\n", - " action=action,\n", - " op_name=get_ref(extract_name).name,\n", - " op_digest=get_ref(extract_name).digest,\n", - " input_mapping={\n", - " \"prompt\": \"inputs.user_input\",\n", - " \"response\": \"output\",\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from weave.trace_server import trace_server_interface as tsi\n", - "\n", - "res = client.server.execute_batch_action(\n", - " req=tsi.ExecuteBatchActionReq(\n", - " project_id=client._project_id(),\n", - " call_ids=[c.id for c in calls[:1]],\n", - " mapping=mapping,\n", - " )\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx index 9f5e723513b0..b4ff2a0258a0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallActionsViewer.tsx @@ -43,7 +43,7 @@ const RunButton: React.FC<{ setIsRunning(true); setError(null); try { - await getClient().executeBatchAction({ + await getClient().actionsExecuteBatch({ project_id: projectIdFromParts({entity, project}), call_ids: [callId], configured_action_ref: actionRef, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index f6164d6cbd76..5db08c27bf35 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -301,10 +301,10 @@ export const fileExtensions = { [ContentType.json]: 'json', }; -export type ExecuteBatchActionReq = { +export type ActionsExecuteBatchReq = { project_id: string; call_ids: string[]; configured_action_ref: string; }; -export type ExecuteBatchActionRes = {}; +export type ActionsExecuteBatchRes = {}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts index d3eeb8738a26..ae9c6c440e25 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts @@ -16,9 +16,9 @@ import {getCookie} from '@wandb/weave/common/util/cookie'; import fetch from 'isomorphic-unfetch'; import { + ActionsExecuteBatchReq, + ActionsExecuteBatchRes, ContentType, - ExecuteBatchActionReq, - ExecuteBatchActionRes, FeedbackCreateReq, FeedbackCreateRes, FeedbackPurgeReq, @@ -279,10 +279,10 @@ export class DirectTraceServerClient { ); } - public executeBatchAction( - req: ExecuteBatchActionReq - ): Promise { - return this.makeRequest( + public actionsExecuteBatch( + req: ActionsExecuteBatchReq + ): Promise { + return this.makeRequest( '/actions/execute_batch', req ); diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 695d5ce45574..9550ae15b6ac 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1093,13 +1093,6 @@ def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: print("COST PURGE is not implemented for local sqlite", req) return tsi.CostPurgeRes() - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - raise NotImplementedError( - "EXECUTE BATCH ACTION is not implemented for local sqlite" - ) - def _table_row_read(self, project_id: str, row_digest: str) -> tsi.TableRowSchema: conn, cursor = get_conn_cursor(self.db_path) # Now get the rows diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 737a95dad051..84da03a23be9 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -824,16 +824,6 @@ class ActionsAckBatchRes(BaseModel): id: str -class ExecuteBatchActionReq(BaseModel): - project_id: str - call_ids: list[str] - configured_action_ref: str - - -class ExecuteBatchActionRes(BaseModel): - pass - - class TraceServerInterface(Protocol): def ensure_project_exists( self, entity: str, project: str diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 361225c8a2e6..0f32232cf4a0 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -567,16 +567,6 @@ def cost_purge( "/cost/purge", req, tsi.CostPurgeReq, tsi.CostPurgeRes ) - def execute_batch_action( - self, req: tsi.ExecuteBatchActionReq - ) -> tsi.ExecuteBatchActionRes: - return self._generic_request( - "/actions/execute_batch", - req, - tsi.ExecuteBatchActionReq, - tsi.ExecuteBatchActionRes, - ) - __docspec__ = [ RemoteHTTPTraceServer,