Skip to content

Commit

Permalink
Add local configuration support
Browse files Browse the repository at this point in the history
  • Loading branch information
aurelien-sudre committed Oct 23, 2024
1 parent b59c191 commit c110932
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 46 deletions.
117 changes: 87 additions & 30 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import { classNames } from '~/utils/classNames';
import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';
import { useState, useEffect } from 'react';
import { openDatabase, setProvider, getProviderById } from '~/lib/persistence/db';

import styles from './BaseChat.module.scss';
import type { Provider } from '~/types/provider';

const EXAMPLE_PROMPTS = [
{ text: 'Build a todo app in React using Tailwind' },
Expand All @@ -22,46 +24,101 @@ const EXAMPLE_PROMPTS = [
{ text: 'How do I center a div?' },
];

const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))];
const ModelSelector = ({ model, setModel }) => {
const [providerList, setProviderList] = useState<Provider[]>([]);
const [selectedProvider, setSelectedProvider] = useState<Provider>();
const [apiKeyOrUrl, setApiKeyOrUrl] = useState<string>('');

useEffect(() => {
const loadProviders = async () => {
const providerIdList = [...new Set(MODEL_LIST.map((model) => model.provider))];
const db = await openDatabase();
if (db) {
const providerPromises = providerIdList.map(async (id) => {
return (await getProviderById(db, id)) || { id, models: MODEL_LIST.filter((m) => m.provider === id) };
});

setProviderList(await Promise.all(providerPromises));
setSelectedProvider(providerList.find((p) => p.id === DEFAULT_PROVIDER));
}
};
loadProviders();
}, []);

useEffect(() => {
if (selectedProvider) {
console.log('Selected Provider:', selectedProvider);
setApiKeyOrUrl(selectedProvider.apiKey ?? '');
}
}, [selectedProvider]);

const saveProvider = async () => {
if (selectedProvider) {
selectedProvider.apiKey = apiKeyOrUrl;
const db = await openDatabase();
if (db) {
await setProvider(db, selectedProvider);
}
}
};

const ModelSelector = ({ model, setModel, modelList, providerList }) => {
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
return (
<div className="mb-2">
<select
value={provider}
onChange={(e) => {
setProvider(e.target.value);
const firstModel = [...modelList].find((m) => m.provider == e.target.value);
setModel(firstModel ? firstModel.name : '');
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{providerList.map((provider) => (
<option key={provider} value={provider}>
{provider}
</option>
))}
</select>
<select
value={model}
onChange={(e) => setModel(e.target.value)}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{[...modelList]
.filter((e) => e.provider == provider && e.name)
.map((modelOption) => (
<div className="flex mt-2">
<select
value={selectedProvider?.id}
onChange={(e) => {
const provider = providerList.find((p) => p.id === e.target.value);
setSelectedProvider(provider);
const firstModel = provider?.models.find((m) => m);
setModel(firstModel?.name);
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{providerList.map((provider) => (
<option key={provider.id} value={provider.id}>
{provider.id}
</option>
))}
</select>
</div>
{selectedProvider?.id === 'Ollama' || (
<div className="flex mt-2">
<input
type="text"
value={apiKeyOrUrl}
placeholder="Custom URL"
onChange={(e) => setApiKeyOrUrl(e.target.value)}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
/>
<button
onClick={async (e) => {
await saveProvider();
}}
className="ml-2 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
Save
</button>
</div>
)}
<div className="flex mt-2">
<select
value={model}
onChange={(e) => setModel(e.target.value)}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{selectedProvider?.models.map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
</option>
))}
</select>
</select>
</div>
</div>
);
};

const TEXTAREA_MIN_HEIGHT = 76;

interface BaseChatProps {
textareaRef?: React.RefObject<HTMLTextAreaElement> | undefined;
messageRef?: RefCallback<HTMLDivElement> | undefined;
Expand Down Expand Up @@ -149,7 +206,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
'sticky bottom-0': chatStarted,
})}
>
<ModelSelector model={model} setModel={setModel} modelList={MODEL_LIST} providerList={providerList} />
<ModelSelector model={model} setModel={setModel} />
<div
className={classNames(
'shadow-sm border border-bolt-elements-borderColor bg-bolt-elements-prompt-background backdrop-filter backdrop-blur-[8px] rounded-lg overflow-hidden',
Expand Down
19 changes: 15 additions & 4 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import { useAnimate } from 'framer-motion';
import { memo, useEffect, useRef, useState } from 'react';
import { cssTransition, toast, ToastContainer } from 'react-toastify';
import { useMessageParser, usePromptEnhancer, useShortcuts, useSnapScroll } from '~/lib/hooks';
import { useChatHistory } from '~/lib/persistence';
import { useChatHistory, openDatabase, getProviderById } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL } from '~/utils/constants';
import { MODEL_LIST, DEFAULT_MODEL } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat';
import type { Provider } from '~/types/provider';

const toastAnimation = cssTransition({
enter: 'animated fadeInRight',
Expand Down Expand Up @@ -172,6 +173,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp

runAnimation();

let apiKey = '';
let provider: Provider | undefined;
const providerId = MODEL_LIST.find((m) => m.name === model)?.provider;
if (providerId) {
const db = await openDatabase();
if (db) {
provider = await getProviderById(db, providerId);
}
}

if (fileModifications !== undefined) {
const diff = fileModificationsToHTML(fileModifications);

Expand All @@ -182,15 +193,15 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}][APIKey: ${provider?.apiKey || ''}]\n\n${diff}\n\n${_input}` });

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}][APIKey: ${provider?.apiKey || ''}]\n\n${_input}` });
}

setInput('');
Expand Down
3 changes: 2 additions & 1 deletion app/lib/.server/llm/api-key.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// @ts-nocheck
// Preventing TS checks with files presented in the video for a better presentation.
import { env } from 'node:process';
import { openDatabase, getProviderById } from '~/lib/persistence/db';

export function getAPIKey(cloudflareEnv: Env, provider: string) {
export async function getAPIKey(cloudflareEnv: Env, provider: string) {
/**
* The `cloudflareEnv` is only used when deployed or when previewing locally.
* In development the environment variables are available through `env`.
Expand Down
11 changes: 8 additions & 3 deletions app/lib/.server/llm/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export function getGoogleModel(apiKey: string, model: string) {
}

export function getMistralModel(apiKey: string, model: string) {
const mistral = createMistral(apiKey);
const mistral = createMistral({ apiKey });

return mistral(model);
}
Expand All @@ -57,8 +57,13 @@ export function getOpenRouterModel(apiKey: string, model: string) {
return openRouter.chat(model);
}

export function getModel(provider: string, model: string, env: Env) {
const apiKey = getAPIKey(env, provider);
export function getModel(provider: string, customApiKey: string, model: string, env: Env) {
let apiKey = getAPIKey(env, provider);

// If a custom API key is provided, use it instead of the environment variable.
if (customApiKey !== '') {
apiKey = customApiKey;
}

switch (provider) {
case 'Anthropic':
Expand Down
18 changes: 12 additions & 6 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,33 @@ export type Messages = Message[];

export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;

function extractModelFromMessage(message: Message): { model: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\n\n/;
function extractModelFromMessage(message: Message): { model: string; apiKey: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\[APIKey: (.*?)\]\n\n/;
const match = message.content.match(modelRegex);

if (match) {
const model = match[1];
const apiKey = match[2];

console.log('APIKey:', apiKey);

const content = message.content.replace(modelRegex, '');
return { model, content };
return { model, apiKey, content };
}

// Default model if not specified
return { model: DEFAULT_MODEL, content: message.content };
return { model: DEFAULT_MODEL, apiKey: '', content: message.content };
}

export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
let currentModel = DEFAULT_MODEL;
let customApiKey = '';
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, content } = extractModelFromMessage(message);
const { model, apiKey, content } = extractModelFromMessage(message);
if (model && MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; // Update the current model
customApiKey = apiKey; // Update the API key if provided
}
return { ...message, content };
}
Expand All @@ -54,7 +60,7 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;

return _streamText({
model: getModel(provider, currentModel, env),
model: getModel(provider, customApiKey, currentModel, env),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
// headers: {
Expand Down
30 changes: 29 additions & 1 deletion app/lib/persistence/db.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import type { Message } from 'ai';
import { createScopedLogger } from '~/utils/logger';
import type { ChatHistoryItem } from './useChatHistory';
import type { Provider } from '~/types/provider';

const logger = createScopedLogger('ChatHistory');

// this is used at the top level and never rejects
export async function openDatabase(): Promise<IDBDatabase | undefined> {
return new Promise((resolve) => {
const request = indexedDB.open('boltHistory', 2);
const request = indexedDB.open('boltHistory', 4);

request.onupgradeneeded = (event: IDBVersionChangeEvent) => {
const db = (event.target as IDBOpenDBRequest).result;
Expand All @@ -21,6 +22,10 @@ export async function openDatabase(): Promise<IDBDatabase | undefined> {
if (!db.objectStoreNames.contains('theme')) {
db.createObjectStore('theme', { keyPath: 'id' });
}

if (!db.objectStoreNames.contains('providers')) {
db.createObjectStore('providers', { keyPath: 'id' });
}
};

request.onsuccess = (event: Event) => {
Expand Down Expand Up @@ -185,3 +190,26 @@ export async function getTheme(db: IDBDatabase): Promise<{ id: string; value: st
request.onerror = () => reject(request.error);
});
}

export async function setProvider(db: IDBDatabase, provider: Provider): Promise<void> {
return new Promise((resolve, reject) => {
const transaction = db.transaction('providers', 'readwrite');
const store = transaction.objectStore('providers');

const request = store.put(provider);

request.onsuccess = () => resolve();
request.onerror = () => reject(request.error);
});
}

export async function getProviderById(db: IDBDatabase, id: string): Promise<Provider> {
return new Promise((resolve, reject) => {
const transaction = db.transaction('providers', 'readonly');
const store = transaction.objectStore('providers');
const request = store.get(id);

request.onsuccess = () => resolve(request.result as Provider);
request.onerror = () => reject(request.error);
});
}
7 changes: 7 additions & 0 deletions app/types/provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import type { ModelInfo } from '~/utils/types';

export interface Provider {
id: string;
apiKey: string;
models: ModelInfo[];
}
2 changes: 1 addition & 1 deletion app/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const staticModels: ModelInfo[] = [
{ name: 'gpt-4-turbo', label: 'GPT-4 Turbo', provider: 'OpenAI' },
{ name: 'gpt-4', label: 'GPT-4', provider: 'OpenAI' },
{ name: 'gpt-3.5-turbo', label: 'GPT-3.5 Turbo', provider: 'OpenAI' },
{ name: 'mistral-large', label: 'Mistral-large', provider: 'Mistral' },
{ name: 'mistral-large-latest', label: 'Mistral-large-latest', provider: 'Mistral' },
];

export let MODEL_LIST: ModelInfo[] = [...staticModels];
Expand Down

0 comments on commit c110932

Please sign in to comment.