Skip to content

Commit

Permalink
incremental improvement of Tool; better handling of prompt in gemini api
Browse files Browse the repository at this point in the history
  • Loading branch information
jnaglick committed Jun 20, 2024
1 parent e1425b8 commit 098fc40
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 183 deletions.
283 changes: 226 additions & 57 deletions packages/core/src/apis/google/__snapshots__/gemini.spec.ts.snap

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions packages/core/src/apis/google/errors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
export const FUNCTION_CALL_WITHOUT_TOOLS =
"The last item of conversation history (`contents`) contains a `function_call`, but no `$tools` were passed, so generative-ts cannot append `function_response` to conversation history. Instead, appending prompt to end of conversation history. Model behavior might be unexpected, because model's function calls were effectively ignored.";

export const UNRESOLVED_INVOCATION =
"The last item of conversation history (`contents`) contains a `function_call`, and a $tool with matching invocations was found, but that invocation is NOT resolved, and thus does not have a `returned` value (Did invoking the tool fail?), so generative-ts cannot append `function_response` to conversation history. Model behavior might be unexpected, because model's function calls were effectively ignored.";

export const NO_MATCHING_INVOCATION =
"The last item of conversation history (`contents`) contains a `function_call`, and a matching `$tool` was found, but no matching invocation was found in the tool's invocations (Did you forget to invoke the tool?), so generative-ts cannot append `function_response` to conversation history. Model behavior might be unexpected, because model's function calls were effectively ignored.";

export const NO_MATCHING_TOOL =
"The last item of conversation history (`contents`) contains a `function_call`, but no matching tool was found in `$tools`, so generative-ts cannot append `function_response` to conversation history. Model behavior might be unexpected, because model's function calls were effectively ignored.";
99 changes: 82 additions & 17 deletions packages/core/src/apis/google/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,30 +58,43 @@ describe("GoogleGeminiApi.requestTemplate", () => {
* "Native" few shot options (prompt, contents, system_instruction):
*/

test("prompt, contents (appends prompt)", () => {
test("prompt, contents with user / model (appends prompt)", () => {
const rendered = render({
prompt: "mock-prompt",
contents: [
{
role: "user",
parts: [{ text: "mock-user-text" }],
},
{
role: "model",
parts: [{ text: "mock-model-text" }],
},
],
});
expect(rendered).toMatchSnapshot();
});

test("prompt, contents with model / user (prepends prompt)", () => {
const rendered = render({
prompt: "mock-prompt",
contents: [
{
role: "user",
parts: [{ text: "mock-user-text" }],
role: "model",
parts: [{ text: "mock-model-text" }],
},
{
role: "model",
parts: [{ text: "mock-model-text-2" }],
role: "user",
parts: [{ text: "mock-user-text-2" }],
},
],
});
expect(rendered).toMatchSnapshot();
});

test("prompt, contents ending with user (does not append prompt)", () => {
test("prompt, contents starting and ending with model (appends AND prepends prompt)", () => {
const rendered = render({
prompt: "mock-prompt-should-not-appear",
prompt: "mock-prompt",
contents: [
{
role: "model",
Expand All @@ -95,6 +108,23 @@ describe("GoogleGeminiApi.requestTemplate", () => {
role: "model",
parts: [{ text: "mock-model-text-2" }],
},
],
});
expect(rendered).toMatchSnapshot();
});

test("prompt, contents starting and ending with user (does not add prompt)", () => {
const rendered = render({
prompt: "mock-prompt-should-not-appear",
contents: [
{
role: "user",
parts: [{ text: "mock-user-text" }],
},
{
role: "model",
parts: [{ text: "mock-model-text" }],
},
{
role: "user",
parts: [{ text: "mock-user-text-2" }],
Expand Down Expand Up @@ -126,10 +156,6 @@ describe("GoogleGeminiApi.requestTemplate", () => {
role: "user",
parts: [{ text: "mock-user-text" }],
},
{
role: "model",
parts: [{ text: "mock-model-text-2" }],
},
],
system_instruction: {
parts: [{ text: "mock-system-text" }],
Expand All @@ -142,13 +168,20 @@ describe("GoogleGeminiApi.requestTemplate", () => {
* Combinations of FewShotRequestOptions and "native" options:
*/

test("prompt, examplePairs, contents", () => {
test("prompt, examplePairs, contents with user / model (appends prompt)", () => {
const rendered = render({
prompt: "mock-prompt",
examplePairs: [
{ user: "mock-user-msg-1", assistant: "mock-assistant-msg-1" },
{
user: "mock-user-example-pair",
assistant: "mock-assistant-example-pair",
},
],
contents: [
{
role: "user",
parts: [{ text: "mock-user-text" }],
},
{
role: "model",
parts: [{ text: "mock-model-text" }],
Expand All @@ -158,6 +191,29 @@ describe("GoogleGeminiApi.requestTemplate", () => {
expect(rendered).toMatchSnapshot();
});

test("prompt, examplePairs, contents with model / user (inserts prompt, conversation is valid)", () => {
const rendered = render({
prompt: "mock-prompt",
examplePairs: [
{
user: "mock-user-example-pair",
assistant: "mock-assistant-example-pair",
},
],
contents: [
{
role: "model",
parts: [{ text: "mock-model-text" }],
},
{
role: "user",
parts: [{ text: "mock-user-text" }],
},
],
});
expect(rendered).toMatchSnapshot();
});

test("prompt, examplePairs, system_instruction", () => {
const rendered = render({
prompt: "mock-prompt",
Expand Down Expand Up @@ -222,10 +278,14 @@ describe("GoogleGeminiApi.requestTemplate", () => {
* Tool-related:
*/

test("prompt, contents ending with function_call, $tools with matching invocation (adds function_response content items)", () => {
test("prompt, contents ending with function_call, $tools with matching invocation (appends function_response content items)", () => {
const rendered = render({
prompt: "mock-prompt",
contents: [
{
role: "user",
parts: [{ text: "mock-user-text" }],
},
{
role: "model",
parts: [
Expand Down Expand Up @@ -253,6 +313,7 @@ describe("GoogleGeminiApi.requestTemplate", () => {
invocations: [
{
arguments: { key: "value" },
resolved: true,
returned: { responseKey: "responseValue" },
},
],
Expand All @@ -266,6 +327,10 @@ describe("GoogleGeminiApi.requestTemplate", () => {
const rendered = render({
prompt: "mock-prompt",
contents: [
{
role: "user",
parts: [{ text: "mock-user-text" }],
},
{
role: "model",
parts: [
Expand Down Expand Up @@ -298,7 +363,7 @@ describe("GoogleGeminiApi.requestTemplate", () => {
// TODO expect warning
});

test("prompt, contents ending with model function_call, no $tools (appends prompt; TODO logs warning)", () => {
test("prompt, contents ending with model function_call, no $tools (prepends and appends prompt; TODO logs warning)", () => {
const rendered = render({
prompt: "mock-prompt",
contents: [
Expand All @@ -319,9 +384,9 @@ describe("GoogleGeminiApi.requestTemplate", () => {
// TODO expect warning
});

test("prompt, contents ending with user function_response (doesnt append prompt)", () => {
test("prompt, contents ending with user function_response (prepends prompt)", () => {
const rendered = render({
prompt: "mock-prompt-should-not-appear",
prompt: "mock-prompt",
contents: [
{
role: "model",
Expand Down
122 changes: 72 additions & 50 deletions packages/core/src/apis/google/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ import { composite } from "../_utils/ioTsHelpers";

import type { FewShotRequestOptions, ToolUseRequestOptions } from "../shared";

import {
FUNCTION_CALL_WITHOUT_TOOLS,
NO_MATCHING_INVOCATION,
NO_MATCHING_TOOL,
UNRESOLVED_INVOCATION,
} from "./errors";

interface FunctionCall {
name: string;
args: Record<string, unknown>;
Expand Down Expand Up @@ -164,46 +171,39 @@ export const GoogleGeminiTemplate = new FnTemplate(
safety_settings,
generation_config,
}: GoogleGeminiOptions) => {
const _contents: GoogleGeminiContentItem[] = [
...(examplePairs
? examplePairs.flatMap((pair) => [
{
role: "user" as const,
parts: [{ text: pair.user }],
},
{
role: "model" as const,
parts: [{ text: pair.assistant }],
},
])
: []),

...(contents ? (Array.isArray(contents) ? contents : [contents]) : []),
];
let _contents = (
contents && Array.isArray(contents) ? contents : [contents]
).filter((i) => i);

const firstItem = _contents[0];
const lastItem = _contents[_contents.length - 1];

if (!lastItem) {
_contents.push({
role: "user",
parts: [{ text: prompt }],
});
} else if (lastItem.role === "model") {
// the conversation must start with a user message:
if (!firstItem || firstItem.role === "model") {
_contents = [
{
role: "user",
parts: [{ text: prompt }],
},
..._contents,
];
}

// the conversation must end with a user message (either prompt or tool responses):
if (lastItem && lastItem.role === "model") {
const functionCalls = lastItem.parts
.filter((part): part is PartWithFunctionCall => "functionCall" in part)
.map((part) => part.functionCall);

if (functionCalls.length && $tools) {
// append tool responses:
const responses: PartWithFunctionResponse[] = [];

functionCalls.forEach(({ name, args }) => {
const matchingTool = $tools.find((tool) => tool.name === name);

if (!matchingTool) {
console.warn(
"The last item of conversation history (`contents`) contains a `function_call`, but no matching tool was found in `$tools`. Model behavior might be unexpected, because model's function calls were effectively ignored.",
);

console.warn(NO_MATCHING_TOOL);
return;
}

Expand All @@ -218,53 +218,75 @@ export const GoogleGeminiTemplate = new FnTemplate(
);

if (!matchingInvocation) {
console.warn(
"The last item of conversation history (`contents`) contains a `function_call`, and a matching `$tool` was found, but no matching invocation was found in the tool's invocations. (Did you forget to call `mapGeminiResponseToToolInvocations`?) Model behavior might be unexpected, because model's function calls were effectively ignored.",
);

console.warn(NO_MATCHING_INVOCATION);
return;
}

const { returned } = matchingInvocation;

if (!returned) {
console.warn(
"The last item of conversation history (`contents`) contains a `function_call`, and a $tool with matching invocations was found, but that invocation does NOT have a `returned` value! (Did you forget to execute the tool?) Model behavior might be unexpected, because model's function calls were effectively ignored.",
);

if (!matchingInvocation.resolved) {
console.warn(UNRESOLVED_INVOCATION);
return;
}

responses.push({
function_response: {
name: matchingTool.name,
// TODO dont wrap in object if already an object?
response: {
returned,
returned: matchingInvocation.returned,
},
},
});
});

_contents.push({
role: "user",
parts: responses,
});
if (responses.length) {
_contents = [
..._contents,
{
role: "user",
parts: responses,
},
];
} else {
// if no tool responses, we logged a warning above, and now append prompt as fallback:
_contents = [
..._contents,
{
role: "user",
parts: [{ text: prompt }],
},
];
}
} else {
if (functionCalls.length && !$tools) {
console.warn(
"The last item of conversation history (`contents`) contains a `function_call`, but no `$tools` were passed, so generative-ts cannot append `function_response` to conversation history. Instead, appending prompt to end of conversation history. Model behavior might be unexpected, because model's function calls were effectively ignored.",
);
console.warn(FUNCTION_CALL_WITHOUT_TOOLS);
}

_contents.push({
role: "user",
parts: [{ text: prompt }],
});
_contents = [
..._contents,
{
role: "user",
parts: [{ text: prompt }],
},
];
}
}

const rewritten = {
contents: _contents,
contents: [
...(examplePairs
? examplePairs.flatMap((pair) => [
{
role: "user" as const,
parts: [{ text: pair.user }],
},
{
role: "model" as const,
parts: [{ text: pair.assistant }],
},
])
: []),
..._contents,
],
...(tools || $tools
? {
tools: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export function mapGeminiResponseToToolInvocations<
validatedArgs[param.name] = argValue;
});

tool.invoke(validatedArgs as ConvertParamMapToArgs<TParamMap>);
tool.addInvocation(validatedArgs as ConvertParamMapToArgs<TParamMap>);
}
});
});
Expand Down
Loading

0 comments on commit 098fc40

Please sign in to comment.