diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/LLMDropdown.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/LLMDropdown.tsx index b7a71e88df05..d14a9db30427 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/LLMDropdown.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/LLMDropdown.tsx @@ -2,18 +2,19 @@ import {Box} from '@mui/material'; import {Select} from '@wandb/weave/components/Form/Select'; import React from 'react'; -import {LLM_MAX_TOKENS} from '../llmMaxTokens'; +import {LLM_MAX_TOKENS, LLMMaxTokensKey} from '../llmMaxTokens'; interface LLMDropdownProps { - value: string; - onChange: (value: string, maxTokens: number) => void; + value: LLMMaxTokensKey; + onChange: (value: LLMMaxTokensKey, maxTokens: number) => void; } export const LLMDropdown: React.FC = ({value, onChange}) => { - const options = Object.keys(LLM_MAX_TOKENS).map(llm => ({ - value: llm, - label: llm, - })); + const options: Array<{value: LLMMaxTokensKey; label: LLMMaxTokensKey}> = + Object.keys(LLM_MAX_TOKENS).map(llm => ({ + value: llm as LLMMaxTokensKey, + label: llm as LLMMaxTokensKey, + })); return ( = ({value, onChange}) => { LLM_MAX_TOKENS[ (option as {value: string}).value as keyof typeof LLM_MAX_TOKENS ]?.max_tokens || 0; - onChange((option as {value: string}).value, maxTokens); + onChange((option as {value: LLMMaxTokensKey}).value, maxTokens); } }} options={options} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index cdd14a4a5c25..ec6653801978 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -1,24 +1,26 @@ import {Box, CircularProgress, Divider} from '@mui/material'; import {MOON_200} from '@wandb/weave/common/css/color.styles'; import {Tailwind} from '@wandb/weave/components/Tailwind'; -import React, {SetStateAction, useState} from 'react'; +import React, {useState} from 'react'; +import {CallChat} from '../../CallPage/CallChat'; import {TraceCallSchema} from '../../wfReactInterface/traceServerClientTypes'; -import {PlaygroundState, PlaygroundStateKey} from '../types'; +import {PlaygroundContext} from '../PlaygroundContext'; +import {PlaygroundState} from '../types'; import {PlaygroundCallStats} from './PlaygroundCallStats'; import {PlaygroundChatInput} from './PlaygroundChatInput'; import {PlaygroundChatTopBar} from './PlaygroundChatTopBar'; +import { + SetPlaygroundStateFieldFunctionType, + useChatFunctions, +} from './useChatFunctions'; export type PlaygroundChatProps = { entity: string; project: string; setPlaygroundStates: (states: PlaygroundState[]) => void; playgroundStates: PlaygroundState[]; - setPlaygroundStateField: ( - index: number, - field: PlaygroundStateKey, - value: SetStateAction - ) => void; + setPlaygroundStateField: SetPlaygroundStateFieldFunctionType; setSettingsTab: (callIndex: number | null) => void; settingsTab: number | null; }; @@ -37,6 +39,16 @@ export const PlaygroundChat = ({ const [isLoading, setIsLoading] = useState(false); const chatPercentWidth = 100 / playgroundStates.length; + const {deleteMessage, editMessage, deleteChoice, editChoice, addMessage} = + useChatFunctions(setPlaygroundStateField); + + const handleAddMessage = (role: 'assistant' | 'user', text: string) => { + for (let i = 0; i < playgroundStates.length; i++) { + addMessage(i, {role, content: text}); + } + setChatText(''); + }; + return (
- Chat + {state.traceCall && ( + + deleteMessage(idx, messageIndex, responseIndexes), + editMessage: (messageIndex, newMessage) => + editMessage(idx, messageIndex, newMessage), + deleteChoice: choiceIndex => + deleteChoice(idx, choiceIndex), + addMessage: newMessage => addMessage(idx, newMessage), + editChoice: (choiceIndex, newChoice) => + editChoice(idx, choiceIndex, newChoice), + retry: (messageIndex: number, isChoice?: boolean) => + console.log('retry', messageIndex, isChoice), + sendMessage: ( + role: 'assistant' | 'user' | 'tool', + content: string, + toolCallId?: string + ) => + console.log( + 'sendMessage', + role, + content, + toolCallId + ), + }}> + + + )}
@@ -147,7 +188,7 @@ export const PlaygroundChat = ({ setChatText={setChatText} isLoading={isLoading} onSend={() => {}} - onAdd={() => {}} + onAdd={handleAddMessage} />
); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx index 397f3a3a15d0..2fcd24f56e3a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx @@ -5,19 +5,17 @@ import React from 'react'; import {useHistory} from 'react-router-dom'; import {CopyableId} from '../../common/Id'; +import {LLMMaxTokensKey} from '../llmMaxTokens'; import {OptionalTraceCallSchema, PlaygroundState} from '../types'; import {DEFAULT_SYSTEM_MESSAGE} from '../usePlaygroundState'; import {LLMDropdown} from './LLMDropdown'; +import {SetPlaygroundStateFieldFunctionType} from './useChatFunctions'; type PlaygroundChatTopBarProps = { idx: number; settingsTab: number | null; setSettingsTab: (tab: number | null) => void; - setPlaygroundStateField: ( - index: number, - field: keyof PlaygroundState, - value: any - ) => void; + setPlaygroundStateField: SetPlaygroundStateFieldFunctionType; entity: string; project: string; playgroundStates: PlaygroundState[]; @@ -60,7 +58,7 @@ export const PlaygroundChatTopBar: React.FC = ({ const handleModelChange = ( index: number, - model: string, + model: LLMMaxTokensKey, maxTokens: number ) => { setPlaygroundStateField(index, 'model', model); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx new file mode 100644 index 000000000000..178752a71da9 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -0,0 +1,155 @@ +import {cloneDeep} from 'lodash'; +import {SetStateAction} from 'react'; + +import {Choice, Message} from '../../ChatView/types'; +import {OptionalTraceCallSchema, PlaygroundState} from '../types'; +import {DEFAULT_SYSTEM_MESSAGE} from '../usePlaygroundState'; +type TraceCallOutput = { + choices?: any[]; +}; + +export type SetPlaygroundStateFieldFunctionType = ( + index: number, + field: keyof PlaygroundState, + // The value here is a function that returns a PlaygroundState field + value: SetStateAction +) => void; + +export const useChatFunctions = ( + setPlaygroundStateField: SetPlaygroundStateFieldFunctionType +) => { + const deleteMessage = ( + callIndex: number, + messageIndex: number, + responseIndexes?: number[] + ) => { + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { + const newTraceCall = clearTraceCall( + cloneDeep(prevTraceCall as OptionalTraceCallSchema) + ); + if (newTraceCall && newTraceCall.inputs?.messages) { + // Remove the message and all responses to it + newTraceCall.inputs.messages = newTraceCall.inputs.messages.filter( + (_: any, index: number) => + index !== messageIndex && !responseIndexes?.includes(index) + ); + + // If there are no messages left, add a system message + if (newTraceCall.inputs.messages.length === 0) { + newTraceCall.inputs.messages = [DEFAULT_SYSTEM_MESSAGE]; + } + } + return newTraceCall; + }); + }; + + const editMessage = ( + callIndex: number, + messageIndex: number, + newMessage: Message + ) => { + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { + const newTraceCall = clearTraceCall( + cloneDeep(prevTraceCall as OptionalTraceCallSchema) + ); + if (newTraceCall && newTraceCall.inputs?.messages) { + // Replace the message + newTraceCall.inputs.messages[messageIndex] = newMessage; + } + return newTraceCall; + }); + }; + + const addMessage = (callIndex: number, newMessage: Message) => { + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { + const newTraceCall = clearTraceCall( + cloneDeep(prevTraceCall as OptionalTraceCallSchema) + ); + if (newTraceCall && newTraceCall.inputs?.messages) { + if ( + newTraceCall.output && + (newTraceCall.output as TraceCallOutput).choices && + Array.isArray((newTraceCall.output as TraceCallOutput).choices) + ) { + // Add all the choices as messages + (newTraceCall.output as TraceCallOutput).choices!.forEach( + (choice: any) => { + if (choice.message) { + newTraceCall.inputs!.messages.push(choice.message); + } + } + ); + // Set the choices to undefined + (newTraceCall.output as TraceCallOutput).choices = undefined; + } + // Add the new message + newTraceCall.inputs.messages.push(newMessage); + } + return newTraceCall; + }); + }; + + const deleteChoice = (callIndex: number, choiceIndex: number) => { + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { + const newTraceCall = clearTraceCall( + cloneDeep(prevTraceCall as OptionalTraceCallSchema) + ); + const output = newTraceCall?.output as TraceCallOutput; + if (output && Array.isArray(output.choices)) { + // Remove the choice + output.choices.splice(choiceIndex, 1); + if (newTraceCall) { + newTraceCall.output = output; + } + } + return newTraceCall; + }); + }; + + const editChoice = ( + callIndex: number, + choiceIndex: number, + newChoice: Choice + ) => { + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { + const newTraceCall = clearTraceCall( + cloneDeep(prevTraceCall as OptionalTraceCallSchema) + ); + if ( + newTraceCall?.output && + Array.isArray((newTraceCall.output as TraceCallOutput).choices) + ) { + // Delete the old choice + (newTraceCall.output as TraceCallOutput).choices!.splice( + choiceIndex, + 1 + ); + + // Add the new choice as a message + newTraceCall.inputs = newTraceCall.inputs ?? {}; + newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? []; + newTraceCall.inputs.messages.push({ + role: 'assistant', + content: newChoice.message?.content, + }); + } + return newTraceCall; + }); + }; + + return { + deleteMessage, + editMessage, + addMessage, + deleteChoice, + editChoice, + }; +}; + +export const clearTraceCall = (traceCall: OptionalTraceCallSchema) => { + if (traceCall) { + traceCall.id = ''; + traceCall.summary = undefined; + } + return traceCall; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx new file mode 100644 index 000000000000..14485ed2434d --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx @@ -0,0 +1,44 @@ +import {createContext, useContext} from 'react'; + +import {Choice, Message} from '../ChatView/types'; + +export type PlaygroundContextType = { + isPlayground: boolean; + addMessage: (newMessage: Message) => void; + editMessage: (messageIndex: number, newMessage: Message) => void; + deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void; + + editChoice: (choiceIndex: number, newChoice: Choice) => void; + deleteChoice: (choiceIndex: number) => void; + + retry: (messageIndex: number, isChoice?: boolean) => void; + sendMessage: ( + role: 'assistant' | 'user' | 'tool', + content: string, + toolCallId?: string + ) => void; +}; + +const DEFAULT_CONTEXT: PlaygroundContextType = { + isPlayground: false, + addMessage: () => {}, + editMessage: () => {}, + deleteMessage: () => {}, + + editChoice: () => {}, + deleteChoice: () => {}, + + retry: () => {}, + sendMessage: () => {}, +}; + +// Create context that can be undefined +export const PlaygroundContext = createContext< + PlaygroundContextType | undefined +>(DEFAULT_CONTEXT); + +// Custom hook that handles the undefined context +export const usePlaygroundContext = () => { + const context = useContext(PlaygroundContext); + return context ?? DEFAULT_CONTEXT; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx index f69a6e52a0a9..63d6539f2f1b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx @@ -3,9 +3,10 @@ import {MOON_250} from '@wandb/weave/common/css/color.styles'; import {Switch} from '@wandb/weave/components'; import * as Tabs from '@wandb/weave/components/Tabs'; import {Tag} from '@wandb/weave/components/Tag'; -import React, {SetStateAction} from 'react'; +import React from 'react'; -import {PlaygroundState, PlaygroundStateKey} from '../types'; +import {SetPlaygroundStateFieldFunctionType} from '../PlaygroundChat/useChatFunctions'; +import {PlaygroundState} from '../types'; import {FunctionEditor} from './FunctionEditor'; import {PlaygroundSlider} from './PlaygroundSlider'; import {ResponseFormatEditor} from './ResponseFormatEditor'; @@ -13,11 +14,7 @@ import {StopSequenceEditor} from './StopSequenceEditor'; export type PlaygroundSettingsProps = { playgroundStates: PlaygroundState[]; - setPlaygroundStateField: ( - index: number, - field: PlaygroundStateKey, - value: SetStateAction - ) => void; + setPlaygroundStateField: SetPlaygroundStateFieldFunctionType; settingsTab: number; setSettingsTab: (tab: number) => void; };