Skip to content

Commit

Permalink
[ts] 6/n tests for PaLM Text Model Parser (#377)
Browse files Browse the repository at this point in the history
[ts] 6/n tests for PaLM Text Model Parser




## What

tests for PaLM text model parser

Something to note, not resolved, the palm api for node.js and python
seem different.

So what should happen in the config json? Best thing to do here is to
just standardize it? pr #400 rfc

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/377).
* __->__ #377
* #396
* #397
* #394
* #376
* #375
* #395
  • Loading branch information
Ankush-lastmile authored Dec 11, 2023
2 parents fcdb5d9 + 8902bef commit a132969
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 2 deletions.
48 changes: 48 additions & 0 deletions typescript/__tests__/parsers/palm-text/palm-text.aiconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"name": "Exploring NYC through PaLM Config",
"description": "",
"schema_version": "latest",
"metadata": {
"models": {
"models/text-bison-001": {
"model": "models/text-bison-001",
"topP": 0.9,
"temperature": 0.9
}
},
"default_model": "models/text-bison-001",
"model_parsers": {
"models/text-bison-001": "models/text-bison-001"
}
},
"prompts": [
{
"name": "prompt1",
"input": "What is your favorite condiment?"
},
{
"name": "prompt2",
"input": "What are 5 interesting things to do in Rome?",
"metadata": {
"model": {
"name": "models/text-bison-001",
"settings": {
"top_p": 0.7
}
}
}
},
{
"name": "promptWithParams",
"input": "What's your favorite building in {{city}}?",
"metadata": {
"model": {
"name": "models/text-bison-001"
},
"parameters": {
"city": "London"
}
}
}
]
}
118 changes: 118 additions & 0 deletions typescript/__tests__/parsers/palm-text/palm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import { AIConfigRuntime } from "../../../lib/config";
import path from "path";
import { ExecuteResult, Output, Prompt } from "../../../types";
import { TextServiceClient, protos } from "@google-ai/generativelanguage";
import { PaLMTextParser } from "../../../lib/parsers/palm";
import { JSONObject } from "../../../common";
import { getAPIKeyFromEnv } from "../../../lib/utils";

const PALM_CONFIG_PATH = path.join(__dirname, "palm-text.aiconfig.json");

const mockGetApiKeyFromEnv = getAPIKeyFromEnv as jest.MockedFunction<typeof getAPIKeyFromEnv>;

// This could probably be abstracted out into a test util
jest.mock("../../../lib/utils", () => {
const originalModule = jest.requireActual("../../../lib/utils");
return {
...originalModule,
getAPIKeyFromEnv: jest.fn(),
};
});

mockGetApiKeyFromEnv.mockReturnValue("test-api-key");

describe("PaLM Text ModelParser", () => {
test("serializing params to config prompt", async () => {
// no need to instantiate model parser. Load will instantiate it for us since its a default parser
const aiConfig = AIConfigRuntime.load(PALM_CONFIG_PATH);

const completionParams: protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest = {
model: "models/text-bison-001",
// Note: top_p matches global config settings for the model and temperature is different
topP: 0.9,
temperature: 0.8,
prompt: { text: "What are 5 interesting things to do in Toronto?" },
};

// Casting as JSONObject since the type of completionParams is protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest doesn't confrom to shape even though it looks like it does
const prompts = (await aiConfig.serialize("models/text-bison-001", completionParams as JSONObject, "interestingThingsToronto")) as Prompt[];

expect(prompts).toHaveLength(1);
const prompt = prompts[0];

expect(prompt.name).toEqual("interestingThingsToronto");
expect(prompt.input).toEqual("What are 5 interesting things to do in Toronto?");
expect(prompt.metadata?.model).toEqual({
name: "models/text-bison-001",
settings: {
temperature: 0.8,
},
});
});

test("deserializing params to config", async () => {
const aiconfig = AIConfigRuntime.load(PALM_CONFIG_PATH);

const deserialized = await aiconfig.resolve("prompt1");
expect(deserialized).toEqual({
model: "models/text-bison-001",
temperature: 0.9,
topP: 0.9,
prompt: { text: "What is your favorite condiment?" },
topK: null,
candidateCount: null,
maxOutputTokens: null,
safetySettings: null,
stopSequences: null,
});
});

test("run prompt, non-streaming", async () => {
// When Jest Mocking Palm Text Generation, Typing requires a never type for the return value of generateText. Not sure why this is happening
// TODO: @ankush-lastmile Figure out why this is happening
jest.spyOn(TextServiceClient.prototype, "generateText").mockResolvedValue([
{
candidates: [
{
safetyRatings: [
{
category: "HARM_CATEGORY_DEROGATORY",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_TOXICITY",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_VIOLENCE",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_SEXUAL",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_MEDICAL",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_DANGEROUS",
probability: "NEGLIGIBLE",
},
],
output: "Ranch",
},
],
filters: [],
safetyFeedback: [],
},
null,
null,
] as never);

const aiconfig = AIConfigRuntime.load(PALM_CONFIG_PATH);

const [result] = (await aiconfig.run("prompt1")) as Output[];
expect((result as ExecuteResult).data).toEqual("Ranch");
});
});
5 changes: 3 additions & 2 deletions typescript/lib/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import _ from "lodash";
import { getAPIKeyFromEnv } from "./utils";
import { ParameterizedModelParser } from "./parameterizedModelParser";
import { OpenAIChatModelParser, OpenAIModelParser } from "./parsers/openai";
import { PaLMTextParser } from "./parsers/palm";
import { extractOverrideSettings } from "./utils";
import { HuggingFaceTextGenerationParser } from "./parsers/hf";
import { CallbackEvent, CallbackManager } from "./callback";
Expand Down Expand Up @@ -57,8 +58,8 @@ ModelParserRegistry.registerModelParser(new OpenAIChatModelParser(), [
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
]);

ModelParserRegistry.registerModelParser(new HuggingFaceTextGenerationParser());
ModelParserRegistry.registerModelParser(new PaLMTextParser());

/**
* Represents an AIConfig. This is the main class for interacting with AIConfig files.
Expand Down Expand Up @@ -276,7 +277,7 @@ export class AIConfigRuntime implements AIConfig {
* @param id The ID of the model parser to get.
*/
public static getModelParser(id: string) {
ModelParserRegistry.getModelParser(id);
return ModelParserRegistry.getModelParser(id);
}

//#endregion
Expand Down
Loading

0 comments on commit a132969

Please sign in to comment.