Skip to content

Commit

Permalink
chore(weave): Add call chat to playground page (#2993)
Browse files Browse the repository at this point in the history
* add playground context

* clean up

* types

* make context nullable, remove all the optionals

* add call chat to playground and add functions interacting with chat messages

* lint

* remove callchat edits, move provider into playgroundchat

* pr comments, tighten type defs
  • Loading branch information
jwlee64 authored Nov 18, 2024
1 parent dd2d5de commit f664268
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ import {Box} from '@mui/material';
import {Select} from '@wandb/weave/components/Form/Select';
import React from 'react';

import {LLM_MAX_TOKENS} from '../llmMaxTokens';
import {LLM_MAX_TOKENS, LLMMaxTokensKey} from '../llmMaxTokens';

interface LLMDropdownProps {
value: string;
onChange: (value: string, maxTokens: number) => void;
value: LLMMaxTokensKey;
onChange: (value: LLMMaxTokensKey, maxTokens: number) => void;
}

export const LLMDropdown: React.FC<LLMDropdownProps> = ({value, onChange}) => {
const options = Object.keys(LLM_MAX_TOKENS).map(llm => ({
value: llm,
label: llm,
}));
const options: Array<{value: LLMMaxTokensKey; label: LLMMaxTokensKey}> =
Object.keys(LLM_MAX_TOKENS).map(llm => ({
value: llm as LLMMaxTokensKey,
label: llm as LLMMaxTokensKey,
}));

return (
<Box
Expand All @@ -35,7 +36,7 @@ export const LLMDropdown: React.FC<LLMDropdownProps> = ({value, onChange}) => {
LLM_MAX_TOKENS[
(option as {value: string}).value as keyof typeof LLM_MAX_TOKENS
]?.max_tokens || 0;
onChange((option as {value: string}).value, maxTokens);
onChange((option as {value: LLMMaxTokensKey}).value, maxTokens);
}
}}
options={options}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import {Box, CircularProgress, Divider} from '@mui/material';
import {MOON_200} from '@wandb/weave/common/css/color.styles';
import {Tailwind} from '@wandb/weave/components/Tailwind';
import React, {SetStateAction, useState} from 'react';
import React, {useState} from 'react';

import {CallChat} from '../../CallPage/CallChat';
import {TraceCallSchema} from '../../wfReactInterface/traceServerClientTypes';
import {PlaygroundState, PlaygroundStateKey} from '../types';
import {PlaygroundContext} from '../PlaygroundContext';
import {PlaygroundState} from '../types';
import {PlaygroundCallStats} from './PlaygroundCallStats';
import {PlaygroundChatInput} from './PlaygroundChatInput';
import {PlaygroundChatTopBar} from './PlaygroundChatTopBar';
import {
SetPlaygroundStateFieldFunctionType,
useChatFunctions,
} from './useChatFunctions';

export type PlaygroundChatProps = {
entity: string;
project: string;
setPlaygroundStates: (states: PlaygroundState[]) => void;
playgroundStates: PlaygroundState[];
setPlaygroundStateField: (
index: number,
field: PlaygroundStateKey,
value: SetStateAction<PlaygroundState[PlaygroundStateKey]>
) => void;
setPlaygroundStateField: SetPlaygroundStateFieldFunctionType;
setSettingsTab: (callIndex: number | null) => void;
settingsTab: number | null;
};
Expand All @@ -37,6 +39,16 @@ export const PlaygroundChat = ({
const [isLoading, setIsLoading] = useState(false);
const chatPercentWidth = 100 / playgroundStates.length;

const {deleteMessage, editMessage, deleteChoice, editChoice, addMessage} =
useChatFunctions(setPlaygroundStateField);

const handleAddMessage = (role: 'assistant' | 'user', text: string) => {
for (let i = 0; i < playgroundStates.length; i++) {
addMessage(i, {role, content: text});
}
setChatText('');
};

return (
<Box
sx={{
Expand Down Expand Up @@ -119,7 +131,36 @@ export const PlaygroundChat = ({
}}>
<Tailwind>
<div className="mx-auto h-full min-w-[400px] max-w-[800px] pb-8">
Chat
{state.traceCall && (
<PlaygroundContext.Provider
value={{
isPlayground: true,
deleteMessage: (messageIndex, responseIndexes) =>
deleteMessage(idx, messageIndex, responseIndexes),
editMessage: (messageIndex, newMessage) =>
editMessage(idx, messageIndex, newMessage),
deleteChoice: choiceIndex =>
deleteChoice(idx, choiceIndex),
addMessage: newMessage => addMessage(idx, newMessage),
editChoice: (choiceIndex, newChoice) =>
editChoice(idx, choiceIndex, newChoice),
retry: (messageIndex: number, isChoice?: boolean) =>
console.log('retry', messageIndex, isChoice),
sendMessage: (
role: 'assistant' | 'user' | 'tool',
content: string,
toolCallId?: string
) =>
console.log(
'sendMessage',
role,
content,
toolCallId
),
}}>
<CallChat call={state.traceCall as TraceCallSchema} />
</PlaygroundContext.Provider>
)}
</div>
</Tailwind>
</Box>
Expand Down Expand Up @@ -147,7 +188,7 @@ export const PlaygroundChat = ({
setChatText={setChatText}
isLoading={isLoading}
onSend={() => {}}
onAdd={() => {}}
onAdd={handleAddMessage}
/>
</Box>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@ import React from 'react';
import {useHistory} from 'react-router-dom';

import {CopyableId} from '../../common/Id';
import {LLMMaxTokensKey} from '../llmMaxTokens';
import {OptionalTraceCallSchema, PlaygroundState} from '../types';
import {DEFAULT_SYSTEM_MESSAGE} from '../usePlaygroundState';
import {LLMDropdown} from './LLMDropdown';
import {SetPlaygroundStateFieldFunctionType} from './useChatFunctions';

type PlaygroundChatTopBarProps = {
idx: number;
settingsTab: number | null;
setSettingsTab: (tab: number | null) => void;
setPlaygroundStateField: (
index: number,
field: keyof PlaygroundState,
value: any
) => void;
setPlaygroundStateField: SetPlaygroundStateFieldFunctionType;
entity: string;
project: string;
playgroundStates: PlaygroundState[];
Expand Down Expand Up @@ -60,7 +58,7 @@ export const PlaygroundChatTopBar: React.FC<PlaygroundChatTopBarProps> = ({

const handleModelChange = (
index: number,
model: string,
model: LLMMaxTokensKey,
maxTokens: number
) => {
setPlaygroundStateField(index, 'model', model);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import {cloneDeep} from 'lodash';
import {SetStateAction} from 'react';

import {Choice, Message} from '../../ChatView/types';
import {OptionalTraceCallSchema, PlaygroundState} from '../types';
import {DEFAULT_SYSTEM_MESSAGE} from '../usePlaygroundState';
type TraceCallOutput = {
choices?: any[];
};

export type SetPlaygroundStateFieldFunctionType = (
index: number,
field: keyof PlaygroundState,
// The value here is a function that returns a PlaygroundState field
value: SetStateAction<PlaygroundState[keyof PlaygroundState]>
) => void;

export const useChatFunctions = (
setPlaygroundStateField: SetPlaygroundStateFieldFunctionType
) => {
const deleteMessage = (
callIndex: number,
messageIndex: number,
responseIndexes?: number[]
) => {
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
cloneDeep(prevTraceCall as OptionalTraceCallSchema)
);
if (newTraceCall && newTraceCall.inputs?.messages) {
// Remove the message and all responses to it
newTraceCall.inputs.messages = newTraceCall.inputs.messages.filter(
(_: any, index: number) =>
index !== messageIndex && !responseIndexes?.includes(index)
);

// If there are no messages left, add a system message
if (newTraceCall.inputs.messages.length === 0) {
newTraceCall.inputs.messages = [DEFAULT_SYSTEM_MESSAGE];
}
}
return newTraceCall;
});
};

const editMessage = (
callIndex: number,
messageIndex: number,
newMessage: Message
) => {
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
cloneDeep(prevTraceCall as OptionalTraceCallSchema)
);
if (newTraceCall && newTraceCall.inputs?.messages) {
// Replace the message
newTraceCall.inputs.messages[messageIndex] = newMessage;
}
return newTraceCall;
});
};

const addMessage = (callIndex: number, newMessage: Message) => {
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
cloneDeep(prevTraceCall as OptionalTraceCallSchema)
);
if (newTraceCall && newTraceCall.inputs?.messages) {
if (
newTraceCall.output &&
(newTraceCall.output as TraceCallOutput).choices &&
Array.isArray((newTraceCall.output as TraceCallOutput).choices)
) {
// Add all the choices as messages
(newTraceCall.output as TraceCallOutput).choices!.forEach(
(choice: any) => {
if (choice.message) {
newTraceCall.inputs!.messages.push(choice.message);
}
}
);
// Set the choices to undefined
(newTraceCall.output as TraceCallOutput).choices = undefined;
}
// Add the new message
newTraceCall.inputs.messages.push(newMessage);
}
return newTraceCall;
});
};

const deleteChoice = (callIndex: number, choiceIndex: number) => {
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
cloneDeep(prevTraceCall as OptionalTraceCallSchema)
);
const output = newTraceCall?.output as TraceCallOutput;
if (output && Array.isArray(output.choices)) {
// Remove the choice
output.choices.splice(choiceIndex, 1);
if (newTraceCall) {
newTraceCall.output = output;
}
}
return newTraceCall;
});
};

const editChoice = (
callIndex: number,
choiceIndex: number,
newChoice: Choice
) => {
setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => {
const newTraceCall = clearTraceCall(
cloneDeep(prevTraceCall as OptionalTraceCallSchema)
);
if (
newTraceCall?.output &&
Array.isArray((newTraceCall.output as TraceCallOutput).choices)
) {
// Delete the old choice
(newTraceCall.output as TraceCallOutput).choices!.splice(
choiceIndex,
1
);

// Add the new choice as a message
newTraceCall.inputs = newTraceCall.inputs ?? {};
newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? [];
newTraceCall.inputs.messages.push({
role: 'assistant',
content: newChoice.message?.content,
});
}
return newTraceCall;
});
};

return {
deleteMessage,
editMessage,
addMessage,
deleteChoice,
editChoice,
};
};

export const clearTraceCall = (traceCall: OptionalTraceCallSchema) => {
if (traceCall) {
traceCall.id = '';
traceCall.summary = undefined;
}
return traceCall;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import {createContext, useContext} from 'react';

import {Choice, Message} from '../ChatView/types';

export type PlaygroundContextType = {
isPlayground: boolean;
addMessage: (newMessage: Message) => void;
editMessage: (messageIndex: number, newMessage: Message) => void;
deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void;

editChoice: (choiceIndex: number, newChoice: Choice) => void;
deleteChoice: (choiceIndex: number) => void;

retry: (messageIndex: number, isChoice?: boolean) => void;
sendMessage: (
role: 'assistant' | 'user' | 'tool',
content: string,
toolCallId?: string
) => void;
};

const DEFAULT_CONTEXT: PlaygroundContextType = {
isPlayground: false,
addMessage: () => {},
editMessage: () => {},
deleteMessage: () => {},

editChoice: () => {},
deleteChoice: () => {},

retry: () => {},
sendMessage: () => {},
};

// Create context that can be undefined
export const PlaygroundContext = createContext<
PlaygroundContextType | undefined
>(DEFAULT_CONTEXT);

// Custom hook that handles the undefined context
export const usePlaygroundContext = () => {
const context = useContext(PlaygroundContext);
return context ?? DEFAULT_CONTEXT;
};
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@ import {MOON_250} from '@wandb/weave/common/css/color.styles';
import {Switch} from '@wandb/weave/components';
import * as Tabs from '@wandb/weave/components/Tabs';
import {Tag} from '@wandb/weave/components/Tag';
import React, {SetStateAction} from 'react';
import React from 'react';

import {PlaygroundState, PlaygroundStateKey} from '../types';
import {SetPlaygroundStateFieldFunctionType} from '../PlaygroundChat/useChatFunctions';
import {PlaygroundState} from '../types';
import {FunctionEditor} from './FunctionEditor';
import {PlaygroundSlider} from './PlaygroundSlider';
import {ResponseFormatEditor} from './ResponseFormatEditor';
import {StopSequenceEditor} from './StopSequenceEditor';

export type PlaygroundSettingsProps = {
playgroundStates: PlaygroundState[];
setPlaygroundStateField: (
index: number,
field: PlaygroundStateKey,
value: SetStateAction<PlaygroundState[PlaygroundStateKey]>
) => void;
setPlaygroundStateField: SetPlaygroundStateFieldFunctionType;
settingsTab: number;
setSettingsTab: (tab: number) => void;
};
Expand Down

0 comments on commit f664268

Please sign in to comment.