Skip to content

Commit

Permalink
move stuff around
Browse files Browse the repository at this point in the history
  • Loading branch information
jnaglick committed Jun 19, 2024
1 parent c13384f commit e1425b8
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 93 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ConvertParamMapToArgs, Tool, ToolParamMap } from "../shared";
import { ConvertParamMapToArgs, Tool, ToolParamMap } from "../../utils/Tool";

import type { GoogleGeminiResponse } from "./gemini";

Expand Down
92 changes: 1 addition & 91 deletions packages/core/src/apis/shared/ToolUseRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,94 +1,4 @@
type ToolParameterTypes = "STR" | "NUM" | "BOOL";

export type ToolParamMap = {
[key: string]: {
description: string;
type: ToolParameterTypes;
required: boolean;
};
};

type ExtractArgumentType<T extends ToolParameterTypes> = T extends "STR"
? string
: T extends "NUM"
? number
: T extends "BOOL"
? boolean
: never;

export type ConvertParamMapToArgs<TParamMap extends ToolParamMap> = {
[K in keyof TParamMap as TParamMap[K]["required"] extends true
? K
: never]: ExtractArgumentType<TParamMap[K]["type"]>;
} & {
[K in keyof TParamMap as TParamMap[K]["required"] extends false | undefined
? K
: never]?: ExtractArgumentType<TParamMap[K]["type"]>;
};

/*
* Ifaces
*/
interface ToolParam {
name: string;
description: string;
type: ToolParameterTypes;
required: boolean;
}

interface ToolInvocation<TArgs> {
arguments: TArgs;
returned?: unknown;
}

interface ToolDescriptor {
name: string;
description: string;
parameters: ToolParam[];
invocations: ToolInvocation<any>[];
}

export class Tool<TParamMap extends ToolParamMap, TReturns = unknown> {
private invokeFn: (args: ConvertParamMapToArgs<TParamMap>) => TReturns;

public descriptor: ToolDescriptor;

constructor(
name: string,
description: string,
paramMap: TParamMap,
invokeFn: (args: ConvertParamMapToArgs<TParamMap>) => TReturns,
) {
this.invokeFn = invokeFn;
this.descriptor = {
name,
description,
parameters: this.createParameters(paramMap),
invocations: [],
};
}

// eslint-disable-next-line class-methods-use-this
private createParameters(paramMap: TParamMap): ToolParam[] {
return Object.entries(paramMap).map(([name, paramInfo]) => ({
name,
description: paramInfo.description,
type: paramInfo.type,
required: paramInfo.required,
}));
}

public invoke(args: ConvertParamMapToArgs<TParamMap>): TReturns {
const returned = this.invokeFn(args);

this.descriptor.invocations.push({
arguments: args,
returned,
});

return returned;
}
}
import { ToolDescriptor } from "@typeDefs";

/**
* @category Core Interfaces
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ export {
* Public API - Utils
*/
export { FnTemplate } from "./utils";
// TODO export Tool

/*
* Public API - TypeDefs
Expand Down
21 changes: 21 additions & 0 deletions packages/core/src/typeDefs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,24 @@ export type InferRequestOptions<T> =
T extends ModelApi<infer U, unknown> ? U : never;
export type InferResponse<T> =
T extends ModelApi<ModelRequestOptions, infer V> ? V : never;

export type ToolParameterTypes = "STR" | "NUM" | "BOOL";

export interface ToolParam {
name: string;
description: string;
type: ToolParameterTypes;
required: boolean;
}

export interface ToolInvocation<TArgs> {
arguments: TArgs;
returned?: unknown;
}

export interface ToolDescriptor {
name: string;
description: string;
parameters: ToolParam[];
invocations: ToolInvocation<any>[];
}
144 changes: 144 additions & 0 deletions packages/core/src/utils/Tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import { ToolDescriptor, ToolParam, ToolParameterTypes } from "@typeDefs";

export type ToolParamMap = {
[key: string]: {
description: string;
type: ToolParameterTypes;
required: boolean;
};
};

type ExtractArgumentType<T extends ToolParameterTypes> = T extends "STR"
? string
: T extends "NUM"
? number
: T extends "BOOL"
? boolean
: never;

export type ConvertParamMapToArgs<TParamMap extends ToolParamMap> = {
[K in keyof TParamMap as TParamMap[K]["required"] extends true
? K
: never]: ExtractArgumentType<TParamMap[K]["type"]>;
} & {
[K in keyof TParamMap as TParamMap[K]["required"] extends false | undefined
? K
: never]?: ExtractArgumentType<TParamMap[K]["type"]>;
};

export class Tool<TParamMap extends ToolParamMap, TReturns = unknown> {
private invokeFn: (args: ConvertParamMapToArgs<TParamMap>) => TReturns;

public descriptor: ToolDescriptor;

constructor(
name: string,
description: string,
paramMap: TParamMap,
invokeFn: (args: ConvertParamMapToArgs<TParamMap>) => TReturns,
) {
this.invokeFn = invokeFn;
this.descriptor = {
name,
description,
parameters: this.createParameters(paramMap),
invocations: [],
};
}

// eslint-disable-next-line class-methods-use-this
private createParameters(paramMap: TParamMap): ToolParam[] {
return Object.entries(paramMap).map(([name, paramInfo]) => ({
name,
description: paramInfo.description,
type: paramInfo.type,
required: paramInfo.required,
}));
}

public invoke(args: ConvertParamMapToArgs<TParamMap>): TReturns {
const returned = this.invokeFn(args);

this.descriptor.invocations.push({
arguments: args,
returned,
});

return returned;
}
}

export const a = new Tool(
"get_current_weather",
"Get the current weather for a given location",
{
city: {
description: "The city name",
type: "STR",
required: true,
},
state: {
description: "The state name",
type: "STR",
required: true,
},
zipcode: {
description: "An optional zipcode",
type: "NUM",
required: false,
},
},
// should work:
({ city, state, zipcode }) => {
console.log("Invoking get_current_weather tool...", {
city,
state,
zipcode,
});
return {
temperature: "70",
};
},
);

a.invoke({ city: "San Francisco", state: "CA" });
a.invoke({ city: "San Francisco", state: "CA", zipcode: 94105 });

// bad:
// a.invoke({ city: "San Francisco" });
// a.invoke({ city: "San Francisco", state: "CA", zipcode: 94105, xxx: "yyy" });
// a.invoke({ city: "San Francisco", state: 123 });

// export const x1 = new Tool(
// "get_current_weather",
// "Get the current weather for a given location",
// {
// city: {
// description: "The city name",
// type: "STR",
// required: true,
// },
// state: {
// description: "The state name",
// type: "STR",
// required: true,
// },
// zipcode: {
// description: "An optional zipcode",
// type: "NUM",
// required: false,
// },
// },
// // should break, prop not defined:
// ({ city, state, zipcode, xxx }) => {
// console.log("Invoking get_current_weather tool...", {
// city,
// state,
// zipcode,
// xxx,
// });
// return {
// temperature: "70",
// };
// },
// );
1 change: 1 addition & 0 deletions packages/core/src/utils/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export * from "./Template";
export * from "./Tool";
3 changes: 2 additions & 1 deletion tests/integration/vertexai-tools.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/* eslint-disable import/no-relative-packages */
import { createVertexAiModelProvider } from "@packages/gcloud-vertex-ai";
// TODO decide on these import/exports as part of public API
import { mapGeminiResponseToToolInvocations } from "../../packages/core/src/apis/google/mapGeminiResponseToToolInvocations";
import { Tool } from "../../packages/core/src/apis/shared/ToolUseRequestOptions";
import { Tool } from "../../packages/core/src/utils/Tool";

xtest("VertexAI - Google Gemini (Tools)", async () => {
// arrange
Expand Down

0 comments on commit e1425b8

Please sign in to comment.