Skip to content

Commit

Permalink
Upgrade ModelFusion.
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Nov 28, 2023
1 parent d69d2ab commit ecbe183
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 135 deletions.
2 changes: 1 addition & 1 deletion lib/extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"@rubberduck/common": "*",
"handlebars": "^4.7.7",
"marked": "4.2.12",
"modelfusion": "0.41.1",
"modelfusion": "0.87.2",
"secure-json-parse": "2.7.0",
"simple-git": "3.16.1",
"zod": "3.21.4"
Expand Down
52 changes: 28 additions & 24 deletions lib/extension/src/ai/AIClient.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import {
InstructionPrompt,
LlamaCppTextGenerationModel,
Llama2PromptFormat,
OpenAIApiConfiguration,
OpenAIChatModel,
OpenAITextEmbeddingModel,
OpenAITextEmbeddingResponse,
TextInstructionPrompt,
TextStreamingModel,
embed,
mapInstructionPromptToLlama2Format,
mapInstructionPromptToOpenAIChatFormat,
llamacpp,
openai,
streamText,
} from "modelfusion";
import * as vscode from "vscode";
Expand Down Expand Up @@ -77,24 +75,29 @@ export class AIClient {
maxTokens: number;
stop?: string[] | undefined;
temperature?: number | undefined;
}): Promise<TextStreamingModel<InstructionPrompt>> {
}): Promise<TextStreamingModel<TextInstructionPrompt>> {
const modelConfiguration = getModel();

return modelConfiguration === "llama.cpp"
? new LlamaCppTextGenerationModel({
maxCompletionTokens: maxTokens,
stopSequences: stop,
temperature,
}).withPromptFormat(mapInstructionPromptToLlama2Format())
: new OpenAIChatModel({
api: await this.getOpenAIApiConfiguration(),
model: modelConfiguration,
maxCompletionTokens: maxTokens,
stopSequences: stop,
temperature,
frequencyPenalty: 0,
presencePenalty: 0,
}).withPromptFormat(mapInstructionPromptToOpenAIChatFormat());
? llamacpp
.TextGenerator({
maxCompletionTokens: maxTokens,
stopSequences: stop,
temperature,
})
// TODO the prompt format needs to be configurable for non-Llama2 models
.withTextPromptFormat(Llama2PromptFormat.instruction())
: openai
.ChatTextGenerator({
api: await this.getOpenAIApiConfiguration(),
model: modelConfiguration,
maxCompletionTokens: maxTokens,
stopSequences: stop,
temperature,
frequencyPenalty: 0,
presencePenalty: 0,
})
.withInstructionPrompt();
}

async streamText({
Expand All @@ -119,12 +122,13 @@ export class AIClient {
async generateEmbedding({ input }: { input: string }) {
try {
const { value, response } = await embed(
new OpenAITextEmbeddingModel({
openai.TextEmbedder({
api: await this.getOpenAIApiConfiguration(),
model: "text-embedding-ada-002",
}),
input
).asFullResponse();
input,
{ returnType: "full" }
);

return {
type: "success" as const,
Expand Down
Loading

0 comments on commit ecbe183

Please sign in to comment.