From 2ca1851f8567010825175f41c73be60e5bbee8e9 Mon Sep 17 00:00:00 2001 From: jwlee64 Date: Mon, 18 Nov 2024 13:47:38 -0800 Subject: [PATCH 1/2] wire up completions --- .../PlaygroundChat/PlaygroundChat.tsx | 26 ++- .../useChatCompletionFunctions.tsx | 214 ++++++++++++++++++ .../PlaygroundChat/useChatFunctions.tsx | 9 +- .../PlaygroundPage/PlaygroundContext.tsx | 4 +- 4 files changed, 235 insertions(+), 18 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index ec665380197..270f030a1b6 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -10,6 +10,7 @@ import {PlaygroundState} from '../types'; import {PlaygroundCallStats} from './PlaygroundCallStats'; import {PlaygroundChatInput} from './PlaygroundChatInput'; import {PlaygroundChatTopBar} from './PlaygroundChatTopBar'; +import {useChatCompletionFunctions} from './useChatCompletionFunctions'; import { SetPlaygroundStateFieldFunctionType, useChatFunctions, @@ -35,10 +36,19 @@ export const PlaygroundChat = ({ settingsTab, }: PlaygroundChatProps) => { const [chatText, setChatText] = useState(''); - // eslint-disable-next-line @typescript-eslint/no-unused-vars const [isLoading, setIsLoading] = useState(false); const chatPercentWidth = 100 / playgroundStates.length; + const {handleRetry, handleSend} = useChatCompletionFunctions( + setPlaygroundStates, + setIsLoading, + chatText, + playgroundStates, + entity, + project, + setChatText + ); + const {deleteMessage, editMessage, deleteChoice, editChoice, addMessage} = useChatFunctions(setPlaygroundStateField); @@ -145,18 +155,14 @@ export const PlaygroundChat = ({ editChoice: (choiceIndex, newChoice) => editChoice(idx, choiceIndex, newChoice), retry: (messageIndex: number, isChoice?: boolean) => - console.log('retry', messageIndex, isChoice), + handleRetry(idx, messageIndex, isChoice), sendMessage: ( role: 'assistant' | 'user' | 'tool', content: string, toolCallId?: string - ) => - console.log( - 'sendMessage', - role, - content, - toolCallId - ), + ) => { + handleSend(role, idx, content, toolCallId); + }, }}> @@ -187,7 +193,7 @@ export const PlaygroundChat = ({ chatText={chatText} setChatText={setChatText} isLoading={isLoading} - onSend={() => {}} + onSend={handleSend} onAdd={handleAddMessage} /> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx new file mode 100644 index 00000000000..0ccb62d8140 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx @@ -0,0 +1,214 @@ +import {toast} from '@wandb/weave/common/components/elements/Toast'; +import React from 'react'; +import {Link} from 'react-router-dom'; + +import {Message} from '../../ChatView/types'; +import {useGetTraceServerClientContext} from '../../wfReactInterface/traceServerClientContext'; +import {CompletionsCreateRes} from '../../wfReactInterface/traceServerClientTypes'; +import {PlaygroundState} from '../types'; +import {getInputFromPlaygroundState} from '../usePlaygroundState'; +import {clearTraceCall} from './useChatFunctions'; + +export const useChatCompletionFunctions = ( + setPlaygroundStates: (states: PlaygroundState[]) => void, + setIsLoading: (isLoading: boolean) => void, + chatText: string, + playgroundStates: PlaygroundState[], + entity: string, + project: string, + setChatText: (text: string) => void +) => { + const getTsClient = useGetTraceServerClientContext(); + + const makeCompletionRequest = async ( + callIndex: number, + updatedStates: PlaygroundState[] + ) => { + const inputs = getInputFromPlaygroundState(updatedStates[callIndex]); + + return getTsClient().completionsCreate({ + project_id: `${entity}/${project}`, + inputs, + track_llm_call: updatedStates[callIndex].trackLLMCall, + }); + }; + + const handleErrorsAndUpdate = async ( + response: Array, + updatedStates: PlaygroundState[], + callIndex?: number + ) => { + const hasMissingLLMApiKey = handleMissingLLMApiKey(response, entity); + const hasError = handleErrorResponse(response.map(r => r?.response)); + + if (hasMissingLLMApiKey || hasError) { + return false; + } + + const finalStates = updatedStates.map((state, index) => { + if (callIndex === undefined || index === callIndex) { + return handleUpdateCallWithResponse(state, response[index]); + } + return state; + }); + + setPlaygroundStates(finalStates); + return true; + }; + + const handleSend = async ( + role: 'assistant' | 'user' | 'tool', + callIndex?: number, + content?: string, + toolCallId?: string + ) => { + try { + setIsLoading(true); + + const newMessage = createMessage(role, content || chatText, toolCallId); + const updatedStates = playgroundStates.map((state, index) => { + if (callIndex === undefined || callIndex !== index) { + return state; + } + const updatedState = appendChoicesToMessages(state); + if (updatedState.traceCall?.inputs?.messages) { + updatedState.traceCall.inputs.messages.push(newMessage); + } + return updatedState; + }); + + setPlaygroundStates(updatedStates); + setChatText(''); + + const responses = await Promise.all( + updatedStates.map((_, index) => { + if (callIndex === undefined || callIndex !== index) { + return Promise.resolve(null); + } + return makeCompletionRequest(index, updatedStates); + }) + ); + await handleErrorsAndUpdate(responses, updatedStates); + } catch (error) { + console.error('Error processing completion:', error); + } finally { + setIsLoading(false); + } + }; + + const handleRetry = async ( + callIndex: number, + messageIndex: number, + isChoice?: boolean + ) => { + try { + setIsLoading(true); + const updatedStates = playgroundStates.map((state, index) => { + if (index === callIndex) { + if (isChoice) { + return appendChoicesToMessages(state); + } + const updatedState = JSON.parse(JSON.stringify(state)); + if (updatedState.traceCall?.inputs?.messages) { + updatedState.traceCall.inputs.messages = + updatedState.traceCall.inputs.messages.slice(0, messageIndex + 1); + } + return updatedState; + } + return state; + }); + + const response = await makeCompletionRequest(callIndex, updatedStates); + await handleErrorsAndUpdate( + updatedStates.map(() => response), + updatedStates, + callIndex + ); + } catch (error) { + console.error('Error processing completion:', error); + } finally { + setIsLoading(false); + } + }; + + return {handleRetry, handleSend}; +}; + +// Helper functions +const createMessage = ( + role: 'assistant' | 'user' | 'tool', + content: string, + toolCallId?: string +): Message | undefined => { + return content.trim() ? {role, content, tool_call_id: toolCallId} : undefined; +}; + +const handleMissingLLMApiKey = (responses: any, entity: string) => { + if (Array.isArray(responses)) { + responses.forEach((response: any) => { + handleMissingLLMApiKey(response, entity); + }); + } else { + if (responses.api_key && responses.reason) { + toast( +
+
{responses.reason}
+ Please add your API key to{' '} + Team secrets in settings to + use this LLM +
, + { + type: 'error', + } + ); + return true; + } + } + return false; +}; + +const handleErrorResponse = ( + responses: Array +): boolean => { + if (!responses) { + return true; + } + if (responses.some(r => r && 'error' in r)) { + const errorResponse = responses.find(r => r && 'error' in r) as { + error: string; + }; + toast(errorResponse?.error, { + type: 'error', + }); + return true; + } + + return false; +}; + +const handleUpdateCallWithResponse = (updatedCall: any, response: any) => { + return { + ...updatedCall, + traceCall: { + ...clearTraceCall(updatedCall.traceCall), + id: response.weave_call_id ?? '', + output: response.response, + }, + }; +}; + +const appendChoicesToMessages = (state: PlaygroundState) => { + const updatedState = JSON.parse(JSON.stringify(state)); + if ( + updatedState.traceCall?.inputs?.messages && + updatedState.traceCall.output?.choices + ) { + updatedState.traceCall.output.choices.forEach((choice: any) => { + if (choice.message) { + updatedState.traceCall.inputs.messages.push(choice.message); + } + }); + updatedState.traceCall.output.choices = undefined; + } + return updatedState; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index 178752a71da..46c47abab74 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -1,7 +1,7 @@ import {cloneDeep} from 'lodash'; import {SetStateAction} from 'react'; -import {Choice, Message} from '../../ChatView/types'; +import {Message} from '../../ChatView/types'; import {OptionalTraceCallSchema, PlaygroundState} from '../types'; import {DEFAULT_SYSTEM_MESSAGE} from '../usePlaygroundState'; type TraceCallOutput = { @@ -109,7 +109,7 @@ export const useChatFunctions = ( const editChoice = ( callIndex: number, choiceIndex: number, - newChoice: Choice + newChoice: Message ) => { setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( @@ -128,10 +128,7 @@ export const useChatFunctions = ( // Add the new choice as a message newTraceCall.inputs = newTraceCall.inputs ?? {}; newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? []; - newTraceCall.inputs.messages.push({ - role: 'assistant', - content: newChoice.message?.content, - }); + newTraceCall.inputs.messages.push(newChoice); } return newTraceCall; }); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx index 14485ed2434..ec16276bc46 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx @@ -1,6 +1,6 @@ import {createContext, useContext} from 'react'; -import {Choice, Message} from '../ChatView/types'; +import {Message} from '../ChatView/types'; export type PlaygroundContextType = { isPlayground: boolean; @@ -8,7 +8,7 @@ export type PlaygroundContextType = { editMessage: (messageIndex: number, newMessage: Message) => void; deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void; - editChoice: (choiceIndex: number, newChoice: Choice) => void; + editChoice: (choiceIndex: number, newChoice: Message) => void; deleteChoice: (choiceIndex: number) => void; retry: (messageIndex: number, isChoice?: boolean) => void; From 3b826ecd81b1096dff77acc2ae9e6446070f58cc Mon Sep 17 00:00:00 2001 From: jwlee64 Date: Mon, 18 Nov 2024 14:30:23 -0800 Subject: [PATCH 2/2] fix styles --- .../PlaygroundChat/PlaygroundChatTopBar.tsx | 4 +-- .../useChatCompletionFunctions.tsx | 27 +++++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx index 2fcd24f56e3..7e1b4b13bb1 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx @@ -79,7 +79,7 @@ export const PlaygroundChatTopBar: React.FC = ({ display: 'flex', gap: '8px', alignItems: 'center', - backgroundColor: 'white', + backgroundColor: 'transparent', }}> {!onlyOneChat && } = ({ display: 'flex', alignItems: 'center', gap: '4px', - backgroundColor: 'white', + backgroundColor: 'transparent', }}>