-
{message.role}
- {message.content && (
-
- {_.isString(message.content) ? (
-
+ {!isNested && (
+
+ {!isUser && !isTool && (
+
- ) : (
- message.content.map((p, i) => (
-
- ))
)}
)}
- {message.tool_calls && }
+
+ setIsHovering(true)}
+ onMouseLeave={() => setIsHovering(false)}>
+
+ {isSystemPrompt && (
+
+ )}
+
+ {isTool && (
+
+ )}
+
+
+ {editorHeight ? (
+
+ ) : (
+ <>
+ {hasContent && (
+
+ {_.isString(message.content) ? (
+
+ ) : (
+ message.content!.map((p, i) => (
+
+ ))
+ )}
+
+ )}
+ {hasToolCalls && (
+
+
+
+ )}
+ >
+ )}
+
+
+ {isOverflowing && !editorHeight && (
+
+ )}
+
+ {isPlayground && isHovering && !editorHeight && (
+
+
+
+
+
+
+
+ )}
+
+
);
};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanelPart.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanelPart.tsx
index b4d1904473a..6540d2e6ae5 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanelPart.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanelPart.tsx
@@ -19,10 +19,11 @@ export const MessagePanelPart = ({
const reformat = JSON.stringify(JSON.parse(value), null, 2);
return
;
}
+ // Markdown is slowing down chat view, maybe bring this back if users complain
if (isLikelyMarkdown(value)) {
return
;
}
- return
{value};
+ return
{value}
;
}
if (value.type === 'text' && 'text' in value) {
return
{value.text}
;
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ShowMoreButton.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ShowMoreButton.tsx
new file mode 100644
index 00000000000..24184c226a8
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ShowMoreButton.tsx
@@ -0,0 +1,34 @@
+import {Button} from '@wandb/weave/components/Button';
+import classNames from 'classnames';
+import React, {Dispatch, SetStateAction} from 'react';
+
+type ShowMoreButtonProps = {
+ isShowingMore: boolean;
+ setIsShowingMore: Dispatch
>;
+ isUser?: boolean;
+};
+export const ShowMoreButton = ({
+ isShowingMore,
+ setIsShowingMore,
+ isUser,
+}: ShowMoreButtonProps) => {
+ return (
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCallPanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCallPanel.tsx
new file mode 100644
index 00000000000..164da8c5c78
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCallPanel.tsx
@@ -0,0 +1,117 @@
+import {Button} from '@wandb/weave/components/Button';
+import {Callout} from '@wandb/weave/components/Callout';
+import classNames from 'classnames';
+import _ from 'lodash';
+import React, {useEffect, useRef, useState} from 'react';
+
+import {MessagePanelPart} from './MessagePanelPart';
+import {ShowMoreButton} from './ShowMoreButton';
+import {ToolCalls} from './ToolCalls';
+import {Message} from './types';
+
+type ToolCallProps = {
+ message: Message;
+ isStructuredOutput?: boolean;
+};
+
+export const ToolCallPanel = ({message, isStructuredOutput}: ToolCallProps) => {
+ const [isShowingMore, setIsShowingMore] = useState(false);
+ const [isOverflowing, setIsOverflowing] = useState(false);
+ const contentRef = useRef(null);
+ useEffect(() => {
+ if (contentRef.current) {
+ setIsOverflowing(contentRef.current.scrollHeight > 400);
+ }
+ }, [message.content]);
+
+ const isUser = message.role === 'user';
+ const isSystemPrompt = message.role === 'system';
+ const isTool = message.role === 'tool';
+ const hasToolCalls = message.tool_calls;
+
+ const bg = isUser ? 'bg-cactus-300/[0.24]' : 'bg-moon-50';
+ const justification = isUser ? 'ml-auto' : 'mr-auto';
+ const maxHeight = isShowingMore ? 'max-h-full' : 'max-h-[400px]';
+ const maxWidth = isSystemPrompt ? 'w-full' : 'max-w-3xl';
+ const toolWidth = isTool || hasToolCalls ? 'w-3/4' : '';
+
+ const capitalizedRole =
+ message.role.charAt(0).toUpperCase() + message.role.slice(1);
+
+ return (
+
+
+ {!isUser && !isTool && (
+
+ )}
+
+
+
+ {hasToolCalls && (
+
+ )}
+
+ {isSystemPrompt && (
+
+ )}
+
+
+ {message.content && (
+
+ {_.isString(message.content) ? (
+
+ ) : (
+ message.content.map((p, i) => (
+
+ ))
+ )}
+
+ )}
+ {message.tool_calls && (
+
+
+
+ )}
+
+
+ {isOverflowing && (
+
+ )}
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx
index 7a624fa77ab..f8e38e031cb 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx
@@ -1,7 +1,9 @@
+import {Button} from '@wandb/weave/components/Button';
import Prism from 'prismjs';
-import React, {useEffect, useRef} from 'react';
+import React, {useEffect, useRef, useState} from 'react';
import {Alert} from '../../../../../Alert';
+import {MessagePanel} from './MessagePanel';
import {ToolCall} from './types';
type OneToolCallProps = {
@@ -9,6 +11,19 @@ type OneToolCallProps = {
};
const OneToolCall = ({toolCall}: OneToolCallProps) => {
+ const [isCopying, setIsCopying] = useState(false);
+
+ const handleCopyText = (text: string) => {
+ try {
+ setIsCopying(true);
+ navigator.clipboard.writeText(text);
+ } finally {
+ setTimeout(() => {
+ setIsCopying(false);
+ }, 2000);
+ }
+ };
+
const ref = useRef(null);
useEffect(() => {
if (ref.current) {
@@ -16,6 +31,10 @@ const OneToolCall = ({toolCall}: OneToolCallProps) => {
}
});
+ if (!toolCall.function) {
+ return Null tool call
;
+ }
+
const {function: toolCallFunction} = toolCall;
const {name, arguments: args} = toolCallFunction;
let parsedArgs: any = null;
@@ -27,17 +46,54 @@ const OneToolCall = ({toolCall}: OneToolCallProps) => {
}
} catch (e) {
// The model does not always generate valid JSON
- return Invalid JSON: {args};
+ return (
+
+ );
}
+ const copyText = `${name}(${parsedArgs})`;
return (
-
- {name}(
-
- {parsedArgs}
-
- )
-
+
+
+
+
Function
+
+
+
+
+
+ {name}(
+
+ {parsedArgs}
+
+ )
+
+
+
+
+
+
+
);
};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts
index 5e4429e7611..bb63cca2f6f 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts
@@ -8,7 +8,7 @@ import {
TraceCallSchema,
} from '../wfReactInterface/traceServerClientTypes';
import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface';
-import {ChatCompletion, ChatRequest, Choice} from './types';
+import {Chat, ChatCompletion, ChatRequest, Choice} from './types';
export enum ChatFormat {
None = 'None',
@@ -59,6 +59,10 @@ export const isToolCall = (toolCall: any): boolean => {
};
export const isToolCalls = (toolCalls: any): boolean => {
+ if (toolCalls === null) {
+ return true;
+ }
+
if (!_.isArray(toolCalls)) {
return false;
}
@@ -242,10 +246,12 @@ export const isTraceCallChatFormatGemini = (call: TraceCallSchema): boolean => {
export const isTraceCallChatFormatOpenAI = (call: TraceCallSchema): boolean => {
if (!('messages' in call.inputs)) {
+ console.log('no messages');
return false;
}
const {messages} = call.inputs;
if (!_.isArray(messages)) {
+ console.log('not array');
return false;
}
return messages.every(isMessage);
@@ -361,10 +367,7 @@ export const useCallAsChat = (
call: TraceCallSchema
): {
loading: boolean;
- isStructuredOutput: boolean;
- request: ChatRequest;
- result: ChatCompletion | null;
-} => {
+} & Chat => {
// Traverse the data and find all ref URIs.
const refs = getRefs(call);
const {useRefsData} = useWFHooks();
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts
index 3968fdafa2d..cc7cb290865 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts
@@ -32,12 +32,14 @@ export type ToolCall = {
// Validate the arguments in your code before calling your function.
arguments: string;
};
+ response?: Message;
};
export type Message = {
role: string;
content?: string | MessagePart[];
tool_calls?: ToolCall[];
+ tool_call_id?: string;
};
export type Messages = Message[];
@@ -105,7 +107,7 @@ export type ChatCompletion = {
export type Chat = {
// TODO: Maybe optional information linking back to Call?
isStructuredOutput: boolean;
- request: ChatRequest;
+ request: ChatRequest | null;
result: ChatCompletion | null;
};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/LLMDropdown.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/LLMDropdown.tsx
new file mode 100644
index 00000000000..1cf2d0cb339
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/LLMDropdown.tsx
@@ -0,0 +1,73 @@
+import {MenuItem, TextField} from '@mui/material';
+import {MOON_250, TEAL_400} from '@wandb/weave/common/css/color.styles';
+import {Icon} from '@wandb/weave/components/Icon';
+import React from 'react';
+
+import {llmMaxTokens} from './llm_max_tokens';
+
+interface LLMDropdownProps {
+ value: string;
+ onChange: (value: string, maxTokens: number) => void;
+}
+
+export const LLMDropdown: React.FC = ({value, onChange}) => {
+ const handleChange = (
+ event: React.SyntheticEvent,
+ newValue: string | null
+ ) => {
+ if (newValue) {
+ const maxTokens =
+ llmMaxTokens[newValue as keyof typeof llmMaxTokens] || 0;
+ onChange(newValue, maxTokens);
+ }
+ };
+
+ return (
+ handleChange(e, e.target.value)}
+ size="small"
+ slotProps={{
+ select: {
+ IconComponent: props => ,
+ },
+ }}
+ sx={{
+ width: '100%',
+ minWidth: '100px',
+ height: '32px',
+ padding: 0,
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ '& .MuiSelect-select': {
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ height: '16px',
+ },
+ '& .MuiMenuItem-root': {
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ },
+ '& .MuiInputBase-root': {
+ height: '32px',
+ paddingY: '4px',
+ },
+ '& .MuiOutlinedInput-notchedOutline': {
+ border: `1px solid ${MOON_250}`,
+ },
+ '& .Mui-focused .MuiOutlinedInput-notchedOutline': {
+ border: `1px solid ${MOON_250}`,
+ },
+ '& .MuiOutlinedInput-root:hover .MuiOutlinedInput-notchedOutline': {
+ border: `1px solid ${TEAL_400}`,
+ },
+ }}>
+ {Object.keys(llmMaxTokens).map(llmWithToken => (
+
+ ))}
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundCallStats.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundCallStats.tsx
new file mode 100644
index 00000000000..274aa2a57cd
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundCallStats.tsx
@@ -0,0 +1,72 @@
+import {Button} from '@wandb/weave/components/Button';
+import {Tailwind} from '@wandb/weave/components/Tailwind';
+import {makeRefCall} from '@wandb/weave/util/refs';
+import React from 'react';
+import {useHistory} from 'react-router-dom';
+
+import {useWeaveflowRouteContext} from '../../../context';
+import {Reactions} from '../../../feedback/Reactions';
+import {TraceCallSchema} from '../../wfReactInterface/traceServerClientTypes';
+
+export const PlaygroundCallStats = ({call}: {call: TraceCallSchema}) => {
+ let totalTokens = 0;
+ if (call?.summary?.usage) {
+ for (const key of Object.keys(call.summary.usage)) {
+ totalTokens +=
+ call.summary.usage[key].prompt_tokens ||
+ call.summary.usage[key].input_tokens ||
+ 0;
+ totalTokens +=
+ call.summary.usage[key].completion_tokens ||
+ call.summary.usage[key].output_tokens ||
+ 0;
+ }
+ }
+
+ const [entityName, projectName] = call?.project_id?.split('/') || [];
+ const callId = call?.id || '';
+ const latency = call?.summary?.weave?.latency_ms;
+ const {peekingRouter} = useWeaveflowRouteContext();
+ const history = useHistory();
+
+ if (!callId) {
+ return null;
+ }
+
+ const weaveRef = makeRefCall(entityName, projectName, callId);
+ const callLink = peekingRouter.callUIUrl(
+ entityName,
+ projectName,
+ '',
+ callId,
+ null,
+ false
+ );
+
+ return (
+
+
+ Latency: {latency}ms
+ •
+ {/* Finish reason: {choice.finish_reason}
+ • */}
+ {totalTokens} tokens
+ •
+ {callLink && (
+
+ )}
+ {/* • */}
+ {weaveRef && }
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx
new file mode 100644
index 00000000000..14c3edf6c74
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx
@@ -0,0 +1,388 @@
+import {Box, Divider} from '@mui/material';
+import {MOON_200, MOON_250} from '@wandb/weave/common/css/color.styles';
+import {Tailwind} from '@wandb/weave/components/Tailwind';
+import React, {useState} from 'react';
+
+import {CallChat} from '../../ChatView/CallChat';
+import {TraceCallSchema} from '../../wfReactInterface/traceServerClientTypes';
+import {CallSchema} from '../../wfReactInterface/wfDataModelHooksInterface';
+import {PlaygroundState} from '../PlaygroundSettings/PlaygroundSettings';
+import {PlaygroundCallStats} from './PlaygroundCallStats';
+import {PlaygroundChatInput} from './PlaygroundChatInput';
+import {PlaygroundChatTopBar} from './PlaygroundChatTopBar';
+import {useGetTraceServerClientContext} from '../../wfReactInterface/traceServerClientContext';
+import {toast} from '@wandb/weave/common/components/elements/Toast';
+import {Link} from 'react-router-dom';
+export type OptionalTraceCallSchema = Partial;
+export type OptionalCallSchema = Partial;
+
+export type PlaygroundChatProps = {
+ setCalls: (calls: OptionalCallSchema[]) => void;
+ calls: OptionalCallSchema[];
+ entity: string;
+ project: string;
+ setPlaygroundStates: (states: PlaygroundState[]) => void;
+ playgroundStates: PlaygroundState[];
+ setPlaygroundStateField: (
+ index: number,
+ key: keyof PlaygroundState,
+ value: PlaygroundState[keyof PlaygroundState]
+ ) => void;
+ deleteMessage: (callIndex: number, messageIndex: number) => void;
+ editMessage: (
+ callIndex: number,
+ messageIndex: number,
+ newMessage: any
+ ) => void;
+ addMessage: (callIndex: number, newMessage: any) => void;
+ editChoice: (callIndex: number, choiceIndex: number, newChoice: any) => void;
+ deleteChoice: (callIndex: number, choiceIndex: number) => void;
+ setSettingsTab: (callIndex: number | null) => void;
+ settingsTab: number | null;
+};
+
+export const PlaygroundChat = ({
+ setCalls,
+ calls,
+ deleteMessage,
+ editMessage,
+ addMessage,
+ editChoice,
+ deleteChoice,
+ entity,
+ project,
+ setPlaygroundStates,
+ playgroundStates,
+ setPlaygroundStateField,
+ setSettingsTab,
+ settingsTab,
+}: PlaygroundChatProps) => {
+ const [chatText, setChatText] = useState('');
+ const [isLoading, setIsLoading] = useState(false);
+ const getTsClient = useGetTraceServerClientContext();
+
+ const handleAddMessage = (role: 'assistant' | 'user', text: string) => {
+ for (let i = 0; i < calls.length; i++) {
+ addMessage(i, {role, content: text});
+ }
+ setChatText('');
+ };
+
+ const handleUpdateCallWithResponse = (updatedCall: any, response: any) => {
+ return {
+ ...updatedCall,
+ traceCall: {
+ ...updatedCall.traceCall,
+ id: response.weave_call_id ?? updatedCall.traceCall?.id ?? '',
+ output: response.response,
+ },
+ };
+ };
+
+ const handleUpdateCallsWithResponses = (
+ updatedCalls: any[],
+ responses: any[]
+ ) => {
+ const newCalls = updatedCalls.map((call, index) =>
+ handleUpdateCallWithResponse(call, responses[index])
+ );
+
+ setCalls(newCalls);
+ };
+
+ const handleMissingLLMApiKey = (responses: any[] | any) => {
+ if (Array.isArray(responses)) {
+ responses.forEach((response: any) => {
+ handleMissingLLMApiKey(response);
+ });
+ } else {
+ if (responses.api_key && responses.reason) {
+ toast(
+
+
{responses.reason}
+ Please add your API key to{' '}
+
Team secrets in settings to
+ use this LLM
+
,
+ {
+ type: 'error',
+ }
+ );
+ }
+ }
+ };
+
+ const handleSend = async () => {
+ setIsLoading(true);
+ const newMessage = chatText.trim()
+ ? {role: 'user', content: chatText}
+ : undefined;
+ const updatedCalls = calls.map((call, index) => {
+ const updatedCall = JSON.parse(JSON.stringify(call));
+ if (updatedCall.traceCall?.inputs?.messages) {
+ if (updatedCall.traceCall.output?.choices) {
+ updatedCall.traceCall.output.choices.forEach((choice: any) => {
+ if (choice.message) {
+ updatedCall.traceCall.inputs.messages.push(choice.message);
+ }
+ });
+ updatedCall.traceCall.output.choices = undefined;
+ }
+ if (newMessage) {
+ updatedCall.traceCall.inputs.messages.push(newMessage);
+ }
+ }
+ return updatedCall;
+ });
+
+ setCalls(updatedCalls);
+ setChatText('');
+
+ try {
+ const responses = await Promise.all(
+ updatedCalls.map((call, index) => {
+ const tools = playgroundStates[index].functions.map(func => ({
+ type: 'function',
+ function: func,
+ }));
+ const inputs = {
+ messages: call.traceCall?.inputs?.messages || [],
+ model: playgroundStates[index].model,
+ temperature: playgroundStates[index].temperature,
+ max_tokens: playgroundStates[index].maxTokens,
+ stop: playgroundStates[index].stopSequences,
+ top_p: playgroundStates[index].topP,
+ frequency_penalty: playgroundStates[index].frequencyPenalty,
+ presence_penalty: playgroundStates[index].presencePenalty,
+ n: playgroundStates[index].nTimes,
+ response_format: {
+ type: playgroundStates[index].responseFormat,
+ },
+ tools: tools.length > 0 ? tools : undefined,
+ };
+ return getTsClient().completionsCreate({
+ project_id: `${entity}/${project}`,
+ inputs,
+ });
+ })
+ );
+
+ handleMissingLLMApiKey(responses);
+ handleUpdateCallsWithResponses(updatedCalls, responses);
+ } catch (error) {
+ console.error('Error processing completion:', error);
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ const handleRetry = async (
+ callIndex: number,
+ messageIndex: number,
+ isChoice?: boolean
+ ) => {
+ setIsLoading(true);
+ try {
+ const updatedCalls = calls.map((call, index) => {
+ if (index === callIndex) {
+ const updatedCall = JSON.parse(JSON.stringify(call));
+ if (updatedCall.traceCall?.inputs?.messages) {
+ if (isChoice) {
+ // If it's a choice, add it to the message list
+ const choiceMessage =
+ updatedCall.traceCall.output?.choices?.[messageIndex]?.message;
+ if (choiceMessage) {
+ updatedCall.traceCall.inputs.messages.push(choiceMessage);
+ }
+ updatedCall.traceCall.output = undefined; // Clear previous output
+ } else {
+ // If it's a regular message, truncate the list
+ updatedCall.traceCall.inputs.messages =
+ updatedCall.traceCall.inputs.messages.slice(
+ 0,
+ messageIndex + 1
+ );
+ updatedCall.traceCall.output = undefined; // Clear previous output
+ }
+ }
+ return updatedCall;
+ }
+ return call;
+ });
+
+ // Update the calls state
+ setCalls(updatedCalls);
+
+ const messagesToSend =
+ updatedCalls[callIndex].traceCall?.inputs?.messages || [];
+
+ const tools = playgroundStates[callIndex].functions.map(func => ({
+ type: 'function',
+ function: func,
+ }));
+ const inputs = {
+ messages: messagesToSend,
+ model: playgroundStates[callIndex].model,
+ temperature: playgroundStates[callIndex].temperature,
+ max_tokens: playgroundStates[callIndex].maxTokens,
+ stop: playgroundStates[callIndex].stopSequences,
+ top_p: playgroundStates[callIndex].topP,
+ frequency_penalty: playgroundStates[callIndex].frequencyPenalty,
+ presence_penalty: playgroundStates[callIndex].presencePenalty,
+ n: playgroundStates[callIndex].nTimes,
+ response_format: {
+ type: playgroundStates[callIndex].responseFormat,
+ },
+ tools: tools.length > 0 ? tools : undefined,
+ };
+
+ const response = await getTsClient().completionsCreate({
+ project_id: `${entity}/${project}`,
+ inputs,
+ });
+
+ handleMissingLLMApiKey(response);
+
+ // Update the call with the new response
+ const finalCalls = updatedCalls.map((call, index) => {
+ if (index === callIndex) {
+ return handleUpdateCallWithResponse(call, response);
+ }
+ return call;
+ });
+
+ setCalls(finalCalls);
+ } catch (error) {
+ console.error('Error retrying call:', error);
+ // Handle error (e.g., show an error message to the user)
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ return (
+
+
+ {calls.map((call, idx) => (
+
+ {idx > 0 && (
+
+ )}
+
+
+
+
+
+
+
+ {call?.traceCall && (
+
+ deleteMessage(idx, messageIndex)
+ }
+ editMessage={(messageIndex, newMessage) =>
+ editMessage(idx, messageIndex, newMessage)
+ }
+ deleteChoice={choiceIndex =>
+ deleteChoice(idx, choiceIndex)
+ }
+ addMessage={newMessage => addMessage(idx, newMessage)}
+ editChoice={(choiceIndex, newChoice) =>
+ editChoice(idx, choiceIndex, newChoice)
+ }
+ retry={(messageIndex: number, isChoice?: boolean) =>
+ handleRetry(idx, messageIndex, isChoice)
+ }
+ />
+ )}
+
+
+
+
+ {call?.traceCall && (
+
+ )}
+
+
+
+ ))}
+
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatInput.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatInput.tsx
new file mode 100644
index 00000000000..5e0c5dc067f
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatInput.tsx
@@ -0,0 +1,119 @@
+import {Box, Divider} from '@mui/material';
+import {
+ MOON_250,
+ MOON_500,
+ TEAL_500,
+} from '@wandb/weave/common/css/color.styles';
+import {Button} from '@wandb/weave/components/Button';
+import React, {useState} from 'react';
+
+import {TextArea} from '../Textarea';
+
+type PlaygroundChatInputProps = {
+ chatText: string;
+ setChatText: (text: string) => void;
+ isLoading: boolean;
+ onSend: () => void;
+ onAdd: (role: 'assistant' | 'user', text: string) => void;
+};
+
+export const PlaygroundChatInput: React.FC = ({
+ chatText,
+ setChatText,
+ isLoading,
+ onSend,
+ onAdd,
+}) => {
+ const [addMessageRole, setAddMessageRole] = useState<'assistant' | 'user'>(
+ 'user'
+ );
+
+ const handleKeyDown = (event: React.KeyboardEvent) => {
+ if (event.key === 'Enter' && (event.metaKey || event.ctrlKey)) {
+ event.preventDefault(); // Prevent default to avoid newline in textarea
+ onSend();
+ }
+ };
+
+ return (
+
+
+ Press CMD + Enter to send
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx
new file mode 100644
index 00000000000..06d6dcb7229
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx
@@ -0,0 +1,176 @@
+import {Box} from '@mui/material';
+import {Button} from '@wandb/weave/components/Button';
+import {Tag} from '@wandb/weave/components/Tag';
+import React from 'react';
+import {useHistory} from 'react-router-dom';
+
+import {CopyableId} from '../../common/Id';
+import {PlaygroundState} from '../PlaygroundSettings/PlaygroundSettings';
+import {LLMDropdown} from './LLMDropdown';
+import {OptionalCallSchema} from './PlaygroundChat';
+
+type PlaygroundChatTopBarProps = {
+ idx: number;
+ settingsTab: number | null;
+ setSettingsTab: (tab: number | null) => void;
+ setPlaygroundStateField: (
+ index: number,
+ field: keyof PlaygroundState,
+ value: any
+ ) => void;
+ setCalls: (calls: OptionalCallSchema[]) => void;
+ calls: OptionalCallSchema[];
+ entity: string;
+ project: string;
+ playgroundStates: PlaygroundState[];
+ setPlaygroundStates: (playgroundStates: PlaygroundState[]) => void;
+};
+
+export const PlaygroundChatTopBar: React.FC = ({
+ idx,
+ settingsTab,
+ setSettingsTab,
+ setPlaygroundStateField,
+ setCalls,
+ calls,
+ entity,
+ project,
+ playgroundStates,
+ setPlaygroundStates,
+}) => {
+ const history = useHistory();
+ const handleModelChange = (
+ index: number,
+ model: string,
+ maxTokens: number
+ ) => {
+ setPlaygroundStateField(index, 'model', model);
+ setPlaygroundStateField(index, 'maxTokensLimit', maxTokens);
+ setPlaygroundStateField(index, 'maxTokens', maxTokens / 2);
+ };
+
+ const clearCall = (index: number) => {
+ history.push(`/${entity}/${project}/weave/playground`);
+ setCalls(
+ calls.map((call, i) =>
+ i === index
+ ? {
+ entity,
+ project,
+ traceCall: {
+ project_id: project,
+ id: '', // Generate a new ID or use a placeholder
+ op_name: '',
+ trace_id: '',
+ inputs: {
+ messages: [
+ {
+ role: 'system',
+ content: 'You are a helpful assistant.',
+ },
+ ],
+ },
+ },
+ }
+ : call
+ )
+ );
+ };
+
+ const handleCompare = () => {
+ if (calls.length < 2) {
+ setCalls([calls[0], JSON.parse(JSON.stringify(calls[0]))]);
+ setPlaygroundStates([
+ ...playgroundStates,
+ JSON.parse(JSON.stringify(playgroundStates[0])),
+ ]);
+ }
+ };
+
+ return (
+
+
+ {calls.length > 1 && }
+
+ handleModelChange(idx, model, maxTokens)
+ }
+ />
+ {calls[idx].traceCall?.id && (
+
+ )}
+
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundContext.tsx
new file mode 100644
index 00000000000..9729956892d
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundContext.tsx
@@ -0,0 +1,15 @@
+import {createContext, useContext} from 'react';
+
+// Create a new context for the isPlayground value
+export const PlaygroundContext = createContext<{
+ isPlayground: boolean;
+ deleteMessage?: (messageIndex: number) => void;
+ editMessage?: (messageIndex: number, newMessage: any) => void;
+ deleteChoice?: (choiceIndex: number) => void;
+ editChoice?: (choiceIndex: number, newChoice: any) => void;
+ addMessage?: (newMessage: any) => void;
+ retry?: (messageIndex: number, isChoice?: boolean) => void;
+}>({isPlayground: false});
+
+// Create a custom hook to use the PlaygroundContext
+export const usePlaygroundContext = () => useContext(PlaygroundContext);
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/llm_max_tokens.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/llm_max_tokens.ts
new file mode 100644
index 00000000000..b5940b34d68
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/llm_max_tokens.ts
@@ -0,0 +1,119 @@
+// This is a mapping of LLM names to their max token limits.
+// Directly from the pycache model_providers.json in trace_server.
+// Some are commented out because they are not supported when Josiah tried on Oct 30, 2024.
+export const llmMaxTokens = {
+ 'gpt-4o-mini': 16384,
+
+ 'claude-3-5-sonnet-20240620': 8192,
+ 'claude-3-5-sonnet-20241022': 8192,
+ 'claude-3-haiku-20240307': 4096,
+ 'claude-3-opus-20240229': 4096,
+ 'claude-3-sonnet-20240229': 4096,
+
+ 'gemini/gemini-1.5-flash-001': 8192,
+ 'gemini/gemini-1.5-flash-002': 8192,
+ 'gemini/gemini-1.5-flash-8b-exp-0827': 8192,
+ 'gemini/gemini-1.5-flash-8b-exp-0924': 8192,
+ 'gemini/gemini-1.5-flash-exp-0827': 8192,
+ 'gemini/gemini-1.5-flash-latest': 8192,
+ 'gemini/gemini-1.5-flash': 8192,
+ 'gemini/gemini-1.5-pro-001': 8192,
+ 'gemini/gemini-1.5-pro-002': 8192,
+ 'gemini/gemini-1.5-pro-exp-0801': 8192,
+ 'gemini/gemini-1.5-pro-exp-0827': 8192,
+ 'gemini/gemini-1.5-pro-latest': 8192,
+ 'gemini/gemini-1.5-pro': 8192,
+ 'gemini/gemini-pro': 8192,
+
+ 'gpt-3.5-turbo-0125': 4096,
+ 'gpt-3.5-turbo-1106': 4096,
+ 'gpt-3.5-turbo-16k': 4096,
+ 'gpt-3.5-turbo': 4096,
+ 'gpt-4-0125-preview': 4096,
+ 'gpt-4-0314': 4096,
+ 'gpt-4-0613': 4096,
+ 'gpt-4-1106-preview': 4096,
+ 'gpt-4-32k-0314': 4096,
+ 'gpt-4-turbo-2024-04-09': 4096,
+ 'gpt-4-turbo-preview': 4096,
+ 'gpt-4-turbo': 4096,
+ 'gpt-4': 4096,
+ 'gpt-4o-2024-05-13': 4096,
+ 'gpt-4o-2024-08-06': 16384,
+ 'gpt-4o-mini-2024-07-18': 16384,
+ 'gpt-4o': 4096,
+
+ 'groq/gemma-7b-it': 8192,
+ 'groq/gemma2-9b-it': 8192,
+ 'groq/llama-3.1-70b-versatile': 8192,
+ 'groq/llama-3.1-8b-instant': 8192,
+ 'groq/llama3-70b-8192': 8192,
+ 'groq/llama3-8b-8192': 8192,
+ 'groq/llama3-groq-70b-8192-tool-use-preview': 8192,
+ 'groq/llama3-groq-8b-8192-tool-use-preview': 8192,
+ 'groq/mixtral-8x7b-32768': 32768,
+
+ 'o1-mini-2024-09-12': 65536,
+ 'o1-mini': 65536,
+ 'o1-preview-2024-09-12': 32768,
+ 'o1-preview': 32768,
+
+ // These were all in our model_providers.json (but dont work)
+ // This seems like a dupe of claude-3-5-sonnet-20241022.
+ // 'anthropic/claude-3-5-sonnet-20241022': 8192,
+
+ // 422 Unprocessable Entity
+ // 'claude-2.1': 8191,
+ // 'claude-2': 8191,
+ // 'claude-instant-1.2': 8191,
+ // 'claude-instant-1': 8191,
+
+ // error litellm.BadRequestError: OpenAIException - Error code: 400 - {'error': {'message': "[{'type': 'string_type', 'loc': ('body', 'stop', 'str'), 'msg': 'Input should be a valid string', 'input': []}, {'type': 'too_short', 'loc': ('body', 'stop', 'list[str]'), 'msg': 'List should have at least 1 item after validation, not 0', 'input': [], 'ctx': {'field_type': 'List', 'min_length': 1, 'actual_length': 0}}, {'type': 'too_short', 'loc': ('body', 'stop', 'list[list[int]]'), 'msg': 'List should have at least 1 item after validation, not 0', 'input': [], 'ctx': {'field_type': 'List', 'min_length': 1, 'actual_length': 0}}]", 'type': 'invalid_request_error', 'param': None, 'code': None}}
+ // 'chatgpt-4o-latest': 4096,
+ // 'gpt-4o-audio-preview-2024-10-01': 16384,
+ // 'gpt-4o-audio-preview': 16384,
+
+ // error litellm.NotFoundError: OpenAIException - Error code: 404 - {'error': {'message': 'The model `ft:gpt-3.5-turbo-0125` does not exist or you do not have access to it.', 'type': 'invalid_request_error', 'param': None, 'code': 'model_not_found'}}
+ // 'ft:gpt-3.5-turbo-0125': 4096,
+ // 'ft:gpt-3.5-turbo-0613': 4096,
+ // 'ft:gpt-3.5-turbo-1106': 4096,
+ // 'ft:gpt-3.5-turbo': 4096,
+ // 'ft:gpt-4-0613': 4096,
+ // 'ft:gpt-4o-2024-08-06': 16384,
+ // 'ft:gpt-4o-mini-2024-07-18': 16384,
+ // 'gpt-4-32k-0613': 4096,
+ // 'gpt-4-32k': 4096,
+ // 'groq/llama-3.1-405b-reasoning': 8192,
+ // 'groq/llama2-70b-4096': 4096,
+
+ // error litellm.NotFoundError: OpenAIException - Error code: 404 - {'error': {'message': 'The model `gpt-3.5-turbo-0301` has been deprecated, learn more here: https://platform.openai.com/docs/deprecations', 'type': 'invalid_request_error', 'param': None, 'code': 'model_not_found'}}
+ // 'gpt-3.5-turbo-0301': 4096,
+ // 'gpt-3.5-turbo-0613': 4096,
+ // 'gpt-3.5-turbo-16k-0613': 4096,
+ // 'gpt-4-1106-vision-preview': 4096,
+ // 'gpt-4-vision-preview': 4096,
+
+ // error litellm.NotFoundError: VertexAIException - {
+ // "error": {
+ // "code": 404,
+ // "message": "models/gemini-gemma-2-27b-it is not found for API version v1beta, or is not supported for generateContent. Call ListModels to see the list of available models and their supported methods.",
+ // "status": "NOT_FOUND"
+ // }
+ // }
+ // 'gemini/gemini-gemma-2-27b-it': 8192,
+ // 'gemini/gemini-gemma-2-9b-it': 8192,
+
+ // error litellm.NotFoundError: VertexAIException - {
+ // "error": {
+ // "code": 404,
+ // "message": "Gemini 1.0 Pro Vision has been deprecated on July 12, 2024. Consider switching to different model, for example gemini-1.5-flash.",
+ // "status": "NOT_FOUND"
+ // }
+ // }
+ // 'gemini/gemini-pro-vision': 2048,
+
+ // These are 0 tokens, idk why we would want to use them.
+ // "text-moderation-007": 0,
+ // "text-moderation-latest": 0,
+ // "text-moderation-stable": 0
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx
new file mode 100644
index 00000000000..d9a8ac7bd1c
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx
@@ -0,0 +1,344 @@
+import {Box} from '@mui/material';
+import {WeaveLoader} from '@wandb/weave/common/components/WeaveLoader';
+import React, {
+ SetStateAction,
+ useCallback,
+ useEffect,
+ useMemo,
+ useState,
+} from 'react';
+
+import {SimplePageLayout} from '../common/SimplePageLayout';
+import {useWFHooks} from '../wfReactInterface/context';
+import {
+ OptionalCallSchema,
+ PlaygroundChat,
+} from './PlaygroundChat/PlaygroundChat';
+import {
+ PlaygroundSettings,
+ PlaygroundState,
+} from './PlaygroundSettings/PlaygroundSettings';
+import {PlaygroundResponseFormats} from './PlaygroundSettings/ResponseFormatEditor';
+
+export type PlaygroundPageProps = {
+ entity: string;
+ project: string;
+ callId: string;
+};
+
+type PlaygroundStateKey = keyof PlaygroundState;
+type TraceCallOutput = {
+ choices?: any[];
+};
+
+export const PlaygroundPage = (props: PlaygroundPageProps) => {
+ return (
+ ,
+ },
+ ]}
+ />
+ );
+};
+
+export const PlaygroundPageInner = (props: PlaygroundPageProps) => {
+ const [settingsTab, setSettingsTab] = useState(null);
+ const [playgroundStates, setPlaygroundStates] = useState([
+ {
+ loading: false,
+ functions: [],
+ responseFormat: PlaygroundResponseFormats.Text,
+ temperature: 1,
+ maxTokens: 4000,
+ stopSequences: [],
+ topP: 1,
+ frequencyPenalty: 0,
+ presencePenalty: 0,
+ nTimes: 1,
+ maxTokensLimit: 16000,
+ model: 'gpt-4o-mini',
+ },
+ ]);
+
+ const setPlaygroundStateField = useCallback(
+ (
+ index: number,
+ key: PlaygroundStateKey,
+ value:
+ | PlaygroundState[PlaygroundStateKey]
+ | SetStateAction>
+ | SetStateAction
+ | SetStateAction
+ | SetStateAction
+ ) => {
+ setPlaygroundStates(prevStates =>
+ prevStates.map((state, i) =>
+ i === index
+ ? {
+ ...state,
+ [key]:
+ typeof value === 'function'
+ ? (value as SetStateAction)(state[key])
+ : value,
+ }
+ : state
+ )
+ );
+ },
+ []
+ );
+
+ const {useCall} = useWFHooks();
+ const call = useCall(
+ useMemo(() => {
+ return props.callId
+ ? {
+ entity: props.entity,
+ project: props.project,
+ callId: props.callId,
+ }
+ : null;
+ }, [props.entity, props.project, props.callId])
+ );
+
+ const [calls, setCalls] = useState([]);
+
+ const deleteMessage = (callIndex: number, messageIndex: number) => {
+ setCalls(prevCalls => {
+ const updatedCalls = [...prevCalls];
+ const newCall = clearTraceCallId(updatedCalls[callIndex]);
+ if (newCall && newCall.traceCall?.inputs?.messages) {
+ newCall.traceCall.inputs.messages =
+ newCall.traceCall.inputs.messages.filter(
+ (_: any, index: number) => index !== messageIndex
+ );
+
+ if (newCall.traceCall.inputs.messages.length === 0) {
+ newCall.traceCall.inputs.messages = [
+ {
+ role: 'system',
+ content: 'You are a helpful assistant.',
+ },
+ ];
+ }
+ }
+ return updatedCalls;
+ });
+ };
+
+ const editMessage = (
+ callIndex: number,
+ messageIndex: number,
+ newMessage: any // Replace 'any' with the appropriate type for a message
+ ) => {
+ setCalls(prevCalls => {
+ const updatedCalls = [...prevCalls];
+ const newCall = clearTraceCallId(updatedCalls[callIndex]);
+ if (newCall && newCall.traceCall?.inputs?.messages) {
+ newCall.traceCall.inputs.messages[messageIndex] = newMessage;
+ }
+ return updatedCalls;
+ });
+ };
+
+ const addMessage = (callIndex: number, newMessage: any) => {
+ setCalls(prevCalls => {
+ const updatedCalls = [...prevCalls];
+ const newCall = clearTraceCallId(updatedCalls[callIndex]);
+ if (newCall && newCall.traceCall?.inputs?.messages) {
+ if (
+ newCall.traceCall.output &&
+ (newCall.traceCall.output as TraceCallOutput).choices &&
+ Array.isArray((newCall.traceCall.output as TraceCallOutput).choices)
+ ) {
+ (newCall.traceCall.output as TraceCallOutput).choices!.forEach(
+ (choice: any) => {
+ if (choice.message) {
+ newCall.traceCall?.inputs!.messages.push(choice.message);
+ }
+ }
+ );
+ (newCall.traceCall.output as TraceCallOutput).choices = undefined;
+ }
+ newCall.traceCall.inputs.messages.push(newMessage);
+ }
+ return updatedCalls;
+ });
+ };
+
+ const deleteChoice = (callIndex: number, choiceIndex: number) => {
+ setCalls(prevCalls => {
+ const updatedCalls = [...prevCalls];
+ const newCall = clearTraceCallId(updatedCalls[callIndex]);
+ const output = newCall?.traceCall?.output as TraceCallOutput;
+ if (output && Array.isArray(output.choices)) {
+ output.choices = output.choices.filter(
+ (_, index: number) => index !== choiceIndex
+ );
+ if (newCall && newCall.traceCall) {
+ newCall.traceCall.output = output;
+ updatedCalls[callIndex] = newCall;
+ }
+ }
+ return updatedCalls;
+ });
+ };
+
+ const editChoice = (
+ callIndex: number,
+ choiceIndex: number,
+ newChoice: any
+ ) => {
+ setCalls(prevCalls => {
+ const updatedCalls = [...prevCalls];
+ const newCall = clearTraceCallId(updatedCalls[callIndex]);
+ if (
+ newCall?.traceCall?.output &&
+ Array.isArray((newCall.traceCall.output as TraceCallOutput).choices)
+ ) {
+ // Delete the old choice
+ (newCall.traceCall.output as TraceCallOutput).choices = (
+ newCall.traceCall.output as TraceCallOutput
+ ).choices!.filter((_, index) => index !== choiceIndex);
+
+ // Add the new choice as a message
+ newCall.traceCall.inputs = newCall.traceCall.inputs ?? {};
+ newCall.traceCall.inputs.messages =
+ newCall.traceCall.inputs.messages ?? [];
+ newCall.traceCall.inputs.messages.push({
+ role: 'assistant',
+ content: newChoice.message?.content || newChoice.content,
+ });
+ }
+ return updatedCalls;
+ });
+ };
+
+ const setPlaygroundStateFromInputs = useCallback(
+ (inputs: Record) => {
+ // https://docs.litellm.ai/docs/completion/input
+ // pulled from litellm
+ setPlaygroundStates(prevState => {
+ const newState = {...prevState[0]};
+ if (inputs.tools) {
+ newState.functions = [];
+ for (const tool of inputs.tools) {
+ if (tool.type === 'function') {
+ newState.functions = [...newState.functions, tool.function];
+ }
+ }
+ }
+ if (inputs.n) {
+ newState.nTimes = parseInt(inputs.n, 10);
+ }
+ if (inputs.temperature) {
+ newState.temperature = parseFloat(inputs.temperature);
+ }
+ if (inputs.response_format) {
+ newState.responseFormat = inputs.response_format.type;
+ }
+ if (inputs.top_p) {
+ newState.topP = parseFloat(inputs.top_p);
+ }
+ if (inputs.frequency_penalty) {
+ newState.frequencyPenalty = parseFloat(inputs.frequency_penalty);
+ }
+ if (inputs.presence_penalty) {
+ newState.presencePenalty = parseFloat(inputs.presence_penalty);
+ }
+ return [newState];
+ });
+ },
+ []
+ );
+
+ const clearTraceCallId = (callWithTraceCallId: OptionalCallSchema) => {
+ if (callWithTraceCallId.traceCall) {
+ callWithTraceCallId.traceCall.id = '';
+ }
+ return callWithTraceCallId;
+ };
+
+ useEffect(() => {
+ if (!call.loading && call.result) {
+ setCalls([call.result]);
+ if (call.result.traceCall?.inputs) {
+ setPlaygroundStateFromInputs(call.result.traceCall.inputs);
+ }
+ } else if (calls.length === 0) {
+ setCalls([
+ {
+ entity: props.entity,
+ project: props.project,
+ traceCall: {
+ inputs: {
+ messages: [
+ {
+ role: 'system',
+ content: 'You are a helpful assistant.',
+ },
+ ],
+ },
+ },
+ },
+ ]);
+ }
+ }, [
+ call,
+ props.entity,
+ props.project,
+ // TODO: eslint fix these
+ // calls.length,
+ // setPlaygroundStateFromInputs,
+ ]);
+
+ return (
+
+ {call.loading ? (
+
+
+
+ ) : (
+
+ )}
+ {settingsTab !== null && (
+
+ )}
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/FunctionDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/FunctionDrawer.tsx
new file mode 100644
index 00000000000..e5eec61178c
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/FunctionDrawer.tsx
@@ -0,0 +1,163 @@
+import {Drawer} from '@mui/material';
+import {Alert} from '@wandb/weave/components/Alert';
+import {Button} from '@wandb/weave/components/Button';
+import React, {useEffect, useState} from 'react';
+
+import {Tailwind} from '../../../../../../Tailwind';
+import {TextArea} from '../Textarea';
+
+type FunctionDrawerProps = {
+ drawerFunctionIndex: number | null;
+ onClose: () => void;
+ functions: Array<{name: string; [key: string]: any}>;
+ onAddFunction: (functionJSON: string, index: number) => void;
+};
+
+const FUNCTION_PLACEHOLDER = `{
+ "name": "get_stock_price",
+ "description": "Get the current stock price",
+ "strict": true,
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "symbol": {
+ "type": "string",
+ "description": "The stock symbol"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "symbol"
+ ]
+ }
+}`;
+
+export const FunctionDrawer: React.FC = ({
+ drawerFunctionIndex,
+ onClose,
+ onAddFunction,
+ functions,
+}) => {
+ const [functionJSON, setFunctionJSON] = useState(
+ drawerFunctionIndex !== null
+ ? JSON.stringify(functions[drawerFunctionIndex], null, 2) ?? ''
+ : ''
+ );
+
+ let jsonValidationError = null;
+ let parsedFunctionJSON = null;
+ try {
+ parsedFunctionJSON = JSON.parse(functionJSON);
+ JSON.stringify(parsedFunctionJSON, null, 2);
+ } catch (err) {
+ jsonValidationError = `${err}`;
+ }
+
+ const handleAddFunction = () => {
+ if (drawerFunctionIndex !== null) {
+ onAddFunction(functionJSON, drawerFunctionIndex);
+ }
+ setFunctionJSON('');
+ onClose();
+ };
+
+ useEffect(() => {
+ setFunctionJSON(
+ drawerFunctionIndex !== null
+ ? JSON.stringify(functions[drawerFunctionIndex], null, 2) ?? ''
+ : ''
+ );
+ }, [drawerFunctionIndex]);
+
+ const checkFunctionName = (name: string, index: number) => {
+ return functions.some((func, idx) => func.name === name && idx !== index);
+ };
+
+ const disableActionButton =
+ functionJSON.length === 0 ||
+ !functionJSON.trim() ||
+ !!jsonValidationError ||
+ parsedFunctionJSON.name === null ||
+ checkFunctionName(parsedFunctionJSON.name, drawerFunctionIndex ?? 0);
+
+ let buttonTooltip = `${
+ drawerFunctionIndex !== null && drawerFunctionIndex < functions.length
+ ? 'Update'
+ : 'Add'
+ } function`;
+
+ if (disableActionButton) {
+ if (functionJSON.length === 0 || !functionJSON.trim()) {
+ buttonTooltip = 'Function JSON is empty';
+ } else if (!!jsonValidationError) {
+ buttonTooltip = jsonValidationError;
+ } else if (parsedFunctionJSON.name === null) {
+ buttonTooltip = 'Function JSON has no name';
+ } else if (
+ checkFunctionName(parsedFunctionJSON.name, drawerFunctionIndex ?? 0)
+ ) {
+ buttonTooltip = 'Function with this name already exists';
+ }
+ }
+
+ return (
+
+
+
+
+
+ The model will intelligently decide to call functions based on input
+ it receives from the user.
+
+
setFunctionJSON(FUNCTION_PLACEHOLDER)}>
+ Load placeholder
+
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/FunctionEditor.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/FunctionEditor.tsx
new file mode 100644
index 00000000000..042a71f9f0d
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/FunctionEditor.tsx
@@ -0,0 +1,106 @@
+import {Box, Chip} from '@mui/material';
+import {styled} from '@mui/material';
+import {Button} from '@wandb/weave/components/Button';
+import React, {useState} from 'react';
+
+import {FunctionDrawer} from './FunctionDrawer';
+
+const StyledChip = styled(Chip)(({theme}) => ({
+ cursor: 'pointer',
+ width: '100%',
+ justifyContent: 'space-between',
+ '& .MuiChip-label': {
+ paddingLeft: 8,
+ paddingRight: 0,
+ flex: 1,
+ textAlign: 'left',
+ },
+ '& .MuiChip-deleteIcon': {
+ marginRight: 8,
+ marginLeft: 'auto',
+ },
+}));
+
+type FunctionEditorProps = {
+ functions: Array<{name: string; [key: string]: any}>;
+ setFunctions: React.Dispatch<
+ React.SetStateAction>
+ >;
+};
+
+export const FunctionEditor: React.FC = ({
+ functions,
+ setFunctions,
+}) => {
+ // null means the drawer is closed
+ const [drawerFunctionIndex, setDrawerFunctionIndex] = useState(
+ null
+ );
+
+ const handleAddFunction = (functionJSON: string, index: number) => {
+ try {
+ const json = JSON.parse(functionJSON);
+ if (
+ typeof json === 'object' &&
+ json !== null &&
+ 'name' in json &&
+ (functions[index] || functions.every(func => func.name !== json.name))
+ ) {
+ setFunctions(prevFunctions => {
+ const newFunctions = [...prevFunctions];
+ if (index < newFunctions.length) {
+ newFunctions[index] = json;
+ } else {
+ newFunctions.push(json);
+ }
+ return newFunctions;
+ });
+ } else {
+ console.error('Function JSON must have a name property');
+ }
+ } catch (err) {
+ console.error('Error parsing function json', err);
+ }
+ };
+
+ const handleDeleteFunction = (functionToDelete: string) => {
+ setFunctions(functions.filter(func => func.name !== functionToDelete));
+ };
+
+ return (
+
+ Functions
+
+ {functions.map((func, index) => (
+ handleDeleteFunction(func.name)}
+ size="small"
+ onClick={() => setDrawerFunctionIndex(index)}
+ />
+ ))}
+
+
+ setDrawerFunctionIndex(functions.length)}>
+ Add function
+
+
+ setDrawerFunctionIndex(null)}
+ onAddFunction={handleAddFunction}
+ functions={functions}
+ />
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx
new file mode 100644
index 00000000000..c6a708fb9cb
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx
@@ -0,0 +1,169 @@
+import Box from '@mui/material/Box';
+import {MOON_250} from '@wandb/weave/common/css/color.styles';
+import * as Tabs from '@wandb/weave/components/Tabs';
+import {Tag} from '@wandb/weave/components/Tag';
+import React from 'react';
+
+import {FunctionEditor} from './FunctionEditor';
+import {PlaygroundSlider} from './PlaygroundSlider';
+import {
+ PlaygroundResponseFormats,
+ ResponseFormatEditor,
+} from './ResponseFormatEditor';
+import {StopSequenceEditor} from './StopSequenceEditor';
+
+export type PlaygroundSettingsProps = {
+ playgroundStates: PlaygroundState[];
+ setPlaygroundStateField: (
+ idx: number,
+ key: keyof PlaygroundState,
+ value:
+ | PlaygroundState[keyof PlaygroundState]
+ | React.SetStateAction>
+ | React.SetStateAction
+ | React.SetStateAction
+ | React.SetStateAction
+ ) => void;
+ settingsTab: number;
+ setSettingsTab: (tab: number) => void;
+};
+
+export type PlaygroundState = {
+ loading: boolean;
+ functions: Array<{name: string; [key: string]: any}>;
+ responseFormat: PlaygroundResponseFormats;
+ temperature: number;
+ maxTokens: number;
+ stopSequences: string[];
+ topP: number;
+ frequencyPenalty: number;
+ presencePenalty: number;
+ nTimes: number;
+ maxTokensLimit: number;
+ model: string;
+};
+
+export const PlaygroundSettings: React.FC = ({
+ playgroundStates,
+ setPlaygroundStateField,
+ settingsTab,
+ setSettingsTab,
+}) => {
+ return (
+
+
+
+ {playgroundStates.map((state, idx) => (
+ setSettingsTab(idx)}>
+ {playgroundStates.length > 1 && }
+ {state.model}
+
+ ))}
+
+ {playgroundStates.map((playgroundState, idx) => (
+
+
+
+ setPlaygroundStateField(idx, 'functions', value)
+ }
+ />
+
+
+ setPlaygroundStateField(idx, 'responseFormat', value)
+ }
+ />
+
+
+ setPlaygroundStateField(idx, 'temperature', value)
+ }
+ label="Temperature"
+ value={playgroundState.temperature}
+ />
+
+
+ setPlaygroundStateField(idx, 'maxTokens', value)
+ }
+ label="Maximum tokens"
+ value={playgroundState.maxTokens}
+ />
+
+
+ setPlaygroundStateField(idx, 'stopSequences', value)
+ }
+ />
+
+ setPlaygroundStateField(idx, 'topP', value)}
+ label="Top P"
+ value={playgroundState.topP}
+ />
+
+
+ setPlaygroundStateField(idx, 'frequencyPenalty', value)
+ }
+ label="Frequency penalty"
+ value={playgroundState.frequencyPenalty}
+ />
+
+
+ setPlaygroundStateField(idx, 'presencePenalty', value)
+ }
+ label="Presence penalty"
+ value={playgroundState.presencePenalty}
+ />
+
+
+ setPlaygroundStateField(idx, 'nTimes', value)
+ }
+ label="n times to run"
+ value={playgroundState.nTimes}
+ />
+
+
+ ))}
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSlider.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSlider.tsx
new file mode 100644
index 00000000000..e3e41d5359a
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSlider.tsx
@@ -0,0 +1,151 @@
+import {Box} from '@material-ui/core';
+import {TextField} from '@mui/material';
+import Slider from '@mui/material/Slider';
+import {styled} from '@mui/material/styles';
+import {
+ MOON_250,
+ MOON_350,
+ TEAL_500,
+} from '@wandb/weave/common/css/color.styles';
+import {isArray} from 'lodash';
+import React, {useEffect} from 'react';
+
+export const StyledSlider = styled(Slider)(({theme}) => ({
+ color: MOON_250,
+ height: 4,
+ marginBottom: 0,
+ padding: 0,
+ // So that the track lines up with the margin
+ width: 'calc( 100% - 12px )',
+ '& .MuiSlider-thumb': {
+ height: 12,
+ width: 12,
+ backgroundColor: '#fff',
+ boxShadow: '0 0 2px 0px rgba(0, 0, 0, 0.1)',
+ marginLeft: 6,
+ marginRight: 6,
+ border: `1px solid ${MOON_350}`,
+ '&:focus, &:hover, &.Mui-active': {
+ boxShadow: '0px 0px 0px 0px rgba(0, 0, 0, 0.1)',
+ '@media (hover: none)': {
+ boxShadow:
+ '0px 0px 0px 0px rgba(0,0,0,0.2), 0px 0px 0px 0px rgba(0,0,0,0.14), 0px 0px 1px 0px rgba(0,0,0,0.12)',
+ },
+ },
+ '&:after': {
+ height: 12,
+ width: 12,
+ },
+ },
+ '& .MuiSlider-track': {
+ border: 'none',
+ height: 4,
+ backgroundColor: TEAL_500,
+ },
+ '& .MuiSlider-rail': {
+ opacity: 1,
+ backgroundColor: MOON_250,
+ paddingRight: 12,
+ },
+}));
+
+type PlaygroundSliderProps = {
+ label: string;
+ min: number;
+ max: number;
+ step: number;
+ value: number;
+ setValue: (value: number) => void;
+};
+
+export const PlaygroundSlider = ({
+ setValue,
+ ...props
+}: PlaygroundSliderProps) => {
+ const [editing, setEditing] = React.useState(null);
+
+ const handleInputChange = (e: React.ChangeEvent) => {
+ setEditing(e.target.value);
+ };
+
+ const handleInputBlur = (): void => {
+ if (editing !== null) {
+ const newValue = parseFloat(editing);
+ if (!isNaN(newValue)) {
+ setValue(newValue);
+ }
+ setEditing(null);
+ }
+ };
+
+ const handleInputKeyPress = (e: React.KeyboardEvent) => {
+ if (e.key === 'Enter') {
+ handleInputBlur();
+ }
+ };
+
+ useEffect(() => {
+ if (props.value < props.min) {
+ setValue(props.min);
+ } else if (props.value > props.max) {
+ setValue(props.max);
+ }
+ }, [props.value, props.min, props.max, setValue]);
+
+ return (
+
+
+ {props.label}
+
+
+ {
+ if (isArray(value)) {
+ setValue(value[0]);
+ } else {
+ setValue(value);
+ }
+ }}
+ {...props}
+ />
+
+ );
+};
+
+export const formatValueToStep = (value: number, step: number): string => {
+ if (step >= 1) {
+ return value.toFixed(0);
+ }
+
+ const decimalPlaces = Math.max(0, -Math.floor(Math.log10(step)));
+ return value.toFixed(decimalPlaces);
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/ResponseFormatEditor.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/ResponseFormatEditor.tsx
new file mode 100644
index 00000000000..0799065a911
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/ResponseFormatEditor.tsx
@@ -0,0 +1,78 @@
+import {Box, MenuItem, TextField} from '@mui/material';
+import {MOON_250, TEAL_400} from '@wandb/weave/common/css/color.styles';
+import {Icon} from '@wandb/weave/components/Icon';
+import React from 'react';
+
+export enum PlaygroundResponseFormats {
+ Text = 'text',
+ JsonObject = 'json_object',
+ JsonSchema = 'json_schema',
+}
+
+const RESPONSE_FORMATS: PlaygroundResponseFormats[] = Object.values(
+ PlaygroundResponseFormats
+);
+
+type ResponseFormatEditorProps = {
+ responseFormat: PlaygroundResponseFormats;
+ setResponseFormat: React.Dispatch<
+ React.SetStateAction
+ >;
+};
+
+export const ResponseFormatEditor: React.FC = ({
+ responseFormat,
+ setResponseFormat,
+}) => {
+ return (
+
+ Response format
+
+ setResponseFormat(e.target.value as PlaygroundResponseFormats)
+ }
+ size="small"
+ slotProps={{
+ select: {
+ IconComponent: props => ,
+ },
+ }}
+ sx={{
+ width: '100%',
+ padding: 0,
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ '& .MuiSelect-select': {
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ },
+ '& .MuiMenuItem-root': {
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ },
+ '& .MuiOutlinedInput-notchedOutline': {
+ border: `1px solid ${MOON_250}`,
+ },
+ '& .Mui-focused .MuiOutlinedInput-notchedOutline': {
+ border: `1px solid ${MOON_250}`,
+ },
+ '& .MuiOutlinedInput-root:hover .MuiOutlinedInput-notchedOutline': {
+ border: `1px solid ${TEAL_400}`,
+ },
+ }}>
+ {RESPONSE_FORMATS.map(format => (
+
+ ))}
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/StopSequenceEditor.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/StopSequenceEditor.tsx
new file mode 100644
index 00000000000..1dd1318fbb4
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/StopSequenceEditor.tsx
@@ -0,0 +1,94 @@
+import {Box, Chip, TextField} from '@mui/material';
+import {MOON_250} from '@wandb/weave/common/css/color.styles';
+import React, {KeyboardEvent, useState} from 'react';
+
+type StopSequenceEditorProps = {
+ stopSequences: string[];
+ setStopSequences: React.Dispatch>;
+};
+
+export const StopSequenceEditor: React.FC = ({
+ stopSequences,
+ setStopSequences,
+}) => {
+ const [currentStopSequence, setCurrentStopSequence] = useState('');
+
+ const handleStopSequenceKeyDown = (
+ event: KeyboardEvent
+ ) => {
+ if (event.key === 'Enter' && currentStopSequence.trim() !== '') {
+ if (!stopSequences.includes(currentStopSequence.trim())) {
+ setStopSequences([...stopSequences, currentStopSequence.trim()]);
+ }
+ setCurrentStopSequence('');
+ }
+ };
+
+ const handleDeleteStopSequence = (sequenceToDelete: string) => {
+ setStopSequences(stopSequences.filter(seq => seq !== sequenceToDelete));
+ };
+
+ return (
+
+ Stop sequences
+ 0
+ ? {padding: '8px', paddingBottom: 0}
+ : {}),
+ borderRadius: '4px',
+ width: '100%',
+ }}>
+
+ {stopSequences.map((seq, index) => (
+ handleDeleteStopSequence(seq)}
+ size="small"
+ />
+ ))}
+
+ setCurrentStopSequence(e.target.value)}
+ onKeyDown={handleStopSequenceKeyDown}
+ placeholder="Type and press Enter"
+ size="small"
+ fullWidth
+ variant="standard"
+ InputProps={{
+ disableUnderline: true,
+ }}
+ sx={{
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ '& .MuiInputBase-root': {
+ border: 'none',
+ '&:before, &:after': {
+ borderBottom: 'none',
+ },
+ '&:hover:not(.Mui-disabled):before': {
+ borderBottom: 'none',
+ },
+ },
+ '& .MuiInputBase-input': {
+ padding: '8px',
+ fontFamily: 'Source Sans Pro',
+ fontSize: '16px',
+ },
+ }}
+ />
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/Textarea.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/Textarea.tsx
new file mode 100644
index 00000000000..86370516d31
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/Textarea.tsx
@@ -0,0 +1,53 @@
+/**
+ * A form multi-line text input.
+ */
+import {Tailwind} from '@wandb/weave/components/Tailwind';
+import classNames from 'classnames';
+import React, {forwardRef, TextareaHTMLAttributes} from 'react';
+
+type TextAreaProps = React.TextareaHTMLAttributes;
+
+export const TextArea = forwardRef(
+ ({className, ...props}, ref) => {
+ return (
+
+
+
+ );
+ }
+);
+
+TextArea.displayName = 'TextArea';
+
+export const FixedSizeTextArea = ({
+ className,
+ height = '80px',
+ ...props
+}: TextareaHTMLAttributes & {height?: string}) => {
+ return (
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/Untitled-1.md b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/Untitled-1.md
new file mode 100644
index 00000000000..f81c356931f
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/Untitled-1.md
@@ -0,0 +1,4 @@
+null tool calls https://app.wandb.test/shawn/programmerdev-eval-edits1/weave/calls/0191e2dd-e3ea-7680-9c99-a483f005de6a?path=AgentTextEditor.step*0+litellm.completion*0&tracetree=1
+multiple tool calls, n times run, null choice https://app.wandb.test/shawn/programmerdev-eval-edits1/weave/playground/0191e2c6-2a08-7491-bc38-0f0c682aeba2
+message with tool call https://app.wandb.test/shawn/programmerdev-eval-edits1/weave/playground/0191e2f5-2ac5-7e32-8cc1-b03d8f5a7ec2
+structured output https://app.wandb.test/shawn/programmerjs-dev1/weave/playground/0192260e-04e4-79d7-849d-cd890017cf55?peekPath=%2Fshawn%2Fprogrammerjs-dev1%2Fcalls%2F0192260e-04e4-79d7-849d-cd890017cf55%3Ftracetree%3D0
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts
new file mode 100644
index 00000000000..34cee0ea092
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/index.ts
@@ -0,0 +1 @@
+export * from './PlaygroundPage';
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx
index 4a5d169ad29..2c05b24420f 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx
@@ -14,6 +14,7 @@ import React, {
import {ErrorBoundary} from '../../../../../ErrorBoundary';
import {SplitPanel} from './SplitPanel';
import {isPrimitive} from './util';
+import {MOON_200} from '@wandb/weave/common/css/color.styles';
type SimplePageLayoutContextType = {
headerPrefix?: ReactNode;
@@ -69,7 +70,7 @@ export const SimplePageLayout: FC<{
pb: 0,
height: 44,
width: '100%',
- borderBottom: '1px solid #e0e0e0',
+ borderBottom: `1px solid ${MOON_200}`,
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
@@ -139,7 +140,7 @@ export const SimplePageLayout: FC<{
overflow: 'hidden',
height: '100%',
maxHeight: '100%',
- borderRight: '1px solid #e0e0e0',
+ borderRight: `1px solid ${MOON_200}`,
}}>
{props.leftSidebar}
@@ -212,7 +213,7 @@ export const SimplePageLayoutWithHeader: FC<{
zIndex: 1,
backgroundColor: 'white',
pb: 0,
- borderBottom: '1px solid #e0e0e0',
+ borderBottom: `1px solid ${MOON_200}`,
justifyContent: 'flex-start',
}}>
{simplePageLayoutContextValue.headerPrefix}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts
index af997016a4a..0c696d98202 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts
@@ -1,6 +1,8 @@
import _ from 'lodash';
import {
+ CompletionsCreateReq,
+ CompletionsCreateRes,
FeedbackCreateReq,
FeedbackCreateRes,
FeedbackPurgeReq,
@@ -117,6 +119,12 @@ export class TraceServerClient extends DirectTraceServerClient {
return this.requestReadBatch(req);
}
+ public completionsCreate(
+ req: CompletionsCreateReq
+ ): Promise {
+ return super.completionsCreate(req);
+ }
+
private requestReadBatch(
req: TraceRefsReadBatchReq
): Promise {
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts
index 88113a37a74..52f8c44fa89 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts
@@ -273,6 +273,17 @@ export type TraceFileContentReadRes = {
content: ArrayBuffer;
};
+export type CompletionsCreateReq = {
+ project_id: string;
+ // TODO make this type better
+ inputs: any;
+};
+
+export type CompletionsCreateRes = {
+ response: any;
+ weave_call_id: string;
+};
+
export enum ContentType {
csv = 'text/csv',
tsv = 'text/tab-separated-values',
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts
index caaf63b7f56..da7713666af 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts
@@ -44,6 +44,8 @@ import {
TraceTableQueryRes,
TraceTableQueryStatsReq,
TraceTableQueryStatsRes,
+ CompletionsCreateReq,
+ CompletionsCreateRes,
} from './traceServerClientTypes';
export class DirectTraceServerClient {
@@ -287,6 +289,26 @@ export class DirectTraceServerClient {
});
}
+ public completionsCreate(
+ req: CompletionsCreateReq
+ ): Promise {
+ try {
+ return this.makeRequest(
+ '/completions/create',
+ req
+ );
+ } catch (error: any) {
+ if (error?.api_key_name) {
+ console.log('Missing LLM API key:', error.api_key_name);
+ }
+ }
+
+ return this.makeRequest(
+ '/completions/create',
+ req
+ );
+ }
+
private makeRequest = async (
endpoint: string,
req: QT,
diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py
index 9a1f1a1303f..5e9a3ccee2f 100644
--- a/weave/trace_server/clickhouse_trace_server_batched.py
+++ b/weave/trace_server/clickhouse_trace_server_batched.py
@@ -73,7 +73,12 @@
)
from weave.trace_server.constants import COMPLETIONS_CREATE_OP_NAME
from weave.trace_server.emoji_util import detone_emojis
-from weave.trace_server.errors import InsertTooLarge, InvalidRequest, RequestTooLarge
+from weave.trace_server.errors import (
+ InsertTooLarge,
+ InvalidRequest,
+ RequestTooLarge,
+ MissingLLMApiKeyError,
+)
from weave.trace_server.feedback import (
TABLE_FEEDBACK,
validate_feedback_create_req,
@@ -1540,7 +1545,10 @@ def completions_create(
raise InvalidRequest(f"No secret name found for model {model_name}")
api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name)
if not api_key:
- raise InvalidRequest(f"No API key found for model {model_name}")
+ raise MissingLLMApiKeyError(
+ f"No API key {secret_name} found for model {model_name}",
+ api_key_name=secret_name,
+ )
start_time = datetime.datetime.now()
res = lite_llm_completion(api_key, req.inputs)
@@ -1579,7 +1587,10 @@ def completions_create(
batch_data.append(values)
self._insert_call_batch(batch_data)
- return res
+
+ return tsi.CompletionsCreateRes(
+ response=res.response, weave_call_id=start_call.id
+ )
# Private Methods
@property
diff --git a/weave/trace_server/errors.py b/weave/trace_server/errors.py
index 865ec908f5f..b2014fc1a48 100644
--- a/weave/trace_server/errors.py
+++ b/weave/trace_server/errors.py
@@ -26,3 +26,11 @@ class InvalidFieldError(Error):
"""Raised when a field is invalid."""
pass
+
+
+class MissingLLMApiKeyError(Error):
+ """Raised when a LLM API key is missing for completion."""
+
+ def __init__(self, message: str, api_key_name: str):
+ self.api_key_name = api_key_name
+ super().__init__(message)
diff --git a/weave/trace_server/llm_completion.py b/weave/trace_server/llm_completion.py
index 907e8fe413b..551731e1332 100644
--- a/weave/trace_server/llm_completion.py
+++ b/weave/trace_server/llm_completion.py
@@ -4,10 +4,17 @@
def lite_llm_completion(
api_key: str, inputs: tsi.CompletionsCreateRequestInputs
) -> tsi.CompletionsCreateRes:
- from litellm import completion
+ import litellm
+ print("inputs", flush=True)
+
+ litellm.drop_params = True
try:
- res = completion(**inputs.model_dump(exclude_none=True), api_key=api_key)
+ res = litellm.completion(
+ **inputs.model_dump(exclude_none=True), api_key=api_key
+ )
+ print("res", res.model_dump(), flush=True)
return tsi.CompletionsCreateRes(response=res.model_dump())
except Exception as e:
+ print("error", str(e), flush=True)
return tsi.CompletionsCreateRes(response={"error": str(e)})
diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py
index 4f064759505..f25b828eb96 100644
--- a/weave/trace_server/trace_server_interface.py
+++ b/weave/trace_server/trace_server_interface.py
@@ -275,6 +275,7 @@ class CompletionsCreateReq(BaseModel):
class CompletionsCreateRes(BaseModel):
response: Dict[str, Any]
+ weave_call_id: Optional[str] = None
class CallsFilter(BaseModel):
@@ -888,6 +889,7 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ...
def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ...
def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ...
def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ...
+
# Action API
def execute_batch_action(
self, req: ExecuteBatchActionReq