diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py index 264cb6e8c58..f12b419624b 100644 --- a/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -7,6 +7,11 @@ from weave.trace.weave_client import WeaveClient, get_ref from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import InvalidRequest +from weave.trace_server.trace_server_interface import ( + FeedbackCreateReq, + FeedbackQueryReq, + FeedbackReplaceReq, +) def test_client_feedback(client) -> None: @@ -173,9 +178,11 @@ def test_annotation_feedback(client: WeaveClient) -> None: "wb_user_id": "shawn", "creator": None, # Sad - seems like sqlite and clickhouse remote different types here - "created_at": create_res.created_at.isoformat().replace("T", " ") - if client_is_sqlite(client) - else MatchAnyDatetime(), + "created_at": ( + create_res.created_at.isoformat().replace("T", " ") + if client_is_sqlite(client) + else MatchAnyDatetime() + ), "feedback_type": feedback_type, "payload": payload, "annotation_ref": annotation_ref, @@ -321,9 +328,11 @@ def test_runnable_feedback(client: WeaveClient) -> None: "wb_user_id": "shawn", "creator": None, # Sad - seems like sqlite and clickhouse remote different types here - "created_at": create_res.created_at.isoformat().replace("T", " ") - if client_is_sqlite(client) - else MatchAnyDatetime(), + "created_at": ( + create_res.created_at.isoformat().replace("T", " ") + if client_is_sqlite(client) + else MatchAnyDatetime() + ), "feedback_type": feedback_type, "payload": payload, "annotation_ref": None, @@ -535,3 +544,80 @@ def test_filter_and_sort_by_feedback(client: WeaveClient) -> None: calls = list(calls) assert len(calls) == 2 assert [c.id for c in calls] == [ids[2], ids[0]] + + +def test_feedback_replace(client) -> None: + # Create initial feedback + create_req = FeedbackCreateReq( + project_id="test/project", + weave_ref="weave:///test/project/obj/123:abc", + feedback_type="reaction", + payload={"emoji": "👍"}, + wb_user_id="test_user", + ) + initial_feedback = client.server.feedback_create(create_req) + + # Create another feedback with different type + note_feedback = client.server.feedback_create( + FeedbackCreateReq( + project_id="test/project", + weave_ref="weave:///test/project/obj/456:def", + feedback_type="note", + payload={"note": "This is a test note"}, + wb_user_id="test_user", + ) + ) + + # Replace the first feedback with new content + replace_req = FeedbackReplaceReq( + project_id="test/project", + weave_ref="weave:///test/project/obj/123:abc", + feedback_type="note", + payload={"note": "Updated feedback"}, + feedback_id=initial_feedback.id, + wb_user_id="test_user", + ) + replaced_feedback = client.server.feedback_replace(replace_req) + + # Verify the replacement + assert note_feedback.id != replaced_feedback.id + + # Verify the other feedback remains unchanged + query_res = client.server.feedback_query( + FeedbackQueryReq( + project_id="test/project", fields=["id", "feedback_type", "payload"] + ) + ) + + feedbacks = query_res.result + assert len(feedbacks) == 2 + + # Find the non-replaced feedback and verify it's unchanged + other_feedback = next(f for f in feedbacks if f["id"] == note_feedback.id) + assert other_feedback["feedback_type"] == "note" + assert other_feedback["payload"] == {"note": "This is a test note"} + + # now replace the replaced feedback with the original content + replace_req = FeedbackReplaceReq( + project_id="test/project", + weave_ref="weave:///test/project/obj/123:abc", + feedback_type="reaction", + payload={"emoji": "👍"}, + feedback_id=replaced_feedback.id, + wb_user_id="test_user", + ) + replaced_feedback = client.server.feedback_replace(replace_req) + + assert replaced_feedback.id != initial_feedback.id + + # Verify the latest feedback payload + query_res = client.server.feedback_query( + FeedbackQueryReq( + project_id="test/project", fields=["id", "feedback_type", "payload"] + ) + ) + feedbacks = query_res.result + assert len(feedbacks) == 2 + new_feedback = next(f for f in feedbacks if f["id"] == replaced_feedback.id) + assert new_feedback["feedback_type"] == "reaction" + assert new_feedback["payload"] == {"emoji": "👍"} diff --git a/weave-js/src/assets/icons/icon-buzz-bot10.svg b/weave-js/src/assets/icons/icon-buzz-bot10.svg new file mode 100644 index 00000000000..a4a833c2847 --- /dev/null +++ b/weave-js/src/assets/icons/icon-buzz-bot10.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/weave-js/src/components/Icon/Icon.tsx b/weave-js/src/components/Icon/Icon.tsx index e552741b926..2dfc93e678f 100644 --- a/weave-js/src/components/Icon/Icon.tsx +++ b/weave-js/src/components/Icon/Icon.tsx @@ -18,6 +18,7 @@ import {ReactComponent as ImportBookDictionary} from '../../assets/icons/icon-bo import {ReactComponent as ImportBoolean} from '../../assets/icons/icon-boolean.svg'; import {ReactComponent as ImportBoxPlot} from '../../assets/icons/icon-box-plot.svg'; import {ReactComponent as ImportBug} from '../../assets/icons/icon-bug.svg'; +import {ReactComponent as ImportBuzzBot10} from '../../assets/icons/icon-buzz-bot10.svg'; import {ReactComponent as ImportCategoryMultimodal} from '../../assets/icons/icon-category-multimodal.svg'; import {ReactComponent as ImportChartHorizontalBars} from '../../assets/icons/icon-chart-horizontal-bars.svg'; import {ReactComponent as ImportChartPie} from '../../assets/icons/icon-chart-pie.svg'; @@ -324,6 +325,9 @@ export const IconBoxPlot = (props: SVGIconProps) => ( export const IconBug = (props: SVGIconProps) => ( ); +export const IconBuzzBot10 = (props: SVGIconProps) => ( + +); export const IconCategoryMultimodal = (props: SVGIconProps) => ( ); @@ -1058,6 +1062,7 @@ const ICON_NAME_TO_ICON: Record = { boolean: IconBoolean, 'box-plot': IconBoxPlot, bug: IconBug, + 'buzz-bot10': IconBuzzBot10, 'category-multimodal': IconCategoryMultimodal, 'chart-horizontal-bars': IconChartHorizontalBars, 'chart-pie': IconChartPie, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index f2e4964c77f..3d0b581786d 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -18,6 +18,7 @@ export { IconBoolean, IconBoxPlot, IconBug, + IconBuzzBot10, IconCategoryMultimodal, IconChartHorizontalBars, IconChartPie, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index d5a53de5f86..b886bca869e 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -17,6 +17,7 @@ export const IconNames = { Boolean: 'boolean', BoxPlot: 'box-plot', Bug: 'bug', + BuzzBot10: 'buzz-bot10', CategoryMultimodal: 'category-multimodal', ChartHorizontalBars: 'chart-horizontal-bars', ChartPie: 'chart-pie', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 2c458814245..562a0ca5cb8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -94,6 +94,7 @@ import {OpPage} from './Browse3/pages/OpPage'; import {OpsPage} from './Browse3/pages/OpsPage'; import {OpVersionPage} from './Browse3/pages/OpVersionPage'; import {OpVersionsPage} from './Browse3/pages/OpVersionsPage'; +import {PlaygroundPage} from './Browse3/pages/PlaygroundPage/PlaygroundPage'; import {TablePage} from './Browse3/pages/TablePage'; import {TablesPage} from './Browse3/pages/TablesPage'; import {useURLSearchParamsDict} from './Browse3/pages/util'; @@ -523,6 +524,14 @@ const Browse3ProjectRoot: FC<{ + {/* PLAYGROUND */} + + + ); @@ -1035,6 +1044,17 @@ const AppBarLink = (props: ComponentProps) => ( /> ); +const PlaygroundPageBinding = () => { + const params = useParamsDecoded(); + return ( + + ); +}; + const Browse3Breadcrumbs: FC = props => { const params = useParamsDecoded(); const query = useURLSearchParamsDict(); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx new file mode 100644 index 00000000000..c677694af50 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx @@ -0,0 +1,95 @@ +import {Box} from '@mui/material'; +import {WeaveLoader} from '@wandb/weave/common/components/WeaveLoader'; +import React, {useEffect, useMemo, useState} from 'react'; + +import {SimplePageLayout} from '../common/SimplePageLayout'; +import {useWFHooks} from '../wfReactInterface/context'; +import {DEFAULT_SYSTEM_MESSAGE, usePlaygroundState} from './usePlaygroundState'; + +export type PlaygroundPageProps = { + entity: string; + project: string; + callId: string; +}; + +export const PlaygroundPage = (props: PlaygroundPageProps) => { + return ( + , + }, + ]} + /> + ); +}; + +export const PlaygroundPageInner = (props: PlaygroundPageProps) => { + const { + setPlaygroundStateFromTraceCall, + playgroundStates, + setPlaygroundStateField, + } = usePlaygroundState(); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const [settingsTab, setSettingsTab] = useState(null); + + const {useCall} = useWFHooks(); + const call = useCall( + useMemo(() => { + return props.callId + ? { + entity: props.entity, + project: props.project, + callId: props.callId, + } + : null; + }, [props.entity, props.project, props.callId]) + ); + + useEffect(() => { + if (!call.loading && call.result) { + if (call.result.traceCall?.inputs) { + setPlaygroundStateFromTraceCall(call.result.traceCall); + } + } else if ( + playgroundStates.length === 1 && + !playgroundStates[0].traceCall.project_id + ) { + setPlaygroundStateField(0, 'traceCall', { + inputs: { + messages: [DEFAULT_SYSTEM_MESSAGE], + }, + project_id: `${props.entity}/${props.project}`, + }); + } + // Only set the call the first time the page loads, and we get the call + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [props.callId]); + + return ( + + {call.loading ? ( + + + + ) : ( +
Playground
+ )} + {settingsTab !== null &&
Settings
} +
+ ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts new file mode 100644 index 00000000000..34cee0ea092 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts @@ -0,0 +1 @@ +export * from './PlaygroundPage'; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts new file mode 100644 index 00000000000..1334e32a454 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts @@ -0,0 +1,123 @@ +// This is a mapping of LLM names to their max token limits. +// Directly from the pycache model_providers.json in trace_server. +// Some were removed because they are not supported when Josiah tried on Oct 30, 2024. +export const LLM_MAX_TOKENS = { + 'gpt-4o-mini': {max_tokens: 16384, supports_function_calling: true}, + 'claude-3-5-sonnet-20240620': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'claude-3-5-sonnet-20241022': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'claude-3-haiku-20240307': { + max_tokens: 4096, + supports_function_calling: true, + }, + 'claude-3-opus-20240229': {max_tokens: 4096, supports_function_calling: true}, + 'claude-3-sonnet-20240229': { + max_tokens: 4096, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-flash-001': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-flash-002': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-flash-8b-exp-0827': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-flash-8b-exp-0924': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-flash-exp-0827': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-flash-latest': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-flash': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-pro-001': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-pro-002': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-pro-exp-0801': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-pro-exp-0827': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-pro-latest': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'gemini/gemini-1.5-pro': {max_tokens: 8192, supports_function_calling: true}, + 'gemini/gemini-pro': {max_tokens: 8192, supports_function_calling: true}, + 'gpt-3.5-turbo-0125': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-3.5-turbo-1106': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-3.5-turbo-16k': {max_tokens: 4096, supports_function_calling: false}, + 'gpt-3.5-turbo': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4-0125-preview': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4-0314': {max_tokens: 4096, supports_function_calling: false}, + 'gpt-4-0613': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4-1106-preview': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4-32k-0314': {max_tokens: 4096, supports_function_calling: false}, + 'gpt-4-turbo-2024-04-09': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4-turbo-preview': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4-turbo': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4o-2024-05-13': {max_tokens: 4096, supports_function_calling: true}, + 'gpt-4o-2024-08-06': {max_tokens: 16384, supports_function_calling: true}, + 'gpt-4o-mini-2024-07-18': { + max_tokens: 16384, + supports_function_calling: true, + }, + 'gpt-4o': {max_tokens: 4096, supports_function_calling: true}, + 'groq/gemma-7b-it': {max_tokens: 8192, supports_function_calling: true}, + 'groq/gemma2-9b-it': {max_tokens: 8192, supports_function_calling: true}, + 'groq/llama-3.1-70b-versatile': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'groq/llama-3.1-8b-instant': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'groq/llama3-70b-8192': {max_tokens: 8192, supports_function_calling: true}, + 'groq/llama3-8b-8192': {max_tokens: 8192, supports_function_calling: true}, + 'groq/llama3-groq-70b-8192-tool-use-preview': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'groq/llama3-groq-8b-8192-tool-use-preview': { + max_tokens: 8192, + supports_function_calling: true, + }, + 'groq/mixtral-8x7b-32768': { + max_tokens: 32768, + supports_function_calling: true, + }, + 'o1-mini-2024-09-12': {max_tokens: 65536, supports_function_calling: true}, + 'o1-mini': {max_tokens: 65536, supports_function_calling: true}, + 'o1-preview-2024-09-12': {max_tokens: 32768, supports_function_calling: true}, + 'o1-preview': {max_tokens: 32768, supports_function_calling: true}, +}; + +export type LLMMaxTokensKey = keyof typeof LLM_MAX_TOKENS; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts new file mode 100644 index 00000000000..3e4f9dd30c8 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts @@ -0,0 +1,30 @@ +import {TraceCallSchema} from '../wfReactInterface/traceServerClientTypes'; +import {LLMMaxTokensKey} from './llmMaxTokens'; + +export enum PlaygroundResponseFormats { + Text = 'text', + JsonObject = 'json_object', + // Fast follow + // JsonSchema = 'json_schema', +} + +export type PlaygroundState = { + traceCall: OptionalTraceCallSchema; + trackLLMCall: boolean; + loading: boolean; + functions: Array<{name: string; [key: string]: any}>; + responseFormat: PlaygroundResponseFormats; + temperature: number; + maxTokens: number; + stopSequences: string[]; + topP: number; + frequencyPenalty: number; + presencePenalty: number; + // nTimes: number; + maxTokensLimit: number; + model: LLMMaxTokensKey; +}; + +export type PlaygroundStateKey = keyof PlaygroundState; + +export type OptionalTraceCallSchema = Partial; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts new file mode 100644 index 00000000000..f034302c468 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts @@ -0,0 +1,147 @@ +import {SetStateAction, useCallback, useState} from 'react'; + +import {LLMMaxTokensKey} from './llmMaxTokens'; +import { + OptionalTraceCallSchema, + PlaygroundResponseFormats, + PlaygroundState, + PlaygroundStateKey, +} from './types'; + +export const DEFAULT_SYSTEM_MESSAGE_CONTENT = + 'You are an AI assistant designed to assist users by providing clear, concise, and helpful responses.'; + +export const DEFAULT_SYSTEM_MESSAGE = { + role: 'system', + content: DEFAULT_SYSTEM_MESSAGE_CONTENT, +}; + +const DEFAULT_PLAYGROUND_STATE = { + traceCall: { + inputs: { + messages: [DEFAULT_SYSTEM_MESSAGE], + }, + }, + trackLLMCall: true, + loading: false, + functions: [], + responseFormat: PlaygroundResponseFormats.Text, + temperature: 1, + maxTokens: 4096, + stopSequences: [], + topP: 1, + frequencyPenalty: 0, + presencePenalty: 0, + // nTimes: 1, + maxTokensLimit: 16384, + model: 'gpt-4o-mini' as LLMMaxTokensKey, +}; + +export const usePlaygroundState = () => { + const [playgroundStates, setPlaygroundStates] = useState([ + DEFAULT_PLAYGROUND_STATE, + ]); + + const setPlaygroundStateField = useCallback( + ( + index: number, + field: PlaygroundStateKey, + value: SetStateAction + ) => { + setPlaygroundStates(prevStates => + prevStates.map((state, i) => + i === index + ? { + ...state, + [field]: + typeof value === 'function' + ? (value as SetStateAction)(state[field]) + : value, + } + : state + ) + ); + }, + [] + ); + + // Takes in a function input and sets the state accordingly + const setPlaygroundStateFromTraceCall = useCallback( + (traceCall: OptionalTraceCallSchema) => { + const inputs = traceCall.inputs; + // https://docs.litellm.ai/docs/completion/input + // pulled from litellm + setPlaygroundStates(prevState => { + const newState = {...prevState[0]}; + + newState.traceCall = traceCall; + + if (!inputs) { + return [newState]; + } + + if (inputs.tools) { + newState.functions = []; + for (const tool of inputs.tools) { + if (tool.type === 'function') { + newState.functions = [...newState.functions, tool.function]; + } + } + } + // if (inputs.n) { + // newState.nTimes = parseInt(inputs.n, 10); + // } + if (inputs.temperature) { + newState.temperature = parseFloat(inputs.temperature); + } + if (inputs.response_format) { + newState.responseFormat = inputs.response_format.type; + } + if (inputs.top_p) { + newState.topP = parseFloat(inputs.top_p); + } + if (inputs.frequency_penalty) { + newState.frequencyPenalty = parseFloat(inputs.frequency_penalty); + } + if (inputs.presence_penalty) { + newState.presencePenalty = parseFloat(inputs.presence_penalty); + } + return [newState]; + }); + }, + [] + ); + + return { + playgroundStates, + setPlaygroundStates, + setPlaygroundStateField, + setPlaygroundStateFromTraceCall, + }; +}; + +export const getInputFromPlaygroundState = (state: PlaygroundState) => { + const tools = state.functions.map(func => ({ + type: 'function', + function: func, + })); + return { + // Adding this to prevent the exact same call from not getting run + // eg running the same call in parallel + key: Math.random() * 1000, + + messages: state.traceCall?.inputs?.messages, + model: state.model, + temperature: state.temperature, + max_tokens: state.maxTokens, + stop: state.stopSequences, + top_p: state.topP, + frequency_penalty: state.frequencyPenalty, + presence_penalty: state.presencePenalty, + // n: state.nTimes, + response_format: { + type: state.responseFormat, + }, + tools: tools.length > 0 ? tools : undefined, + }; +}; diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 8547c97648f..34a9cbfaa16 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -31,12 +31,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import ( - Any, - Optional, - Union, - cast, -) +from typing import Any, Optional, Union, cast from zoneinfo import ZoneInfo import clickhouse_connect @@ -50,9 +45,7 @@ from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi from weave.trace_server.actions_worker.dispatcher import execute_batch -from weave.trace_server.base_object_class_util import ( - process_incoming_object, -) +from weave.trace_server.base_object_class_util import process_incoming_object from weave.trace_server.calls_query_builder import ( CallsQuery, HardCodedFilter, @@ -1407,6 +1400,32 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: self.ch_client.query(prepared.sql, prepared.parameters) return tsi.FeedbackPurgeRes() + def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: + # To replace, first purge, then if successful, create. + query = tsi.Query( + **{ + "$expr": { + "$eq": [ + {"$getField": "id"}, + {"$literal": req.feedback_id}, + ], + } + } + ) + purge_request = tsi.FeedbackPurgeReq( + project_id=req.project_id, + query=query, + ) + self.feedback_purge(purge_request) + create_req = tsi.FeedbackCreateReq(**req.model_dump(exclude={"feedback_id"})) + create_result = self.feedback_create(create_req) + return tsi.FeedbackReplaceRes( + id=create_result.id, + created_at=create_result.created_at, + wb_user_id=create_result.wb_user_id, + payload=create_result.payload, + ) + def actions_execute_batch( self, req: tsi.ActionsExecuteBatchReq ) -> tsi.ActionsExecuteBatchRes: diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 5878bdad0ee..1df739adbcd 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -327,6 +327,18 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) return self._ref_apply(self._internal_trace_server.feedback_purge, req) + def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + original_user_id = req.wb_user_id + if original_user_id is None: + raise ValueError("wb_user_id cannot be None") + req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id) + res = self._ref_apply(self._internal_trace_server.feedback_replace, req) + if res.wb_user_id != req.wb_user_id: + raise ValueError("Internal Error - User Mismatch") + res.wb_user_id = original_user_id + return res + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) return self._ref_apply(self._internal_trace_server.cost_create, req) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index f500b0fe9bd..0ff9afa325d 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -14,9 +14,7 @@ from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.base_object_class_util import ( - process_incoming_object, -) +from weave.trace_server.base_object_class_util import process_incoming_object from weave.trace_server.emoji_util import detone_emojis from weave.trace_server.errors import InvalidRequest from weave.trace_server.feedback import ( @@ -1054,6 +1052,29 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: conn.commit() return tsi.FeedbackPurgeRes() + def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: + purge_request = tsi.FeedbackPurgeReq( + project_id=req.project_id, + query={ + "$expr": { + "$eq": [ + {"$getField": "id"}, + {"$literal": req.feedback_id}, + ], + } + }, + ) + self.feedback_purge(purge_request) + create_req = tsi.FeedbackCreateReq(**req.model_dump(exclude={"feedback_id"})) + create_result = self.feedback_create(create_req) + + return tsi.FeedbackReplaceRes( + id=create_result.id, + created_at=create_result.created_at, + wb_user_id=create_result.wb_user_id, + payload=create_result.payload, + ) + def actions_execute_batch( self, req: tsi.ActionsExecuteBatchReq ) -> tsi.ActionsExecuteBatchRes: diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index b3fb3b58a7f..77ef6198ecf 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -743,6 +743,14 @@ class FeedbackPurgeRes(BaseModel): pass +class FeedbackReplaceReq(FeedbackCreateReq): + feedback_id: str + + +class FeedbackReplaceRes(FeedbackCreateRes): + pass + + class FileCreateReq(BaseModel): project_id: str name: str @@ -900,6 +908,7 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... + def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ... # Action API def actions_execute_batch( diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index f929c2065a5..d0b2fb06c36 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -527,6 +527,13 @@ def feedback_purge( "/feedback/purge", req, tsi.FeedbackPurgeReq, tsi.FeedbackPurgeRes ) + def feedback_replace( + self, req: Union[tsi.FeedbackReplaceReq, dict[str, Any]] + ) -> tsi.FeedbackReplaceRes: + return self._generic_request( + "/feedback/replace", req, tsi.FeedbackReplaceReq, tsi.FeedbackReplaceRes + ) + def actions_execute_batch( self, req: Union[tsi.ActionsExecuteBatchReq, dict[str, Any]] ) -> tsi.ActionsExecuteBatchRes: diff --git a/weave_query/weave_query/wandb_interface/wandb_stream_table.py b/weave_query/weave_query/wandb_interface/wandb_stream_table.py index eda69512f2b..d89616e4b9c 100644 --- a/weave_query/weave_query/wandb_interface/wandb_stream_table.py +++ b/weave_query/weave_query/wandb_interface/wandb_stream_table.py @@ -13,7 +13,6 @@ from wandb.sdk.lib.ipython import _get_python_type from wandb.sdk.lib.paths import LogicalPath -from wandb.sdk.lib.printer import get_printer from weave_query import weave_types from weave_query import ( @@ -186,11 +185,6 @@ def _ensure_remote_initialized(self) -> StreamTableType: ) self._artifact = WandbLiveRunFiles(name=uri.name, uri=uri) self._artifact.set_file_pusher(self._lite_run.pusher) - if print_url: - base_url = environment.weave_server_url() - url = f"{base_url}/browse/wandb/{self._entity_name}/{self._project_name}/table/{self._table_name}" - printer = get_printer(_get_python_type() != "python") - # printer.display(f'{printer.emoji("star")} View data at {printer.link(url)}') return self._weave_stream_table def log(self, row_or_rows: ROW_TYPE) -> None: