diff --git a/tests/trace/test_feedback.py b/tests/trace/test_feedback.py
index 264cb6e8c58..f12b419624b 100644
--- a/tests/trace/test_feedback.py
+++ b/tests/trace/test_feedback.py
@@ -7,6 +7,11 @@
from weave.trace.weave_client import WeaveClient, get_ref
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.errors import InvalidRequest
+from weave.trace_server.trace_server_interface import (
+ FeedbackCreateReq,
+ FeedbackQueryReq,
+ FeedbackReplaceReq,
+)
def test_client_feedback(client) -> None:
@@ -173,9 +178,11 @@ def test_annotation_feedback(client: WeaveClient) -> None:
"wb_user_id": "shawn",
"creator": None,
# Sad - seems like sqlite and clickhouse remote different types here
- "created_at": create_res.created_at.isoformat().replace("T", " ")
- if client_is_sqlite(client)
- else MatchAnyDatetime(),
+ "created_at": (
+ create_res.created_at.isoformat().replace("T", " ")
+ if client_is_sqlite(client)
+ else MatchAnyDatetime()
+ ),
"feedback_type": feedback_type,
"payload": payload,
"annotation_ref": annotation_ref,
@@ -321,9 +328,11 @@ def test_runnable_feedback(client: WeaveClient) -> None:
"wb_user_id": "shawn",
"creator": None,
# Sad - seems like sqlite and clickhouse remote different types here
- "created_at": create_res.created_at.isoformat().replace("T", " ")
- if client_is_sqlite(client)
- else MatchAnyDatetime(),
+ "created_at": (
+ create_res.created_at.isoformat().replace("T", " ")
+ if client_is_sqlite(client)
+ else MatchAnyDatetime()
+ ),
"feedback_type": feedback_type,
"payload": payload,
"annotation_ref": None,
@@ -535,3 +544,80 @@ def test_filter_and_sort_by_feedback(client: WeaveClient) -> None:
calls = list(calls)
assert len(calls) == 2
assert [c.id for c in calls] == [ids[2], ids[0]]
+
+
+def test_feedback_replace(client) -> None:
+ # Create initial feedback
+ create_req = FeedbackCreateReq(
+ project_id="test/project",
+ weave_ref="weave:///test/project/obj/123:abc",
+ feedback_type="reaction",
+ payload={"emoji": "👍"},
+ wb_user_id="test_user",
+ )
+ initial_feedback = client.server.feedback_create(create_req)
+
+ # Create another feedback with different type
+ note_feedback = client.server.feedback_create(
+ FeedbackCreateReq(
+ project_id="test/project",
+ weave_ref="weave:///test/project/obj/456:def",
+ feedback_type="note",
+ payload={"note": "This is a test note"},
+ wb_user_id="test_user",
+ )
+ )
+
+ # Replace the first feedback with new content
+ replace_req = FeedbackReplaceReq(
+ project_id="test/project",
+ weave_ref="weave:///test/project/obj/123:abc",
+ feedback_type="note",
+ payload={"note": "Updated feedback"},
+ feedback_id=initial_feedback.id,
+ wb_user_id="test_user",
+ )
+ replaced_feedback = client.server.feedback_replace(replace_req)
+
+ # Verify the replacement
+ assert note_feedback.id != replaced_feedback.id
+
+ # Verify the other feedback remains unchanged
+ query_res = client.server.feedback_query(
+ FeedbackQueryReq(
+ project_id="test/project", fields=["id", "feedback_type", "payload"]
+ )
+ )
+
+ feedbacks = query_res.result
+ assert len(feedbacks) == 2
+
+ # Find the non-replaced feedback and verify it's unchanged
+ other_feedback = next(f for f in feedbacks if f["id"] == note_feedback.id)
+ assert other_feedback["feedback_type"] == "note"
+ assert other_feedback["payload"] == {"note": "This is a test note"}
+
+ # now replace the replaced feedback with the original content
+ replace_req = FeedbackReplaceReq(
+ project_id="test/project",
+ weave_ref="weave:///test/project/obj/123:abc",
+ feedback_type="reaction",
+ payload={"emoji": "👍"},
+ feedback_id=replaced_feedback.id,
+ wb_user_id="test_user",
+ )
+ replaced_feedback = client.server.feedback_replace(replace_req)
+
+ assert replaced_feedback.id != initial_feedback.id
+
+ # Verify the latest feedback payload
+ query_res = client.server.feedback_query(
+ FeedbackQueryReq(
+ project_id="test/project", fields=["id", "feedback_type", "payload"]
+ )
+ )
+ feedbacks = query_res.result
+ assert len(feedbacks) == 2
+ new_feedback = next(f for f in feedbacks if f["id"] == replaced_feedback.id)
+ assert new_feedback["feedback_type"] == "reaction"
+ assert new_feedback["payload"] == {"emoji": "👍"}
diff --git a/weave-js/src/assets/icons/icon-buzz-bot10.svg b/weave-js/src/assets/icons/icon-buzz-bot10.svg
new file mode 100644
index 00000000000..a4a833c2847
--- /dev/null
+++ b/weave-js/src/assets/icons/icon-buzz-bot10.svg
@@ -0,0 +1,6 @@
+
diff --git a/weave-js/src/components/Icon/Icon.tsx b/weave-js/src/components/Icon/Icon.tsx
index e552741b926..2dfc93e678f 100644
--- a/weave-js/src/components/Icon/Icon.tsx
+++ b/weave-js/src/components/Icon/Icon.tsx
@@ -18,6 +18,7 @@ import {ReactComponent as ImportBookDictionary} from '../../assets/icons/icon-bo
import {ReactComponent as ImportBoolean} from '../../assets/icons/icon-boolean.svg';
import {ReactComponent as ImportBoxPlot} from '../../assets/icons/icon-box-plot.svg';
import {ReactComponent as ImportBug} from '../../assets/icons/icon-bug.svg';
+import {ReactComponent as ImportBuzzBot10} from '../../assets/icons/icon-buzz-bot10.svg';
import {ReactComponent as ImportCategoryMultimodal} from '../../assets/icons/icon-category-multimodal.svg';
import {ReactComponent as ImportChartHorizontalBars} from '../../assets/icons/icon-chart-horizontal-bars.svg';
import {ReactComponent as ImportChartPie} from '../../assets/icons/icon-chart-pie.svg';
@@ -324,6 +325,9 @@ export const IconBoxPlot = (props: SVGIconProps) => (
export const IconBug = (props: SVGIconProps) => (
);
+export const IconBuzzBot10 = (props: SVGIconProps) => (
+
+);
export const IconCategoryMultimodal = (props: SVGIconProps) => (
);
@@ -1058,6 +1062,7 @@ const ICON_NAME_TO_ICON: Record = {
boolean: IconBoolean,
'box-plot': IconBoxPlot,
bug: IconBug,
+ 'buzz-bot10': IconBuzzBot10,
'category-multimodal': IconCategoryMultimodal,
'chart-horizontal-bars': IconChartHorizontalBars,
'chart-pie': IconChartPie,
diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts
index f2e4964c77f..3d0b581786d 100644
--- a/weave-js/src/components/Icon/index.ts
+++ b/weave-js/src/components/Icon/index.ts
@@ -18,6 +18,7 @@ export {
IconBoolean,
IconBoxPlot,
IconBug,
+ IconBuzzBot10,
IconCategoryMultimodal,
IconChartHorizontalBars,
IconChartPie,
diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts
index d5a53de5f86..b886bca869e 100644
--- a/weave-js/src/components/Icon/types.ts
+++ b/weave-js/src/components/Icon/types.ts
@@ -17,6 +17,7 @@ export const IconNames = {
Boolean: 'boolean',
BoxPlot: 'box-plot',
Bug: 'bug',
+ BuzzBot10: 'buzz-bot10',
CategoryMultimodal: 'category-multimodal',
ChartHorizontalBars: 'chart-horizontal-bars',
ChartPie: 'chart-pie',
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx
index 2c458814245..562a0ca5cb8 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx
@@ -94,6 +94,7 @@ import {OpPage} from './Browse3/pages/OpPage';
import {OpsPage} from './Browse3/pages/OpsPage';
import {OpVersionPage} from './Browse3/pages/OpVersionPage';
import {OpVersionsPage} from './Browse3/pages/OpVersionsPage';
+import {PlaygroundPage} from './Browse3/pages/PlaygroundPage/PlaygroundPage';
import {TablePage} from './Browse3/pages/TablePage';
import {TablesPage} from './Browse3/pages/TablesPage';
import {useURLSearchParamsDict} from './Browse3/pages/util';
@@ -523,6 +524,14 @@ const Browse3ProjectRoot: FC<{
+ {/* PLAYGROUND */}
+
+
+
);
@@ -1035,6 +1044,17 @@ const AppBarLink = (props: ComponentProps) => (
/>
);
+const PlaygroundPageBinding = () => {
+ const params = useParamsDecoded();
+ return (
+
+ );
+};
+
const Browse3Breadcrumbs: FC = props => {
const params = useParamsDecoded();
const query = useURLSearchParamsDict();
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx
new file mode 100644
index 00000000000..c677694af50
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx
@@ -0,0 +1,95 @@
+import {Box} from '@mui/material';
+import {WeaveLoader} from '@wandb/weave/common/components/WeaveLoader';
+import React, {useEffect, useMemo, useState} from 'react';
+
+import {SimplePageLayout} from '../common/SimplePageLayout';
+import {useWFHooks} from '../wfReactInterface/context';
+import {DEFAULT_SYSTEM_MESSAGE, usePlaygroundState} from './usePlaygroundState';
+
+export type PlaygroundPageProps = {
+ entity: string;
+ project: string;
+ callId: string;
+};
+
+export const PlaygroundPage = (props: PlaygroundPageProps) => {
+ return (
+ ,
+ },
+ ]}
+ />
+ );
+};
+
+export const PlaygroundPageInner = (props: PlaygroundPageProps) => {
+ const {
+ setPlaygroundStateFromTraceCall,
+ playgroundStates,
+ setPlaygroundStateField,
+ } = usePlaygroundState();
+
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ const [settingsTab, setSettingsTab] = useState(null);
+
+ const {useCall} = useWFHooks();
+ const call = useCall(
+ useMemo(() => {
+ return props.callId
+ ? {
+ entity: props.entity,
+ project: props.project,
+ callId: props.callId,
+ }
+ : null;
+ }, [props.entity, props.project, props.callId])
+ );
+
+ useEffect(() => {
+ if (!call.loading && call.result) {
+ if (call.result.traceCall?.inputs) {
+ setPlaygroundStateFromTraceCall(call.result.traceCall);
+ }
+ } else if (
+ playgroundStates.length === 1 &&
+ !playgroundStates[0].traceCall.project_id
+ ) {
+ setPlaygroundStateField(0, 'traceCall', {
+ inputs: {
+ messages: [DEFAULT_SYSTEM_MESSAGE],
+ },
+ project_id: `${props.entity}/${props.project}`,
+ });
+ }
+ // Only set the call the first time the page loads, and we get the call
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [props.callId]);
+
+ return (
+
+ {call.loading ? (
+
+
+
+ ) : (
+ Playground
+ )}
+ {settingsTab !== null && Settings
}
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts
new file mode 100644
index 00000000000..34cee0ea092
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts
@@ -0,0 +1 @@
+export * from './PlaygroundPage';
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts
new file mode 100644
index 00000000000..1334e32a454
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/llmMaxTokens.ts
@@ -0,0 +1,123 @@
+// This is a mapping of LLM names to their max token limits.
+// Directly from the pycache model_providers.json in trace_server.
+// Some were removed because they are not supported when Josiah tried on Oct 30, 2024.
+export const LLM_MAX_TOKENS = {
+ 'gpt-4o-mini': {max_tokens: 16384, supports_function_calling: true},
+ 'claude-3-5-sonnet-20240620': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'claude-3-5-sonnet-20241022': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'claude-3-haiku-20240307': {
+ max_tokens: 4096,
+ supports_function_calling: true,
+ },
+ 'claude-3-opus-20240229': {max_tokens: 4096, supports_function_calling: true},
+ 'claude-3-sonnet-20240229': {
+ max_tokens: 4096,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-flash-001': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-flash-002': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-flash-8b-exp-0827': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-flash-8b-exp-0924': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-flash-exp-0827': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-flash-latest': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-flash': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-pro-001': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-pro-002': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-pro-exp-0801': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-pro-exp-0827': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-pro-latest': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'gemini/gemini-1.5-pro': {max_tokens: 8192, supports_function_calling: true},
+ 'gemini/gemini-pro': {max_tokens: 8192, supports_function_calling: true},
+ 'gpt-3.5-turbo-0125': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-3.5-turbo-1106': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-3.5-turbo-16k': {max_tokens: 4096, supports_function_calling: false},
+ 'gpt-3.5-turbo': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4-0125-preview': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4-0314': {max_tokens: 4096, supports_function_calling: false},
+ 'gpt-4-0613': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4-1106-preview': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4-32k-0314': {max_tokens: 4096, supports_function_calling: false},
+ 'gpt-4-turbo-2024-04-09': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4-turbo-preview': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4-turbo': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4o-2024-05-13': {max_tokens: 4096, supports_function_calling: true},
+ 'gpt-4o-2024-08-06': {max_tokens: 16384, supports_function_calling: true},
+ 'gpt-4o-mini-2024-07-18': {
+ max_tokens: 16384,
+ supports_function_calling: true,
+ },
+ 'gpt-4o': {max_tokens: 4096, supports_function_calling: true},
+ 'groq/gemma-7b-it': {max_tokens: 8192, supports_function_calling: true},
+ 'groq/gemma2-9b-it': {max_tokens: 8192, supports_function_calling: true},
+ 'groq/llama-3.1-70b-versatile': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'groq/llama-3.1-8b-instant': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'groq/llama3-70b-8192': {max_tokens: 8192, supports_function_calling: true},
+ 'groq/llama3-8b-8192': {max_tokens: 8192, supports_function_calling: true},
+ 'groq/llama3-groq-70b-8192-tool-use-preview': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'groq/llama3-groq-8b-8192-tool-use-preview': {
+ max_tokens: 8192,
+ supports_function_calling: true,
+ },
+ 'groq/mixtral-8x7b-32768': {
+ max_tokens: 32768,
+ supports_function_calling: true,
+ },
+ 'o1-mini-2024-09-12': {max_tokens: 65536, supports_function_calling: true},
+ 'o1-mini': {max_tokens: 65536, supports_function_calling: true},
+ 'o1-preview-2024-09-12': {max_tokens: 32768, supports_function_calling: true},
+ 'o1-preview': {max_tokens: 32768, supports_function_calling: true},
+};
+
+export type LLMMaxTokensKey = keyof typeof LLM_MAX_TOKENS;
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts
new file mode 100644
index 00000000000..3e4f9dd30c8
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts
@@ -0,0 +1,30 @@
+import {TraceCallSchema} from '../wfReactInterface/traceServerClientTypes';
+import {LLMMaxTokensKey} from './llmMaxTokens';
+
+export enum PlaygroundResponseFormats {
+ Text = 'text',
+ JsonObject = 'json_object',
+ // Fast follow
+ // JsonSchema = 'json_schema',
+}
+
+export type PlaygroundState = {
+ traceCall: OptionalTraceCallSchema;
+ trackLLMCall: boolean;
+ loading: boolean;
+ functions: Array<{name: string; [key: string]: any}>;
+ responseFormat: PlaygroundResponseFormats;
+ temperature: number;
+ maxTokens: number;
+ stopSequences: string[];
+ topP: number;
+ frequencyPenalty: number;
+ presencePenalty: number;
+ // nTimes: number;
+ maxTokensLimit: number;
+ model: LLMMaxTokensKey;
+};
+
+export type PlaygroundStateKey = keyof PlaygroundState;
+
+export type OptionalTraceCallSchema = Partial;
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
new file mode 100644
index 00000000000..f034302c468
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
@@ -0,0 +1,147 @@
+import {SetStateAction, useCallback, useState} from 'react';
+
+import {LLMMaxTokensKey} from './llmMaxTokens';
+import {
+ OptionalTraceCallSchema,
+ PlaygroundResponseFormats,
+ PlaygroundState,
+ PlaygroundStateKey,
+} from './types';
+
+export const DEFAULT_SYSTEM_MESSAGE_CONTENT =
+ 'You are an AI assistant designed to assist users by providing clear, concise, and helpful responses.';
+
+export const DEFAULT_SYSTEM_MESSAGE = {
+ role: 'system',
+ content: DEFAULT_SYSTEM_MESSAGE_CONTENT,
+};
+
+const DEFAULT_PLAYGROUND_STATE = {
+ traceCall: {
+ inputs: {
+ messages: [DEFAULT_SYSTEM_MESSAGE],
+ },
+ },
+ trackLLMCall: true,
+ loading: false,
+ functions: [],
+ responseFormat: PlaygroundResponseFormats.Text,
+ temperature: 1,
+ maxTokens: 4096,
+ stopSequences: [],
+ topP: 1,
+ frequencyPenalty: 0,
+ presencePenalty: 0,
+ // nTimes: 1,
+ maxTokensLimit: 16384,
+ model: 'gpt-4o-mini' as LLMMaxTokensKey,
+};
+
+export const usePlaygroundState = () => {
+ const [playgroundStates, setPlaygroundStates] = useState([
+ DEFAULT_PLAYGROUND_STATE,
+ ]);
+
+ const setPlaygroundStateField = useCallback(
+ (
+ index: number,
+ field: PlaygroundStateKey,
+ value: SetStateAction
+ ) => {
+ setPlaygroundStates(prevStates =>
+ prevStates.map((state, i) =>
+ i === index
+ ? {
+ ...state,
+ [field]:
+ typeof value === 'function'
+ ? (value as SetStateAction)(state[field])
+ : value,
+ }
+ : state
+ )
+ );
+ },
+ []
+ );
+
+ // Takes in a function input and sets the state accordingly
+ const setPlaygroundStateFromTraceCall = useCallback(
+ (traceCall: OptionalTraceCallSchema) => {
+ const inputs = traceCall.inputs;
+ // https://docs.litellm.ai/docs/completion/input
+ // pulled from litellm
+ setPlaygroundStates(prevState => {
+ const newState = {...prevState[0]};
+
+ newState.traceCall = traceCall;
+
+ if (!inputs) {
+ return [newState];
+ }
+
+ if (inputs.tools) {
+ newState.functions = [];
+ for (const tool of inputs.tools) {
+ if (tool.type === 'function') {
+ newState.functions = [...newState.functions, tool.function];
+ }
+ }
+ }
+ // if (inputs.n) {
+ // newState.nTimes = parseInt(inputs.n, 10);
+ // }
+ if (inputs.temperature) {
+ newState.temperature = parseFloat(inputs.temperature);
+ }
+ if (inputs.response_format) {
+ newState.responseFormat = inputs.response_format.type;
+ }
+ if (inputs.top_p) {
+ newState.topP = parseFloat(inputs.top_p);
+ }
+ if (inputs.frequency_penalty) {
+ newState.frequencyPenalty = parseFloat(inputs.frequency_penalty);
+ }
+ if (inputs.presence_penalty) {
+ newState.presencePenalty = parseFloat(inputs.presence_penalty);
+ }
+ return [newState];
+ });
+ },
+ []
+ );
+
+ return {
+ playgroundStates,
+ setPlaygroundStates,
+ setPlaygroundStateField,
+ setPlaygroundStateFromTraceCall,
+ };
+};
+
+export const getInputFromPlaygroundState = (state: PlaygroundState) => {
+ const tools = state.functions.map(func => ({
+ type: 'function',
+ function: func,
+ }));
+ return {
+ // Adding this to prevent the exact same call from not getting run
+ // eg running the same call in parallel
+ key: Math.random() * 1000,
+
+ messages: state.traceCall?.inputs?.messages,
+ model: state.model,
+ temperature: state.temperature,
+ max_tokens: state.maxTokens,
+ stop: state.stopSequences,
+ top_p: state.topP,
+ frequency_penalty: state.frequencyPenalty,
+ presence_penalty: state.presencePenalty,
+ // n: state.nTimes,
+ response_format: {
+ type: state.responseFormat,
+ },
+ tools: tools.length > 0 ? tools : undefined,
+ };
+};
diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py
index 8547c97648f..34a9cbfaa16 100644
--- a/weave/trace_server/clickhouse_trace_server_batched.py
+++ b/weave/trace_server/clickhouse_trace_server_batched.py
@@ -31,12 +31,7 @@
from collections import defaultdict
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
-from typing import (
- Any,
- Optional,
- Union,
- cast,
-)
+from typing import Any, Optional, Union, cast
from zoneinfo import ZoneInfo
import clickhouse_connect
@@ -50,9 +45,7 @@
from weave.trace_server import refs_internal as ri
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.actions_worker.dispatcher import execute_batch
-from weave.trace_server.base_object_class_util import (
- process_incoming_object,
-)
+from weave.trace_server.base_object_class_util import process_incoming_object
from weave.trace_server.calls_query_builder import (
CallsQuery,
HardCodedFilter,
@@ -1407,6 +1400,32 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes:
self.ch_client.query(prepared.sql, prepared.parameters)
return tsi.FeedbackPurgeRes()
+ def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes:
+ # To replace, first purge, then if successful, create.
+ query = tsi.Query(
+ **{
+ "$expr": {
+ "$eq": [
+ {"$getField": "id"},
+ {"$literal": req.feedback_id},
+ ],
+ }
+ }
+ )
+ purge_request = tsi.FeedbackPurgeReq(
+ project_id=req.project_id,
+ query=query,
+ )
+ self.feedback_purge(purge_request)
+ create_req = tsi.FeedbackCreateReq(**req.model_dump(exclude={"feedback_id"}))
+ create_result = self.feedback_create(create_req)
+ return tsi.FeedbackReplaceRes(
+ id=create_result.id,
+ created_at=create_result.created_at,
+ wb_user_id=create_result.wb_user_id,
+ payload=create_result.payload,
+ )
+
def actions_execute_batch(
self, req: tsi.ActionsExecuteBatchReq
) -> tsi.ActionsExecuteBatchRes:
diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py
index 5878bdad0ee..1df739adbcd 100644
--- a/weave/trace_server/external_to_internal_trace_server_adapter.py
+++ b/weave/trace_server/external_to_internal_trace_server_adapter.py
@@ -327,6 +327,18 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
return self._ref_apply(self._internal_trace_server.feedback_purge, req)
+ def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes:
+ req.project_id = self._idc.ext_to_int_project_id(req.project_id)
+ original_user_id = req.wb_user_id
+ if original_user_id is None:
+ raise ValueError("wb_user_id cannot be None")
+ req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id)
+ res = self._ref_apply(self._internal_trace_server.feedback_replace, req)
+ if res.wb_user_id != req.wb_user_id:
+ raise ValueError("Internal Error - User Mismatch")
+ res.wb_user_id = original_user_id
+ return res
+
def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
return self._ref_apply(self._internal_trace_server.cost_create, req)
diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py
index f500b0fe9bd..0ff9afa325d 100644
--- a/weave/trace_server/sqlite_trace_server.py
+++ b/weave/trace_server/sqlite_trace_server.py
@@ -14,9 +14,7 @@
from weave.trace_server import refs_internal as ri
from weave.trace_server import trace_server_interface as tsi
-from weave.trace_server.base_object_class_util import (
- process_incoming_object,
-)
+from weave.trace_server.base_object_class_util import process_incoming_object
from weave.trace_server.emoji_util import detone_emojis
from weave.trace_server.errors import InvalidRequest
from weave.trace_server.feedback import (
@@ -1054,6 +1052,29 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes:
conn.commit()
return tsi.FeedbackPurgeRes()
+ def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes:
+ purge_request = tsi.FeedbackPurgeReq(
+ project_id=req.project_id,
+ query={
+ "$expr": {
+ "$eq": [
+ {"$getField": "id"},
+ {"$literal": req.feedback_id},
+ ],
+ }
+ },
+ )
+ self.feedback_purge(purge_request)
+ create_req = tsi.FeedbackCreateReq(**req.model_dump(exclude={"feedback_id"}))
+ create_result = self.feedback_create(create_req)
+
+ return tsi.FeedbackReplaceRes(
+ id=create_result.id,
+ created_at=create_result.created_at,
+ wb_user_id=create_result.wb_user_id,
+ payload=create_result.payload,
+ )
+
def actions_execute_batch(
self, req: tsi.ActionsExecuteBatchReq
) -> tsi.ActionsExecuteBatchRes:
diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py
index b3fb3b58a7f..77ef6198ecf 100644
--- a/weave/trace_server/trace_server_interface.py
+++ b/weave/trace_server/trace_server_interface.py
@@ -743,6 +743,14 @@ class FeedbackPurgeRes(BaseModel):
pass
+class FeedbackReplaceReq(FeedbackCreateReq):
+ feedback_id: str
+
+
+class FeedbackReplaceRes(FeedbackCreateRes):
+ pass
+
+
class FileCreateReq(BaseModel):
project_id: str
name: str
@@ -900,6 +908,7 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ...
def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ...
def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ...
def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ...
+ def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ...
# Action API
def actions_execute_batch(
diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py
index f929c2065a5..d0b2fb06c36 100644
--- a/weave/trace_server_bindings/remote_http_trace_server.py
+++ b/weave/trace_server_bindings/remote_http_trace_server.py
@@ -527,6 +527,13 @@ def feedback_purge(
"/feedback/purge", req, tsi.FeedbackPurgeReq, tsi.FeedbackPurgeRes
)
+ def feedback_replace(
+ self, req: Union[tsi.FeedbackReplaceReq, dict[str, Any]]
+ ) -> tsi.FeedbackReplaceRes:
+ return self._generic_request(
+ "/feedback/replace", req, tsi.FeedbackReplaceReq, tsi.FeedbackReplaceRes
+ )
+
def actions_execute_batch(
self, req: Union[tsi.ActionsExecuteBatchReq, dict[str, Any]]
) -> tsi.ActionsExecuteBatchRes:
diff --git a/weave_query/weave_query/wandb_interface/wandb_stream_table.py b/weave_query/weave_query/wandb_interface/wandb_stream_table.py
index eda69512f2b..d89616e4b9c 100644
--- a/weave_query/weave_query/wandb_interface/wandb_stream_table.py
+++ b/weave_query/weave_query/wandb_interface/wandb_stream_table.py
@@ -13,7 +13,6 @@
from wandb.sdk.lib.ipython import _get_python_type
from wandb.sdk.lib.paths import LogicalPath
-from wandb.sdk.lib.printer import get_printer
from weave_query import weave_types
from weave_query import (
@@ -186,11 +185,6 @@ def _ensure_remote_initialized(self) -> StreamTableType:
)
self._artifact = WandbLiveRunFiles(name=uri.name, uri=uri)
self._artifact.set_file_pusher(self._lite_run.pusher)
- if print_url:
- base_url = environment.weave_server_url()
- url = f"{base_url}/browse/wandb/{self._entity_name}/{self._project_name}/table/{self._table_name}"
- printer = get_printer(_get_python_type() != "python")
- # printer.display(f'{printer.emoji("star")} View data at {printer.link(url)}')
return self._weave_stream_table
def log(self, row_or_rows: ROW_TYPE) -> None: