diff --git a/src/helpers/FixOpenApiSpec/Program.cs b/src/helpers/FixOpenApiSpec/Program.cs index 808b801..d16c969 100644 --- a/src/helpers/FixOpenApiSpec/Program.cs +++ b/src/helpers/FixOpenApiSpec/Program.cs @@ -40,11 +40,6 @@ }, }, }; -openApiDocument.Components.Schemas["ToolCallFunction"]!.Properties["arguments"] = new OpenApiSchema -{ - Type = "string", -}; - text = openApiDocument.SerializeAsYaml(OpenApiSpecVersion.OpenApi3_0); _ = new OpenApiStringReader().Read(text, out diagnostics); diff --git a/src/libs/Ollama.Generators/OpenAiFunctionsGenerator.cs b/src/libs/Ollama.Generators/OllamaToolsGenerator.cs similarity index 88% rename from src/libs/Ollama.Generators/OpenAiFunctionsGenerator.cs rename to src/libs/Ollama.Generators/OllamaToolsGenerator.cs index 4fd19d0..f1d1dd8 100755 --- a/src/libs/Ollama.Generators/OpenAiFunctionsGenerator.cs +++ b/src/libs/Ollama.Generators/OllamaToolsGenerator.cs @@ -7,11 +7,11 @@ namespace Ollama.Generators; [Generator] -public class OllamaFunctionsGenerator : IIncrementalGenerator +public class OllamaToolsGenerator : IIncrementalGenerator { #region Constants - public const string Name = nameof(OllamaFunctionsGenerator); + public const string Name = nameof(OllamaToolsGenerator); public const string Id = "OAFG"; #endregion @@ -22,7 +22,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var attributes = context.SyntaxProvider - .ForAttributeWithMetadataName("Ollama.OllamaFunctionsAttribute") + .ForAttributeWithMetadataName("Ollama.OllamaToolsAttribute") .SelectManyAllAttributesOfCurrentInterfaceSyntax() .SelectAndReportExceptions(PrepareData, context, Id); diff --git a/src/libs/Ollama/Attributes/OllamaFunctionsAttribute.cs b/src/libs/Ollama/Attributes/OllamaToolsAttribute.cs similarity index 78% rename from src/libs/Ollama/Attributes/OllamaFunctionsAttribute.cs rename to src/libs/Ollama/Attributes/OllamaToolsAttribute.cs index 345441c..7675354 100644 --- a/src/libs/Ollama/Attributes/OllamaFunctionsAttribute.cs +++ b/src/libs/Ollama/Attributes/OllamaToolsAttribute.cs @@ -6,4 +6,4 @@ namespace Ollama; /// [AttributeUsage(AttributeTargets.Interface)] [System.Diagnostics.Conditional("OLLAMA_FUNCTIONS_ATTRIBUTES")] -public sealed class OllamaFunctionsAttribute : Attribute; \ No newline at end of file +public sealed class OllamaToolsAttribute : Attribute; \ No newline at end of file diff --git a/src/libs/Ollama/Extensions/StringExtensions.cs b/src/libs/Ollama/Extensions/StringExtensions.cs index a8f6179..f93d3a3 100644 --- a/src/libs/Ollama/Extensions/StringExtensions.cs +++ b/src/libs/Ollama/Extensions/StringExtensions.cs @@ -61,4 +61,14 @@ public static Message AsToolMessage(this string content) Content = content, }; } + + /// + /// + /// + /// + /// + public static string AsJson(this ToolCallFunctionArgs args) + { + return JsonSerializer.Serialize(args, SourceGenerationContext.Default.ToolCallFunctionArgs); + } } \ No newline at end of file diff --git a/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs b/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs index 925a9a4..c026d6b 100644 --- a/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs +++ b/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs @@ -93,114 +93,118 @@ public sealed partial class JsonSerializerContextTypes /// /// /// - public global::System.Collections.Generic.IList? Type20 { get; set; } + public global::Ollama.ToolCallFunctionArgs? Type20 { get; set; } /// /// /// - public global::System.AnyOf? Type21 { get; set; } + public global::System.Collections.Generic.IList? Type21 { get; set; } /// /// /// - public global::System.Collections.Generic.IList? Type22 { get; set; } + public global::System.AnyOf? Type22 { get; set; } /// /// /// - public global::Ollama.ModelDetails? Type23 { get; set; } + public global::System.Collections.Generic.IList? Type23 { get; set; } /// /// /// - public global::System.Collections.Generic.IList? Type24 { get; set; } + public global::Ollama.ModelDetails? Type24 { get; set; } /// /// /// - public global::Ollama.ModelInformation? Type25 { get; set; } + public global::System.Collections.Generic.IList? Type25 { get; set; } /// /// /// - public global::System.AnyOf? Type26 { get; set; } + public global::Ollama.ModelInformation? Type26 { get; set; } /// /// /// - public global::System.AnyOf? Type27 { get; set; } + public global::System.AnyOf? Type27 { get; set; } /// /// /// - public global::Ollama.VersionResponse? Type28 { get; set; } + public global::System.AnyOf? Type28 { get; set; } /// /// /// - public global::Ollama.GenerateCompletionRequest? Type29 { get; set; } + public global::Ollama.VersionResponse? Type29 { get; set; } /// /// /// - public global::Ollama.GenerateCompletionResponse? Type30 { get; set; } + public global::Ollama.GenerateCompletionRequest? Type30 { get; set; } /// /// /// - public global::Ollama.GenerateChatCompletionRequest? Type31 { get; set; } + public global::Ollama.GenerateCompletionResponse? Type31 { get; set; } /// /// /// - public global::Ollama.GenerateChatCompletionResponse? Type32 { get; set; } + public global::Ollama.GenerateChatCompletionRequest? Type32 { get; set; } /// /// /// - public global::Ollama.GenerateEmbeddingRequest? Type33 { get; set; } + public global::Ollama.GenerateChatCompletionResponse? Type33 { get; set; } /// /// /// - public global::Ollama.GenerateEmbeddingResponse? Type34 { get; set; } + public global::Ollama.GenerateEmbeddingRequest? Type34 { get; set; } /// /// /// - public global::Ollama.CreateModelRequest? Type35 { get; set; } + public global::Ollama.GenerateEmbeddingResponse? Type35 { get; set; } /// /// /// - public global::Ollama.CreateModelResponse? Type36 { get; set; } + public global::Ollama.CreateModelRequest? Type36 { get; set; } /// /// /// - public global::Ollama.ModelsResponse? Type37 { get; set; } + public global::Ollama.CreateModelResponse? Type37 { get; set; } /// /// /// - public global::Ollama.ProcessResponse? Type38 { get; set; } + public global::Ollama.ModelsResponse? Type38 { get; set; } /// /// /// - public global::Ollama.ModelInfoRequest? Type39 { get; set; } + public global::Ollama.ProcessResponse? Type39 { get; set; } /// /// /// - public global::Ollama.ModelInfo? Type40 { get; set; } + public global::Ollama.ModelInfoRequest? Type40 { get; set; } /// /// /// - public global::Ollama.CopyModelRequest? Type41 { get; set; } + public global::Ollama.ModelInfo? Type41 { get; set; } /// /// /// - public global::Ollama.DeleteModelRequest? Type42 { get; set; } + public global::Ollama.CopyModelRequest? Type42 { get; set; } /// /// /// - public global::Ollama.PullModelRequest? Type43 { get; set; } + public global::Ollama.DeleteModelRequest? Type43 { get; set; } /// /// /// - public global::Ollama.PullModelResponse? Type44 { get; set; } + public global::Ollama.PullModelRequest? Type44 { get; set; } /// /// /// - public global::Ollama.PushModelRequest? Type45 { get; set; } + public global::Ollama.PullModelResponse? Type45 { get; set; } /// /// /// - public global::Ollama.PushModelResponse? Type46 { get; set; } + public global::Ollama.PushModelRequest? Type46 { get; set; } /// /// /// - public byte[]? Type47 { get; set; } + public global::Ollama.PushModelResponse? Type47 { get; set; } + /// + /// + /// + public byte[]? Type48 { get; set; } } } \ No newline at end of file diff --git a/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs b/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs index 901b432..7be9a20 100644 --- a/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs +++ b/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs @@ -16,11 +16,11 @@ public sealed partial class ToolCallFunction public required string Name { get; set; } /// - /// + /// The arguments to pass to the function. /// [global::System.Text.Json.Serialization.JsonPropertyName("arguments")] [global::System.Text.Json.Serialization.JsonRequired] - public required string Arguments { get; set; } + public required global::Ollama.ToolCallFunctionArgs Arguments { get; set; } /// /// Additional properties that are not explicitly defined in the schema diff --git a/src/libs/Ollama/JsonSerializerContextTypes.AdditionalTypes.cs b/src/libs/Ollama/JsonSerializerContextTypes.AdditionalTypes.cs new file mode 100644 index 0000000..d52fa6c --- /dev/null +++ b/src/libs/Ollama/JsonSerializerContextTypes.AdditionalTypes.cs @@ -0,0 +1,9 @@ +namespace Ollama; + +public partial class JsonSerializerContextTypes +{ + /// + /// + /// + public JsonElement JsonElement { get; set; } +} \ No newline at end of file diff --git a/src/libs/Ollama/OllamaApiClientExtensions.cs b/src/libs/Ollama/OllamaApiClientExtensions.cs index 96b52d8..e052ea4 100644 --- a/src/libs/Ollama/OllamaApiClientExtensions.cs +++ b/src/libs/Ollama/OllamaApiClientExtensions.cs @@ -107,6 +107,7 @@ public static async Task WaitAsync( enumerable = enumerable ?? throw new ArgumentNullException(nameof(enumerable)); MessageRole? responseRole = null; + IList? toolCalls = null; var responseContent = new StringBuilder(); var currentResponse = new GenerateChatCompletionResponse @@ -123,6 +124,7 @@ public static async Task WaitAsync( await foreach (var response in enumerable.ConfigureAwait(false)) { responseRole ??= response.Message.Role; + toolCalls ??= response.Message.ToolCalls; responseContent.Append(response.Message.Content); currentResponse = response; @@ -131,7 +133,8 @@ public static async Task WaitAsync( currentResponse.Message = new Message { Role = responseRole ?? MessageRole.User, - Content = responseContent.ToString() + Content = responseContent.ToString(), + ToolCalls = toolCalls, }; return currentResponse; diff --git a/src/libs/Ollama/openapi.yaml b/src/libs/Ollama/openapi.yaml index e6d8979..a3b4a90 100644 --- a/src/libs/Ollama/openapi.yaml +++ b/src/libs/Ollama/openapi.yaml @@ -677,7 +677,7 @@ components: type: string description: The name of the function to be called. arguments: - type: string + $ref: '#/components/schemas/ToolCallFunctionArgs' description: The function the model wants to call. ToolCallFunctionArgs: type: object diff --git a/src/tests/Ollama.IntegrationTests/Tests.Integration.cs b/src/tests/Ollama.IntegrationTests/Tests.Integration.cs index 7765913..2dc2a3a 100755 --- a/src/tests/Ollama.IntegrationTests/Tests.Integration.cs +++ b/src/tests/Ollama.IntegrationTests/Tests.Integration.cs @@ -176,65 +176,66 @@ public async Task GetChat() Console.WriteLine(message.Content); } -// [TestMethod] -// public async Task Tools() -// { -// #if DEBUG -// await using var container = await PrepareEnvironmentAsync(EnvironmentType.Local, "llama3.1"); -// #else -// await using var container = await PrepareEnvironmentAsync(EnvironmentType.Container, "llama3.1"); -// #endif -// -// var messages = new List -// { -// "You are a helpful weather assistant.".AsSystemMessage(), -// "What is the current temperature in Dubai, UAE in Celsius?".AsUserMessage(), -// }; -// const string model = "llama3.1"; -// -// try -// { -// var service = new WeatherService(); -// var tools = service.AsTools(); -// var response = container.ApiClient.Chat.GenerateChatCompletionAsync( -// model, -// messages, -// tools: tools); -// var doneResponse = await response.WaitAsync(); -// -// Console.WriteLine(doneResponse.Message.Content); -// -// var resultMessage = doneResponse.Message; -// messages.Add(resultMessage); -// -// if (resultMessage.ToolCalls == null || -// resultMessage.ToolCalls.Count == 0) -// { -// throw new InvalidOperationException("Expected a function call."); -// } -// -// foreach (var call in resultMessage.ToolCalls) -// { -// var json = await service.CallAsync( -// functionName: call.Function?.Name ?? string.Empty, -// argumentsAsJson: call.Function?.Arguments ?? string.Empty); -// messages.Add(json.AsToolMessage()); -// } -// -// response = container.ApiClient.Chat.GenerateChatCompletionAsync( -// model, -// messages, -// tools: tools); -// doneResponse = await response.WaitAsync(); -// messages.Add(resultMessage); -// -// Console.WriteLine(doneResponse.Message.Content); -// } -// finally -// { -// PrintMessages(messages); -// } -// } + [TestMethod] + public async Task Tools() + { +#if DEBUG + await using var container = await PrepareEnvironmentAsync(EnvironmentType.Local, "llama3.1"); +#else + await using var container = await PrepareEnvironmentAsync(EnvironmentType.Container, "llama3.1"); +#endif + + var messages = new List + { + "You are a helpful weather assistant.".AsSystemMessage(), + "What is the current temperature in Dubai, UAE in Celsius?".AsUserMessage(), + }; + const string model = "llama3.1"; + + try + { + var service = new WeatherService(); + var tools = service.AsTools(); + var response = container.ApiClient.Chat.GenerateChatCompletionAsync( + model, + messages, + stream: false, + tools: tools); + var doneResponse = await response.WaitAsync(); + + Console.WriteLine(doneResponse.Message.Content); + + var resultMessage = doneResponse.Message; + messages.Add(resultMessage); + + if (resultMessage.ToolCalls == null || + resultMessage.ToolCalls.Count == 0) + { + throw new InvalidOperationException("Expected a function call."); + } + + foreach (var call in resultMessage.ToolCalls) + { + var json = await service.CallAsync( + functionName: call.Function?.Name ?? string.Empty, + argumentsAsJson: call.Function?.Arguments.AsJson() ?? string.Empty); + messages.Add(json.AsToolMessage()); + } + + response = container.ApiClient.Chat.GenerateChatCompletionAsync( + model, + messages, + tools: tools); + doneResponse = await response.WaitAsync(); + messages.Add(resultMessage); + + Console.WriteLine(doneResponse.Message.Content); + } + finally + { + PrintMessages(messages); + } + } private static void PrintMessages(List messages) { @@ -247,7 +248,7 @@ private static void PrintMessages(List messages) Console.WriteLine("Tool calls:"); foreach (var call in message.ToolCalls) { - Console.WriteLine($"> {call.Function?.Name}({call.Function?.Arguments})"); + Console.WriteLine($"{call.Function?.Name}({call.Function?.Arguments.AsJson()})"); } } } diff --git a/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs b/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs index f3e056f..cd91d4c 100644 --- a/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs +++ b/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs @@ -3,7 +3,7 @@ namespace Ollama.IntegrationTests; -[OllamaFunctions] +[OllamaTools] public interface IVariousTypesFunctions { [Description("Get the current weather in a given location")] diff --git a/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs b/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs index a81dea1..54eae7f 100644 --- a/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs +++ b/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs @@ -18,7 +18,7 @@ public class Weather public string Description { get; set; } = string.Empty; } -[OllamaFunctions] +[OllamaTools] public interface IWeatherFunctions { [Description("Get the current weather in a given location")] diff --git a/src/tests/Ollama.SnapshotTests/TestHelper.cs b/src/tests/Ollama.SnapshotTests/TestHelper.cs index 3ec1b35..818c7f2 100755 --- a/src/tests/Ollama.SnapshotTests/TestHelper.cs +++ b/src/tests/Ollama.SnapshotTests/TestHelper.cs @@ -15,7 +15,7 @@ public static async Task CheckSourceAsync( var referenceAssemblies = LatestReferenceAssemblies.Net80; var references = await referenceAssemblies.ResolveAsync(null, cancellationToken); references = references - .Add(MetadataReference.CreateFromFile(typeof(OllamaFunctionsAttribute).Assembly.Location)); + .Add(MetadataReference.CreateFromFile(typeof(OllamaToolsAttribute).Assembly.Location)); var compilation = (Compilation)CSharpCompilation.Create( assemblyName: "Tests", @@ -26,7 +26,7 @@ public static async Task CheckSourceAsync( references: references, options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); var driver = CSharpGeneratorDriver - .Create(new OllamaFunctionsGenerator()) + .Create(new OllamaToolsGenerator()) .RunGeneratorsAndUpdateCompilation(compilation, out compilation, out _, cancellationToken); var diagnostics = compilation.GetDiagnostics(cancellationToken);