Skip to content

Commit

Permalink
[ts] 6/n tests for PaLM Text Model Parser
Browse files Browse the repository at this point in the history
## What

tests for PaLM text model parser

Something to note, not resolved, the palm api for node.js and python seem different.

So what should happen in the config json? Best thing to do here is to just standardize it? pr #400 rfc
  • Loading branch information
Ankush Pala [email protected] committed Dec 11, 2023
1 parent e6b81d1 commit 8902bef
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 615 deletions.
6 changes: 0 additions & 6 deletions package.json

This file was deleted.

48 changes: 48 additions & 0 deletions typescript/__tests__/parsers/palm-text/palm-text.aiconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"name": "Exploring NYC through PaLM Config",
"description": "",
"schema_version": "latest",
"metadata": {
"models": {
"models/text-bison-001": {
"model": "models/text-bison-001",
"topP": 0.9,
"temperature": 0.9
}
},
"default_model": "models/text-bison-001",
"model_parsers": {
"models/text-bison-001": "models/text-bison-001"
}
},
"prompts": [
{
"name": "prompt1",
"input": "What is your favorite condiment?"
},
{
"name": "prompt2",
"input": "What are 5 interesting things to do in Rome?",
"metadata": {
"model": {
"name": "models/text-bison-001",
"settings": {
"top_p": 0.7
}
}
}
},
{
"name": "promptWithParams",
"input": "What's your favorite building in {{city}}?",
"metadata": {
"model": {
"name": "models/text-bison-001"
},
"parameters": {
"city": "London"
}
}
}
]
}
118 changes: 118 additions & 0 deletions typescript/__tests__/parsers/palm-text/palm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import { AIConfigRuntime } from "../../../lib/config";
import path from "path";
import { ExecuteResult, Output, Prompt } from "../../../types";
import { TextServiceClient, protos } from "@google-ai/generativelanguage";
import { PaLMTextParser } from "../../../lib/parsers/palm";
import { JSONObject } from "../../../common";
import { getAPIKeyFromEnv } from "../../../lib/utils";

const PALM_CONFIG_PATH = path.join(__dirname, "palm-text.aiconfig.json");

const mockGetApiKeyFromEnv = getAPIKeyFromEnv as jest.MockedFunction<typeof getAPIKeyFromEnv>;

// This could probably be abstracted out into a test util
jest.mock("../../../lib/utils", () => {
const originalModule = jest.requireActual("../../../lib/utils");
return {
...originalModule,
getAPIKeyFromEnv: jest.fn(),
};
});

mockGetApiKeyFromEnv.mockReturnValue("test-api-key");

describe("PaLM Text ModelParser", () => {
test("serializing params to config prompt", async () => {
// no need to instantiate model parser. Load will instantiate it for us since its a default parser
const aiConfig = AIConfigRuntime.load(PALM_CONFIG_PATH);

const completionParams: protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest = {
model: "models/text-bison-001",
// Note: top_p matches global config settings for the model and temperature is different
topP: 0.9,
temperature: 0.8,
prompt: { text: "What are 5 interesting things to do in Toronto?" },
};

// Casting as JSONObject since the type of completionParams is protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest doesn't confrom to shape even though it looks like it does
const prompts = (await aiConfig.serialize("models/text-bison-001", completionParams as JSONObject, "interestingThingsToronto")) as Prompt[];

expect(prompts).toHaveLength(1);
const prompt = prompts[0];

expect(prompt.name).toEqual("interestingThingsToronto");
expect(prompt.input).toEqual("What are 5 interesting things to do in Toronto?");
expect(prompt.metadata?.model).toEqual({
name: "models/text-bison-001",
settings: {
temperature: 0.8,
},
});
});

test("deserializing params to config", async () => {
const aiconfig = AIConfigRuntime.load(PALM_CONFIG_PATH);

const deserialized = await aiconfig.resolve("prompt1");
expect(deserialized).toEqual({
model: "models/text-bison-001",
temperature: 0.9,
topP: 0.9,
prompt: { text: "What is your favorite condiment?" },
topK: null,
candidateCount: null,
maxOutputTokens: null,
safetySettings: null,
stopSequences: null,
});
});

test("run prompt, non-streaming", async () => {
// When Jest Mocking Palm Text Generation, Typing requires a never type for the return value of generateText. Not sure why this is happening
// TODO: @ankush-lastmile Figure out why this is happening
jest.spyOn(TextServiceClient.prototype, "generateText").mockResolvedValue([
{
candidates: [
{
safetyRatings: [
{
category: "HARM_CATEGORY_DEROGATORY",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_TOXICITY",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_VIOLENCE",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_SEXUAL",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_MEDICAL",
probability: "NEGLIGIBLE",
},
{
category: "HARM_CATEGORY_DANGEROUS",
probability: "NEGLIGIBLE",
},
],
output: "Ranch",
},
],
filters: [],
safetyFeedback: [],
},
null,
null,
] as never);

const aiconfig = AIConfigRuntime.load(PALM_CONFIG_PATH);

const [result] = (await aiconfig.run("prompt1")) as Output[];
expect((result as ExecuteResult).data).toEqual("Ranch");
});
});
Loading

0 comments on commit 8902bef

Please sign in to comment.