Skip to content

Commit

Permalink
[ts] Change serialize output from Prompt | Prompt[] to Prompt[] (
Browse files Browse the repository at this point in the history
…#658)

[ts] Change `serialize` output from `Prompt | Prompt[]` to `Prompt[]`


TSIA, already did this for Python in
#632, just doing for ts now
too

# Test plan
Run the typescript demos to make sure that they continue to work. Run
these commands in terminal from the `aiconfig` top-level dir
```bash
export OPENAI_API_KEY=sk-XYZ #set your API key, which you can get from https://platform.openai.com/api-keys
npx ts-node typescript/demo/function-call-stream.ts
npx ts-node typescript/demo/demo.ts
npx ts-node typescript/demo/test-hf.ts
```
^Make sure delete any generated files so we don't include them in your
GH push requests

Now enter into the `aiconfig/typescript` dir and run these commands to
make sure tests pass and yarn compiles:
```bash
yarn test
yarn build
rm -rf dist/ #Run this everytime after you run `yarn build`
```
  • Loading branch information
rossdanlm authored Dec 29, 2023
2 parents b9b0ec2 + f1f1c2a commit c88b9c9
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AIConfigRuntime, ExecuteResult } from "aiconfig";
import { AIConfigRuntime, ExecuteResult, Prompt } from "aiconfig";
import { Chat } from "openai/resources";

const BOOKS_DB = [
Expand Down Expand Up @@ -131,11 +131,15 @@ async function main() {
};

// Add the prompt
let newPrompts = await config.serialize(model, data, "recommend_book", {
book: "To Kill a Mockingbird",
});
let newPrompt = Array.isArray(newPrompts) ? newPrompts[0] : newPrompts;
config.addPrompt(newPrompt);
let newPrompts: Prompt[] = await config.serialize(
model,
data,
"recommend_book",
{
book: "To Kill a Mockingbird",
}
);
config.addPrompt(newPrompts[0]);

const params = { book: "Where the Crawdads Sing" };
const inferenceOptions = {
Expand Down Expand Up @@ -194,8 +198,7 @@ async function main() {
};

newPrompts = await config.serialize(model, promptData, "gen_summary");
newPrompt = Array.isArray(newPrompts) ? newPrompts[0] : newPrompts;
config.addPrompt(newPrompt);
config.addPrompt(newPrompts[0]);

await config.run("gen_summary", { book_info: value }, inferenceOptions);
}
Expand Down
2 changes: 1 addition & 1 deletion cookbooks/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export class HuggingFaceTextGenerationModelParser extends aiconfig.Parameterized
data: TextGenerationArgs,
aiConfig: AIConfigRuntime,
params?: JSONObject | undefined
): Prompt | Prompt[] {
): Prompt[] {
const input: PromptInput = data.inputs;

let modelMetadata: ModelMetadata | string;
Expand Down
2 changes: 1 addition & 1 deletion extensions/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
data: TextGenerationArgs,
aiConfig: AIConfigRuntime,
params?: JSONObject | undefined
): Prompt | Prompt[] {
): Prompt[] {
const startEvent = {
name: "on_serialize_start",
file: __filename,
Expand Down
19 changes: 6 additions & 13 deletions extensions/llama/typescript/llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ export class LlamaModelParser extends ParameterizedModelParser<LlamaCompletionPa
data: LlamaCompletionParams,
aiConfig: AIConfigRuntime,
params?: JSONObject | undefined
): Prompt | Prompt[] {
): Prompt[] {
const startEvent = {
name: "on_serialize_start",
file: __filename,
Expand All @@ -126,7 +126,7 @@ export class LlamaModelParser extends ParameterizedModelParser<LlamaCompletionPa
completionParams.model
);

let prompts: Prompt[] = [];
const prompts: Prompt[] = [];

if (conversationHistory && conversationHistory.length > 0) {
let i = 0;
Expand All @@ -153,7 +153,7 @@ export class LlamaModelParser extends ParameterizedModelParser<LlamaCompletionPa
}
}

const prompt: Prompt = {
const newPrompt: Prompt = {
name: promptName,
input,
metadata: {
Expand All @@ -162,23 +162,16 @@ export class LlamaModelParser extends ParameterizedModelParser<LlamaCompletionPa
remember_chat_context: (conversationHistory?.length ?? 0) > 0,
},
};

let result: Prompt | Prompt[] = prompt;
if (prompts.length > 0) {
prompts.push(prompt);
result = prompts;
}
prompts.push(newPrompt);

const endEvent = {
name: "on_serialize_end",
file: __filename,
data: {
result,
},
data: { prompts },
};
aiConfig.callbackManager.runCallbacks(endEvent);

return result;
return prompts;
}

public refineCompletionParams(
Expand Down
12 changes: 2 additions & 10 deletions typescript/__tests__/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ describe("Loading an AIConfig", () => {
],
});

const serializeResult = await aiConfig.serialize(
const prompts: Prompt[] = await aiConfig.serialize(
"gpt-3.5-turbo",
completionParams,
"prompt",
Expand All @@ -132,10 +132,6 @@ describe("Loading an AIConfig", () => {
}
);

expect(Array.isArray(serializeResult)).toBe(true);

const prompts: Prompt[] = serializeResult as Prompt[];

expect(prompts.length).toBe(2);

const prompt1 = prompts[0];
Expand Down Expand Up @@ -177,7 +173,7 @@ describe("Loading an AIConfig", () => {
],
};

const serializeResult = await aiConfig.serialize(
const prompts: Prompt[] = await aiConfig.serialize(
"gpt-3.5-turbo",
completionParams,
"prompt",
Expand All @@ -186,10 +182,6 @@ describe("Loading an AIConfig", () => {
}
);

expect(Array.isArray(serializeResult)).toBe(true);

const prompts: Prompt[] = serializeResult as Prompt[];

expect(prompts.length).toBe(2);

const prompt1 = prompts[0];
Expand Down
8 changes: 4 additions & 4 deletions typescript/__tests__/parsers/hf/hf.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ describe("HuggingFaceTextGeneration ModelParser", () => {
inputs: "What are 5 interesting things to do in Toronto?",
};

const prompts = parser.serialize(
const prompts: Prompt[] = parser.serialize(
"interestingThingsToronto",
completionParams,
aiConfig
) as Prompt[];
);

expect(prompts).toHaveLength(1);
const prompt = prompts[0];
Expand Down Expand Up @@ -165,11 +165,11 @@ describe("HuggingFaceTextGeneration ModelParser", () => {
const callbackManager = new CallbackManager([callback]);
aiConfig.setCallbackManager(callbackManager);

const prompts = parser.serialize(
const prompts: Prompt[] = parser.serialize(
"interestingThingsToronto",
completionParams,
aiConfig
) as Prompt[];
);

const onStartEvent = callback.mock.calls[0][0];
expect(onStartEvent.name).toEqual("on_serialize_start");
Expand Down
4 changes: 2 additions & 2 deletions typescript/__tests__/parsers/palm-text/palm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ describe("PaLM Text ModelParser", () => {
};

// Casting as JSONObject since the type of completionParams is protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest doesn't confrom to shape even though it looks like it does
const prompts = (await aiConfig.serialize(
const prompts: Prompt[] = await aiConfig.serialize(
"models/text-bison-001",
completionParams as JSONObject,
"interestingThingsToronto"
)) as Prompt[];
);

expect(prompts).toHaveLength(1);
const prompt = prompts[0];
Expand Down
11 changes: 4 additions & 7 deletions typescript/demo/demo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import OpenAI from "openai";
import * as path from "path";
import { AIConfigRuntime } from "../lib/config";
import { InferenceOptions } from "../lib/modelParser";
import { Prompt } from "../types";

// This example is taken from https://github.com/openai/openai-node/blob/v4/examples/demo.ts
// and modified to show the same functionality using AIConfig.
Expand Down Expand Up @@ -104,14 +105,10 @@ async function createAIConfig() {
};

const aiConfig = AIConfigRuntime.create("demo", "this is a demo AIConfig");
const result = await aiConfig.serialize(model, data, "demoPrompt");
const prompts: Prompt[] = await aiConfig.serialize(model, data, "demoPrompt");

if (Array.isArray(result)) {
for (const prompt of result) {
aiConfig.addPrompt(prompt);
}
} else {
aiConfig.addPrompt(result);
for (const prompt of prompts) {
aiConfig.addPrompt(prompt);
}

aiConfig.save("demo/demo.aiconfig.json", { serializeOutputs: true });
Expand Down
14 changes: 7 additions & 7 deletions typescript/demo/function-call-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,14 @@ async function createAIConfig() {
"function-call-demo",
"this is a demo AIConfig to show function calling using OpenAI"
);
const result = await aiConfig.serialize(model, data, "functionCallResult");
const prompts: Prompt[] = await aiConfig.serialize(
model,
data,
"functionCallResult"
);

if (Array.isArray(result)) {
for (const prompt of result) {
aiConfig.addPrompt(prompt);
}
} else {
aiConfig.addPrompt(result);
for (const prompt of prompts) {
aiConfig.addPrompt(prompt);
}

aiConfig.save("demo/function-call.aiconfig.json", {
Expand Down
4 changes: 2 additions & 2 deletions typescript/demo/test-hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ async function run() {
console.log("Latest output: ", config.getOutputText("prompt1"));

console.log("serialize prompt2: ");
const prompts = (await config.serialize(
const prompts: Prompt[] = await config.serialize(
"mistralai/Mistral-7B-v0.1",
{ inputs: "Hello, world!" },
"prompt2"
)) as Prompt[];
);

const prompt2 = prompts[0];

Expand Down
9 changes: 7 additions & 2 deletions typescript/lib/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ export class AIConfigRuntime implements AIConfig {
data: JSONObject,
promptName: string,
params?: JSONObject
): Promise<Prompt | Prompt[]> {
): Promise<Prompt[]> {
const startEvent = {
name: "on_serialize_start",
file: __filename,
Expand All @@ -376,7 +376,12 @@ export class AIConfigRuntime implements AIConfig {
);
}

const prompts = modelParser.serialize(promptName, data, this, params);
const prompts: Prompt[] = modelParser.serialize(
promptName,
data,
this,
params
);
const endEvent = {
name: "on_serialize_end",
file: __filename,
Expand Down
2 changes: 1 addition & 1 deletion typescript/lib/modelParser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export abstract class ModelParser<T = JSONObject, R = T> {
data: T,
aiConfig: AIConfigRuntime,
params?: JSONObject
): Prompt | Prompt[];
): Prompt[];

/**
* Deserialize a Prompt object loaded from an AIConfig into a structure that can be used for model inference.
Expand Down
2 changes: 1 addition & 1 deletion typescript/lib/parsers/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
data: TextGenerationArgs,
aiConfig: AIConfigRuntime,
params?: JSONObject | undefined
): Prompt | Prompt[] {
): Prompt[] {
const startEvent = {
name: "on_serialize_start",
file: __filename,
Expand Down
4 changes: 2 additions & 2 deletions typescript/lib/parsers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export class OpenAIModelParser extends ParameterizedModelParser<CompletionCreate
data: CompletionCreateParams,
aiConfig: AIConfigRuntime,
params?: JSONObject
): Prompt {
): Prompt[] {
// Serialize prompt input
let input: PromptInput;
if (typeof data.prompt === "string") {
Expand Down Expand Up @@ -105,7 +105,7 @@ export class OpenAIModelParser extends ParameterizedModelParser<CompletionCreate
},
};

return prompt;
return [prompt];
}

public deserialize(
Expand Down
2 changes: 1 addition & 1 deletion typescript/lib/parsers/palm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export class PaLMTextParser extends ParameterizedModelParser {
data: JSONObject,
aiConfig: AIConfigRuntime,
params?: JSONObject | undefined
): Prompt | Prompt[] {
): Prompt[] {
const startEvent = {
name: "on_serialize_start",
file: __filename,
Expand Down

0 comments on commit c88b9c9

Please sign in to comment.