From c13384f73bff95a1aafc98bfca06e96192f5bca8 Mon Sep 17 00:00:00 2001 From: jnaglick Date: Wed, 19 Jun 2024 16:44:45 -0400 Subject: [PATCH] temp stop point on Tools --- .../google/__snapshots__/gemini.spec.ts.snap | 274 ++++++++++++- packages/core/src/apis/google/gemini.spec.ts | 245 ++++++++++-- packages/core/src/apis/google/gemini.ts | 370 ++++++++++-------- .../mapGeminiResponseToToolInvocations.ts | 91 +++++ .../core/src/apis/openai/openAiChatApi.ts | 22 +- .../src/apis/shared/FewShotRequestOptions.ts | 3 + .../src/apis/shared/ToolUseRequestOptions.ts | 93 ++++- packages/core/src/index.ts | 2 +- .../__snapshots__/vertexai-tools.test.ts.snap | 140 +++++++ tests/integration/vertexai-tools.test.ts | 99 ++++- 10 files changed, 1131 insertions(+), 208 deletions(-) create mode 100644 packages/core/src/apis/google/mapGeminiResponseToToolInvocations.ts diff --git a/packages/core/src/apis/google/__snapshots__/gemini.spec.ts.snap b/packages/core/src/apis/google/__snapshots__/gemini.spec.ts.snap index 8ec9cd8..f62cf22 100644 --- a/packages/core/src/apis/google/__snapshots__/gemini.spec.ts.snap +++ b/packages/core/src/apis/google/__snapshots__/gemini.spec.ts.snap @@ -1,12 +1,12 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP -exports[`GoogleGeminiApi.requestTemplate doesnt include prompt if contents end with a user role 1`] = ` +exports[`GoogleGeminiApi.requestTemplate prompt 1`] = ` { "contents": [ { "parts": [ { - "text": "mock-user-text", + "text": "mock-prompt", }, ], "role": "user", @@ -15,7 +15,7 @@ exports[`GoogleGeminiApi.requestTemplate doesnt include prompt if contents end w } `; -exports[`GoogleGeminiApi.requestTemplate prompt 1`] = ` +exports[`GoogleGeminiApi.requestTemplate prompt, $tools 1`] = ` { "contents": [ { @@ -27,10 +27,48 @@ exports[`GoogleGeminiApi.requestTemplate prompt 1`] = ` "role": "user", }, ], + "tools": [ + { + "function_declarations": [ + { + "description": "mock-description-1", + "name": "mock-function-1", + "parameters": { + "properties": { + "mock-function-1-param-1": { + "description": "mock-function-1-param-1-description-1", + "type": "STRING", + }, + "mock-function-1-param-2": { + "description": "mock-function-1-param-2-description-2", + "type": "NUMBER", + }, + }, + "required": [], + "type": "OBJECT", + }, + }, + { + "description": "mock-description-2", + "name": "mock-function-2", + "parameters": { + "properties": { + "mock-function-2-param-1": { + "description": "mock-function-2-param-1-description-1", + "type": "BOOLEAN", + }, + }, + "required": [], + "type": "OBJECT", + }, + }, + ], + }, + ], } `; -exports[`GoogleGeminiApi.requestTemplate prompt, contents 1`] = ` +exports[`GoogleGeminiApi.requestTemplate prompt, contents (appends prompt) 1`] = ` { "contents": [ { @@ -69,7 +107,7 @@ exports[`GoogleGeminiApi.requestTemplate prompt, contents 1`] = ` } `; -exports[`GoogleGeminiApi.requestTemplate prompt, contents with function_call 1`] = ` +exports[`GoogleGeminiApi.requestTemplate prompt, contents ending with function_call, $tools with matching invocation (adds function_response content items) 1`] = ` { "contents": [ { @@ -83,33 +121,108 @@ exports[`GoogleGeminiApi.requestTemplate prompt, contents with function_call 1`] }, }, ], + "role": "model", }, { "parts": [ { - "text": "mock-prompt", + "function_response": { + "name": "mock-function", + "response": { + "returned": { + "responseKey": "responseValue", + }, + }, + }, }, ], "role": "user", }, ], + "tools": [ + { + "function_declarations": [ + { + "description": "mock-description", + "name": "mock-function", + "parameters": { + "properties": { + "key": { + "description": "mock-key-description", + "type": "STRING", + }, + }, + "required": [ + "key", + ], + "type": "OBJECT", + }, + }, + ], + }, + ], } `; -exports[`GoogleGeminiApi.requestTemplate prompt, contents with function_response 1`] = ` +exports[`GoogleGeminiApi.requestTemplate prompt, contents ending with model function_call, $tools without matching invocation (appends prompt; TODO logs warning) 1`] = ` { "contents": [ { "parts": [ { - "function_response": { + "function_call": { + "args": { + "key": "value", + }, "name": "mock-function", - "response": { + }, + }, + ], + "role": "model", + }, + { + "parts": [], + "role": "user", + }, + ], + "tools": [ + { + "function_declarations": [ + { + "description": "another-description", + "name": "another-function", + "parameters": { + "properties": { + "another-key": { + "description": "another-key-description", + "type": "STRING", + }, + }, + "required": [], + "type": "OBJECT", + }, + }, + ], + }, + ], +} +`; + +exports[`GoogleGeminiApi.requestTemplate prompt, contents ending with model function_call, no $tools (appends prompt; TODO logs warning) 1`] = ` +{ + "contents": [ + { + "parts": [ + { + "function_call": { + "args": { "key": "value", }, + "name": "mock-function", }, }, ], + "role": "model", }, { "parts": [ @@ -123,6 +236,78 @@ exports[`GoogleGeminiApi.requestTemplate prompt, contents with function_response } `; +exports[`GoogleGeminiApi.requestTemplate prompt, contents ending with user (does not append prompt) 1`] = ` +{ + "contents": [ + { + "parts": [ + { + "text": "mock-model-text", + }, + ], + "role": "model", + }, + { + "parts": [ + { + "text": "mock-user-text", + }, + ], + "role": "user", + }, + { + "parts": [ + { + "text": "mock-model-text-2", + }, + ], + "role": "model", + }, + { + "parts": [ + { + "text": "mock-user-text-2", + }, + ], + "role": "user", + }, + ], +} +`; + +exports[`GoogleGeminiApi.requestTemplate prompt, contents ending with user function_response (doesnt append prompt) 1`] = ` +{ + "contents": [ + { + "parts": [ + { + "function_call": { + "args": { + "key": "value", + }, + "name": "mock-function", + }, + }, + ], + "role": "model", + }, + { + "parts": [ + { + "function_response": { + "name": "mock-function", + "response": { + "responseKey": "responseValue", + }, + }, + }, + ], + "role": "user", + }, + ], +} +`; + exports[`GoogleGeminiApi.requestTemplate prompt, contents, system_instruction 1`] = ` { "contents": [ @@ -600,6 +785,77 @@ exports[`GoogleGeminiApi.requestTemplate prompt, tools 1`] = ` } `; +exports[`GoogleGeminiApi.requestTemplate prompt, tools, $tools 1`] = ` +{ + "contents": [ + { + "parts": [ + { + "text": "mock-prompt", + }, + ], + "role": "user", + }, + ], + "tools": [ + { + "function_declarations": [ + { + "description": "mock-description", + "name": "mock-function", + "parameters": { + "properties": { + "key": { + "type": "STRING", + }, + }, + "type": "OBJECT", + }, + }, + ], + }, + { + "function_declarations": [ + { + "description": "mock-description-1", + "name": "mock-function-1", + "parameters": { + "properties": { + "mock-function-1-param-1": { + "description": "mock-function-1-param-1-description-1", + "type": "STRING", + }, + "mock-function-1-param-2": { + "description": "mock-function-1-param-2-description-2", + "type": "NUMBER", + }, + }, + "required": [ + "mock-function-1-param-2", + ], + "type": "OBJECT", + }, + }, + { + "description": "mock-description-2", + "name": "mock-function-2", + "parameters": { + "properties": { + "mock-function-2-param-1": { + "description": "mock-function-2-param-1-description-1", + "type": "BOOLEAN", + }, + }, + "required": [], + "type": "OBJECT", + }, + }, + ], + }, + ], +} +`; + exports[`GoogleGeminiApi.requestTemplate prompt, tools_config 1`] = ` { "contents": [ diff --git a/packages/core/src/apis/google/gemini.spec.ts b/packages/core/src/apis/google/gemini.spec.ts index 34f80c2..6d715b5 100644 --- a/packages/core/src/apis/google/gemini.spec.ts +++ b/packages/core/src/apis/google/gemini.spec.ts @@ -58,7 +58,7 @@ describe("GoogleGeminiApi.requestTemplate", () => { * "Native" few shot options (prompt, contents, system_instruction): */ - test("prompt, contents", () => { + test("prompt, contents (appends prompt)", () => { const rendered = render({ prompt: "mock-prompt", contents: [ @@ -79,6 +79,31 @@ describe("GoogleGeminiApi.requestTemplate", () => { expect(rendered).toMatchSnapshot(); }); + test("prompt, contents ending with user (does not append prompt)", () => { + const rendered = render({ + prompt: "mock-prompt-should-not-appear", + contents: [ + { + role: "model", + parts: [{ text: "mock-model-text" }], + }, + { + role: "user", + parts: [{ text: "mock-user-text" }], + }, + { + role: "model", + parts: [{ text: "mock-model-text-2" }], + }, + { + role: "user", + parts: [{ text: "mock-user-text-2" }], + }, + ], + }); + expect(rendered).toMatchSnapshot(); + }); + test("prompt, system_instruction", () => { const rendered = render({ prompt: "mock-prompt", @@ -112,6 +137,7 @@ describe("GoogleGeminiApi.requestTemplate", () => { }); expect(rendered).toMatchSnapshot(); }); + /** * Combinations of FewShotRequestOptions and "native" options: */ @@ -191,18 +217,20 @@ describe("GoogleGeminiApi.requestTemplate", () => { }); expect(rendered).toMatchSnapshot(); }); + /** * Tool-related: */ - test("prompt, contents with function_call", () => { + test("prompt, contents ending with function_call, $tools with matching invocation (adds function_response content items)", () => { const rendered = render({ prompt: "mock-prompt", contents: [ { + role: "model", parts: [ { - function_call: { + functionCall: { name: "mock-function", args: { key: "value" }, }, @@ -210,20 +238,77 @@ describe("GoogleGeminiApi.requestTemplate", () => { ], }, ], + $tools: [ + { + name: "mock-function", + description: "mock-description", + parameters: [ + { + name: "key", + description: "mock-key-description", + type: "STR", + required: true, + }, + ], + invocations: [ + { + arguments: { key: "value" }, + returned: { responseKey: "responseValue" }, + }, + ], + }, + ], }); expect(rendered).toMatchSnapshot(); }); - test("prompt, contents with function_response", () => { + test("prompt, contents ending with model function_call, $tools without matching invocation (appends prompt; TODO logs warning)", () => { const rendered = render({ prompt: "mock-prompt", contents: [ { + role: "model", parts: [ { - function_response: { + functionCall: { + name: "mock-function", + args: { key: "value" }, + }, + }, + ], + }, + ], + $tools: [ + { + name: "another-function", + description: "another-description", + parameters: [ + { + name: "another-key", + description: "another-key-description", + type: "STR", + required: false, + }, + ], + invocations: [], + }, + ], + }); + expect(rendered).toMatchSnapshot(); + // TODO expect warning + }); + + test("prompt, contents ending with model function_call, no $tools (appends prompt; TODO logs warning)", () => { + const rendered = render({ + prompt: "mock-prompt", + contents: [ + { + role: "model", + parts: [ + { + functionCall: { name: "mock-function", - response: { key: "value" }, + args: { key: "value" }, }, }, ], @@ -231,8 +316,42 @@ describe("GoogleGeminiApi.requestTemplate", () => { ], }); expect(rendered).toMatchSnapshot(); + // TODO expect warning }); + test("prompt, contents ending with user function_response (doesnt append prompt)", () => { + const rendered = render({ + prompt: "mock-prompt-should-not-appear", + contents: [ + { + role: "model", + parts: [ + { + functionCall: { + name: "mock-function", + args: { key: "value" }, + }, + }, + ], + }, + { + role: "user", + parts: [ + { + function_response: { + name: "mock-function", + response: { responseKey: "responseValue" }, + }, + }, + ], + }, + ], + }); + expect(rendered).toMatchSnapshot(); + }); + /* + * Tool declarations: + */ test("prompt, tools", () => { const rendered = render({ prompt: "mock-prompt", @@ -256,6 +375,104 @@ describe("GoogleGeminiApi.requestTemplate", () => { expect(rendered).toMatchSnapshot(); }); + test("prompt, $tools", () => { + const rendered = render({ + prompt: "mock-prompt", + $tools: [ + { + name: "mock-function-1", + description: "mock-description-1", + parameters: [ + { + name: "mock-function-1-param-1", + description: "mock-function-1-param-1-description-1", + type: "STR", + required: false, + }, + { + name: "mock-function-1-param-2", + description: "mock-function-1-param-2-description-2", + type: "NUM", + required: false, + }, + ], + invocations: [], + }, + { + name: "mock-function-2", + description: "mock-description-2", + parameters: [ + { + name: "mock-function-2-param-1", + description: "mock-function-2-param-1-description-1", + type: "BOOL", + required: false, + }, + ], + invocations: [], + }, + ], + }); + expect(rendered).toMatchSnapshot(); + }); + + test("prompt, tools, $tools", () => { + const rendered = render({ + prompt: "mock-prompt", + tools: [ + { + function_declarations: [ + { + name: "mock-function", + description: "mock-description", + parameters: { + type: "OBJECT", + properties: { + key: { type: "STRING" }, + }, + }, + }, + ], + }, + ], + $tools: [ + { + name: "mock-function-1", + description: "mock-description-1", + parameters: [ + { + name: "mock-function-1-param-1", + description: "mock-function-1-param-1-description-1", + type: "STR", + required: false, + }, + { + name: "mock-function-1-param-2", + description: "mock-function-1-param-2-description-2", + type: "NUM", + required: true, + }, + ], + invocations: [], + }, + { + name: "mock-function-2", + description: "mock-description-2", + parameters: [ + { + name: "mock-function-2-param-1", + description: "mock-function-2-param-1-description-1", + type: "BOOL", + required: false, + }, + ], + invocations: [], + }, + ], + }); + expect(rendered).toMatchSnapshot(); + }); + test("prompt, tools_config", () => { const rendered = render({ prompt: "mock-prompt", @@ -327,20 +544,4 @@ describe("GoogleGeminiApi.requestTemplate", () => { }); expect(rendered).toMatchSnapshot(); }); - - /** - * Custom 1-off logic: - */ - test("doesnt include prompt if contents end with a user role", () => { - const rendered = render({ - prompt: "mock-prompt", - contents: [ - { - role: "user", - parts: [{ text: "mock-user-text" }], - }, - ], - }); - expect(rendered).toMatchSnapshot(); - }); }); diff --git a/packages/core/src/apis/google/gemini.ts b/packages/core/src/apis/google/gemini.ts index 1e9a70b..c9c9eee 100644 --- a/packages/core/src/apis/google/gemini.ts +++ b/packages/core/src/apis/google/gemini.ts @@ -1,3 +1,4 @@ +/* eslint-disable no-nested-ternary */ /* eslint-disable camelcase */ import * as t from "io-ts"; import type { TypeOf } from "io-ts"; @@ -9,55 +10,71 @@ import { FnTemplate } from "../../utils/Template"; import { composite } from "../_utils/ioTsHelpers"; -import type { FewShotRequestOptions } from "../shared"; +import type { FewShotRequestOptions, ToolUseRequestOptions } from "../shared"; -interface Content { +interface FunctionCall { + name: string; + args: Record; +} + +interface FunctionResponse { + name: string; + response: Record; +} + +interface Part { + text?: string; + functionCall?: FunctionCall; // TODO deal with casing disparity + function_response?: FunctionResponse; +} + +interface PartWithFunctionCall extends Part { + functionCall: FunctionCall; +} + +interface PartWithFunctionResponse extends Part { + function_response: FunctionResponse; +} + +interface GoogleGeminiContentItem { role?: "user" | "model"; - parts: { - text?: string; - function_call?: { - name: string; - args: Record; - }; - function_response?: { - name: string; - response: Record; - }; - // inline_data (not supported) - // file_data (not supported) - // video_metadata (not supported) - }[]; + parts: Part[]; } -interface Schema { +interface GoogleGeminiSchema { type: "STRING" | "INTEGER" | "BOOLEAN" | "NUMBER" | "ARRAY" | "OBJECT"; description?: string; enum?: string[]; - items?: Schema[]; + items?: GoogleGeminiSchema[]; properties?: { - [key: string]: Schema; + [key: string]: GoogleGeminiSchema; }; required?: string[]; nullable?: boolean; } -/** - * @category Google Gemini - * @category Requests - */ -export interface GoogleGeminiOptions - extends FewShotRequestOptions, - ModelRequestOptions { - contents?: Content | Content[]; - system_instruction?: Content; +interface GoogleGeminiToolsOptions { tools?: { function_declarations: { name: string; description?: string; - parameters?: Schema; - response?: Schema; + parameters?: GoogleGeminiSchema; + response?: GoogleGeminiSchema; }[]; }[]; +} + +/** + * @category Google Gemini + * @category Requests + */ +export interface GoogleGeminiOptions + extends ModelRequestOptions, + FewShotRequestOptions, + ToolUseRequestOptions, + GoogleGeminiToolsOptions { + contents?: GoogleGeminiContentItem | GoogleGeminiContentItem[]; + system_instruction?: GoogleGeminiContentItem; tools_config?: { mode?: "AUTO" | "NONE" | "ANY"; allowed_function_names?: string[]; @@ -81,6 +98,55 @@ export interface GoogleGeminiOptions }; } +const toGeminiToolParamType = (type: "STR" | "NUM" | "BOOL") => { + return ( + (type === "STR" && "STRING") || (type === "NUM" && "NUMBER") || "BOOLEAN" + ); +}; + +function mapToolDescriptionsToGeminiRequest({ + $tools: tools, +}: ToolUseRequestOptions): GoogleGeminiToolsOptions { + if (!tools) { + return { + tools: [], + }; + } + + return { + tools: [ + { + function_declarations: tools.map((tool) => { + return { + name: tool.name, + description: tool.description, + ...(tool.parameters + ? { + parameters: { + type: "OBJECT", + properties: tool.parameters.reduce( + (acc, param) => { + acc[param.name] = { + type: toGeminiToolParamType(param.type), + description: param.description, + }; + return acc; + }, + {} as { [key: string]: GoogleGeminiSchema }, + ), + required: tool.parameters + .filter(({ required }) => required) + .map(({ name }) => name), + }, + } + : {}), + }; + }), + }, + ], + }; +} + /** * @category Google Gemini * @category Templates @@ -89,97 +155,124 @@ export const GoogleGeminiTemplate = new FnTemplate( ({ prompt, examplePairs, - contents, system, + $tools, + contents, tools, tools_config, system_instruction, safety_settings, generation_config, }: GoogleGeminiOptions) => { - const rewritten = { - contents: [ - ...(examplePairs - ? examplePairs.flatMap((pair) => [ - { - role: "user", - parts: [{ text: pair.user }], - }, - { - role: "model", - parts: [{ text: pair.assistant }], - }, - ]) - : []), - ...(contents - ? (Array.isArray(contents) ? contents : [contents]).map( - (contentItem) => ({ - parts: contentItem.parts.map((part) => ({ - ...(part.text ? { text: part.text } : {}), - ...(part.function_call - ? { - function_call: { - name: part.function_call.name, - args: part.function_call.args, - }, - } - : {}), - ...(part.function_response - ? { - function_response: { - name: part.function_response.name, - response: part.function_response.response, - }, - } - : {}), - })), - ...(contentItem.role ? { role: contentItem.role } : {}), - }), - ) - : []), - // Only insert a user prompt if the last item in contents is NOT user - // TODO: revisit this logic. it's basically a hack caused by the facts (1) tool results in gemini are specified via user messages (2) our interface requires 'prompt' (3) gemini errors if two user messages are consecutive so, in the case tool results are given, our interface still requires 'prompt' but CANNOT insert it - ...(!contents || - (Array.isArray(contents) && - (!contents.length || contents[contents.length - 1]?.role !== "user")) - ? [ - { - role: "user", - parts: [{ text: prompt }], - }, - ] - : []), - ], - ...(tools - ? { - tools: tools.map((tool) => ({ - function_declarations: tool.function_declarations.map( - (declaration) => ({ - name: declaration.name, - ...(declaration.description - ? { description: declaration.description } - : {}), - ...(declaration.parameters - ? { parameters: declaration.parameters } - : {}), - ...(declaration.response - ? { response: declaration.response } - : {}), - }), + const _contents: GoogleGeminiContentItem[] = [ + ...(examplePairs + ? examplePairs.flatMap((pair) => [ + { + role: "user" as const, + parts: [{ text: pair.user }], + }, + { + role: "model" as const, + parts: [{ text: pair.assistant }], + }, + ]) + : []), + + ...(contents ? (Array.isArray(contents) ? contents : [contents]) : []), + ]; + + const lastItem = _contents[_contents.length - 1]; + + if (!lastItem) { + _contents.push({ + role: "user", + parts: [{ text: prompt }], + }); + } else if (lastItem.role === "model") { + const functionCalls = lastItem.parts + .filter((part): part is PartWithFunctionCall => "functionCall" in part) + .map((part) => part.functionCall); + + if (functionCalls.length && $tools) { + const responses: PartWithFunctionResponse[] = []; + + functionCalls.forEach(({ name, args }) => { + const matchingTool = $tools.find((tool) => tool.name === name); + + if (!matchingTool) { + console.warn( + "The last item of conversation history (`contents`) contains a `function_call`, but no matching tool was found in `$tools`. Model behavior might be unexpected, because model's function calls were effectively ignored.", + ); + + return; + } + + const matchingInvocation = matchingTool.invocations + ?.reverse() + .find((invocation) => + Object.keys(args).every( + (key) => + key in invocation.arguments && + invocation.arguments[key] === args[key], ), - })), + ); + + if (!matchingInvocation) { + console.warn( + "The last item of conversation history (`contents`) contains a `function_call`, and a matching `$tool` was found, but no matching invocation was found in the tool's invocations. (Did you forget to call `mapGeminiResponseToToolInvocations`?) Model behavior might be unexpected, because model's function calls were effectively ignored.", + ); + + return; } - : {}), - ...(tools_config - ? { - tools_config: { - ...(tools_config.mode ? { mode: tools_config.mode } : {}), - ...(tools_config.allowed_function_names - ? { - allowed_function_names: tools_config.allowed_function_names, - } - : {}), + + const { returned } = matchingInvocation; + + if (!returned) { + console.warn( + "The last item of conversation history (`contents`) contains a `function_call`, and a $tool with matching invocations was found, but that invocation does NOT have a `returned` value! (Did you forget to execute the tool?) Model behavior might be unexpected, because model's function calls were effectively ignored.", + ); + + return; + } + + responses.push({ + function_response: { + name: matchingTool.name, + response: { + returned, + }, }, + }); + }); + + _contents.push({ + role: "user", + parts: responses, + }); + } else { + if (functionCalls.length && !$tools) { + console.warn( + "The last item of conversation history (`contents`) contains a `function_call`, but no `$tools` were passed, so generative-ts cannot append `function_response` to conversation history. Instead, appending prompt to end of conversation history. Model behavior might be unexpected, because model's function calls were effectively ignored.", + ); + } + + _contents.push({ + role: "user", + parts: [{ text: prompt }], + }); + } + } + + const rewritten = { + contents: _contents, + ...(tools || $tools + ? { + tools: [ + ...(tools || []), + ...($tools + ? mapToolDescriptionsToGeminiRequest({ $tools }).tools || [] + : []), + ], } : {}), ...(system_instruction || system @@ -196,58 +289,19 @@ export const GoogleGeminiTemplate = new FnTemplate( }, } : {}), + ...(tools_config + ? { + tools_config, + } + : {}), ...(safety_settings ? { - safety_settings: { - ...(safety_settings.category - ? { category: safety_settings.category } - : {}), - ...(safety_settings.threshold - ? { threshold: safety_settings.threshold } - : {}), - ...(safety_settings.max_influential_terms - ? { - max_influential_terms: - safety_settings.max_influential_terms, - } - : {}), - ...(safety_settings.method - ? { method: safety_settings.method } - : {}), - }, + safety_settings, } : {}), ...(generation_config ? { - generation_config: { - ...(generation_config.temperature - ? { temperature: generation_config.temperature } - : {}), - ...(generation_config.top_p - ? { top_p: generation_config.top_p } - : {}), - ...(generation_config.top_k - ? { top_k: generation_config.top_k } - : {}), - ...(generation_config.candidate_count - ? { candidate_count: generation_config.candidate_count } - : {}), - ...(generation_config.max_output_tokens - ? { max_output_tokens: generation_config.max_output_tokens } - : {}), - ...(generation_config.stop_sequences - ? { stop_sequences: generation_config.stop_sequences } - : {}), - ...(generation_config.presence_penalty - ? { presence_penalty: generation_config.presence_penalty } - : {}), - ...(generation_config.frequency_penalty - ? { frequency_penalty: generation_config.frequency_penalty } - : {}), - ...(generation_config.response_mime_type - ? { response_mime_type: generation_config.response_mime_type } - : {}), - }, + generation_config, } : {}), }; @@ -271,7 +325,7 @@ const GoogleGeminiResponseCodec = t.type({ text: t.string, functionCall: t.type({ name: t.string, - args: t.record(t.string, t.string), + args: t.record(t.string, t.unknown), }), }), ), diff --git a/packages/core/src/apis/google/mapGeminiResponseToToolInvocations.ts b/packages/core/src/apis/google/mapGeminiResponseToToolInvocations.ts new file mode 100644 index 0000000..350ac7d --- /dev/null +++ b/packages/core/src/apis/google/mapGeminiResponseToToolInvocations.ts @@ -0,0 +1,91 @@ +import { ConvertParamMapToArgs, Tool, ToolParamMap } from "../shared"; + +import type { GoogleGeminiResponse } from "./gemini"; + +/** + * @category Google Gemini + * @category Tools + */ +export function mapGeminiResponseToToolInvocations< + TParamMap extends ToolParamMap, +>( + { data: { candidates } }: GoogleGeminiResponse, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + tools: Tool[], +) { + if (!tools.length) { + return; + } + + candidates.forEach((candidate) => { + candidate.content?.parts.forEach((part) => { + if (part.functionCall) { + const { name, args } = part.functionCall; + + const tool = tools.find(({ descriptor }) => descriptor.name === name); + + if (!tool) { + throw new Error( + `Model attempted to invoke tool ${name} that does not exist`, + ); + } + + // TODO: check unexpected arguments + // Object.keys(args).forEach((argName) => { + // if (!$tool.parameters?.some((param) => param.name === argName)) { + // throw new Error( + // `Model attempted to invoke tool ${name} using unexpected argument ${argName}`, + // ); + // } + // }); + + const validatedArgs: Record = {}; + + tool.descriptor.parameters.forEach((param) => { + const argValue = args[param.name]; + + if (!argValue && param.required) { + throw new Error( + `Model attempted to call function ${name} without providing required argument ${param.name}`, + ); + } + + if (!argValue) { + return; + } + + switch (param.type) { + case "STR": + if (typeof argValue !== "string") { + throw new Error( + `Model attempted to call function ${name} with invalid argument type for ${param.name}. Should have been ${param.type} but got ${typeof argValue}`, + ); + } + break; + case "NUM": + if (typeof argValue !== "number") { + throw new Error( + `Model attempted to call function ${name} with invalid argument type for ${param.name}. Should have been ${param.type} but got ${typeof argValue}`, + ); + } + break; + case "BOOL": + if (typeof argValue !== "boolean") { + throw new Error( + `Model attempted to call function ${name} with invalid argument type for ${param.name}. Should have been ${param.type} but got ${typeof argValue}`, + ); + } + break; + default: + // impossible + return; + } + + validatedArgs[param.name] = argValue; + }); + + tool.invoke(validatedArgs as ConvertParamMapToArgs); + } + }); + }); +} diff --git a/packages/core/src/apis/openai/openAiChatApi.ts b/packages/core/src/apis/openai/openAiChatApi.ts index c1c5293..15a1ffe 100644 --- a/packages/core/src/apis/openai/openAiChatApi.ts +++ b/packages/core/src/apis/openai/openAiChatApi.ts @@ -30,13 +30,25 @@ interface ChatCompletionRequestMessage { }; } +interface OpenAiChatToolsOptions { + tools?: { + type: "function"; + function: { + name: string; + description?: string; + parameters?: object; // TODO JsonSchema + }; + }[]; +} + /** * @category OpenAI ChatCompletion * @category Requests */ export interface OpenAiChatOptions extends ModelRequestOptions, - FewShotRequestOptions { + FewShotRequestOptions, + OpenAiChatToolsOptions { messages?: ChatCompletionRequestMessage[]; frequency_penalty?: number; logit_bias?: Record; @@ -57,14 +69,6 @@ export interface OpenAiChatOptions temperature?: number; top_p?: number; user?: string; - tools?: { - type: "function"; - function: { - name: string; - description?: string; - parameters?: object; // TODO JsonSchema - }; - }[]; tool_choice?: | "none" | "auto" diff --git a/packages/core/src/apis/shared/FewShotRequestOptions.ts b/packages/core/src/apis/shared/FewShotRequestOptions.ts index e59dcd2..ca51cfe 100644 --- a/packages/core/src/apis/shared/FewShotRequestOptions.ts +++ b/packages/core/src/apis/shared/FewShotRequestOptions.ts @@ -1,3 +1,6 @@ +/** + * @category Core Interfaces + */ export interface FewShotRequestOptions { prompt: string; system?: string; diff --git a/packages/core/src/apis/shared/ToolUseRequestOptions.ts b/packages/core/src/apis/shared/ToolUseRequestOptions.ts index 92533a0..95e02af 100644 --- a/packages/core/src/apis/shared/ToolUseRequestOptions.ts +++ b/packages/core/src/apis/shared/ToolUseRequestOptions.ts @@ -1,15 +1,98 @@ -interface ToolParameter { +type ToolParameterTypes = "STR" | "NUM" | "BOOL"; + +export type ToolParamMap = { + [key: string]: { + description: string; + type: ToolParameterTypes; + required: boolean; + }; +}; + +type ExtractArgumentType = T extends "STR" + ? string + : T extends "NUM" + ? number + : T extends "BOOL" + ? boolean + : never; + +export type ConvertParamMapToArgs = { + [K in keyof TParamMap as TParamMap[K]["required"] extends true + ? K + : never]: ExtractArgumentType; +} & { + [K in keyof TParamMap as TParamMap[K]["required"] extends false | undefined + ? K + : never]?: ExtractArgumentType; +}; + +/* + * Ifaces + */ +interface ToolParam { name: string; description: string; - type: "STR" | "INT" | "BOOL"; + type: ToolParameterTypes; + required: boolean; } -interface ToolDescription { +interface ToolInvocation { + arguments: TArgs; + returned?: unknown; +} + +interface ToolDescriptor { name: string; description: string; - parameters: ToolParameter[]; + parameters: ToolParam[]; + invocations: ToolInvocation[]; +} + +export class Tool { + private invokeFn: (args: ConvertParamMapToArgs) => TReturns; + + public descriptor: ToolDescriptor; + + constructor( + name: string, + description: string, + paramMap: TParamMap, + invokeFn: (args: ConvertParamMapToArgs) => TReturns, + ) { + this.invokeFn = invokeFn; + this.descriptor = { + name, + description, + parameters: this.createParameters(paramMap), + invocations: [], + }; + } + + // eslint-disable-next-line class-methods-use-this + private createParameters(paramMap: TParamMap): ToolParam[] { + return Object.entries(paramMap).map(([name, paramInfo]) => ({ + name, + description: paramInfo.description, + type: paramInfo.type, + required: paramInfo.required, + })); + } + + public invoke(args: ConvertParamMapToArgs): TReturns { + const returned = this.invokeFn(args); + + this.descriptor.invocations.push({ + arguments: args, + returned, + }); + + return returned; + } } +/** + * @category Core Interfaces + */ export interface ToolUseRequestOptions { - $toolDescriptions?: ToolDescription[]; + $tools?: ToolDescriptor[]; } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 609fb1a..89c60a3 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -79,7 +79,7 @@ export { createLmStudioModelProvider, createMistralModelProvider, createOpenAiChatModelProvider, - AwsBedrockAuthConfig as AwsAuthConfig, + AwsBedrockAuthConfig as AwsAuthConfig, // TODO ??? CohereAuthConfig, GroqAuthConfig, HuggingfaceAuthConfig, diff --git a/tests/integration/__snapshots__/vertexai-tools.test.ts.snap b/tests/integration/__snapshots__/vertexai-tools.test.ts.snap index 08a3557..f94c8b7 100644 --- a/tests/integration/__snapshots__/vertexai-tools.test.ts.snap +++ b/tests/integration/__snapshots__/vertexai-tools.test.ts.snap @@ -1,5 +1,145 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP +exports[`VertexAI - Google Gemini ($tools) 1`] = ` +{ + "config": { + "body": "string", + "errorRedactor": "function", + "headers": { + "Authorization": "string", + "Content-Type": "string", + "User-Agent": "string", + "x-goog-api-client": "string", + }, + "method": "string", + "paramsSerializer": "function", + "responseType": "string", + "url": "string", + "validateStatus": "function", + }, + "data": { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "string", + }, + ], + "role": "string", + }, + "finishReason": "string", + "safetyRatings": [ + { + "category": "string", + "probability": "string", + "probabilityScore": "number", + "severity": "string", + "severityScore": "number", + }, + ], + }, + ], + "usageMetadata": { + "candidatesTokenCount": "number", + "promptTokenCount": "number", + "totalTokenCount": "number", + }, + }, + "headers": { + "alt-svc": "string", + "cache-control": "string", + "content-encoding": "string", + "content-type": "string", + "date": "string", + "server": "string", + "transfer-encoding": "string", + "vary": "string", + "x-content-type-options": "string", + "x-frame-options": "string", + "x-xss-protection": "string", + }, + "request": { + "responseURL": "string", + }, + "status": "number", + "statusText": "string", +} +`; + +exports[`VertexAI - Google Gemini (COOL TOOLS) 1`] = ` +{ + "config": { + "body": "string", + "errorRedactor": "function", + "headers": { + "Authorization": "string", + "Content-Type": "string", + "User-Agent": "string", + "x-goog-api-client": "string", + }, + "method": "string", + "paramsSerializer": "function", + "responseType": "string", + "url": "string", + "validateStatus": "function", + }, + "data": { + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "args": { + "city": "string", + "state": "string", + }, + "name": "string", + }, + }, + ], + "role": "string", + }, + "finishReason": "string", + "safetyRatings": [ + { + "category": "string", + "probability": "string", + "probabilityScore": "number", + "severity": "string", + "severityScore": "number", + }, + ], + }, + ], + "usageMetadata": { + "candidatesTokenCount": "number", + "promptTokenCount": "number", + "totalTokenCount": "number", + }, + }, + "headers": { + "alt-svc": "string", + "cache-control": "string", + "content-encoding": "string", + "content-type": "string", + "date": "string", + "server": "string", + "transfer-encoding": "string", + "vary": "string", + "x-content-type-options": "string", + "x-frame-options": "string", + "x-xss-protection": "string", + }, + "request": { + "responseURL": "string", + }, + "status": "number", + "statusText": "string", +} +`; + exports[`VertexAI - Google Gemini (Tools with Responses) 1`] = ` { "config": { diff --git a/tests/integration/vertexai-tools.test.ts b/tests/integration/vertexai-tools.test.ts index 4da51cd..bd7b879 100644 --- a/tests/integration/vertexai-tools.test.ts +++ b/tests/integration/vertexai-tools.test.ts @@ -1,6 +1,9 @@ +/* eslint-disable import/no-relative-packages */ import { createVertexAiModelProvider } from "@packages/gcloud-vertex-ai"; +import { mapGeminiResponseToToolInvocations } from "../../packages/core/src/apis/google/mapGeminiResponseToToolInvocations"; +import { Tool } from "../../packages/core/src/apis/shared/ToolUseRequestOptions"; -test("VertexAI - Google Gemini (Tools)", async () => { +xtest("VertexAI - Google Gemini (Tools)", async () => { // arrange const model = await createVertexAiModelProvider({ modelId: "gemini-1.0-pro", @@ -42,7 +45,7 @@ test("VertexAI - Google Gemini (Tools)", async () => { expect(response).toMatchApiSnapshot(); }); -test("VertexAI - Google Gemini (Tools with Responses)", async () => { +xtest("VertexAI - Google Gemini (Tools with Responses)", async () => { // arrange const model = await createVertexAiModelProvider({ modelId: "gemini-1.0-pro", @@ -76,7 +79,7 @@ test("VertexAI - Google Gemini (Tools with Responses)", async () => { role: "model", parts: [ { - function_call: { + functionCall: { name: "get_current_weather", args: { city: "Boston", @@ -85,7 +88,7 @@ test("VertexAI - Google Gemini (Tools with Responses)", async () => { }, }, { - function_call: { + functionCall: { name: "get_current_weather", args: { city: "New York City", @@ -158,3 +161,91 @@ test("VertexAI - Google Gemini (Tools with Responses)", async () => { // assert expect(response).toMatchApiSnapshot(); }); + +test("VertexAI - Google Gemini ($tools)", async () => { + // arrange + const model = await createVertexAiModelProvider({ + modelId: "gemini-1.0-pro", + }); + + const tools = [ + new Tool( + "get_current_weather", + "Get the current weather for a given location", + { + city: { + description: "The city name", + type: "STR", + required: true, + }, + state: { + description: "The state name", + type: "STR", + required: true, + }, + zipcode: { + description: "An optional zipcode", + type: "NUM", + required: false, + }, + }, + ({ city, state, zipcode }) => { + console.log("Invoking get_current_weather tool...", { + city, + state, + zipcode, + }); + return { + temperature: "70", + }; + }, + ), + ]; + + // act + const response = await model.sendRequest({ + system: "Use tools to help answer questions.", + prompt: "What is the weather in Boston and New York City?", + $tools: tools.map(({ descriptor }) => descriptor), + }); + + console.log("Got response. Mapping tool invocations...."); + + mapGeminiResponseToToolInvocations(response, tools); + + console.log(JSON.stringify(tools, null, 2)); + + const last = response.data.candidates[0]?.content; + + if (!last) { + throw new Error("No content found in response!?"); + } + + const response2 = await model.sendRequest({ + system: "Use tools to help answer questions.", + prompt: "What is the weather in Boston and New York City?", + contents: [ + { + // TODO eyyy insert prompt at start etc etc... + role: "user", + parts: [ + { + text: "What is the weather in Boston and New York City?", + }, + ], + }, + { + // ...last, + // TODO fix the need for this (above line gives error cant convert between t.string() and "model" | "user") + role: "model", + parts: last.parts, + }, + ], + $tools: tools.map(({ descriptor }) => descriptor), + }); + + console.log(JSON.stringify(response2.data.candidates[0]?.content, null, 2)); + + // assert + expect(response2).toMatchApiSnapshot(); +});