diff --git a/lib/extension/package.json b/lib/extension/package.json index a1e0f16..3698e70 100644 --- a/lib/extension/package.json +++ b/lib/extension/package.json @@ -15,7 +15,7 @@ "@rubberduck/common": "*", "handlebars": "^4.7.7", "marked": "4.2.12", - "modelfusion": "0.40.1", + "modelfusion": "0.41.1", "secure-json-parse": "2.7.0", "simple-git": "3.16.1", "zod": "3.21.4" diff --git a/lib/extension/src/ai/AIClient.ts b/lib/extension/src/ai/AIClient.ts index e111de0..e3f6be5 100644 --- a/lib/extension/src/ai/AIClient.ts +++ b/lib/extension/src/ai/AIClient.ts @@ -1,8 +1,11 @@ import { + InstructionPrompt, LlamaCppTextGenerationModel, OpenAIApiConfiguration, OpenAIChatModel, OpenAITextEmbeddingModel, + OpenAITextEmbeddingResponse, + TextStreamingModel, embed, mapInstructionPromptToLlama2Format, mapInstructionPromptToOpenAIChatFormat, @@ -66,49 +69,56 @@ export class AIClient { }); } - async streamText({ - prompt, + async getTextStreamingModel({ maxTokens, stop, temperature = 0, }: { - prompt: string; maxTokens: number; stop?: string[] | undefined; temperature?: number | undefined; - }) { - this.logger.log(["--- Start prompt ---", prompt, "--- End prompt ---"]); - + }): Promise> { const modelConfiguration = getModel(); - if (modelConfiguration === "llama.cpp") { - return streamText( - new LlamaCppTextGenerationModel({ + return modelConfiguration === "llama.cpp" + ? new LlamaCppTextGenerationModel({ maxCompletionTokens: maxTokens, stopSequences: stop, temperature, - }).withPromptFormat(mapInstructionPromptToLlama2Format()), - { instruction: prompt } - ); - } + }).withPromptFormat(mapInstructionPromptToLlama2Format()) + : new OpenAIChatModel({ + api: await this.getOpenAIApiConfiguration(), + model: modelConfiguration, + maxCompletionTokens: maxTokens, + stopSequences: stop, + temperature, + frequencyPenalty: 0, + presencePenalty: 0, + }).withPromptFormat(mapInstructionPromptToOpenAIChatFormat()); + } + + async streamText({ + prompt, + maxTokens, + stop, + temperature = 0, + }: { + prompt: string; + maxTokens: number; + stop?: string[] | undefined; + temperature?: number | undefined; + }) { + this.logger.log(["--- Start prompt ---", prompt, "--- End prompt ---"]); return streamText( - new OpenAIChatModel({ - api: await this.getOpenAIApiConfiguration(), - model: modelConfiguration, - maxCompletionTokens: maxTokens, - stopSequences: stop, - temperature, - frequencyPenalty: 0, - presencePenalty: 0, - }).withPromptFormat(mapInstructionPromptToOpenAIChatFormat()), + await this.getTextStreamingModel({ maxTokens, stop, temperature }), { instruction: prompt } ); } async generateEmbedding({ input }: { input: string }) { try { - const { output, response } = await embed( + const { value, response } = await embed( new OpenAITextEmbeddingModel({ api: await this.getOpenAIApiConfiguration(), model: "text-embedding-ada-002", @@ -118,8 +128,9 @@ export class AIClient { return { type: "success" as const, - embedding: output, - totalTokenCount: response[0]!.usage.total_tokens, + embedding: value, + totalTokenCount: (response as OpenAITextEmbeddingResponse).usage + .total_tokens, }; } catch (error: any) { console.log(error); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 755eee8..5681a2d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -74,8 +74,8 @@ importers: specifier: 4.2.12 version: 4.2.12 modelfusion: - specifier: 0.40.1 - version: 0.40.1 + specifier: 0.41.1 + version: 0.41.1 secure-json-parse: specifier: 2.7.0 version: 2.7.0 @@ -3009,8 +3009,8 @@ packages: ufo: 1.0.1 dev: true - /modelfusion@0.40.1: - resolution: {integrity: sha512-eSyuwkIELr41gkJupZXl5mkEnHh6LdA6P6XRoCy/MynAvgxJo21CnF9P12D530wPKMEU3QGmir1Lbuk5afnXtA==} + /modelfusion@0.41.1: + resolution: {integrity: sha512-vRiXBU29+vOu1zLDKhGEcZXc8iX9YNJkNqmm/w9kftwMr3LqWg0Rxh2QlA3XzH//JKqxiXEi2z1ThTPEGuii4g==} engines: {node: '>=18'} dependencies: deep-equal: 2.2.2