Skip to content

Commit

Permalink
chore(weave): Wire up completions (#3009)
Browse files Browse the repository at this point in the history
* wire up completions

* fix styles
  • Loading branch information
jwlee64 authored Nov 20, 2024
1 parent 7aa0140 commit db8a09c
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {PlaygroundState} from '../types';
import {PlaygroundCallStats} from './PlaygroundCallStats';
import {PlaygroundChatInput} from './PlaygroundChatInput';
import {PlaygroundChatTopBar} from './PlaygroundChatTopBar';
import {useChatCompletionFunctions} from './useChatCompletionFunctions';
import {
SetPlaygroundStateFieldFunctionType,
useChatFunctions,
Expand All @@ -35,10 +36,19 @@ export const PlaygroundChat = ({
settingsTab,
}: PlaygroundChatProps) => {
const [chatText, setChatText] = useState('');
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [isLoading, setIsLoading] = useState(false);
const chatPercentWidth = 100 / playgroundStates.length;

const {handleRetry, handleSend} = useChatCompletionFunctions(
setPlaygroundStates,
setIsLoading,
chatText,
playgroundStates,
entity,
project,
setChatText
);

const {deleteMessage, editMessage, deleteChoice, editChoice, addMessage} =
useChatFunctions(setPlaygroundStateField);

Expand Down Expand Up @@ -145,18 +155,14 @@ export const PlaygroundChat = ({
editChoice: (choiceIndex, newChoice) =>
editChoice(idx, choiceIndex, newChoice),
retry: (messageIndex: number, isChoice?: boolean) =>
console.log('retry', messageIndex, isChoice),
handleRetry(idx, messageIndex, isChoice),
sendMessage: (
role: 'assistant' | 'user' | 'tool',
content: string,
toolCallId?: string
) =>
console.log(
'sendMessage',
role,
content,
toolCallId
),
) => {
handleSend(role, idx, content, toolCallId);
},
}}>
<CallChat call={state.traceCall as TraceCallSchema} />
</PlaygroundContext.Provider>
Expand Down Expand Up @@ -187,7 +193,7 @@ export const PlaygroundChat = ({
chatText={chatText}
setChatText={setChatText}
isLoading={isLoading}
onSend={() => {}}
onSend={handleSend}
onAdd={handleAddMessage}
/>
</Box>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export const PlaygroundChatTopBar: React.FC<PlaygroundChatTopBarProps> = ({
display: 'flex',
gap: '8px',
alignItems: 'center',
backgroundColor: 'white',
backgroundColor: 'transparent',
}}>
{!onlyOneChat && <Tag label={`${idx + 1}`} />}
<LLMDropdown
Expand All @@ -97,7 +97,7 @@ export const PlaygroundChatTopBar: React.FC<PlaygroundChatTopBarProps> = ({
display: 'flex',
alignItems: 'center',
gap: '4px',
backgroundColor: 'white',
backgroundColor: 'transparent',
}}>
<Button
tooltip={'Clear chat'}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import {toast} from '@wandb/weave/common/components/elements/Toast';
import React from 'react';
import {Link} from 'react-router-dom';

import {Message} from '../../ChatView/types';
import {useGetTraceServerClientContext} from '../../wfReactInterface/traceServerClientContext';
import {CompletionsCreateRes} from '../../wfReactInterface/traceServerClientTypes';
import {PlaygroundState} from '../types';
import {getInputFromPlaygroundState} from '../usePlaygroundState';
import {clearTraceCall} from './useChatFunctions';

export const useChatCompletionFunctions = (
setPlaygroundStates: (states: PlaygroundState[]) => void,
setIsLoading: (isLoading: boolean) => void,
chatText: string,
playgroundStates: PlaygroundState[],
entity: string,
project: string,
setChatText: (text: string) => void
) => {
const getTsClient = useGetTraceServerClientContext();

const makeCompletionRequest = async (
callIndex: number,
updatedStates: PlaygroundState[]
): Promise<CompletionsCreateRes | null> => {
const inputs = getInputFromPlaygroundState(updatedStates[callIndex]);

return getTsClient().completionsCreate({
project_id: `${entity}/${project}`,
inputs,
track_llm_call: updatedStates[callIndex].trackLLMCall,
});
};

const handleErrorsAndUpdate = async (
response: Array<CompletionsCreateRes | null>,
updatedStates: PlaygroundState[],
callIndex?: number
): Promise<boolean> => {
const hasMissingLLMApiKey = handleMissingLLMApiKey(response, entity);
const hasError = handleErrorResponse(response.map(r => r?.response));

if (hasMissingLLMApiKey || hasError) {
return false;
}

const finalStates = updatedStates.map((state, index) => {
if (callIndex === undefined || index === callIndex) {
return handleUpdateCallWithResponse(state, response[index]);
}
return state;
});

setPlaygroundStates(finalStates);
return true;
};

const handleSend = async (
role: 'assistant' | 'user' | 'tool',
callIndex?: number,
content?: string,
toolCallId?: string
) => {
try {
setIsLoading(true);
const newMessage = createMessage(role, content || chatText, toolCallId);
const updatedStates = playgroundStates.map((state, index) => {
if (callIndex !== undefined && callIndex !== index) {
return state;
}
const updatedState = appendChoicesToMessages(state);
if (updatedState.traceCall?.inputs?.messages) {
updatedState.traceCall.inputs.messages.push(newMessage);
}
return updatedState;
});

setPlaygroundStates(updatedStates);
setChatText('');

const responses = await Promise.all(
updatedStates.map(async (_, index) => {
if (callIndex !== undefined && callIndex !== index) {
return Promise.resolve(null);
}
return await makeCompletionRequest(index, updatedStates);
})
);
await handleErrorsAndUpdate(responses, updatedStates);
} catch (error) {
console.error('Error processing completion:', error);
} finally {
setIsLoading(false);
}
};

const handleRetry = async (
callIndex: number,
messageIndex: number,
isChoice?: boolean
) => {
try {
setIsLoading(true);
const updatedStates = playgroundStates.map((state, index) => {
if (index === callIndex) {
if (isChoice) {
return appendChoicesToMessages(state);
}
const updatedState = JSON.parse(JSON.stringify(state));
if (updatedState.traceCall?.inputs?.messages) {
updatedState.traceCall.inputs.messages =
updatedState.traceCall.inputs.messages.slice(0, messageIndex + 1);
}
return updatedState;
}
return state;
});

const response = await makeCompletionRequest(callIndex, updatedStates);
await handleErrorsAndUpdate(
updatedStates.map(() => response),
updatedStates,
callIndex
);
} catch (error) {
console.error('Error processing completion:', error);
} finally {
setIsLoading(false);
}
};

return {handleRetry, handleSend};
};

// Helper functions
const createMessage = (
role: 'assistant' | 'user' | 'tool',
content: string,
toolCallId?: string
): Message | undefined => {
return content.trim() ? {role, content, tool_call_id: toolCallId} : undefined;
};

const handleMissingLLMApiKey = (responses: any, entity: string): boolean => {
if (Array.isArray(responses)) {
responses.forEach((response: any) => {
handleMissingLLMApiKey(response, entity);
});
} else {
if (responses && responses.api_key && responses.reason) {
toast(
<div>
<div>{responses.reason}</div>
Please add your API key to{' '}
<Link to={`/${entity}/settings`}>Team secrets in settings</Link> to
use this LLM
</div>,
{
type: 'error',
}
);
return true;
}
}
return false;
};

const handleErrorResponse = (
responses: Array<CompletionsCreateRes | null | {error: string}>
): boolean => {
if (!responses) {
return true;
}
if (responses.some(r => r && 'error' in r)) {
const errorResponse = responses.find(r => r && 'error' in r) as {
error: string;
};
toast(errorResponse?.error, {
type: 'error',
});
return true;
}

return false;
};

const handleUpdateCallWithResponse = (
updatedCall: PlaygroundState,
response: any
): PlaygroundState => {
if (!response) {
return updatedCall;
}
return {
...updatedCall,
traceCall: {
...clearTraceCall(updatedCall.traceCall),
id: response.weave_call_id ?? '',
output: response.response,
},
};
};

const appendChoicesToMessages = (state: PlaygroundState): PlaygroundState => {
const updatedState = JSON.parse(JSON.stringify(state));
if (
updatedState.traceCall?.inputs?.messages &&
updatedState.traceCall.output?.choices
) {
updatedState.traceCall.output.choices.forEach((choice: any) => {
if (choice.message) {
updatedState.traceCall.inputs.messages.push(choice.message);
}
});
updatedState.traceCall.output.choices = undefined;
}
return updatedState;
};
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {cloneDeep} from 'lodash';
import {SetStateAction} from 'react';

import {Choice, Message} from '../../ChatView/types';
import {Message} from '../../ChatView/types';
import {OptionalTraceCallSchema, PlaygroundState} from '../types';
import {DEFAULT_SYSTEM_MESSAGE} from '../usePlaygroundState';
type TraceCallOutput = {
Expand Down Expand Up @@ -109,7 +109,7 @@ export const useChatFunctions = (
const editChoice = (
callIndex: number,
choiceIndex: number,
newChoice: Choice
newChoice: Message
) => {
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
Expand All @@ -128,10 +128,7 @@ export const useChatFunctions = (
// Add the new choice as a message
newTraceCall.inputs = newTraceCall.inputs ?? {};
newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? [];
newTraceCall.inputs.messages.push({
role: 'assistant',
content: newChoice.message?.content,
});
newTraceCall.inputs.messages.push(newChoice);
}
return newTraceCall;
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import {createContext, useContext} from 'react';

import {Choice, Message} from '../ChatView/types';
import {Message} from '../ChatView/types';

export type PlaygroundContextType = {
isPlayground: boolean;
addMessage: (newMessage: Message) => void;
editMessage: (messageIndex: number, newMessage: Message) => void;
deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void;

editChoice: (choiceIndex: number, newChoice: Choice) => void;
editChoice: (choiceIndex: number, newChoice: Message) => void;
deleteChoice: (choiceIndex: number) => void;

retry: (messageIndex: number, isChoice?: boolean) => void;
Expand Down

0 comments on commit db8a09c

Please sign in to comment.