Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Wire up completions #3009

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {PlaygroundState} from '../types';
import {PlaygroundCallStats} from './PlaygroundCallStats';
import {PlaygroundChatInput} from './PlaygroundChatInput';
import {PlaygroundChatTopBar} from './PlaygroundChatTopBar';
import {useChatCompletionFunctions} from './useChatCompletionFunctions';
import {
SetPlaygroundStateFieldFunctionType,
useChatFunctions,
Expand All @@ -35,10 +36,19 @@ export const PlaygroundChat = ({
settingsTab,
}: PlaygroundChatProps) => {
const [chatText, setChatText] = useState('');
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [isLoading, setIsLoading] = useState(false);
const chatPercentWidth = 100 / playgroundStates.length;

const {handleRetry, handleSend} = useChatCompletionFunctions(
setPlaygroundStates,
setIsLoading,
chatText,
playgroundStates,
entity,
project,
setChatText
);

const {deleteMessage, editMessage, deleteChoice, editChoice, addMessage} =
useChatFunctions(setPlaygroundStateField);

Expand Down Expand Up @@ -145,18 +155,14 @@ export const PlaygroundChat = ({
editChoice: (choiceIndex, newChoice) =>
editChoice(idx, choiceIndex, newChoice),
retry: (messageIndex: number, isChoice?: boolean) =>
console.log('retry', messageIndex, isChoice),
handleRetry(idx, messageIndex, isChoice),
sendMessage: (
role: 'assistant' | 'user' | 'tool',
content: string,
toolCallId?: string
) =>
console.log(
'sendMessage',
role,
content,
toolCallId
),
) => {
handleSend(role, idx, content, toolCallId);
},
}}>
<CallChat call={state.traceCall as TraceCallSchema} />
</PlaygroundContext.Provider>
Expand Down Expand Up @@ -187,7 +193,7 @@ export const PlaygroundChat = ({
chatText={chatText}
setChatText={setChatText}
isLoading={isLoading}
onSend={() => {}}
onSend={handleSend}
onAdd={handleAddMessage}
/>
</Box>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export const PlaygroundChatTopBar: React.FC<PlaygroundChatTopBarProps> = ({
display: 'flex',
gap: '8px',
alignItems: 'center',
backgroundColor: 'white',
backgroundColor: 'transparent',
}}>
{!onlyOneChat && <Tag label={`${idx + 1}`} />}
<LLMDropdown
Expand All @@ -97,7 +97,7 @@ export const PlaygroundChatTopBar: React.FC<PlaygroundChatTopBarProps> = ({
display: 'flex',
alignItems: 'center',
gap: '4px',
backgroundColor: 'white',
backgroundColor: 'transparent',
}}>
<Button
tooltip={'Clear chat'}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import {toast} from '@wandb/weave/common/components/elements/Toast';
import React from 'react';
import {Link} from 'react-router-dom';

import {Message} from '../../ChatView/types';
import {useGetTraceServerClientContext} from '../../wfReactInterface/traceServerClientContext';
import {CompletionsCreateRes} from '../../wfReactInterface/traceServerClientTypes';
import {PlaygroundState} from '../types';
import {getInputFromPlaygroundState} from '../usePlaygroundState';
import {clearTraceCall} from './useChatFunctions';

export const useChatCompletionFunctions = (
setPlaygroundStates: (states: PlaygroundState[]) => void,
setIsLoading: (isLoading: boolean) => void,
chatText: string,
playgroundStates: PlaygroundState[],
entity: string,
project: string,
setChatText: (text: string) => void
) => {
const getTsClient = useGetTraceServerClientContext();

const makeCompletionRequest = async (
callIndex: number,
updatedStates: PlaygroundState[]
): Promise<CompletionsCreateRes | null> => {
const inputs = getInputFromPlaygroundState(updatedStates[callIndex]);

return getTsClient().completionsCreate({
project_id: `${entity}/${project}`,
inputs,
track_llm_call: updatedStates[callIndex].trackLLMCall,
});
};

const handleErrorsAndUpdate = async (
response: Array<CompletionsCreateRes | null>,
updatedStates: PlaygroundState[],
callIndex?: number
): Promise<boolean> => {
const hasMissingLLMApiKey = handleMissingLLMApiKey(response, entity);
const hasError = handleErrorResponse(response.map(r => r?.response));

if (hasMissingLLMApiKey || hasError) {
return false;
}

const finalStates = updatedStates.map((state, index) => {
if (callIndex === undefined || index === callIndex) {
return handleUpdateCallWithResponse(state, response[index]);
}
return state;
});

setPlaygroundStates(finalStates);
return true;
};

const handleSend = async (
role: 'assistant' | 'user' | 'tool',
callIndex?: number,
content?: string,
toolCallId?: string
) => {
try {
setIsLoading(true);
const newMessage = createMessage(role, content || chatText, toolCallId);
const updatedStates = playgroundStates.map((state, index) => {
if (callIndex !== undefined && callIndex !== index) {
return state;
}
const updatedState = appendChoicesToMessages(state);
if (updatedState.traceCall?.inputs?.messages) {
updatedState.traceCall.inputs.messages.push(newMessage);
}
return updatedState;
});

setPlaygroundStates(updatedStates);
setChatText('');

const responses = await Promise.all(
updatedStates.map(async (_, index) => {
if (callIndex !== undefined && callIndex !== index) {
return Promise.resolve(null);
}
return await makeCompletionRequest(index, updatedStates);
})
);
await handleErrorsAndUpdate(responses, updatedStates);
} catch (error) {
console.error('Error processing completion:', error);
} finally {
setIsLoading(false);
}
};

const handleRetry = async (
callIndex: number,
messageIndex: number,
isChoice?: boolean
) => {
try {
setIsLoading(true);
const updatedStates = playgroundStates.map((state, index) => {
if (index === callIndex) {
if (isChoice) {
return appendChoicesToMessages(state);
}
const updatedState = JSON.parse(JSON.stringify(state));
if (updatedState.traceCall?.inputs?.messages) {
updatedState.traceCall.inputs.messages =
updatedState.traceCall.inputs.messages.slice(0, messageIndex + 1);
}
return updatedState;
}
return state;
});

const response = await makeCompletionRequest(callIndex, updatedStates);
await handleErrorsAndUpdate(
updatedStates.map(() => response),
updatedStates,
callIndex
);
} catch (error) {
console.error('Error processing completion:', error);
} finally {
setIsLoading(false);
}
};

return {handleRetry, handleSend};
};

// Helper functions
const createMessage = (
role: 'assistant' | 'user' | 'tool',
content: string,
toolCallId?: string
): Message | undefined => {
return content.trim() ? {role, content, tool_call_id: toolCallId} : undefined;
};

const handleMissingLLMApiKey = (responses: any, entity: string): boolean => {
if (Array.isArray(responses)) {
responses.forEach((response: any) => {
handleMissingLLMApiKey(response, entity);
});
} else {
if (responses && responses.api_key && responses.reason) {
toast(
<div>
<div>{responses.reason}</div>
Please add your API key to{' '}
<Link to={`/${entity}/settings`}>Team secrets in settings</Link> to
use this LLM
</div>,
{
type: 'error',
}
);
return true;
}
}
return false;
};

const handleErrorResponse = (
responses: Array<CompletionsCreateRes | null | {error: string}>
): boolean => {
if (!responses) {
return true;
}
if (responses.some(r => r && 'error' in r)) {
const errorResponse = responses.find(r => r && 'error' in r) as {
error: string;
};
toast(errorResponse?.error, {
type: 'error',
});
return true;
}

return false;
};

const handleUpdateCallWithResponse = (
updatedCall: PlaygroundState,
response: any
): PlaygroundState => {
if (!response) {
return updatedCall;
}
return {
...updatedCall,
traceCall: {
...clearTraceCall(updatedCall.traceCall),
id: response.weave_call_id ?? '',
output: response.response,
},
};
};

const appendChoicesToMessages = (state: PlaygroundState): PlaygroundState => {
const updatedState = JSON.parse(JSON.stringify(state));
if (
updatedState.traceCall?.inputs?.messages &&
updatedState.traceCall.output?.choices
) {
updatedState.traceCall.output.choices.forEach((choice: any) => {
if (choice.message) {
updatedState.traceCall.inputs.messages.push(choice.message);
}
});
updatedState.traceCall.output.choices = undefined;
}
return updatedState;
};
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {cloneDeep} from 'lodash';
import {SetStateAction} from 'react';

import {Choice, Message} from '../../ChatView/types';
import {Message} from '../../ChatView/types';
import {OptionalTraceCallSchema, PlaygroundState} from '../types';
import {DEFAULT_SYSTEM_MESSAGE} from '../usePlaygroundState';
type TraceCallOutput = {
Expand Down Expand Up @@ -109,7 +109,7 @@ export const useChatFunctions = (
const editChoice = (
callIndex: number,
choiceIndex: number,
newChoice: Choice
newChoice: Message
) => {
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
Expand All @@ -128,10 +128,7 @@ export const useChatFunctions = (
// Add the new choice as a message
newTraceCall.inputs = newTraceCall.inputs ?? {};
newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? [];
newTraceCall.inputs.messages.push({
role: 'assistant',
content: newChoice.message?.content,
});
newTraceCall.inputs.messages.push(newChoice);
}
return newTraceCall;
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import {createContext, useContext} from 'react';

import {Choice, Message} from '../ChatView/types';
import {Message} from '../ChatView/types';

export type PlaygroundContextType = {
isPlayground: boolean;
addMessage: (newMessage: Message) => void;
editMessage: (messageIndex: number, newMessage: Message) => void;
deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void;

editChoice: (choiceIndex: number, newChoice: Choice) => void;
editChoice: (choiceIndex: number, newChoice: Message) => void;
deleteChoice: (choiceIndex: number) => void;

retry: (messageIndex: number, isChoice?: boolean) => void;
Expand Down
Loading