diff --git a/.env.example b/.env.example index 234dba955..6b0c93069 100644 --- a/.env.example +++ b/.env.example @@ -43,5 +43,10 @@ OPENAI_LIKE_API_KEY= # You only need this environment variable set if you want to use Mistral models MISTRAL_API_KEY= +# Get your Cerebras API Key by following these instructions - +# https://github.com/Cerebras/inference-examples/blob/main/getting-started/README.md +# You only need this environment variable set if you want to use Cerebras models +CEREBRAS_API_KEY= + # Include this environment variable if you want more logging for debugging locally VITE_LOG_LEVEL=debug diff --git a/app/lib/.server/llm/api-key.ts b/app/lib/.server/llm/api-key.ts index e2764c1dc..42041bd52 100644 --- a/app/lib/.server/llm/api-key.ts +++ b/app/lib/.server/llm/api-key.ts @@ -25,6 +25,8 @@ export function getAPIKey(cloudflareEnv: Env, provider: string) { return env.MISTRAL_API_KEY || cloudflareEnv.MISTRAL_API_KEY; case "OpenAILike": return env.OPENAI_LIKE_API_KEY || cloudflareEnv.OPENAI_LIKE_API_KEY; + case "Cerebras": + return env.CEREBRAS_API_KEY || cloudflareEnv.CEREBRAS_API_KEY; default: return ""; } diff --git a/app/lib/.server/llm/model.ts b/app/lib/.server/llm/model.ts index 390d57aeb..0d7ebab86 100644 --- a/app/lib/.server/llm/model.ts +++ b/app/lib/.server/llm/model.ts @@ -80,6 +80,15 @@ export function getOpenRouterModel(apiKey: string, model: string) { return openRouter.chat(model); } +export function getCerebrasModel(apiKey: string, model: string) { + const openai = createOpenAI({ + baseURL: 'https://api.cerebras.ai/v1', + apiKey, + }); + + return openai(model); +} + export function getModel(provider: string, model: string, env: Env) { const apiKey = getAPIKey(env, provider); const baseURL = getBaseURL(env, provider); @@ -101,6 +110,8 @@ export function getModel(provider: string, model: string, env: Env) { return getDeepseekModel(apiKey, model) case 'Mistral': return getMistralModel(apiKey, model); + case 'Cerebras': + return getCerebrasModel(apiKey, model); default: return getOllamaModel(baseURL, model); } diff --git a/app/utils/constants.ts b/app/utils/constants.ts index b48cb3442..9b6eb2467 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -43,6 +43,9 @@ const staticModels: ModelInfo[] = [ { name: 'mistral-small-latest', label: 'Mistral Small', provider: 'Mistral' }, { name: 'codestral-latest', label: 'Codestral', provider: 'Mistral' }, { name: 'mistral-large-latest', label: 'Mistral Large Latest', provider: 'Mistral' }, + { name: 'pixtral-12b-2409', label: 'Pixtral 12B', provider: 'Mistral' }, + { name: 'llama3.1-8b', label: 'Llama 3.1 8B', provider: 'Cerebras' }, + { name: 'llama3.1-70b', label: 'Llama 3.1 70B', provider: 'Cerebras' }, ]; export let MODEL_LIST: ModelInfo[] = [...staticModels];