From 27ec106c4ec457acbf737c3c7804c537ce164b01 Mon Sep 17 00:00:00 2001 From: neuecc Date: Fri, 5 Apr 2024 16:29:37 +0900 Subject: [PATCH] WIP tool function calling --- sandbox/ConsoleApp1/GeneratedMock.cs | 191 +++++++++++++++++------ sandbox/ConsoleApp1/Program.cs | 16 +- src/Claudia.FunctionGenerator/Emitter.cs | 78 +++++++++ src/Claudia/Anthropic.cs | 1 - src/Claudia/MessageRequest.cs | 55 +++++++ 5 files changed, 281 insertions(+), 60 deletions(-) diff --git a/sandbox/ConsoleApp1/GeneratedMock.cs b/sandbox/ConsoleApp1/GeneratedMock.cs index 2fd55bd..df09413 100644 --- a/sandbox/ConsoleApp1/GeneratedMock.cs +++ b/sandbox/ConsoleApp1/GeneratedMock.cs @@ -17,16 +17,17 @@ //using Claudia; //using System; +//using System.Collections.Generic; //using System.Linq; //using System.Text; //using System.Threading.Tasks; //using System.Xml.Linq; -//static partial class FunctionTools +//static partial class FunctionTools2 //{ // public const string SystemPrompt = @$" -//In this environment you have access to a set of tools you can use to answer the user's question. If there are multiple tags, please consolidate them into a single block. +//In this environment you have access to a set of tools you can use to answer the user's question. If your solution involves the use of multiple tools, please include multiple s within a single tag. Each step-by-step answer or tag is not required. Only a single tag should be returned at the beginning. //You may call them like this: // @@ -37,59 +38,50 @@ // ... // // +// +// $TOOL_NAME +// +// <$PARAMETER_NAME>$PARAMETER_VALUE +// ... +// +// +// ... // //Here are the tools available: //{PromptXml.ToolsAll} + +//Again, including multiple tags in the reply is prohibited. //"; // public static class PromptXml // { // public const string ToolsAll = @$" -//{Today} -//{Sum} +// +//{TimeOfDay} //{DoPairwiseArithmetic} +//{GetHtmlFromWeb} +// //"; - -// public const string Today = @" +// public const string TimeOfDay = @" // -// Today -// Date of target location. +// TimeOfDay +// Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London. // // -// timeZoneId +// timeZone // string -// TimeZone of localtion like 'Tokeyo Standard Time', 'Eastern Standard Time', etc. +// The time zone to get the current time for, such as UTC, US/Pacific, Europe/London. // // // //"; - -// public const string Sum = @" -// -// Sum -// Sum of two integer parameters. -// -// -// x -// int -// parameter1. -// -// -// y -// int -// parameter2. -// -// -// -//"; - // public const string DoPairwiseArithmetic = @" // // DoPairwiseArithmetic // Calculator function for doing basic arithmetic. -// Supports addition, subtraction, multiplication +// Supports addition, subtraction, multiplication // // // firstOperand @@ -109,8 +101,108 @@ // // //"; +// public const string GetHtmlFromWeb = @" +// +// GetHtmlFromWeb +// Retrieves the HTML from the specified URL. +// +// +// url +// string +// The URL to retrieve the HTML from. +// +// +// +//"; +// public static class Tools +// { +// public static readonly Tool TimeOfDay = new Tool +// { +// Name = "TimeOfDay", +// Description = "Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London.", +// InputSchema = new InputSchema +// { +// Type = "object", +// Properties = new Dictionary +// { +// { +// "timeZone", new ToolProperty() +// { +// Type = "string", +// Description = "Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London." +// } +// }, + +// }, +// Required = new[] { timeZone } +// } + +// }; +// public static readonly Tool DoPairwiseArithmetic = new Tool +// { +// Name = "DoPairwiseArithmetic", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, +// subtraction, +// multiplication", +// InputSchema = new InputSchema +// { +// Type = "object", +// Properties = new Dictionary +// { +// { +// "firstOperand", new ToolProperty() +// { +// Type = "double", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, subtraction, multiplication" +// }, +// }, +// { +// "secondOperand", new ToolProperty() +// { +// Type = "double", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, subtraction, multiplication" +// }, +// }, +// { +// "operator", new ToolProperty() +// { +// Type = "string", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, subtraction, multiplication" +// }, +// }, + +// }, +// Required = new[] { firstOperand, secondOperand, operator } +// } + +// }; +// public static readonly Tool GetHtmlFromWeb = new Tool +// { +// Name = "GetHtmlFromWeb", +// Description = "Retrieves the HTML from the specified URL.", +// InputSchema = new InputSchema +// { +// Type = "object", +// Properties = new Dictionary +// { +// { +// "url", new ToolProperty() +// { +// Type = "string", +// Description = "Retrieves the HTML from the specified URL." +// }, +// }, +// }, +// Required = new[] { url } +// } +// }; +// } // } //#pragma warning disable CS1998 @@ -118,9 +210,9 @@ // { // var content = message.Content.FirstOrDefault(x => x.Text != null); // if (content == null) return null; - + // var text = content.Text; -// var tagStart = text .IndexOf(""); +// var tagStart = text.IndexOf(""); // if (tagStart == -1) return null; // var functionCalls = text.Substring(tagStart) + ""; @@ -135,34 +227,33 @@ // var name = (string)item.Element("tool_name")!; // switch (name) // { -// case "Today": +// case "TimeOfDay": // { // var parameters = item.Element("parameters")!; -// var _0 = (string)parameters.Element("timeZoneId")!; +// var _0 = (string)parameters.Element("timeZone")!; -// BuildResult(sb, "Today", Today(_0)); +// BuildResult(sb, "TimeOfDay", TimeOfDay(_0)); // break; // } -// case "Sum": +// case "DoPairwiseArithmetic": // { // var parameters = item.Element("parameters")!; -// var _0 = (int)parameters.Element("x")!; -// var _1 = (int)parameters.Element("y")!; +// var _0 = (double)parameters.Element("firstOperand")!; +// var _1 = (double)parameters.Element("secondOperand")!; +// var _2 = (string)parameters.Element("operator")!; -// BuildResult(sb, "Sum", Sum(_0, _1)); +// BuildResult(sb, "DoPairwiseArithmetic", DoPairwiseArithmetic(_0, _1, _2)); // break; // } -// case "DoPairwiseArithmetic": +// case "GetHtmlFromWeb": // { // var parameters = item.Element("parameters")!; -// var _0 = (double)parameters.Element("firstOperand")!; -// var _1 = (double)parameters.Element("secondOperand")!; -// var _2 = (string)parameters.Element("operator")!; +// var _0 = (string)parameters.Element("url")!; -// BuildResult(sb, "DoPairwiseArithmetic", DoPairwiseArithmetic(_0, _1, _2)); +// BuildResult(sb, "GetHtmlFromWeb", await GetHtmlFromWeb(_0).ConfigureAwait(false)); // break; // } @@ -177,14 +268,10 @@ // static void BuildResult(StringBuilder sb, string toolName, T result) // { -// sb.AppendLine(@$" -// +// sb.AppendLine(@$" // {toolName} -// -// {result} -// -// -//"); +// {result} +// "); // } // } //#pragma warning restore CS1998 diff --git a/sandbox/ConsoleApp1/Program.cs b/sandbox/ConsoleApp1/Program.cs index d25a2ad..95475ee 100644 --- a/sandbox/ConsoleApp1/Program.cs +++ b/sandbox/ConsoleApp1/Program.cs @@ -46,26 +46,28 @@ - +anthropic.HttpClient.DefaultRequestHeaders.Add("anthropic-beta", "tools-2024-04-04"); var input = new Message { Role = Roles.User, - Content = """ - What time is it in Seattle and Tokyo? - Incidentally multiply 1,984,135 by 9,343,116. -""" + Content = "What time is it in Tokyo?", }; var message = await anthropic.Messages.CreateAsync(new() { Model = Models.Claude3Haiku, MaxTokens = 1024, - System = FunctionTools.SystemPrompt, // set generated prompt - StopSequences = [StopSequnces.CloseFunctionCalls], // set as stop sequence + + //System = FunctionTools.SystemPrompt, // set generated prompt + //StopSequences = [StopSequnces.CloseFunctionCalls], // set as stop sequence + + Tools = [FunctionTools.PromptXml.Tools.TimeOfDay], Messages = [input], }); + + var partialAssistantMessage = await FunctionTools.InvokeAsync(message); var callResult = await anthropic.Messages.CreateAsync(new() diff --git a/src/Claudia.FunctionGenerator/Emitter.cs b/src/Claudia.FunctionGenerator/Emitter.cs index ab14b25..2986475 100644 --- a/src/Claudia.FunctionGenerator/Emitter.cs +++ b/src/Claudia.FunctionGenerator/Emitter.cs @@ -1,6 +1,7 @@ using Microsoft.CodeAnalysis; using System.Text; using System.Xml.Linq; +using static Claudia.FunctionGenerator.Emitter; namespace Claudia.FunctionGenerator; @@ -88,6 +89,8 @@ public static class PromptXml EmitToolDescription(method); } + EmitTools(parseResult); + sb.AppendLine(" }"); // close PromptXml sb.AppendLine(); @@ -206,6 +209,81 @@ static void BuildResult(StringBuilder sb, string toolName, T result) sb.AppendLine(code); } + void EmitTools(ParseResult parseResult) + { + // TODO: Add All + // public static readonly Tool[] All = @$" + sb.AppendLine($$""" + public static class Tools + { +"""); + + // Emit Tool + foreach (var method in parseResult.Methods) + { + var docComment = method.Syntax.GetDocumentationCommentTriviaSyntax()!; + var description = docComment.GetSummary().Replace("\"", "'").Replace("\r\n", " ").Replace("\n", " "); + + // property + var inputSchema = new StringBuilder(); + if (method.Symbol.Parameters.Length != 0) + { + var propBuilder = new StringBuilder(); + var paramRequired = new List(); + foreach (var p in docComment.GetParams()) + { + var paramDescription = p.Description.Replace("\"", "'").Replace("\r\n", " ").Replace("\n", " "); + + // type retrieve from method symbol + var name = p.Name; + var paramType = method.Symbol.Parameters.First(x => x.Name == name).Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + + // TODO: support enum + propBuilder.AppendLine($$""" + { + "{{name}}", new ToolProperty() + { + Type = "{{paramType}}", + Description = "{{description}}" + } + }, +"""); + // TODO: support optional parameter + paramRequired.Add("\"" + name + "\""); + } + var required = string.Join(", ", paramRequired); + if (required.Length != 0) + { + required = "Required = new [] { " + required + " }"; + } + + inputSchema.AppendLine($$""" + InputSchema = new InputSchema + { + Type = "object", + Properties = new System.Collections.Generic.Dictionary + { +{{propBuilder}} + }, + {{required}} + } +"""); + } + + sb.AppendLine($$""" + public static readonly Tool {{method.Name}} = new Tool + { + Name = "{{method.Name}}", + Description = "{{description}}", +{{inputSchema}} + }; + +"""); + } + + sb.AppendLine(" }"); // close Tools + } + static void AddSource(SourceProductionContext context, ISymbol targetSymbol, string code, string fileExtension = ".g.cs") { var fullType = targetSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) diff --git a/src/Claudia/Anthropic.cs b/src/Claudia/Anthropic.cs index 250647a..1907293 100644 --- a/src/Claudia/Anthropic.cs +++ b/src/Claudia/Anthropic.cs @@ -68,7 +68,6 @@ async Task IMessages.CreateAsync(MessageRequest request, Reques { request.Stream = null; using var msg = await SendRequestAsync(request, overrideOptions, cancellationToken).ConfigureAwait(ConfigureAwait); - var result = await RequestWithAsync(msg, cancellationToken, overrideOptions, static (x, ct, _) => x.Content.ReadFromJsonAsync(AnthropicJsonSerialzierContext.Default.Options, ct), null).ConfigureAwait(ConfigureAwait); return result!; } diff --git a/src/Claudia/MessageRequest.cs b/src/Claudia/MessageRequest.cs index ffc2b3e..94b46d5 100644 --- a/src/Claudia/MessageRequest.cs +++ b/src/Claudia/MessageRequest.cs @@ -90,6 +90,11 @@ public override string ToString() { return JsonSerializer.Serialize(this, AnthropicJsonSerialzierContext.Default.Options); } + + // 2024-04-04 beta: https://docs.anthropic.com/claude/docs/tool-use + [JsonPropertyName("tools")] + public Tool[]? Tools { get; set; } + } public record class Message { @@ -145,6 +150,19 @@ public record class Content [JsonPropertyName("source")] public Source? Source { get; set; } + #region tool_use response + + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("input")] + public Dictionary? Input { get; set; } + + #endregion + public static implicit operator Content(string text) => new Content(text); public Content() @@ -214,3 +232,40 @@ public record class Source [JsonPropertyName("data")] public required ReadOnlyMemory Data { get; set; } // Base64 } + +// https://docs.anthropic.com/claude/docs/tool-use +public record class Tool +{ + [JsonPropertyName("name")] + public required string Name { get; set; } + + [JsonPropertyName("description")] + public required string Description { get; set; } + + [JsonPropertyName("input_schema")] + public InputSchema? InputSchema { get; set; } +} + +public record class InputSchema +{ + [JsonPropertyName("type")] + public required string Type { get; set; } + + [JsonPropertyName("properties")] + public Dictionary? Properties { get; set; } + + [JsonPropertyName("required")] + public string[]? Required { get; set; } +} + +public record class ToolProperty +{ + [JsonPropertyName("type")] + public required string Type { get; set; } + + [JsonPropertyName("enum")] + public string[]? Enum { get; set; } + + [JsonPropertyName("description")] + public required string Description { get; set; } +} \ No newline at end of file