From 27ec106c4ec457acbf737c3c7804c537ce164b01 Mon Sep 17 00:00:00 2001 From: neuecc Date: Fri, 5 Apr 2024 16:29:37 +0900 Subject: [PATCH 1/2] WIP tool function calling --- sandbox/ConsoleApp1/GeneratedMock.cs | 191 +++++++++++++++++------ sandbox/ConsoleApp1/Program.cs | 16 +- src/Claudia.FunctionGenerator/Emitter.cs | 78 +++++++++ src/Claudia/Anthropic.cs | 1 - src/Claudia/MessageRequest.cs | 55 +++++++ 5 files changed, 281 insertions(+), 60 deletions(-) diff --git a/sandbox/ConsoleApp1/GeneratedMock.cs b/sandbox/ConsoleApp1/GeneratedMock.cs index 2fd55bd..df09413 100644 --- a/sandbox/ConsoleApp1/GeneratedMock.cs +++ b/sandbox/ConsoleApp1/GeneratedMock.cs @@ -17,16 +17,17 @@ //using Claudia; //using System; +//using System.Collections.Generic; //using System.Linq; //using System.Text; //using System.Threading.Tasks; //using System.Xml.Linq; -//static partial class FunctionTools +//static partial class FunctionTools2 //{ // public const string SystemPrompt = @$" -//In this environment you have access to a set of tools you can use to answer the user's question. If there are multiple tags, please consolidate them into a single block. +//In this environment you have access to a set of tools you can use to answer the user's question. If your solution involves the use of multiple tools, please include multiple s within a single tag. Each step-by-step answer or tag is not required. Only a single tag should be returned at the beginning. //You may call them like this: // @@ -37,59 +38,50 @@ // ... // // +// +// $TOOL_NAME +// +// <$PARAMETER_NAME>$PARAMETER_VALUE +// ... +// +// +// ... // //Here are the tools available: //{PromptXml.ToolsAll} + +//Again, including multiple tags in the reply is prohibited. //"; // public static class PromptXml // { // public const string ToolsAll = @$" -//{Today} -//{Sum} +// +//{TimeOfDay} //{DoPairwiseArithmetic} +//{GetHtmlFromWeb} +// //"; - -// public const string Today = @" +// public const string TimeOfDay = @" // -// Today -// Date of target location. +// TimeOfDay +// Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London. // // -// timeZoneId +// timeZone // string -// TimeZone of localtion like 'Tokeyo Standard Time', 'Eastern Standard Time', etc. +// The time zone to get the current time for, such as UTC, US/Pacific, Europe/London. // // // //"; - -// public const string Sum = @" -// -// Sum -// Sum of two integer parameters. -// -// -// x -// int -// parameter1. -// -// -// y -// int -// parameter2. -// -// -// -//"; - // public const string DoPairwiseArithmetic = @" // // DoPairwiseArithmetic // Calculator function for doing basic arithmetic. -// Supports addition, subtraction, multiplication +// Supports addition, subtraction, multiplication // // // firstOperand @@ -109,8 +101,108 @@ // // //"; +// public const string GetHtmlFromWeb = @" +// +// GetHtmlFromWeb +// Retrieves the HTML from the specified URL. +// +// +// url +// string +// The URL to retrieve the HTML from. +// +// +// +//"; +// public static class Tools +// { +// public static readonly Tool TimeOfDay = new Tool +// { +// Name = "TimeOfDay", +// Description = "Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London.", +// InputSchema = new InputSchema +// { +// Type = "object", +// Properties = new Dictionary +// { +// { +// "timeZone", new ToolProperty() +// { +// Type = "string", +// Description = "Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London." +// } +// }, + +// }, +// Required = new[] { timeZone } +// } + +// }; +// public static readonly Tool DoPairwiseArithmetic = new Tool +// { +// Name = "DoPairwiseArithmetic", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, +// subtraction, +// multiplication", +// InputSchema = new InputSchema +// { +// Type = "object", +// Properties = new Dictionary +// { +// { +// "firstOperand", new ToolProperty() +// { +// Type = "double", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, subtraction, multiplication" +// }, +// }, +// { +// "secondOperand", new ToolProperty() +// { +// Type = "double", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, subtraction, multiplication" +// }, +// }, +// { +// "operator", new ToolProperty() +// { +// Type = "string", +// Description = "Calculator function for doing basic arithmetic. +// Supports addition, subtraction, multiplication" +// }, +// }, + +// }, +// Required = new[] { firstOperand, secondOperand, operator } +// } + +// }; +// public static readonly Tool GetHtmlFromWeb = new Tool +// { +// Name = "GetHtmlFromWeb", +// Description = "Retrieves the HTML from the specified URL.", +// InputSchema = new InputSchema +// { +// Type = "object", +// Properties = new Dictionary +// { +// { +// "url", new ToolProperty() +// { +// Type = "string", +// Description = "Retrieves the HTML from the specified URL." +// }, +// }, +// }, +// Required = new[] { url } +// } +// }; +// } // } //#pragma warning disable CS1998 @@ -118,9 +210,9 @@ // { // var content = message.Content.FirstOrDefault(x => x.Text != null); // if (content == null) return null; - + // var text = content.Text; -// var tagStart = text .IndexOf(""); +// var tagStart = text.IndexOf(""); // if (tagStart == -1) return null; // var functionCalls = text.Substring(tagStart) + ""; @@ -135,34 +227,33 @@ // var name = (string)item.Element("tool_name")!; // switch (name) // { -// case "Today": +// case "TimeOfDay": // { // var parameters = item.Element("parameters")!; -// var _0 = (string)parameters.Element("timeZoneId")!; +// var _0 = (string)parameters.Element("timeZone")!; -// BuildResult(sb, "Today", Today(_0)); +// BuildResult(sb, "TimeOfDay", TimeOfDay(_0)); // break; // } -// case "Sum": +// case "DoPairwiseArithmetic": // { // var parameters = item.Element("parameters")!; -// var _0 = (int)parameters.Element("x")!; -// var _1 = (int)parameters.Element("y")!; +// var _0 = (double)parameters.Element("firstOperand")!; +// var _1 = (double)parameters.Element("secondOperand")!; +// var _2 = (string)parameters.Element("operator")!; -// BuildResult(sb, "Sum", Sum(_0, _1)); +// BuildResult(sb, "DoPairwiseArithmetic", DoPairwiseArithmetic(_0, _1, _2)); // break; // } -// case "DoPairwiseArithmetic": +// case "GetHtmlFromWeb": // { // var parameters = item.Element("parameters")!; -// var _0 = (double)parameters.Element("firstOperand")!; -// var _1 = (double)parameters.Element("secondOperand")!; -// var _2 = (string)parameters.Element("operator")!; +// var _0 = (string)parameters.Element("url")!; -// BuildResult(sb, "DoPairwiseArithmetic", DoPairwiseArithmetic(_0, _1, _2)); +// BuildResult(sb, "GetHtmlFromWeb", await GetHtmlFromWeb(_0).ConfigureAwait(false)); // break; // } @@ -177,14 +268,10 @@ // static void BuildResult(StringBuilder sb, string toolName, T result) // { -// sb.AppendLine(@$" -// +// sb.AppendLine(@$" // {toolName} -// -// {result} -// -// -//"); +// {result} +// "); // } // } //#pragma warning restore CS1998 diff --git a/sandbox/ConsoleApp1/Program.cs b/sandbox/ConsoleApp1/Program.cs index d25a2ad..95475ee 100644 --- a/sandbox/ConsoleApp1/Program.cs +++ b/sandbox/ConsoleApp1/Program.cs @@ -46,26 +46,28 @@ - +anthropic.HttpClient.DefaultRequestHeaders.Add("anthropic-beta", "tools-2024-04-04"); var input = new Message { Role = Roles.User, - Content = """ - What time is it in Seattle and Tokyo? - Incidentally multiply 1,984,135 by 9,343,116. -""" + Content = "What time is it in Tokyo?", }; var message = await anthropic.Messages.CreateAsync(new() { Model = Models.Claude3Haiku, MaxTokens = 1024, - System = FunctionTools.SystemPrompt, // set generated prompt - StopSequences = [StopSequnces.CloseFunctionCalls], // set as stop sequence + + //System = FunctionTools.SystemPrompt, // set generated prompt + //StopSequences = [StopSequnces.CloseFunctionCalls], // set as stop sequence + + Tools = [FunctionTools.PromptXml.Tools.TimeOfDay], Messages = [input], }); + + var partialAssistantMessage = await FunctionTools.InvokeAsync(message); var callResult = await anthropic.Messages.CreateAsync(new() diff --git a/src/Claudia.FunctionGenerator/Emitter.cs b/src/Claudia.FunctionGenerator/Emitter.cs index ab14b25..2986475 100644 --- a/src/Claudia.FunctionGenerator/Emitter.cs +++ b/src/Claudia.FunctionGenerator/Emitter.cs @@ -1,6 +1,7 @@ using Microsoft.CodeAnalysis; using System.Text; using System.Xml.Linq; +using static Claudia.FunctionGenerator.Emitter; namespace Claudia.FunctionGenerator; @@ -88,6 +89,8 @@ public static class PromptXml EmitToolDescription(method); } + EmitTools(parseResult); + sb.AppendLine(" }"); // close PromptXml sb.AppendLine(); @@ -206,6 +209,81 @@ static void BuildResult(StringBuilder sb, string toolName, T result) sb.AppendLine(code); } + void EmitTools(ParseResult parseResult) + { + // TODO: Add All + // public static readonly Tool[] All = @$" + sb.AppendLine($$""" + public static class Tools + { +"""); + + // Emit Tool + foreach (var method in parseResult.Methods) + { + var docComment = method.Syntax.GetDocumentationCommentTriviaSyntax()!; + var description = docComment.GetSummary().Replace("\"", "'").Replace("\r\n", " ").Replace("\n", " "); + + // property + var inputSchema = new StringBuilder(); + if (method.Symbol.Parameters.Length != 0) + { + var propBuilder = new StringBuilder(); + var paramRequired = new List(); + foreach (var p in docComment.GetParams()) + { + var paramDescription = p.Description.Replace("\"", "'").Replace("\r\n", " ").Replace("\n", " "); + + // type retrieve from method symbol + var name = p.Name; + var paramType = method.Symbol.Parameters.First(x => x.Name == name).Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + + // TODO: support enum + propBuilder.AppendLine($$""" + { + "{{name}}", new ToolProperty() + { + Type = "{{paramType}}", + Description = "{{description}}" + } + }, +"""); + // TODO: support optional parameter + paramRequired.Add("\"" + name + "\""); + } + var required = string.Join(", ", paramRequired); + if (required.Length != 0) + { + required = "Required = new [] { " + required + " }"; + } + + inputSchema.AppendLine($$""" + InputSchema = new InputSchema + { + Type = "object", + Properties = new System.Collections.Generic.Dictionary + { +{{propBuilder}} + }, + {{required}} + } +"""); + } + + sb.AppendLine($$""" + public static readonly Tool {{method.Name}} = new Tool + { + Name = "{{method.Name}}", + Description = "{{description}}", +{{inputSchema}} + }; + +"""); + } + + sb.AppendLine(" }"); // close Tools + } + static void AddSource(SourceProductionContext context, ISymbol targetSymbol, string code, string fileExtension = ".g.cs") { var fullType = targetSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) diff --git a/src/Claudia/Anthropic.cs b/src/Claudia/Anthropic.cs index 250647a..1907293 100644 --- a/src/Claudia/Anthropic.cs +++ b/src/Claudia/Anthropic.cs @@ -68,7 +68,6 @@ async Task IMessages.CreateAsync(MessageRequest request, Reques { request.Stream = null; using var msg = await SendRequestAsync(request, overrideOptions, cancellationToken).ConfigureAwait(ConfigureAwait); - var result = await RequestWithAsync(msg, cancellationToken, overrideOptions, static (x, ct, _) => x.Content.ReadFromJsonAsync(AnthropicJsonSerialzierContext.Default.Options, ct), null).ConfigureAwait(ConfigureAwait); return result!; } diff --git a/src/Claudia/MessageRequest.cs b/src/Claudia/MessageRequest.cs index ffc2b3e..94b46d5 100644 --- a/src/Claudia/MessageRequest.cs +++ b/src/Claudia/MessageRequest.cs @@ -90,6 +90,11 @@ public override string ToString() { return JsonSerializer.Serialize(this, AnthropicJsonSerialzierContext.Default.Options); } + + // 2024-04-04 beta: https://docs.anthropic.com/claude/docs/tool-use + [JsonPropertyName("tools")] + public Tool[]? Tools { get; set; } + } public record class Message { @@ -145,6 +150,19 @@ public record class Content [JsonPropertyName("source")] public Source? Source { get; set; } + #region tool_use response + + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("input")] + public Dictionary? Input { get; set; } + + #endregion + public static implicit operator Content(string text) => new Content(text); public Content() @@ -214,3 +232,40 @@ public record class Source [JsonPropertyName("data")] public required ReadOnlyMemory Data { get; set; } // Base64 } + +// https://docs.anthropic.com/claude/docs/tool-use +public record class Tool +{ + [JsonPropertyName("name")] + public required string Name { get; set; } + + [JsonPropertyName("description")] + public required string Description { get; set; } + + [JsonPropertyName("input_schema")] + public InputSchema? InputSchema { get; set; } +} + +public record class InputSchema +{ + [JsonPropertyName("type")] + public required string Type { get; set; } + + [JsonPropertyName("properties")] + public Dictionary? Properties { get; set; } + + [JsonPropertyName("required")] + public string[]? Required { get; set; } +} + +public record class ToolProperty +{ + [JsonPropertyName("type")] + public required string Type { get; set; } + + [JsonPropertyName("enum")] + public string[]? Enum { get; set; } + + [JsonPropertyName("description")] + public required string Description { get; set; } +} \ No newline at end of file From 7e68967a976729a3589f73e4f74b130e9aff7d79 Mon Sep 17 00:00:00 2001 From: neuecc Date: Sat, 6 Apr 2024 02:50:39 +0900 Subject: [PATCH 2/2] done tools --- README.md | 102 +++++- sandbox/ConsoleApp1/GeneratedMock.cs | 263 ++------------ sandbox/ConsoleApp1/Program.cs | 336 +++--------------- sandbox/ConsoleApp1/ToolUseSamples.cs | 108 ++++++ src/Claudia.FunctionGenerator/Emitter.cs | 306 +++++++++++++++- src/Claudia.FunctionGenerator/Parser.cs | 4 + src/Claudia/AnthropicJsonSerialzierContext.cs | 3 + src/Claudia/Constants.cs | 4 + src/Claudia/MessageRequest.cs | 49 ++- src/Claudia/MessageResponse.cs | 6 +- 10 files changed, 635 insertions(+), 546 deletions(-) create mode 100644 sandbox/ConsoleApp1/ToolUseSamples.cs diff --git a/README.md b/README.md index 7dc1bb3..e5b294b 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Unofficial [Anthropic Claude API](https://www.anthropic.com/api) client for .NET We have built a C# API similar to the official [Python SDK](https://github.com/anthropics/anthropic-sdk-python) and [TypeScript SDK](https://github.com/anthropics/anthropic-sdk-typescript). It supports netstandard2.1, net6.0, and net8.0. -In addition to the pure client SDK, it also includes a C# Source Generator for performing Function Calling, similar to [anthropic-tools](https://github.com/anthropics/anthropic-tools/). +In addition to the pure client SDK, it also includes a C# Source Generator for performing Function Calling. Installation --- @@ -494,15 +494,107 @@ void Load() Function Calling --- -Claude supports Function Calling. The [Anthropic Cookbook](https://github.com/anthropics/anthropic-cookbook) provides examples of Function Calling. To achieve this, complex XML generation and parsing processing, as well as execution based on the parsed results, are required. +Claude supports Function Calling. -With Claudia, you only need to define static methods annotated with `[ClaudiaFunction]`, and the C# Source Generator automatically generates the necessary code, including parsers and system messages. +## TOol use + +[Tool use(function calling)](https://docs.anthropic.com/claude/docs/tool-use) is new style of function calling. Currently it is beta and need to add `anthropic-beta` flag in header. + +```csharp +var anthropic = new Anthropic(); +anthropic.HttpClient.DefaultRequestHeaders.Add("anthropic-beta", "tools-2024-04-04"); +``` + +With Claudia, you only need to define static methods annotated with `[ClaudiaFunction]`, and the C# Source Generator automatically generates the necessary code. + +```csharp +public static partial class FunctionTools +{ + /// + /// Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London. + /// + /// The time zone to get the current time for, such as UTC, US/Pacific, Europe/London. + [ClaudiaFunction] + public static string TimeOfDay(string timeZone) + { + var time = TimeZoneInfo.ConvertTimeBySystemTimeZoneId(DateTime.UtcNow, timeZone); + return time.ToString("HH:mm:ss"); + } +} +``` + +The `partial class` includes the generated `.AllTools`, `.Tools.[Methods]` and `.InvokeToolAsync(MessageResponse)`. + +Function Calling requires two requests to Claude. The flow is as follows: "Initial request to Claude with available tools in System Prompt -> Execute functions based on the message containing the necessary tools -> Include the results in a new message and send another request to Claude." + +```csharp +var anthropic = new Anthropic(); +anthropic.HttpClient.DefaultRequestHeaders.Add("anthropic-beta", "tools-2024-04-04"); + +var input = new Message { Role = Roles.User, Content = "What time is it in Los Angeles?" }; +var message = await anthropic.Messages.CreateAsync(new() +{ + Model = Models.Claude3Haiku, + MaxTokens = 1024, + Tools = FunctionTools.AllTools, // use generated Tools + Messages = [input], +}); + +// invoke local function +var toolResult = await FunctionTools.InvokeToolAsync(message); + +var response = await anthropic.Messages.CreateAsync(new() +{ + Model = Models.Claude3Haiku, + MaxTokens = 1024, + Tools = [ToolUseSamples.Tools.Calculator], + Messages = [ + input, + new() { Role = Roles.Assistant, Content = message.Content }, + new() { Role = Roles.User, Content = toolResult! } + ], +}); + +// The current time in Los Angeles is 10:45 AM. +Console.WriteLine(response.Content.ToString()); +``` + +The return type of `ClaudiaFunction` can also be specified as `Task` or `ValueTask`. This allows you to execute a variety of tasks, such as HTTP requests or database requests. For example, a function that retrieves the content of a specified webpage can be defined as shown above. ```csharp public static partial class FunctionTools { - // Sample of anthropic-tools https://github.com/anthropics/anthropic-tools#basetool + // ... + /// + /// Retrieves the HTML from the specified URL. + /// + /// The URL to retrieve the HTML from. + [ClaudiaFunction] + static async Task GetHtmlFromWeb(string url) + { + // When using this in a real-world application, passing the raw HTML might consume too many tokens. + // You can parse the HTML locally using libraries like AngleSharp and convert it into a compact text structure to save tokens. + using var client = new HttpClient(); + return await client.GetStringAsync(url); + } +} +``` + +Note that the allowed parameter types are `bool`, `sbyte`, `byte`, `short`, `ushort`, `int`, `uint`, `long`, `ulong`, `decimal`, `float`, `double`, `string`, `DateTime`, `DateTimeOffset`, `Guid`, `TimeSpan` and `Enum`. + +The return value can be of any type, but it will be converted to a string using `ToString()`. If you want to return a custom string, make the return type `string` and format the string within the function. + + +## Legacy style + +The [Anthropic Cookbook](https://github.com/anthropics/anthropic-cookbook) provides examples of Function Calling. To achieve this, complex XML generation and parsing processing, as well as execution based on the parsed results, are required. + +With Claudia, you only need to define static methods annotated with `[ClaudiaFunction]`, and the C# Source Generator automatically generates the necessary code, including parsers and system messages. + +```csharp +public static partial class FunctionTools +{ /// /// Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London. /// @@ -739,7 +831,7 @@ var callResult = await anthropic.Messages.CreateAsync(new() Console.WriteLine(callResult); ``` -Note that the allowed parameter types are `bool`, `sbyte`, `byte`, `short`, `ushort`, `int`, `uint`, `long`, `ulong`, `decimal`, `float`, `double`, `string`, `DateTime`, `DateTimeOffset`, `Guid`, and `TimeSpan`. +Note that the allowed parameter types are `bool`, `sbyte`, `byte`, `short`, `ushort`, `int`, `uint`, `long`, `ulong`, `decimal`, `float`, `double`, `string`, `DateTime`, `DateTimeOffset`, `Guid`, `TimeSpan` and `Enum`. The return value can be of any type, but it will be converted to a string using `ToString()`. If you want to return a custom string, make the return type `string` and format the string within the function. diff --git a/sandbox/ConsoleApp1/GeneratedMock.cs b/sandbox/ConsoleApp1/GeneratedMock.cs index df09413..4ae9b9d 100644 --- a/sandbox/ConsoleApp1/GeneratedMock.cs +++ b/sandbox/ConsoleApp1/GeneratedMock.cs @@ -20,260 +20,73 @@ //using System.Collections.Generic; //using System.Linq; //using System.Text; +//using System.Text.Json; //using System.Threading.Tasks; //using System.Xml.Linq; -//static partial class FunctionTools2 +//static partial class FunctionTools //{ -// public const string SystemPrompt = @$" -//In this environment you have access to a set of tools you can use to answer the user's question. If your solution involves the use of multiple tools, please include multiple s within a single tag. Each step-by-step answer or tag is not required. Only a single tag should be returned at the beginning. -//You may call them like this: -// -// -// $TOOL_NAME -// -// <$PARAMETER_NAME>$PARAMETER_VALUE -// ... -// -// -// -// $TOOL_NAME -// -// <$PARAMETER_NAME>$PARAMETER_VALUE -// ... -// -// -// ... -// -//Here are the tools available: -//{PromptXml.ToolsAll} - -//Again, including multiple tags in the reply is prohibited. -//"; - -// public static class PromptXml +//#pragma warning disable CS1998 +// public static async ValueTask InvokeToolAsync(MessageResponse message) // { -// public const string ToolsAll = @$" -// -//{TimeOfDay} -//{DoPairwiseArithmetic} -//{GetHtmlFromWeb} -// -//"; -// public const string TimeOfDay = @" -// -// TimeOfDay -// Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London. -// -// -// timeZone -// string -// The time zone to get the current time for, such as UTC, US/Pacific, Europe/London. -// -// -// -//"; -// public const string DoPairwiseArithmetic = @" -// -// DoPairwiseArithmetic -// Calculator function for doing basic arithmetic. -// Supports addition, subtraction, multiplication -// -// -// firstOperand -// double -// First operand (before the operator) -// -// -// secondOperand -// double -// Second operand (after the operator) -// -// -// operator -// string -// The operation to perform. Must be either +, -, *, or / -// -// -// -//"; -// public const string GetHtmlFromWeb = @" -// -// GetHtmlFromWeb -// Retrieves the HTML from the specified URL. -// -// -// url -// string -// The URL to retrieve the HTML from. -// -// -// -//"; -// public static class Tools -// { -// public static readonly Tool TimeOfDay = new Tool -// { -// Name = "TimeOfDay", -// Description = "Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London.", -// InputSchema = new InputSchema -// { -// Type = "object", -// Properties = new Dictionary -// { -// { -// "timeZone", new ToolProperty() -// { -// Type = "string", -// Description = "Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London." -// } -// }, +// var result = new List(); -// }, -// Required = new[] { timeZone } -// } +// foreach (var item in message.Content) +// { +// if (item.Type != ContentTypes.ToolUse) continue; -// }; -// public static readonly Tool DoPairwiseArithmetic = new Tool +// switch (item.ToolUseName) // { -// Name = "DoPairwiseArithmetic", -// Description = "Calculator function for doing basic arithmetic. -// Supports addition, -// subtraction, -// multiplication", -// InputSchema = new InputSchema -// { -// Type = "object", -// Properties = new Dictionary -// { -// { -// "firstOperand", new ToolProperty() -// { -// Type = "double", -// Description = "Calculator function for doing basic arithmetic. -// Supports addition, subtraction, multiplication" -// }, -// }, +// case "TimeOfDay": // { -// "secondOperand", new ToolProperty() +// // if (!item.ToolUseInput.TryGetValue("timeZone", out var _0)) _0 = default; +// var _0 = GetValueOrDefault(item, "timeZone", default!); +// string? _callResult; +// bool? _isError = null; +// try // { -// Type = "double", -// Description = "Calculator function for doing basic arithmetic. -// Supports addition, subtraction, multiplication" -// }, -// }, -// { -// "operator", new ToolProperty() +// _callResult = TimeOfDay(_0).ToString(); +// } +// catch (Exception ex) // { -// Type = "string", -// Description = "Calculator function for doing basic arithmetic. -// Supports addition, subtraction, multiplication" -// }, -// }, - -// }, -// Required = new[] { firstOperand, secondOperand, operator } -// } +// _callResult = ex.Message; +// _isError = true; +// } -// }; -// public static readonly Tool GetHtmlFromWeb = new Tool -// { -// Name = "GetHtmlFromWeb", -// Description = "Retrieves the HTML from the specified URL.", -// InputSchema = new InputSchema -// { -// Type = "object", -// Properties = new Dictionary -// { -// { -// "url", new ToolProperty() +// result.Add(new Content // { -// Type = "string", -// Description = "Retrieves the HTML from the specified URL." -// }, -// }, - -// }, -// Required = new[] { url } -// } - -// }; -// } -// } - -//#pragma warning disable CS1998 -// public static async ValueTask InvokeAsync(MessageResponse message) -// { -// var content = message.Content.FirstOrDefault(x => x.Text != null); -// if (content == null) return null; - -// var text = content.Text; -// var tagStart = text.IndexOf(""); -// if (tagStart == -1) return null; +// Type = ContentTypes.ToolResult, +// ToolUseId = item.ToolUseId, +// ToolResultContent = _callResult, +// ToolResultIsError = _isError +// }); -// var functionCalls = text.Substring(tagStart) + ""; -// var xmlResult = XElement.Parse(functionCalls); - -// var sb = new StringBuilder(); -// sb.AppendLine(functionCalls); -// sb.AppendLine(""); - -// foreach (var item in xmlResult.Elements("invoke")) -// { -// var name = (string)item.Element("tool_name")!; -// switch (name) -// { -// case "TimeOfDay": -// { -// var parameters = item.Element("parameters")!; - -// var _0 = (string)parameters.Element("timeZone")!; - -// BuildResult(sb, "TimeOfDay", TimeOfDay(_0)); -// break; -// } -// case "DoPairwiseArithmetic": -// { -// var parameters = item.Element("parameters")!; - -// var _0 = (double)parameters.Element("firstOperand")!; -// var _1 = (double)parameters.Element("secondOperand")!; -// var _2 = (string)parameters.Element("operator")!; - -// BuildResult(sb, "DoPairwiseArithmetic", DoPairwiseArithmetic(_0, _1, _2)); // break; // } -// case "GetHtmlFromWeb": -// { -// var parameters = item.Element("parameters")!; - -// var _0 = (string)parameters.Element("url")!; - -// BuildResult(sb, "GetHtmlFromWeb", await GetHtmlFromWeb(_0).ConfigureAwait(false)); -// break; -// } - // default: // break; // } // } -// sb.Append(""); // final assistant content cannot end with trailing whitespace +// return result.ToArray(); -// return sb.ToString(); - -// static void BuildResult(StringBuilder sb, string toolName, T result) +// static T GetValueOrDefault(Content content, string name, T defaultValue) // { -// sb.AppendLine(@$" -// {toolName} -// {result} -// "); +// if (content.ToolUseInput.TryGetValue(name, out var stringValue)) +// { +// return System.Text.Json.JsonSerializer.Deserialize(stringValue)!; +// } +// else +// { +// return defaultValue; +// } // } // } + //#pragma warning restore CS1998 //} diff --git a/sandbox/ConsoleApp1/Program.cs b/sandbox/ConsoleApp1/Program.cs index 95475ee..a42c24e 100644 --- a/sandbox/ConsoleApp1/Program.cs +++ b/sandbox/ConsoleApp1/Program.cs @@ -1,332 +1,52 @@ using Claudia; +using ConsoleApp1; using System; +using System.Collections.Specialized; +using System.Data; using System.Linq; using System.Net.Http; using System.Net.NetworkInformation; +using System.Runtime.CompilerServices; using System.Text; +using System.Text.Json; using System.Threading.Tasks; using System.Xml; using System.Xml.Linq; -// function calling -// https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb - -var anthropic = new Anthropic(); - -//var userInput = """ -//Translate and summarize this Japanese site to English. -//https://scrapbox.io/hadashiA/ZLogger_v2%E3%81%AE%E6%96%B0%E3%82%B9%E3%83%88%E3%83%A9%E3%82%AF%E3%83%81%E3%83%A3%E3%83%BC%E3%83%89%E3%83%AD%E3%82%AE%E3%83%B3%E3%82%B0%E4%BD%93%E9%A8%93 -//"""; - -//var message = await anthropic.Messages.CreateAsync(new() -//{ -// Model = Models.Claude3Haiku, -// MaxTokens = 1024, -// System = SystemPrompts.Claude3 + "\n" + FunctionTools.SystemPrompt, -// StopSequences = [StopSequnces.CloseFunctionCalls], -// Messages = [ -// new() { Role = Roles.User, Content = userInput }, -// ], -//}); - -//var partialAssistantMessage = await FunctionTools.InvokeAsync(message); - -//var callResult = await anthropic.Messages.CreateAsync(new() -//{ -// Model = Models.Claude3Haiku, -// MaxTokens = 1024, -// System = SystemPrompts.Claude3 + "\n" + FunctionTools.SystemPrompt + "\n" + "Return message from assistant should be humanreadable so don't use xml tags, and json.", -// Messages = [ -// new() { Role = Roles.User, Content = userInput }, -// new() { Role = Roles.Assistant, Content = partialAssistantMessage! }, -// ], -//}); - -//Console.WriteLine(callResult); - +var anthropic = new Anthropic(); anthropic.HttpClient.DefaultRequestHeaders.Add("anthropic-beta", "tools-2024-04-04"); -var input = new Message -{ - Role = Roles.User, - Content = "What time is it in Tokyo?", -}; - +var input = new Message { Role = Roles.User, Content = "What time is it in Los Angeles?" }; var message = await anthropic.Messages.CreateAsync(new() { Model = Models.Claude3Haiku, MaxTokens = 1024, - - //System = FunctionTools.SystemPrompt, // set generated prompt - //StopSequences = [StopSequnces.CloseFunctionCalls], // set as stop sequence - - Tools = [FunctionTools.PromptXml.Tools.TimeOfDay], + Tools = FunctionTools.AllTools, // use generated Tools Messages = [input], }); +var toolResult = await FunctionTools.InvokeToolAsync(message); - -var partialAssistantMessage = await FunctionTools.InvokeAsync(message); - -var callResult = await anthropic.Messages.CreateAsync(new() +var response = await anthropic.Messages.CreateAsync(new() { Model = Models.Claude3Haiku, MaxTokens = 1024, - System = FunctionTools.SystemPrompt, + Tools = [ToolUseSamples.Tools.Calculator], Messages = [ input, - new() { Role = Roles.Assistant, Content = partialAssistantMessage! } // set as Assistant + new() { Role = Roles.Assistant, Content = message.Content }, + new() { Role = Roles.User, Content = toolResult! } ], }); -Console.WriteLine(callResult); - - -//var systemPrompt = """ -//In this environment you have access to a set of tools you can use to answer the user's question. - -//You may call them like this: -// -// -// $TOOL_NAME -// -// <$PARAMETER_NAME>$PARAMETER_VALUE -// ... -// -// -// - -//Here are the tools available: -// -// -// calculator -// -// Calculator function for doing basic arithmetic. -// Supports addition, subtraction, multiplication -// -// -// -// first_operand -// int -// First operand (before the operator) -// -// -// second_operand -// int -// Second operand (after the operator) -// -// -// operator -// str -// The operation to perform. Must be either +, -, *, or / -// -// -// -// -//"""; - - -//var message = await anthropic.Messages.CreateAsync(new() -//{ -// Model = Models.Claude3Opus, -// MaxTokens = 1024, -// System = systemPrompt, -// StopSequences = ["\n\nHuman:", "\n\nAssistant", ""], -// Messages = [new() { Role = "user", Content = "Multiply 1,984,135 by 9,343,116" }], -//}); - -//// Result XML:: - - -//var text = message.Content[0].Text!; -//var tagStart = text.IndexOf(""); -//var xmlResult = XElement.Parse(text.Substring(tagStart) + message.StopSequence); -//var parameters = xmlResult.Descendants("parameters").Elements(); - -//var first = (double)parameters.First(x => x.Name == "first_operand"); -//var second = (double)parameters.First(x => x.Name == "second_operand"); -//var operation = (string)parameters.First(x => x.Name == "operator"); - -//var result = DoPairwiseArithmetic(first, second, operation); - -//Console.WriteLine(result); - - - - - - - - - - - - - - - - -//var imageBytes = File.ReadAllBytes(@"dish.jpg"); - -//var anthropic = new Anthropic(); - -//var message = await anthropic.Messages.CreateAsync(new() -//{ -// Model = "claude-3-opus-20240229", -// MaxTokens = 1024, -// Messages = [new() -// { -// Role = "user", -// Content = [ -// new() -// { -// Type = "image", -// Source = new() -// { -// Type = "base64", -// MediaType = "image/jpeg", -// Data = imageBytes -// } -// }, -// new() -// { -// Type = "text", -// Text = "Describe this image." -// } -// ] -// }], -//}); -//Console.WriteLine(message); - -//var simple = await anthropic.Messages.CreateAsync(new() -//{ -// Model = Models.Claude3Opus, -// MaxTokens = 1024, -// Messages = [new() -// { -// Role = Roles.User, -// Content = [ -// new(imageBytes, "image/jpeg"), -// "Describe this image." -// ] -// }], -//}); -//Console.WriteLine(simple); - -//// convert to array. -//var array = await stream.ToObservable().ToArrayAsync(); - -//// filterling and execute. -//await stream.ToObservable() -// .OfType() -// .Where(x => x.Delta.Text != null) -// .ForEachAsync(x => -// { -// Console.WriteLine(x.Delta.Text); -// }); - -//// branching query -//var branch = stream.ToObservable().Publish(); - -//var messageStartTask = branch.OfType().FirstAsync(); -//var messageDeltaTask = branch.OfType().FirstAsync(); - -//branch.Connect(); // start consume stream - -//Console.WriteLine((await messageStartTask)); -//Console.WriteLine((await messageDeltaTask)); - - - - - - - -//Console.WriteLine("---"); - -//Console.WriteLine(sb.ToString()); - -// Counting Tokens -//var anthropic = new Anthropic(); - -//var msg = await anthropic.Messages.CreateAsync(new() -//{ -// Model = Models.Claude3Opus, -// MaxTokens = 1024, -// Messages = [new() { Role = "user", Content = "Hello, Claude." }] -//}); - -//// Usage { InputTokens = 11, OutputTokens = 18 } -//Console.WriteLine(msg.Usage); - - - -//Messages = [new() { Role = "user", Content = "Hello, Claude. Responses, please break line after each word." }] - - -//// error - -//try -//{ -// var msg = await anthropic.Messages.CreateAsync(new() -// { -// Model = Models.Claude3Opus, -// MaxTokens = 1024, -// Messages = [new() { Role = "user", Content = "Hello, Claude" }] -// }); -//} -//catch (ClaudiaException ex) -//{ -// Console.WriteLine((int)ex.Status); // 400(ErrorCode.InvalidRequestError) -// Console.WriteLine(ex.Name); // invalid_request_error -// Console.WriteLine(ex.Message); // Field required. Input:... -//} - -// retry - -// Configure the default for all requests: -//var anthropic = new Anthropic -//{ -// MaxRetries = 0, // default is 2 -//}; - -//// Or, configure per-request: -//await anthropic.Messages.CreateAsync(new() -//{ -// MaxTokens = 1024, -// Messages = [new() { Role = "user", Content = "Hello, Claude" }], -// Model = "claude-3-opus-20240229" -//}, new() -//{ -// MaxRetries = 5 -//}); - -// timeout - -//// Configure the default for all requests: -//var anthropic = new Anthropic -//{ -// Timeout = TimeSpan.FromSeconds(20) // 20 seconds (default is 10 minutes) -//}; - -//// Override per-request: -//await anthropic.Messages.CreateAsync(new() -//{ -// MaxTokens = 1024, -// Messages = [new() { Role = "user", Content = "Hello, Claude" }], -// Model = "claude-3-opus-20240229" -//}, new() -//{ -// Timeout = TimeSpan.FromSeconds(5) -//}); +// The current time in Los Angeles is 10:45 AM. +Console.WriteLine(response.Content.ToString()); public static partial class FunctionTools { - // Sample of anthropic-tools https://github.com/anthropics/anthropic-tools#basetool - /// /// Retrieve the current time of day in Hour-Minute-Second format for a specified time zone. Time zones should be written in standard formats such as UTC, US/Pacific, Europe/London. /// @@ -370,7 +90,33 @@ static async Task GetHtmlFromWeb(string url) using var client = new HttpClient(); return await client.GetStringAsync(url); } + + /// + /// Sum of two parameters. + /// + /// x. + /// y. + [ClaudiaFunction] + static int Sum(int x, int y = 100) + { + return x + y; + } + + /// + /// Choose which fruits + /// + /// Fruits basket. + /// Fruits basket2. + [ClaudiaFunction] + static string ChooseFruit(Fruits basket, Fruits more = Fruits.Grape) + { + return basket.ToString(); + } } +public enum Fruits +{ + Orange, Grape +} diff --git a/sandbox/ConsoleApp1/ToolUseSamples.cs b/sandbox/ConsoleApp1/ToolUseSamples.cs new file mode 100644 index 0000000..a97de88 --- /dev/null +++ b/sandbox/ConsoleApp1/ToolUseSamples.cs @@ -0,0 +1,108 @@ +using Claudia; +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Threading.Tasks; + +namespace ConsoleApp1; + +public static partial class ToolUseSamples +{ + // https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/calculator_tool.ipynb + public static async Task TryCalculateAsync() + { + var anthropic = new Anthropic(); + anthropic.HttpClient.DefaultRequestHeaders.Add("anthropic-beta", "tools-2024-04-04"); + + await ChatWithClaude("What is the result of 1,984,135 * 9,343,116?"); + await ChatWithClaude("Calculate (12851 - 593) * 301 + 76"); + await ChatWithClaude("What is 15910385 divided by 193053?"); + + async Task ChatWithClaude(string userMessage) + { + Console.WriteLine("=================================================="); + Console.WriteLine($"User Message: {userMessage}"); + Console.WriteLine("=================================================="); + + var message = await anthropic.Messages.CreateAsync(new() + { + Model = Models.Claude3Haiku, + MaxTokens = 1024, + Tools = [ToolUseSamples.Tools.Calculator], + Messages = [new() { Role = Roles.User, Content = userMessage }] + }); + + Console.WriteLine("Initial Response:"); + Console.WriteLine($"Stop Reason: {message.StopReason}"); + Console.WriteLine($"Content: {message.Content}"); + + var toolResult = await ToolUseSamples.InvokeToolAsync(message); + + var response = await anthropic.Messages.CreateAsync(new() + { + Model = Models.Claude3Haiku, + MaxTokens = 1024, + Tools = [ToolUseSamples.Tools.Calculator], + Messages = [ + new() { Role = Roles.User, Content = userMessage }, + new() { Role = Roles.Assistant, Content = message.Content }, + new() { Role = Roles.User, Content = toolResult! } + ], + }); + + Console.WriteLine(response.Content.ToString()); + } + } + + + /// + /// A simple calculator that performs basic arithmetic operations. + /// + /// The mathematical expression to evaluate (e.g., '2 + 3 * 4'). + [ClaudiaFunction] + static double Calculator(string expression) + { + // cheap calculator, only calc 32bit. + var dt = new DataTable(); + return Convert.ToDouble(dt.Compute(expression, "")); + } + + // https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/customer_service_agent.ipynb + + + + ///// + ///// Retrieves customer information based on their customer ID. Returns the customer's name, email, and phone number. + ///// + ///// The unique identifier for the customer. + //static string GetCustomerInfo(string customerId) + //{ + // // Simulated customer data + // var customers = new Dictionary { + // { "C1", new Customer(name: "John Doe", email: "john@example.com", phone: "123-456-7890") }, + // { "C2", new Customer(name: "Jane Smith", email: "jane@example.com", phone: "987-654-3210") }, + // }; + + // return customers.TryGetValue(customerId, out var customer) ? customer.ToString() : "Customer not found"; + //} + + + ///// + ///// Retrieves the details of a specific order based on the order ID. Returns the order ID, product name, quantity, price, and order status. + ///// + ///// Retrieves the details of a specific order based on the order ID. Returns the order ID, product name, quantity, price, and order status. + //static string GetOrderDetails(string orderId) + //{ + //} + + ///// + ///// The unique identifier for the order to be cancelled. + ///// + ///// Cancels an order based on the provided order ID. Returns a confirmation message if the cancellation is successful. + //static string CancelOrder(string orderId) + //{ + //} + + //record class Customer(string name, string email, string phone); +} diff --git a/src/Claudia.FunctionGenerator/Emitter.cs b/src/Claudia.FunctionGenerator/Emitter.cs index 2986475..6e66471 100644 --- a/src/Claudia.FunctionGenerator/Emitter.cs +++ b/src/Claudia.FunctionGenerator/Emitter.cs @@ -33,7 +33,13 @@ internal void Emit() sb.AppendLine("{"); sb.AppendLine(""); - EmitCore(parseResult); + // new, beta: https://docs.anthropic.com/claude/docs/tool-use + EmitToolCallingCore(parseResult); + + sb.AppendLine(); + + // legacy + EmitXmlCallingCore(parseResult); sb.AppendLine("}"); @@ -41,7 +47,13 @@ internal void Emit() } } - void EmitCore(ParseResult parseResult) + void EmitToolCallingCore(ParseResult parseResult) + { + EmitTools(parseResult); + EmitToolInvoke(parseResult); + } + + void EmitXmlCallingCore(ParseResult parseResult) { var toolsAll = string.Join(Environment.NewLine, parseResult.Methods.Select(x => "{" + x.Name + "}")); @@ -89,14 +101,216 @@ public static class PromptXml EmitToolDescription(method); } - EmitTools(parseResult); - sb.AppendLine(" }"); // close PromptXml sb.AppendLine(); EmitInvoke(parseResult); } + void EmitToolInvoke(ParseResult parseResult) + { + var methodInvoke = BuildToolLocalMethodInvoke(parseResult.Methods); + + var code = $$"""" +#pragma warning disable CS1998 + public static async ValueTask InvokeToolAsync(MessageResponse message) + { + var result = new Contents(); + + foreach (var item in message.Content) + { + if (item.Type != ContentTypes.ToolUse) continue; + + switch (item.ToolUseName) + { +{{methodInvoke}} + default: + break; + } + } + + return result; + + static T GetValueOrDefault(Content content, string name, T defaultValue) + { + if (content.ToolUseInput!.TryGetValue(name, out var stringValue)) + { + if (typeof(T) == typeof(Boolean)) + { + var v = bool.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(SByte)) + { + var v = SByte.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Byte)) + { + var v = Byte.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Int16)) + { + var v = Int16.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(UInt16)) + { + var v = UInt16.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Int32)) + { + var v = Int32.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(UInt32)) + { + var v = UInt32.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Int64)) + { + var v = Int64.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(UInt64)) + { + var v = UInt64.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Decimal)) + { + var v = Decimal.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Single)) + { + var v = Single.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Double)) + { + var v = Double.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(String)) + { + return (T)(object)stringValue; + } + else if (typeof(T) == typeof(DateTime)) + { + var v = DateTime.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(DateTimeOffset)) + { + var v = DateTimeOffset.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(Guid)) + { + var v = Guid.Parse(stringValue); + return Unsafe.As(ref v); + } + else if (typeof(T) == typeof(TimeSpan)) + { + var v = TimeSpan.Parse(stringValue); + return Unsafe.As(ref v); + } + else + { + if (typeof(T).IsEnum) + { + return (T)Enum.Parse(typeof(T), stringValue); + } + throw new NotSupportedException(); + } + } + else + { + return defaultValue; + } + } + } +#pragma warning restore CS1998 +""""; + + sb.AppendLine(code); + } + + string BuildToolLocalMethodInvoke(Method[] methods) + { + var sb = new StringBuilder(); + + foreach (var method in methods) + { + var returnType = method.Symbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var isTask = (returnType.StartsWith("global::System.Threading.Tasks.Task") || returnType.StartsWith("global::System.Threading.Tasks.ValueTask")); + + var i = 0; + var parameterParseString = new StringBuilder(); + var parameterNames = new StringBuilder(); + foreach (var p in method.Symbol.Parameters) + { + var defaultValue = "default!"; + if (p.HasExplicitDefaultValue && p.ExplicitDefaultValue != null) + { + if (p.Type.TypeKind == TypeKind.Enum) + { + defaultValue = $"({p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}){p.ExplicitDefaultValue}"; + } + else + { + defaultValue = p.ExplicitDefaultValue.ToString(); + } + } + var parameterType = p.Type.ToDisplayString(); + parameterNames.Append((i != 0) ? $", _{i}" : $"_{i}"); + parameterParseString.AppendLine($" var _{i} = GetValueOrDefault<{parameterType}>(item, \"{p.Name}\", {defaultValue});"); + i++; + } + + var methodCall = $"{method.Name}({parameterNames})"; + if (isTask) + { + methodCall = $"(await {methodCall})"; + } + + sb.AppendLine($$""" + case "{{method.Name}}": + { +{{parameterParseString}} + string? _callResult; + bool? _isError = null; + try + { + _callResult = {{methodCall}}.ToString(); + } + catch (Exception ex) + { + _callResult = ex.Message; + _isError = true; + } + + result.Add(new Content + { + Type = ContentTypes.ToolResult, + ToolResultId = item.ToolUseId, + ToolResultContent = _callResult, + ToolResultIsError = _isError + }); + + break; + } +"""); + } + + + return sb.ToString(); + } + void EmitToolDescription(Method method) { var docComment = method.Syntax.GetDocumentationCommentTriviaSyntax()!; @@ -147,7 +361,14 @@ void EmitInvoke(ParseResult parseResult) foreach (var p in method.Symbol.Parameters) { parameterNames.Append((i != 0) ? $", _{i}" : $"_{i}"); - parameterParseString.AppendLine($" var _{i++} = ({p.Type.ToDisplayString()})parameters.Element(\"{p.Name}\")!;"); + if (p.Type.TypeKind == TypeKind.Enum) + { + parameterParseString.AppendLine($" var _{i++} = Enum.Parse<{p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>((string)parameters.Element(\"{p.Name}\")!);"); + } + else + { + parameterParseString.AppendLine($" var _{i++} = ({p.Type.ToDisplayString()})parameters.Element(\"{p.Name}\")!;"); + } } parameterParseString.AppendLine(); if (isTask) @@ -211,9 +432,10 @@ static void BuildResult(StringBuilder sb, string toolName, T result) void EmitTools(ParseResult parseResult) { - // TODO: Add All - // public static readonly Tool[] All = @$" + var allTools = string.Join(", ", parseResult.Methods.Select(x => $"Tools.{x.Name}")); sb.AppendLine($$""" + public static readonly Tool[] AllTools = new[] { {{allTools}} }; + public static class Tools { """); @@ -222,7 +444,7 @@ public static class Tools foreach (var method in parseResult.Methods) { var docComment = method.Syntax.GetDocumentationCommentTriviaSyntax()!; - var description = docComment.GetSummary().Replace("\"", "'").Replace("\r\n", " ").Replace("\n", " "); + var description = RemoveStringNewLine(docComment.GetSummary().Replace("\"", "'")); // property var inputSchema = new StringBuilder(); @@ -232,24 +454,58 @@ public static class Tools var paramRequired = new List(); foreach (var p in docComment.GetParams()) { - var paramDescription = p.Description.Replace("\"", "'").Replace("\r\n", " ").Replace("\n", " "); + var paramDescription = RemoveStringNewLine(p.Description.Replace("\"", "'")); // type retrieve from method symbol var name = p.Name; - var paramType = method.Symbol.Parameters.First(x => x.Name == name).Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + var pSymbol = method.Symbol.Parameters.First(x => x.Name == name); + var paramType = "string"; + + var enumMembers = "null"; + if (pSymbol.Type.TypeKind == TypeKind.Enum) + { + enumMembers = string.Join(", ", pSymbol.Type.GetMembers().Where(x => x.Name != ".ctor").Select(x => "\"" + x.Name + "\"")); + enumMembers = $"new [] {{ {enumMembers} }}"; + } + + // mapping jsonschema paramtype https://json-schema.org/understanding-json-schema/reference/type + switch (pSymbol.Type.SpecialType) + { + case SpecialType.System_Boolean: + paramType = "boolean"; + break; + case SpecialType.System_SByte: + case SpecialType.System_Byte: + case SpecialType.System_Int16: + case SpecialType.System_UInt16: + case SpecialType.System_Int32: + case SpecialType.System_UInt32: + case SpecialType.System_Int64: + case SpecialType.System_UInt64: + case SpecialType.System_Decimal: + case SpecialType.System_Single: + case SpecialType.System_Double: + paramType = "number"; + break; + default: + break; + } - // TODO: support enum propBuilder.AppendLine($$""" { "{{name}}", new ToolProperty() { Type = "{{paramType}}", - Description = "{{description}}" + Description = "{{paramDescription}}", + Enum = {{enumMembers}} } }, """); - // TODO: support optional parameter - paramRequired.Add("\"" + name + "\""); + + if (!pSymbol.HasExplicitDefaultValue) + { + paramRequired.Add("\"" + name + "\""); + } } var required = string.Join(", ", paramRequired); if (required.Length != 0) @@ -312,6 +568,7 @@ static void AddSource(SourceProductionContext context, ISymbol targetSymbol, str #pragma warning disable CA1050 using Claudia; +using System.Runtime.CompilerServices; using System; using System.Linq; using System.Text; @@ -336,4 +593,25 @@ static void AddSource(SourceProductionContext context, ISymbol targetSymbol, str var sourceCode = sb.ToString(); context.AddSource($"{fullType}{fileExtension}", sourceCode); } + + static string RemoveStringNewLine(string str) + { + var sb = new StringBuilder(); + var first = true; + using var sr = new StringReader(str); + string line = default!; + while ((line = sr.ReadLine()) != null) + { + if (first) + { + first = false; + } + else + { + sb.Append(" "); + } + sb.Append(line.Trim()); + } + return sb.ToString(); + } } \ No newline at end of file diff --git a/src/Claudia.FunctionGenerator/Parser.cs b/src/Claudia.FunctionGenerator/Parser.cs index 64959ce..9cb3503 100644 --- a/src/Claudia.FunctionGenerator/Parser.cs +++ b/src/Claudia.FunctionGenerator/Parser.cs @@ -130,6 +130,10 @@ internal ParseResult[] Parse() { break; } + if (p.Type.TypeKind == TypeKind.Enum) + { + break; + } hasError = true; context.ReportDiagnostic(Diagnostic.Create(DiagnosticDescriptors.ParameterTypeIsNotSupported, method.Locations[0], method.Name, p.Name, p.Type.Name)); diff --git a/src/Claudia/AnthropicJsonSerialzierContext.cs b/src/Claudia/AnthropicJsonSerialzierContext.cs index c408f00..119e1ee 100644 --- a/src/Claudia/AnthropicJsonSerialzierContext.cs +++ b/src/Claudia/AnthropicJsonSerialzierContext.cs @@ -27,6 +27,9 @@ namespace Claudia; [JsonSerializable(typeof(ContentBlockStop))] [JsonSerializable(typeof(MessageStartBody))] [JsonSerializable(typeof(MessageDeltaBody))] +[JsonSerializable(typeof(Tool))] +[JsonSerializable(typeof(InputSchema))] +[JsonSerializable(typeof(ToolProperty))] public partial class AnthropicJsonSerialzierContext : JsonSerializerContext { } \ No newline at end of file diff --git a/src/Claudia/Constants.cs b/src/Claudia/Constants.cs index 78f75ba..b39d32b 100644 --- a/src/Claudia/Constants.cs +++ b/src/Claudia/Constants.cs @@ -46,6 +46,8 @@ public static class ContentTypes { public const string Text = "text"; public const string Image = "image"; + public const string ToolUse = "tool_use"; + public const string ToolResult = "tool_result"; } public static class StopSequnces @@ -62,6 +64,8 @@ public static class StopReasons public const string MaxTokens = "max_tokens"; /// one of your provided custom stop_sequences was generated public const string StopSequence = "stop_sequence"; + + public const string ToolUse = "tool_use"; } public static class SystemPrompts diff --git a/src/Claudia/MessageRequest.cs b/src/Claudia/MessageRequest.cs index 94b46d5..3d20f02 100644 --- a/src/Claudia/MessageRequest.cs +++ b/src/Claudia/MessageRequest.cs @@ -2,6 +2,7 @@ using System.Text.Json; using System.Collections.ObjectModel; using System.Diagnostics.CodeAnalysis; +using System.Text; namespace Claudia; @@ -132,7 +133,7 @@ public override string ToString() } else { - return base.ToString() ?? ""; + return "[" + string.Join(", ", this.Select(x => x.ToString())) + "]"; } } } @@ -152,14 +153,30 @@ public record class Content #region tool_use response + /// A unique identifier for this particular tool use block. This will be used to match up the tool results later. [JsonPropertyName("id")] - public string? Id { get; set; } + public string? ToolUseId { get; set; } + /// The name of the tool being used. [JsonPropertyName("name")] - public string? Name { get; set; } + public string? ToolUseName { get; set; } + /// An object containing the input being passed to the tool, conforming to the tool's input_schema. [JsonPropertyName("input")] - public Dictionary? Input { get; set; } + public Dictionary? ToolUseInput { get; set; } + + /// The result of the tool. + [JsonPropertyName("content")] + public Contents? ToolResultContent { get; set; } + + /// The id of the tool use request this is a result for. + [JsonPropertyName("tool_use_id")] + public string? ToolResultId { get; set; } + + /// Set to true if the tool execution resulted in an error. + [JsonPropertyName("is_error")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public bool? ToolResultIsError { get; set; } #endregion @@ -198,6 +215,30 @@ public override string ToString() { return $"{Source.Type}(Source.Data.Length)"; } + else if (ToolUseId != null) + { + var sb = new StringBuilder(); + sb.Append(ToolUseName); + sb.Append("("); + if (ToolUseInput != null) + { + var first = true; + foreach (var item in ToolUseInput) + { + if (first) + { + first = true; + } + else + { + sb.Append(", "); + } + sb.Append(item.Key + ": " + item.Value); + } + } + sb.Append(")"); + return sb.ToString(); + } else { return base.ToString() ?? ""; diff --git a/src/Claudia/MessageResponse.cs b/src/Claudia/MessageResponse.cs index 393c55e..5bfa186 100644 --- a/src/Claudia/MessageResponse.cs +++ b/src/Claudia/MessageResponse.cs @@ -24,10 +24,10 @@ public class MessageResponse /// /// Content generated by the model. - /// This is an array of content blocks, each of which has a type that determines its shape. Currently, the only type in responses is "text". + /// This is an array of content blocks, each of which has a type that determines its shape. /// [JsonPropertyName("content")] - public required Content[] Content { get; set; } + public required Contents Content { get; set; } /// /// The model that handled the request. @@ -55,7 +55,7 @@ public class MessageResponse public override string ToString() { - if (Content.Length == 1 && Content[0].Text != null) + if (Content.Count == 1 && Content[0].Text != null) { return Content[0].Text!; }