diff --git a/cookbooks/Function-Calling-OpenAI/typescript/function-calling-with-openai.ts b/cookbooks/Function-Calling-OpenAI/typescript/function-calling-with-openai.ts index b7f3a5d7b..b2943d413 100644 --- a/cookbooks/Function-Calling-OpenAI/typescript/function-calling-with-openai.ts +++ b/cookbooks/Function-Calling-OpenAI/typescript/function-calling-with-openai.ts @@ -1,4 +1,4 @@ -import { AIConfigRuntime, ExecuteResult } from "aiconfig"; +import { AIConfigRuntime, ExecuteResult, Prompt } from "aiconfig"; import { Chat } from "openai/resources"; const BOOKS_DB = [ @@ -131,11 +131,15 @@ async function main() { }; // Add the prompt - let newPrompts = await config.serialize(model, data, "recommend_book", { - book: "To Kill a Mockingbird", - }); - let newPrompt = Array.isArray(newPrompts) ? newPrompts[0] : newPrompts; - config.addPrompt(newPrompt); + let newPrompts: Prompt[] = await config.serialize( + model, + data, + "recommend_book", + { + book: "To Kill a Mockingbird", + } + ); + config.addPrompt(newPrompts[0]); const params = { book: "Where the Crawdads Sing" }; const inferenceOptions = { @@ -194,8 +198,7 @@ async function main() { }; newPrompts = await config.serialize(model, promptData, "gen_summary"); - newPrompt = Array.isArray(newPrompts) ? newPrompts[0] : newPrompts; - config.addPrompt(newPrompt); + config.addPrompt(newPrompts[0]); await config.run("gen_summary", { book_info: value }, inferenceOptions); } diff --git a/cookbooks/HuggingFace/typescript/hf.ts b/cookbooks/HuggingFace/typescript/hf.ts index 697d89836..53fc96b02 100644 --- a/cookbooks/HuggingFace/typescript/hf.ts +++ b/cookbooks/HuggingFace/typescript/hf.ts @@ -45,7 +45,7 @@ export class HuggingFaceTextGenerationModelParser extends aiconfig.Parameterized data: TextGenerationArgs, aiConfig: AIConfigRuntime, params?: JSONObject | undefined - ): Prompt | Prompt[] { + ): Prompt[] { const input: PromptInput = data.inputs; let modelMetadata: ModelMetadata | string; diff --git a/extensions/HuggingFace/typescript/hf.ts b/extensions/HuggingFace/typescript/hf.ts index 89bb283ef..cf25f8674 100644 --- a/extensions/HuggingFace/typescript/hf.ts +++ b/extensions/HuggingFace/typescript/hf.ts @@ -47,7 +47,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized data: TextGenerationArgs, aiConfig: AIConfigRuntime, params?: JSONObject | undefined - ): Prompt | Prompt[] { + ): Prompt[] { const startEvent = { name: "on_serialize_start", file: __filename, diff --git a/extensions/llama/typescript/llama.ts b/extensions/llama/typescript/llama.ts index 3df085e95..cec9f3b5e 100644 --- a/extensions/llama/typescript/llama.ts +++ b/extensions/llama/typescript/llama.ts @@ -105,7 +105,7 @@ export class LlamaModelParser extends ParameterizedModelParser 0) { let i = 0; @@ -153,7 +153,7 @@ export class LlamaModelParser extends ParameterizedModelParser 0, }, }; - - let result: Prompt | Prompt[] = prompt; - if (prompts.length > 0) { - prompts.push(prompt); - result = prompts; - } + prompts.push(newPrompt); const endEvent = { name: "on_serialize_end", file: __filename, - data: { - result, - }, + data: { prompts }, }; aiConfig.callbackManager.runCallbacks(endEvent); - return result; + return prompts; } public refineCompletionParams( diff --git a/typescript/__tests__/config.test.ts b/typescript/__tests__/config.test.ts index 07950822e..1a6c57c1c 100644 --- a/typescript/__tests__/config.test.ts +++ b/typescript/__tests__/config.test.ts @@ -123,7 +123,7 @@ describe("Loading an AIConfig", () => { ], }); - const serializeResult = await aiConfig.serialize( + const prompts: Prompt[] = await aiConfig.serialize( "gpt-3.5-turbo", completionParams, "prompt", @@ -132,10 +132,6 @@ describe("Loading an AIConfig", () => { } ); - expect(Array.isArray(serializeResult)).toBe(true); - - const prompts: Prompt[] = serializeResult as Prompt[]; - expect(prompts.length).toBe(2); const prompt1 = prompts[0]; @@ -177,7 +173,7 @@ describe("Loading an AIConfig", () => { ], }; - const serializeResult = await aiConfig.serialize( + const prompts: Prompt[] = await aiConfig.serialize( "gpt-3.5-turbo", completionParams, "prompt", @@ -186,10 +182,6 @@ describe("Loading an AIConfig", () => { } ); - expect(Array.isArray(serializeResult)).toBe(true); - - const prompts: Prompt[] = serializeResult as Prompt[]; - expect(prompts.length).toBe(2); const prompt1 = prompts[0]; diff --git a/typescript/__tests__/parsers/hf/hf.test.ts b/typescript/__tests__/parsers/hf/hf.test.ts index ed3dd1c51..26ecda8d6 100644 --- a/typescript/__tests__/parsers/hf/hf.test.ts +++ b/typescript/__tests__/parsers/hf/hf.test.ts @@ -126,11 +126,11 @@ describe("HuggingFaceTextGeneration ModelParser", () => { inputs: "What are 5 interesting things to do in Toronto?", }; - const prompts = parser.serialize( + const prompts: Prompt[] = parser.serialize( "interestingThingsToronto", completionParams, aiConfig - ) as Prompt[]; + ); expect(prompts).toHaveLength(1); const prompt = prompts[0]; @@ -165,11 +165,11 @@ describe("HuggingFaceTextGeneration ModelParser", () => { const callbackManager = new CallbackManager([callback]); aiConfig.setCallbackManager(callbackManager); - const prompts = parser.serialize( + const prompts: Prompt[] = parser.serialize( "interestingThingsToronto", completionParams, aiConfig - ) as Prompt[]; + ); const onStartEvent = callback.mock.calls[0][0]; expect(onStartEvent.name).toEqual("on_serialize_start"); diff --git a/typescript/__tests__/parsers/palm-text/palm.test.ts b/typescript/__tests__/parsers/palm-text/palm.test.ts index cdbffe309..89ca8003a 100644 --- a/typescript/__tests__/parsers/palm-text/palm.test.ts +++ b/typescript/__tests__/parsers/palm-text/palm.test.ts @@ -38,11 +38,11 @@ describe("PaLM Text ModelParser", () => { }; // Casting as JSONObject since the type of completionParams is protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest doesn't confrom to shape even though it looks like it does - const prompts = (await aiConfig.serialize( + const prompts: Prompt[] = await aiConfig.serialize( "models/text-bison-001", completionParams as JSONObject, "interestingThingsToronto" - )) as Prompt[]; + ); expect(prompts).toHaveLength(1); const prompt = prompts[0]; diff --git a/typescript/demo/demo.ts b/typescript/demo/demo.ts index 0417431d1..994d51c06 100644 --- a/typescript/demo/demo.ts +++ b/typescript/demo/demo.ts @@ -4,6 +4,7 @@ import OpenAI from "openai"; import * as path from "path"; import { AIConfigRuntime } from "../lib/config"; import { InferenceOptions } from "../lib/modelParser"; +import { Prompt } from "../types"; // This example is taken from https://github.com/openai/openai-node/blob/v4/examples/demo.ts // and modified to show the same functionality using AIConfig. @@ -104,14 +105,10 @@ async function createAIConfig() { }; const aiConfig = AIConfigRuntime.create("demo", "this is a demo AIConfig"); - const result = await aiConfig.serialize(model, data, "demoPrompt"); + const prompts: Prompt[] = await aiConfig.serialize(model, data, "demoPrompt"); - if (Array.isArray(result)) { - for (const prompt of result) { - aiConfig.addPrompt(prompt); - } - } else { - aiConfig.addPrompt(result); + for (const prompt of prompts) { + aiConfig.addPrompt(prompt); } aiConfig.save("demo/demo.aiconfig.json", { serializeOutputs: true }); diff --git a/typescript/demo/function-call-stream.ts b/typescript/demo/function-call-stream.ts index fbd427a4d..b4ae19033 100644 --- a/typescript/demo/function-call-stream.ts +++ b/typescript/demo/function-call-stream.ts @@ -289,14 +289,14 @@ async function createAIConfig() { "function-call-demo", "this is a demo AIConfig to show function calling using OpenAI" ); - const result = await aiConfig.serialize(model, data, "functionCallResult"); + const prompts: Prompt[] = await aiConfig.serialize( + model, + data, + "functionCallResult" + ); - if (Array.isArray(result)) { - for (const prompt of result) { - aiConfig.addPrompt(prompt); - } - } else { - aiConfig.addPrompt(result); + for (const prompt of prompts) { + aiConfig.addPrompt(prompt); } aiConfig.save("demo/function-call.aiconfig.json", { diff --git a/typescript/demo/test-hf.ts b/typescript/demo/test-hf.ts index 466646d5a..d29ef1d1a 100644 --- a/typescript/demo/test-hf.ts +++ b/typescript/demo/test-hf.ts @@ -23,11 +23,11 @@ async function run() { console.log("Latest output: ", config.getOutputText("prompt1")); console.log("serialize prompt2: "); - const prompts = (await config.serialize( + const prompts: Prompt[] = await config.serialize( "mistralai/Mistral-7B-v0.1", { inputs: "Hello, world!" }, "prompt2" - )) as Prompt[]; + ); const prompt2 = prompts[0]; diff --git a/typescript/lib/config.ts b/typescript/lib/config.ts index 4baab769a..1ce06ebfa 100644 --- a/typescript/lib/config.ts +++ b/typescript/lib/config.ts @@ -360,7 +360,7 @@ export class AIConfigRuntime implements AIConfig { data: JSONObject, promptName: string, params?: JSONObject - ): Promise { + ): Promise { const startEvent = { name: "on_serialize_start", file: __filename, @@ -376,7 +376,12 @@ export class AIConfigRuntime implements AIConfig { ); } - const prompts = modelParser.serialize(promptName, data, this, params); + const prompts: Prompt[] = modelParser.serialize( + promptName, + data, + this, + params + ); const endEvent = { name: "on_serialize_end", file: __filename, diff --git a/typescript/lib/modelParser.ts b/typescript/lib/modelParser.ts index 25e320319..92943c371 100644 --- a/typescript/lib/modelParser.ts +++ b/typescript/lib/modelParser.ts @@ -62,7 +62,7 @@ export abstract class ModelParser { data: T, aiConfig: AIConfigRuntime, params?: JSONObject - ): Prompt | Prompt[]; + ): Prompt[]; /** * Deserialize a Prompt object loaded from an AIConfig into a structure that can be used for model inference. diff --git a/typescript/lib/parsers/hf.ts b/typescript/lib/parsers/hf.ts index 8bebae787..cd8712967 100644 --- a/typescript/lib/parsers/hf.ts +++ b/typescript/lib/parsers/hf.ts @@ -39,7 +39,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser