Skip to content

Commit

Permalink
Extract getTextStreamingModel.
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Oct 5, 2023
1 parent e7e63f3 commit 3629c7c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
2 changes: 1 addition & 1 deletion lib/extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"@rubberduck/common": "*",
"handlebars": "^4.7.7",
"marked": "4.2.12",
"modelfusion": "0.40.1",
"modelfusion": "0.41.1",
"secure-json-parse": "2.7.0",
"simple-git": "3.16.1",
"zod": "3.21.4"
Expand Down
61 changes: 36 additions & 25 deletions lib/extension/src/ai/AIClient.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import {
InstructionPrompt,
LlamaCppTextGenerationModel,
OpenAIApiConfiguration,
OpenAIChatModel,
OpenAITextEmbeddingModel,
OpenAITextEmbeddingResponse,
TextStreamingModel,
embed,
mapInstructionPromptToLlama2Format,
mapInstructionPromptToOpenAIChatFormat,
Expand Down Expand Up @@ -66,49 +69,56 @@ export class AIClient {
});
}

async streamText({
prompt,
async getTextStreamingModel({
maxTokens,
stop,
temperature = 0,
}: {
prompt: string;
maxTokens: number;
stop?: string[] | undefined;
temperature?: number | undefined;
}) {
this.logger.log(["--- Start prompt ---", prompt, "--- End prompt ---"]);

}): Promise<TextStreamingModel<InstructionPrompt>> {
const modelConfiguration = getModel();

if (modelConfiguration === "llama.cpp") {
return streamText(
new LlamaCppTextGenerationModel({
return modelConfiguration === "llama.cpp"
? new LlamaCppTextGenerationModel({
maxCompletionTokens: maxTokens,
stopSequences: stop,
temperature,
}).withPromptFormat(mapInstructionPromptToLlama2Format()),
{ instruction: prompt }
);
}
}).withPromptFormat(mapInstructionPromptToLlama2Format())
: new OpenAIChatModel({
api: await this.getOpenAIApiConfiguration(),
model: modelConfiguration,
maxCompletionTokens: maxTokens,
stopSequences: stop,
temperature,
frequencyPenalty: 0,
presencePenalty: 0,
}).withPromptFormat(mapInstructionPromptToOpenAIChatFormat());
}

async streamText({
prompt,
maxTokens,
stop,
temperature = 0,
}: {
prompt: string;
maxTokens: number;
stop?: string[] | undefined;
temperature?: number | undefined;
}) {
this.logger.log(["--- Start prompt ---", prompt, "--- End prompt ---"]);

return streamText(
new OpenAIChatModel({
api: await this.getOpenAIApiConfiguration(),
model: modelConfiguration,
maxCompletionTokens: maxTokens,
stopSequences: stop,
temperature,
frequencyPenalty: 0,
presencePenalty: 0,
}).withPromptFormat(mapInstructionPromptToOpenAIChatFormat()),
await this.getTextStreamingModel({ maxTokens, stop, temperature }),
{ instruction: prompt }
);
}

async generateEmbedding({ input }: { input: string }) {
try {
const { output, response } = await embed(
const { value, response } = await embed(
new OpenAITextEmbeddingModel({
api: await this.getOpenAIApiConfiguration(),
model: "text-embedding-ada-002",
Expand All @@ -118,8 +128,9 @@ export class AIClient {

return {
type: "success" as const,
embedding: output,
totalTokenCount: response[0]!.usage.total_tokens,
embedding: value,
totalTokenCount: (response as OpenAITextEmbeddingResponse).usage
.total_tokens,
};
} catch (error: any) {
console.log(error);
Expand Down
8 changes: 4 additions & 4 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3629c7c

Please sign in to comment.