Skip to content

Commit

Permalink
Add models list
Browse files Browse the repository at this point in the history
  • Loading branch information
aurelien-sudre committed Nov 12, 2024
1 parent e8e5e27 commit 273e8fd
Show file tree
Hide file tree
Showing 11 changed files with 426 additions and 164 deletions.
100 changes: 73 additions & 27 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants';
import { DEFAULT_PROVIDER, staticProviders } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';
Expand All @@ -24,17 +24,57 @@ const EXAMPLE_PROMPTS = [
{ text: 'How do I center a div?' },
];

const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]
const providerList = staticProviders;

const ModelSelector = ({ model, setModel, modelList, providerList, provider, setProvider, setModelList }) => {
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});

const refreshModels = () => {
console.log('Refreshing models...');
};

useEffect(() => {
const storedApiKeys = Cookies.get('apiKeys');
if (storedApiKeys) {
setApiKeys(JSON.parse(storedApiKeys));
}
}, []);

useEffect(() => {
const firstModel = [...modelList].find((m) => m.provider == selectedProvider);
setModel(firstModel ? firstModel : null);
}, [provider]);

const ModelSelector = ({ model, setModel, modelList, providerList, provider, setProvider }) => {
return (
<div className="mb-2 flex gap-2">
<select
<select
value={provider}
onChange={(e) => {
setProvider(e.target.value);
const firstModel = [...modelList].find(m => m.provider == e.target.value);
setModel(firstModel ? firstModel.name : '');
onChange={async (e) => {
const selectedProvider = e.target.value;
setProvider(selectedProvider);
const firstModel = [...modelList].find((m) => m.provider == selectedProvider);
setModel(firstModel ? firstModel : null);

// Appel à l'API pour charger les modèles
try {
const response = await fetch('/api/models', {
method: 'POST',
body: JSON.stringify({
provider: selectedProvider,
apiKeys: apiKeys, // Passer les clés API
}),
});

if (!response.ok) {
const errorBody = await response.text();
throw new Error(`Erreur lors de la récupération des modèles: ${errorBody}`);
}

const models = await response.json();
setModelList(models); // Mettre à jour la liste des modèles
} catch (error) {
console.error('Erreur lors du chargement des modèles:', error);
}
}}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all"
>
Expand All @@ -43,24 +83,26 @@ const ModelSelector = ({ model, setModel, modelList, providerList, provider, set
{provider}
</option>
))}
<option key="Ollama" value="Ollama">
Ollama
</option>
<option key="OpenAILike" value="OpenAILike">
OpenAILike
</option>
</select>
<select
value={model}
onChange={(e) => setModel(e.target.value)}
value={model?.name}
onChange={(e) => {
const selectedModel = modelList.find((m) => m.name === e.target.value);
setModel(selectedModel ? selectedModel : null);
}}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all"
>
{[...modelList].filter(e => e.provider == provider && e.name).map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
</option>
))}
{[...modelList]
.filter((e) => e.provider == provider && e.name)
.map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.name}
</option>
))}
</select>
<IconButton onClick={refreshModels} title="Refresh Models">
<div className="i-ph:arrows-clockwise" />
</IconButton>
</div>
);
};
Expand All @@ -78,8 +120,8 @@ interface BaseChatProps {
enhancingPrompt?: boolean;
promptEnhanced?: boolean;
input?: string;
model: string;
setModel: (model: string) => void;
model: ModelInfo;
setModel: (model: ModelInfo) => void;
handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
Expand Down Expand Up @@ -111,6 +153,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
const [modelList, setModelList] = useState([]); // État pour la liste des modèles

useEffect(() => {
// Load API keys from cookies on component mount
Expand Down Expand Up @@ -138,7 +181,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
expires: 30, // 30 days
secure: true, // Only send over HTTPS
sameSite: 'strict', // Protect against CSRF
path: '/' // Accessible across the site
path: '/', // Accessible across the site
});
} catch (error) {
console.error('Error saving API keys to cookies:', error);
Expand Down Expand Up @@ -192,10 +235,11 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
<ModelSelector
model={model}
setModel={setModel}
modelList={MODEL_LIST}
modelList={modelList} // Passer la liste des modèles
providerList={providerList}
provider={provider}
setProvider={setProvider}
setModelList={setModelList} // Passer la fonction pour mettre à jour la liste des modèles
/>
<APIKeyManager
provider={provider}
Expand Down Expand Up @@ -275,7 +319,9 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
</div>
{input.length > 3 ? (
<div className="text-xs text-bolt-elements-textTertiary">
Use <kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Shift</kbd> + <kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Return</kbd> for a new line
Use <kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Shift</kbd> +{' '}
<kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Return</kbd> for
a new line
</div>
) : null}
</div>
Expand Down Expand Up @@ -309,4 +355,4 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
</div>
);
},
);
);
9 changes: 5 additions & 4 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
const textareaRef = useRef<HTMLTextAreaElement>(null);

const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL);
const [model, setModel] = useState<ModelInfo>(null);

const { showChat } = useStore(chatStore);

Expand All @@ -85,7 +85,8 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
const { messages, isLoading, input, handleInputChange, setInput, stop, append } = useChat({
api: '/api/chat',
body: {
apiKeys
model,
apiKeys,
},
onError: (error) => {
logger.error('Request failed\n\n', error);
Expand Down Expand Up @@ -188,15 +189,15 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
append({ role: 'user', content: `${diff}\n\n${_input}` });

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
append({ role: 'user', content: `${_input}` });
}

setInput('');
Expand Down
3 changes: 0 additions & 3 deletions app/entry.server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server';
import { renderHeadToString } from 'remix-island';
import { Head } from './root';
import { themeStore } from '~/lib/stores/theme';
import { initializeModelList } from '~/utils/constants';

export default async function handleRequest(
request: Request,
Expand All @@ -14,8 +13,6 @@ export default async function handleRequest(
remixContext: EntryContext,
_loadContext: AppLoadContext,
) {
await initializeModelList();

const readable = await renderToReadableStream(<RemixServer context={remixContext} url={request.url} />, {
signal: request.signal,
onError(error: unknown) {
Expand Down
22 changes: 11 additions & 11 deletions app/lib/.server/llm/api-key.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re
case 'OpenRouter':
return env.OPEN_ROUTER_API_KEY || cloudflareEnv.OPEN_ROUTER_API_KEY;
case 'Deepseek':
return env.DEEPSEEK_API_KEY || cloudflareEnv.DEEPSEEK_API_KEY
return env.DEEPSEEK_API_KEY || cloudflareEnv.DEEPSEEK_API_KEY;
case 'Mistral':
return env.MISTRAL_API_KEY || cloudflareEnv.MISTRAL_API_KEY;
case "OpenAILike":
return env.MISTRAL_API_KEY || cloudflareEnv.MISTRAL_API_KEY;
case 'OpenAILike':
return env.OPENAI_LIKE_API_KEY || cloudflareEnv.OPENAI_LIKE_API_KEY;
default:
return "";
return '';
}
}

export function getBaseURL(cloudflareEnv: Env, provider: string) {
switch (provider) {
case 'OpenAILike':
return env.OPENAI_LIKE_API_BASE_URL || cloudflareEnv.OPENAI_LIKE_API_BASE_URL;
return env.OPENAI_LIKE_API_BASE_URL || cloudflareEnv.OPENAI_LIKE_API_BASE_URL || 'http://localhost:4000';
case 'Ollama':
let baseUrl = env.OLLAMA_API_BASE_URL || cloudflareEnv.OLLAMA_API_BASE_URL || "http://localhost:11434";
if (env.RUNNING_IN_DOCKER === 'true') {
baseUrl = baseUrl.replace("localhost", "host.docker.internal");
}
return baseUrl;
let baseUrl = env.OLLAMA_API_BASE_URL || cloudflareEnv.OLLAMA_API_BASE_URL || 'http://localhost:11434';
if (env.RUNNING_IN_DOCKER === 'true') {
baseUrl = baseUrl.replace('localhost', 'host.docker.internal');
}
return baseUrl;
default:
return "";
return '';
}
}
Loading

0 comments on commit 273e8fd

Please sign in to comment.