diff --git a/README.md b/README.md index e6ef0b3..6e98acb 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,79 @@ while (true) } ``` +### Tools +```csharp +using var ollama = new OllamaApiClient(); +var chat = ollama.Chat( + model: "llama3.1", + systemMessage: "You are a helpful weather assistant.", + autoCallTools: true); + +var service = new WeatherService(); +chat.AddToolService(service.AsTools(), service.AsCalls()); + +try +{ + _ = await chat.SendAsync("What is the current temperature in Dubai, UAE in Celsius?"); +} +finally +{ + Console.WriteLine(chat.PrintMessages()); +} +``` +``` +> System: +You are a helpful weather assistant. +> User: +What is the current temperature in Dubai, UAE in Celsius? +> Assistant: +Tool calls: +GetCurrentWeather({"location":"Dubai, UAE","unit":"celsius"}) +> Tool: +{"location":"Dubai, UAE","temperature":22,"unit":"celsius","description":"Sunny"} +> Assistant: +The current temperature in Dubai, UAE is 22°C. +``` +```csharp +public enum Unit +{ + Celsius, + Fahrenheit, +} + +public class Weather +{ + public string Location { get; set; } = string.Empty; + public double Temperature { get; set; } + public Unit Unit { get; set; } + public string Description { get; set; } = string.Empty; +} + +[OllamaTools] +public interface IWeatherFunctions +{ + [Description("Get the current weather in a given location")] + public Task GetCurrentWeatherAsync( + [Description("The city and state, e.g. San Francisco, CA")] string location, + Unit unit = Unit.Celsius, + CancellationToken cancellationToken = default); +} + +public class WeatherService : IWeatherFunctions +{ + public Task GetCurrentWeatherAsync(string location, Unit unit = Unit.Celsius, CancellationToken cancellationToken = default) + { + return Task.FromResult(new Weather + { + Location = location, + Temperature = 22.0, + Unit = unit, + Description = "Sunny", + }); + } +} +``` + ## Credits Icon and name were reused from the amazing [Ollama project](https://github.com/jmorganca/ollama). diff --git a/src/libs/Ollama/Chat.cs b/src/libs/Ollama/Chat.cs index b3d5d94..d114f0a 100644 --- a/src/libs/Ollama/Chat.cs +++ b/src/libs/Ollama/Chat.cs @@ -1,4 +1,6 @@ -namespace Ollama; +using System.Text; + +namespace Ollama; /// /// @@ -8,7 +10,17 @@ public class Chat /// /// /// - public IList History { get; } = new List(); + public List History { get; } = new(); + + /// + /// + /// + public List Tools { get; } = new(); + + /// + /// + /// + public Dictionary>> Calls { get; } = new(); /// /// @@ -19,84 +31,85 @@ public class Chat /// /// public string Model { get; set; } + + /// + /// + /// + public bool AutoCallTools { get; set; } = true; /// /// /// /// /// + /// /// - public Chat(OllamaApiClient client, string model) + public Chat( + OllamaApiClient client, + string model, + string? systemMessage = null) { Client = client ?? throw new ArgumentNullException(nameof(client)); Model = model ?? throw new ArgumentNullException(nameof(model)); + + if (systemMessage != null) + { + History.Add(new Message + { + Role = MessageRole.System, + Content = systemMessage, + }); + } } /// - /// Sends a message to the currently selected model - /// - /// The message to send - /// The token to cancel the operation with - public Task SendAsync( - string message, - CancellationToken cancellationToken = default) - { - return SendAsync(message, null, cancellationToken); - } - - /// - /// Sends a message to the currently selected model - /// - /// The message to send - /// Base64 encoded images to send to the model - /// The token to cancel the operation with - public Task SendAsync( - string message, - IEnumerable? imagesAsBase64, - CancellationToken cancellationToken = default) - { - return SendAsAsync(MessageRole.User, message, imagesAsBase64, cancellationToken); - } - - /// - /// Sends a message in a given role to the currently selected model + /// /// - /// The role in which the message should be sent - /// The message to send - /// The token to cancel the operation with - public Task SendAsAsync( - MessageRole role, - string message, - CancellationToken cancellationToken = default) + /// + /// + public void AddToolService( + IList tools, + IReadOnlyDictionary>> calls) { - return SendAsAsync(role, message, null, cancellationToken); + tools = tools ?? throw new ArgumentNullException(nameof(tools)); + calls = calls ?? throw new ArgumentNullException(nameof(calls)); + + Tools.AddRange(tools); + foreach (var call in calls) + { + Calls.Add(call.Key, call.Value); + } } /// - /// Sends a message in a given role to the currently selected model + /// Sends a message in a given role(User by default) to the currently selected model /// /// The role in which the message should be sent /// The message to send /// Base64 encoded images to send to the model /// The token to cancel the operation with - public async Task SendAsAsync( - MessageRole role, - string message, - IEnumerable? imagesAsBase64, + public async Task SendAsync( + string? message = null, + MessageRole role = MessageRole.User, + IEnumerable? imagesAsBase64 = null, CancellationToken cancellationToken = default) { - History.Add(new Message + if (message != null) { - Content = message, - Role = role, - Images = imagesAsBase64?.ToList() ?? [], - }); + History.Add(new Message + { + Content = message, + Role = role, + Images = imagesAsBase64?.ToList() ?? [], + }); + } var request = new GenerateChatCompletionRequest { - Messages = History.ToList(), + Messages = History, Model = Model, - Stream = true, + Stream = false, + Tools = Tools.Count == 0 ? null : Tools, }; var answer = await Client.Chat.GenerateChatCompletionAsync(request, cancellationToken).WaitAsync().ConfigureAwait(false); @@ -106,7 +119,60 @@ public async Task SendAsAsync( } History.Add(answer.Message); + + if (AutoCallTools && answer.Message.ToolCalls?.Count > 0) + { + foreach (var call in answer.Message.ToolCalls) + { + var func = Calls[call.Function?.Name ?? string.Empty]; + + var json = await func( + call.Function?.Arguments.AsJson() ?? string.Empty, + cancellationToken).ConfigureAwait(false); + History.Add(json.AsToolMessage()); + } + + return await SendAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + } return answer.Message; } + + /// + /// + /// + public string PrintMessages() + { + return PrintMessages(History); + } + + /// + /// + /// + /// + public static string PrintMessages(List messages) + { + messages = messages ?? throw new ArgumentNullException(nameof(messages)); + + var builder = new StringBuilder(); + foreach (var message in messages) + { + builder.AppendLine($"> {message.Role}:"); + if (!string.IsNullOrWhiteSpace(message.Content)) + { + builder.AppendLine(message.Content); + } + if (message.ToolCalls?.Count > 0) + { + builder.AppendLine("Tool calls:"); + + foreach (var call in message.ToolCalls) + { + builder.AppendLine($"{call.Function?.Name}({call.Function?.Arguments.AsJson()})"); + } + } + } + + return builder.ToString(); + } } \ No newline at end of file diff --git a/src/libs/Ollama/Ollama.csproj b/src/libs/Ollama/Ollama.csproj index ca8131b..9e84e05 100644 --- a/src/libs/Ollama/Ollama.csproj +++ b/src/libs/Ollama/Ollama.csproj @@ -2,7 +2,7 @@ netstandard2.0;net4.6.2;net6.0;net8.0 - $(NoWarn);CA2016;CA2227 + $(NoWarn);CA2016;CA2227;CA1002;CA1303 diff --git a/src/libs/Ollama/OllamaApiClientExtensions.cs b/src/libs/Ollama/OllamaApiClientExtensions.cs index e052ea4..1e5a0a4 100644 --- a/src/libs/Ollama/OllamaApiClientExtensions.cs +++ b/src/libs/Ollama/OllamaApiClientExtensions.cs @@ -11,16 +11,23 @@ public static class OllamaApiClientExtensions /// Starts a new chat with the currently selected model. /// /// The client to start the chat with - /// + /// The model to chat with + /// Optional. A system message to send to the model + /// Optional. If set to true, the client will automatically call tools. /// /// A chat instance that can be used to receive and send messages from and to /// the Ollama endpoint while maintaining the message history. /// public static Chat Chat( this OllamaApiClient client, - string model) + string model, + string? systemMessage = null, + bool autoCallTools = true) { - return new Chat(client, model); + return new Chat(client, model, systemMessage) + { + AutoCallTools = autoCallTools, + }; } /// diff --git a/src/tests/Ollama.IntegrationTests/Tests.Chat.cs b/src/tests/Ollama.IntegrationTests/Tests.Chat.cs index b734b3e..9f8c68e 100644 --- a/src/tests/Ollama.IntegrationTests/Tests.Chat.cs +++ b/src/tests/Ollama.IntegrationTests/Tests.Chat.cs @@ -52,7 +52,7 @@ public async Task Sends_Messages_As_Defined_Role() var ollama = MockApiClient(MessageRole.Assistant, "hi system!"); var chat = new Chat(ollama, string.Empty); - var message = await chat.SendAsAsync(MessageRole.System, "henlo hooman"); + var message = await chat.SendAsync("henlo hooman", MessageRole.System); chat.History.Count.Should().Be(2); chat.History[0].Role.Should().Be(MessageRole.System); diff --git a/src/tests/Ollama.IntegrationTests/Tests.Integration.cs b/src/tests/Ollama.IntegrationTests/Tests.Integration.cs index 2dc2a3a..5354210 100755 --- a/src/tests/Ollama.IntegrationTests/Tests.Integration.cs +++ b/src/tests/Ollama.IntegrationTests/Tests.Integration.cs @@ -196,25 +196,17 @@ public async Task Tools() { var service = new WeatherService(); var tools = service.AsTools(); - var response = container.ApiClient.Chat.GenerateChatCompletionAsync( + var response = await container.ApiClient.Chat.GenerateChatCompletionAsync( model, messages, stream: false, - tools: tools); - var doneResponse = await response.WaitAsync(); - - Console.WriteLine(doneResponse.Message.Content); + tools: tools).WaitAsync(); - var resultMessage = doneResponse.Message; - messages.Add(resultMessage); + messages.Add(response.Message); - if (resultMessage.ToolCalls == null || - resultMessage.ToolCalls.Count == 0) - { - throw new InvalidOperationException("Expected a function call."); - } + response.Message.ToolCalls.Should().NotBeNullOrEmpty(because: "Expected a function call."); - foreach (var call in resultMessage.ToolCalls) + foreach (var call in response.Message.ToolCalls!) { var json = await service.CallAsync( functionName: call.Function?.Name ?? string.Empty, @@ -222,35 +214,44 @@ public async Task Tools() messages.Add(json.AsToolMessage()); } - response = container.ApiClient.Chat.GenerateChatCompletionAsync( + response = await container.ApiClient.Chat.GenerateChatCompletionAsync( model, messages, - tools: tools); - doneResponse = await response.WaitAsync(); - messages.Add(resultMessage); - - Console.WriteLine(doneResponse.Message.Content); + stream: false, + tools: tools).WaitAsync(); + messages.Add(response.Message); } finally { - PrintMessages(messages); + Console.WriteLine(Chat.PrintMessages(messages)); } } - private static void PrintMessages(List messages) + + [TestMethod] + public async Task ToolsInChat() { - foreach (var message in messages) +#if DEBUG + await using var container = await PrepareEnvironmentAsync(EnvironmentType.Local, "llama3.1"); +#else + await using var container = await PrepareEnvironmentAsync(EnvironmentType.Container, "llama3.1"); +#endif + + var chat = container.ApiClient.Chat( + model: "llama3.1", + systemMessage: "You are a helpful weather assistant.", + autoCallTools: true); + + var service = new WeatherService(); + chat.AddToolService(service.AsTools(), service.AsCalls()); + + try { - Console.WriteLine($"> {message.Role}:"); - Console.WriteLine($"{message.Content}"); - if (message.ToolCalls?.Count > 0) - { - Console.WriteLine("Tool calls:"); - foreach (var call in message.ToolCalls) - { - Console.WriteLine($"{call.Function?.Name}({call.Function?.Arguments.AsJson()})"); - } - } + _ = await chat.SendAsync("What is the current temperature in Dubai, UAE in Celsius?"); + } + finally + { + Console.WriteLine(chat.PrintMessages()); } } } \ No newline at end of file