Skip to content

Commit

Permalink
feat: Improved tool calling in Chat().
Browse files Browse the repository at this point in the history
  • Loading branch information
HavenDV committed Jul 30, 2024
1 parent dfeb069 commit 851dc8f
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 87 deletions.
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,79 @@ while (true)
}
```

### Tools
```csharp
using var ollama = new OllamaApiClient();
var chat = ollama.Chat(
model: "llama3.1",
systemMessage: "You are a helpful weather assistant.",
autoCallTools: true);

var service = new WeatherService();
chat.AddToolService(service.AsTools(), service.AsCalls());

try
{
_ = await chat.SendAsync("What is the current temperature in Dubai, UAE in Celsius?");
}
finally
{
Console.WriteLine(chat.PrintMessages());
}
```
```
> System:
You are a helpful weather assistant.
> User:
What is the current temperature in Dubai, UAE in Celsius?
> Assistant:
Tool calls:
GetCurrentWeather({"location":"Dubai, UAE","unit":"celsius"})
> Tool:
{"location":"Dubai, UAE","temperature":22,"unit":"celsius","description":"Sunny"}
> Assistant:
The current temperature in Dubai, UAE is 22°C.
```
```csharp
public enum Unit
{
Celsius,
Fahrenheit,
}

public class Weather
{
public string Location { get; set; } = string.Empty;
public double Temperature { get; set; }
public Unit Unit { get; set; }
public string Description { get; set; } = string.Empty;
}

[OllamaTools]
public interface IWeatherFunctions
{
[Description("Get the current weather in a given location")]
public Task<Weather> GetCurrentWeatherAsync(
[Description("The city and state, e.g. San Francisco, CA")] string location,
Unit unit = Unit.Celsius,
CancellationToken cancellationToken = default);
}

public class WeatherService : IWeatherFunctions
{
public Task<Weather> GetCurrentWeatherAsync(string location, Unit unit = Unit.Celsius, CancellationToken cancellationToken = default)
{
return Task.FromResult(new Weather
{
Location = location,
Temperature = 22.0,
Unit = unit,
Description = "Sunny",
});
}
}
```

## Credits

Icon and name were reused from the amazing [Ollama project](https://github.com/jmorganca/ollama).
Expand Down
166 changes: 116 additions & 50 deletions src/libs/Ollama/Chat.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace Ollama;
using System.Text;

namespace Ollama;

/// <summary>
///
Expand All @@ -8,7 +10,17 @@ public class Chat
/// <summary>
///
/// </summary>
public IList<Message> History { get; } = new List<Message>();
public List<Message> History { get; } = new();

/// <summary>
///
/// </summary>
public List<Tool> Tools { get; } = new();

/// <summary>
///
/// </summary>
public Dictionary<string, Func<string, CancellationToken, Task<string>>> Calls { get; } = new();

/// <summary>
///
Expand All @@ -19,84 +31,85 @@ public class Chat
///
/// </summary>
public string Model { get; set; }

/// <summary>
///
/// </summary>
public bool AutoCallTools { get; set; } = true;

/// <summary>
///
/// </summary>
/// <param name="client"></param>
/// <param name="model"></param>
/// <param name="systemMessage"></param>
/// <exception cref="ArgumentNullException"></exception>
public Chat(OllamaApiClient client, string model)
public Chat(
OllamaApiClient client,
string model,
string? systemMessage = null)
{
Client = client ?? throw new ArgumentNullException(nameof(client));
Model = model ?? throw new ArgumentNullException(nameof(model));

if (systemMessage != null)
{
History.Add(new Message
{
Role = MessageRole.System,
Content = systemMessage,
});
}
}

/// <summary>
/// Sends a message to the currently selected model
/// </summary>
/// <param name="message">The message to send</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public Task<Message> SendAsync(
string message,
CancellationToken cancellationToken = default)
{
return SendAsync(message, null, cancellationToken);
}

/// <summary>
/// Sends a message to the currently selected model
/// </summary>
/// <param name="message">The message to send</param>
/// <param name="imagesAsBase64">Base64 encoded images to send to the model</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public Task<Message> SendAsync(
string message,
IEnumerable<string>? imagesAsBase64,
CancellationToken cancellationToken = default)
{
return SendAsAsync(MessageRole.User, message, imagesAsBase64, cancellationToken);
}

/// <summary>
/// Sends a message in a given role to the currently selected model
///
/// </summary>
/// <param name="role">The role in which the message should be sent</param>
/// <param name="message">The message to send</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public Task<Message> SendAsAsync(
MessageRole role,
string message,
CancellationToken cancellationToken = default)
/// <param name="tools"></param>
/// <param name="calls"></param>
public void AddToolService(
IList<Tool> tools,
IReadOnlyDictionary<string, Func<string, CancellationToken, Task<string>>> calls)
{
return SendAsAsync(role, message, null, cancellationToken);
tools = tools ?? throw new ArgumentNullException(nameof(tools));
calls = calls ?? throw new ArgumentNullException(nameof(calls));

Tools.AddRange(tools);
foreach (var call in calls)
{
Calls.Add(call.Key, call.Value);
}
}

/// <summary>
/// Sends a message in a given role to the currently selected model
/// Sends a message in a given role(User by default) to the currently selected model
/// </summary>
/// <param name="role">The role in which the message should be sent</param>
/// <param name="message">The message to send</param>
/// <param name="imagesAsBase64">Base64 encoded images to send to the model</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public async Task<Message> SendAsAsync(
MessageRole role,
string message,
IEnumerable<string?>? imagesAsBase64,
public async Task<Message> SendAsync(
string? message = null,
MessageRole role = MessageRole.User,
IEnumerable<string?>? imagesAsBase64 = null,
CancellationToken cancellationToken = default)
{
History.Add(new Message
if (message != null)
{
Content = message,
Role = role,
Images = imagesAsBase64?.ToList() ?? [],
});
History.Add(new Message
{
Content = message,
Role = role,
Images = imagesAsBase64?.ToList() ?? [],
});
}

var request = new GenerateChatCompletionRequest
{
Messages = History.ToList(),
Messages = History,
Model = Model,
Stream = true,
Stream = false,
Tools = Tools.Count == 0 ? null : Tools,
};

var answer = await Client.Chat.GenerateChatCompletionAsync(request, cancellationToken).WaitAsync().ConfigureAwait(false);
Expand All @@ -106,7 +119,60 @@ public async Task<Message> SendAsAsync(
}

History.Add(answer.Message);

if (AutoCallTools && answer.Message.ToolCalls?.Count > 0)
{
foreach (var call in answer.Message.ToolCalls)
{
var func = Calls[call.Function?.Name ?? string.Empty];

var json = await func(
call.Function?.Arguments.AsJson() ?? string.Empty,
cancellationToken).ConfigureAwait(false);
History.Add(json.AsToolMessage());
}

return await SendAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
}

return answer.Message;
}

/// <summary>
///
/// </summary>
public string PrintMessages()
{
return PrintMessages(History);
}

/// <summary>
///
/// </summary>
/// <param name="messages"></param>
public static string PrintMessages(List<Message> messages)
{
messages = messages ?? throw new ArgumentNullException(nameof(messages));

var builder = new StringBuilder();
foreach (var message in messages)
{
builder.AppendLine($"> {message.Role}:");

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 160 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)
if (!string.IsNullOrWhiteSpace(message.Content))
{
builder.AppendLine(message.Content);
}
if (message.ToolCalls?.Count > 0)
{
builder.AppendLine("Tool calls:");

foreach (var call in message.ToolCalls)
{
builder.AppendLine($"{call.Function?.Name}({call.Function?.Arguments.AsJson()})");

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 171 in src/libs/Ollama/Chat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'Chat.PrintMessages(List<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)
}
}
}

return builder.ToString();
}
}
2 changes: 1 addition & 1 deletion src/libs/Ollama/Ollama.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<TargetFrameworks>netstandard2.0;net4.6.2;net6.0;net8.0</TargetFrameworks>
<NoWarn>$(NoWarn);CA2016;CA2227</NoWarn>
<NoWarn>$(NoWarn);CA2016;CA2227;CA1002;CA1303</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand Down
13 changes: 10 additions & 3 deletions src/libs/Ollama/OllamaApiClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@ public static class OllamaApiClientExtensions
/// Starts a new chat with the currently selected model.
/// </summary>
/// <param name="client">The client to start the chat with</param>
/// <param name="model"></param>
/// <param name="model">The model to chat with</param>
/// <param name="systemMessage">Optional. A system message to send to the model</param>
/// <param name="autoCallTools">Optional. If set to true, the client will automatically call tools.</param>
/// <returns>
/// A chat instance that can be used to receive and send messages from and to
/// the Ollama endpoint while maintaining the message history.
/// </returns>
public static Chat Chat(
this OllamaApiClient client,
string model)
string model,
string? systemMessage = null,
bool autoCallTools = true)
{
return new Chat(client, model);
return new Chat(client, model, systemMessage)
{
AutoCallTools = autoCallTools,
};
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/tests/Ollama.IntegrationTests/Tests.Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public async Task Sends_Messages_As_Defined_Role()
var ollama = MockApiClient(MessageRole.Assistant, "hi system!");

var chat = new Chat(ollama, string.Empty);
var message = await chat.SendAsAsync(MessageRole.System, "henlo hooman");
var message = await chat.SendAsync("henlo hooman", MessageRole.System);

chat.History.Count.Should().Be(2);
chat.History[0].Role.Should().Be(MessageRole.System);
Expand Down
Loading

0 comments on commit 851dc8f

Please sign in to comment.