diff --git a/src/helpers/FixOpenApiSpec/Program.cs b/src/helpers/FixOpenApiSpec/Program.cs
index 808b801..d16c969 100644
--- a/src/helpers/FixOpenApiSpec/Program.cs
+++ b/src/helpers/FixOpenApiSpec/Program.cs
@@ -40,11 +40,6 @@
},
},
};
-openApiDocument.Components.Schemas["ToolCallFunction"]!.Properties["arguments"] = new OpenApiSchema
-{
- Type = "string",
-};
-
text = openApiDocument.SerializeAsYaml(OpenApiSpecVersion.OpenApi3_0);
_ = new OpenApiStringReader().Read(text, out diagnostics);
diff --git a/src/libs/Ollama.Generators/OpenAiFunctionsGenerator.cs b/src/libs/Ollama.Generators/OllamaToolsGenerator.cs
similarity index 88%
rename from src/libs/Ollama.Generators/OpenAiFunctionsGenerator.cs
rename to src/libs/Ollama.Generators/OllamaToolsGenerator.cs
index 4fd19d0..f1d1dd8 100755
--- a/src/libs/Ollama.Generators/OpenAiFunctionsGenerator.cs
+++ b/src/libs/Ollama.Generators/OllamaToolsGenerator.cs
@@ -7,11 +7,11 @@
namespace Ollama.Generators;
[Generator]
-public class OllamaFunctionsGenerator : IIncrementalGenerator
+public class OllamaToolsGenerator : IIncrementalGenerator
{
#region Constants
- public const string Name = nameof(OllamaFunctionsGenerator);
+ public const string Name = nameof(OllamaToolsGenerator);
public const string Id = "OAFG";
#endregion
@@ -22,7 +22,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
var attributes =
context.SyntaxProvider
- .ForAttributeWithMetadataName("Ollama.OllamaFunctionsAttribute")
+ .ForAttributeWithMetadataName("Ollama.OllamaToolsAttribute")
.SelectManyAllAttributesOfCurrentInterfaceSyntax()
.SelectAndReportExceptions(PrepareData, context, Id);
diff --git a/src/libs/Ollama/Attributes/OllamaFunctionsAttribute.cs b/src/libs/Ollama/Attributes/OllamaToolsAttribute.cs
similarity index 78%
rename from src/libs/Ollama/Attributes/OllamaFunctionsAttribute.cs
rename to src/libs/Ollama/Attributes/OllamaToolsAttribute.cs
index 345441c..7675354 100644
--- a/src/libs/Ollama/Attributes/OllamaFunctionsAttribute.cs
+++ b/src/libs/Ollama/Attributes/OllamaToolsAttribute.cs
@@ -6,4 +6,4 @@ namespace Ollama;
///
[AttributeUsage(AttributeTargets.Interface)]
[System.Diagnostics.Conditional("OLLAMA_FUNCTIONS_ATTRIBUTES")]
-public sealed class OllamaFunctionsAttribute : Attribute;
\ No newline at end of file
+public sealed class OllamaToolsAttribute : Attribute;
\ No newline at end of file
diff --git a/src/libs/Ollama/Extensions/StringExtensions.cs b/src/libs/Ollama/Extensions/StringExtensions.cs
index a8f6179..f93d3a3 100644
--- a/src/libs/Ollama/Extensions/StringExtensions.cs
+++ b/src/libs/Ollama/Extensions/StringExtensions.cs
@@ -61,4 +61,14 @@ public static Message AsToolMessage(this string content)
Content = content,
};
}
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static string AsJson(this ToolCallFunctionArgs args)
+ {
+ return JsonSerializer.Serialize(args, SourceGenerationContext.Default.ToolCallFunctionArgs);
+ }
}
\ No newline at end of file
diff --git a/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs b/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs
index 925a9a4..c026d6b 100644
--- a/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs
+++ b/src/libs/Ollama/Generated/JsonSerializerContextTypes.g.cs
@@ -93,114 +93,118 @@ public sealed partial class JsonSerializerContextTypes
///
///
///
- public global::System.Collections.Generic.IList? Type20 { get; set; }
+ public global::Ollama.ToolCallFunctionArgs? Type20 { get; set; }
///
///
///
- public global::System.AnyOf? Type21 { get; set; }
+ public global::System.Collections.Generic.IList? Type21 { get; set; }
///
///
///
- public global::System.Collections.Generic.IList? Type22 { get; set; }
+ public global::System.AnyOf? Type22 { get; set; }
///
///
///
- public global::Ollama.ModelDetails? Type23 { get; set; }
+ public global::System.Collections.Generic.IList? Type23 { get; set; }
///
///
///
- public global::System.Collections.Generic.IList? Type24 { get; set; }
+ public global::Ollama.ModelDetails? Type24 { get; set; }
///
///
///
- public global::Ollama.ModelInformation? Type25 { get; set; }
+ public global::System.Collections.Generic.IList? Type25 { get; set; }
///
///
///
- public global::System.AnyOf? Type26 { get; set; }
+ public global::Ollama.ModelInformation? Type26 { get; set; }
///
///
///
- public global::System.AnyOf? Type27 { get; set; }
+ public global::System.AnyOf? Type27 { get; set; }
///
///
///
- public global::Ollama.VersionResponse? Type28 { get; set; }
+ public global::System.AnyOf? Type28 { get; set; }
///
///
///
- public global::Ollama.GenerateCompletionRequest? Type29 { get; set; }
+ public global::Ollama.VersionResponse? Type29 { get; set; }
///
///
///
- public global::Ollama.GenerateCompletionResponse? Type30 { get; set; }
+ public global::Ollama.GenerateCompletionRequest? Type30 { get; set; }
///
///
///
- public global::Ollama.GenerateChatCompletionRequest? Type31 { get; set; }
+ public global::Ollama.GenerateCompletionResponse? Type31 { get; set; }
///
///
///
- public global::Ollama.GenerateChatCompletionResponse? Type32 { get; set; }
+ public global::Ollama.GenerateChatCompletionRequest? Type32 { get; set; }
///
///
///
- public global::Ollama.GenerateEmbeddingRequest? Type33 { get; set; }
+ public global::Ollama.GenerateChatCompletionResponse? Type33 { get; set; }
///
///
///
- public global::Ollama.GenerateEmbeddingResponse? Type34 { get; set; }
+ public global::Ollama.GenerateEmbeddingRequest? Type34 { get; set; }
///
///
///
- public global::Ollama.CreateModelRequest? Type35 { get; set; }
+ public global::Ollama.GenerateEmbeddingResponse? Type35 { get; set; }
///
///
///
- public global::Ollama.CreateModelResponse? Type36 { get; set; }
+ public global::Ollama.CreateModelRequest? Type36 { get; set; }
///
///
///
- public global::Ollama.ModelsResponse? Type37 { get; set; }
+ public global::Ollama.CreateModelResponse? Type37 { get; set; }
///
///
///
- public global::Ollama.ProcessResponse? Type38 { get; set; }
+ public global::Ollama.ModelsResponse? Type38 { get; set; }
///
///
///
- public global::Ollama.ModelInfoRequest? Type39 { get; set; }
+ public global::Ollama.ProcessResponse? Type39 { get; set; }
///
///
///
- public global::Ollama.ModelInfo? Type40 { get; set; }
+ public global::Ollama.ModelInfoRequest? Type40 { get; set; }
///
///
///
- public global::Ollama.CopyModelRequest? Type41 { get; set; }
+ public global::Ollama.ModelInfo? Type41 { get; set; }
///
///
///
- public global::Ollama.DeleteModelRequest? Type42 { get; set; }
+ public global::Ollama.CopyModelRequest? Type42 { get; set; }
///
///
///
- public global::Ollama.PullModelRequest? Type43 { get; set; }
+ public global::Ollama.DeleteModelRequest? Type43 { get; set; }
///
///
///
- public global::Ollama.PullModelResponse? Type44 { get; set; }
+ public global::Ollama.PullModelRequest? Type44 { get; set; }
///
///
///
- public global::Ollama.PushModelRequest? Type45 { get; set; }
+ public global::Ollama.PullModelResponse? Type45 { get; set; }
///
///
///
- public global::Ollama.PushModelResponse? Type46 { get; set; }
+ public global::Ollama.PushModelRequest? Type46 { get; set; }
///
///
///
- public byte[]? Type47 { get; set; }
+ public global::Ollama.PushModelResponse? Type47 { get; set; }
+ ///
+ ///
+ ///
+ public byte[]? Type48 { get; set; }
}
}
\ No newline at end of file
diff --git a/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs b/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs
index 901b432..7be9a20 100644
--- a/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs
+++ b/src/libs/Ollama/Generated/Ollama.Models.ToolCallFunction.g.cs
@@ -16,11 +16,11 @@ public sealed partial class ToolCallFunction
public required string Name { get; set; }
///
- ///
+ /// The arguments to pass to the function.
///
[global::System.Text.Json.Serialization.JsonPropertyName("arguments")]
[global::System.Text.Json.Serialization.JsonRequired]
- public required string Arguments { get; set; }
+ public required global::Ollama.ToolCallFunctionArgs Arguments { get; set; }
///
/// Additional properties that are not explicitly defined in the schema
diff --git a/src/libs/Ollama/JsonSerializerContextTypes.AdditionalTypes.cs b/src/libs/Ollama/JsonSerializerContextTypes.AdditionalTypes.cs
new file mode 100644
index 0000000..d52fa6c
--- /dev/null
+++ b/src/libs/Ollama/JsonSerializerContextTypes.AdditionalTypes.cs
@@ -0,0 +1,9 @@
+namespace Ollama;
+
+public partial class JsonSerializerContextTypes
+{
+ ///
+ ///
+ ///
+ public JsonElement JsonElement { get; set; }
+}
\ No newline at end of file
diff --git a/src/libs/Ollama/OllamaApiClientExtensions.cs b/src/libs/Ollama/OllamaApiClientExtensions.cs
index 96b52d8..e052ea4 100644
--- a/src/libs/Ollama/OllamaApiClientExtensions.cs
+++ b/src/libs/Ollama/OllamaApiClientExtensions.cs
@@ -107,6 +107,7 @@ public static async Task WaitAsync(
enumerable = enumerable ?? throw new ArgumentNullException(nameof(enumerable));
MessageRole? responseRole = null;
+ IList? toolCalls = null;
var responseContent = new StringBuilder();
var currentResponse = new GenerateChatCompletionResponse
@@ -123,6 +124,7 @@ public static async Task WaitAsync(
await foreach (var response in enumerable.ConfigureAwait(false))
{
responseRole ??= response.Message.Role;
+ toolCalls ??= response.Message.ToolCalls;
responseContent.Append(response.Message.Content);
currentResponse = response;
@@ -131,7 +133,8 @@ public static async Task WaitAsync(
currentResponse.Message = new Message
{
Role = responseRole ?? MessageRole.User,
- Content = responseContent.ToString()
+ Content = responseContent.ToString(),
+ ToolCalls = toolCalls,
};
return currentResponse;
diff --git a/src/libs/Ollama/openapi.yaml b/src/libs/Ollama/openapi.yaml
index e6d8979..a3b4a90 100644
--- a/src/libs/Ollama/openapi.yaml
+++ b/src/libs/Ollama/openapi.yaml
@@ -677,7 +677,7 @@ components:
type: string
description: The name of the function to be called.
arguments:
- type: string
+ $ref: '#/components/schemas/ToolCallFunctionArgs'
description: The function the model wants to call.
ToolCallFunctionArgs:
type: object
diff --git a/src/tests/Ollama.IntegrationTests/Tests.Integration.cs b/src/tests/Ollama.IntegrationTests/Tests.Integration.cs
index 7765913..2dc2a3a 100755
--- a/src/tests/Ollama.IntegrationTests/Tests.Integration.cs
+++ b/src/tests/Ollama.IntegrationTests/Tests.Integration.cs
@@ -176,65 +176,66 @@ public async Task GetChat()
Console.WriteLine(message.Content);
}
-// [TestMethod]
-// public async Task Tools()
-// {
-// #if DEBUG
-// await using var container = await PrepareEnvironmentAsync(EnvironmentType.Local, "llama3.1");
-// #else
-// await using var container = await PrepareEnvironmentAsync(EnvironmentType.Container, "llama3.1");
-// #endif
-//
-// var messages = new List
-// {
-// "You are a helpful weather assistant.".AsSystemMessage(),
-// "What is the current temperature in Dubai, UAE in Celsius?".AsUserMessage(),
-// };
-// const string model = "llama3.1";
-//
-// try
-// {
-// var service = new WeatherService();
-// var tools = service.AsTools();
-// var response = container.ApiClient.Chat.GenerateChatCompletionAsync(
-// model,
-// messages,
-// tools: tools);
-// var doneResponse = await response.WaitAsync();
-//
-// Console.WriteLine(doneResponse.Message.Content);
-//
-// var resultMessage = doneResponse.Message;
-// messages.Add(resultMessage);
-//
-// if (resultMessage.ToolCalls == null ||
-// resultMessage.ToolCalls.Count == 0)
-// {
-// throw new InvalidOperationException("Expected a function call.");
-// }
-//
-// foreach (var call in resultMessage.ToolCalls)
-// {
-// var json = await service.CallAsync(
-// functionName: call.Function?.Name ?? string.Empty,
-// argumentsAsJson: call.Function?.Arguments ?? string.Empty);
-// messages.Add(json.AsToolMessage());
-// }
-//
-// response = container.ApiClient.Chat.GenerateChatCompletionAsync(
-// model,
-// messages,
-// tools: tools);
-// doneResponse = await response.WaitAsync();
-// messages.Add(resultMessage);
-//
-// Console.WriteLine(doneResponse.Message.Content);
-// }
-// finally
-// {
-// PrintMessages(messages);
-// }
-// }
+ [TestMethod]
+ public async Task Tools()
+ {
+#if DEBUG
+ await using var container = await PrepareEnvironmentAsync(EnvironmentType.Local, "llama3.1");
+#else
+ await using var container = await PrepareEnvironmentAsync(EnvironmentType.Container, "llama3.1");
+#endif
+
+ var messages = new List
+ {
+ "You are a helpful weather assistant.".AsSystemMessage(),
+ "What is the current temperature in Dubai, UAE in Celsius?".AsUserMessage(),
+ };
+ const string model = "llama3.1";
+
+ try
+ {
+ var service = new WeatherService();
+ var tools = service.AsTools();
+ var response = container.ApiClient.Chat.GenerateChatCompletionAsync(
+ model,
+ messages,
+ stream: false,
+ tools: tools);
+ var doneResponse = await response.WaitAsync();
+
+ Console.WriteLine(doneResponse.Message.Content);
+
+ var resultMessage = doneResponse.Message;
+ messages.Add(resultMessage);
+
+ if (resultMessage.ToolCalls == null ||
+ resultMessage.ToolCalls.Count == 0)
+ {
+ throw new InvalidOperationException("Expected a function call.");
+ }
+
+ foreach (var call in resultMessage.ToolCalls)
+ {
+ var json = await service.CallAsync(
+ functionName: call.Function?.Name ?? string.Empty,
+ argumentsAsJson: call.Function?.Arguments.AsJson() ?? string.Empty);
+ messages.Add(json.AsToolMessage());
+ }
+
+ response = container.ApiClient.Chat.GenerateChatCompletionAsync(
+ model,
+ messages,
+ tools: tools);
+ doneResponse = await response.WaitAsync();
+ messages.Add(resultMessage);
+
+ Console.WriteLine(doneResponse.Message.Content);
+ }
+ finally
+ {
+ PrintMessages(messages);
+ }
+ }
private static void PrintMessages(List messages)
{
@@ -247,7 +248,7 @@ private static void PrintMessages(List messages)
Console.WriteLine("Tool calls:");
foreach (var call in message.ToolCalls)
{
- Console.WriteLine($"> {call.Function?.Name}({call.Function?.Arguments})");
+ Console.WriteLine($"{call.Function?.Name}({call.Function?.Arguments.AsJson()})");
}
}
}
diff --git a/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs b/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs
index f3e056f..cd91d4c 100644
--- a/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs
+++ b/src/tests/Ollama.IntegrationTests/VariousTypesFunctions.cs
@@ -3,7 +3,7 @@
namespace Ollama.IntegrationTests;
-[OllamaFunctions]
+[OllamaTools]
public interface IVariousTypesFunctions
{
[Description("Get the current weather in a given location")]
diff --git a/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs b/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs
index a81dea1..54eae7f 100644
--- a/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs
+++ b/src/tests/Ollama.IntegrationTests/WeatherFunctions.cs
@@ -18,7 +18,7 @@ public class Weather
public string Description { get; set; } = string.Empty;
}
-[OllamaFunctions]
+[OllamaTools]
public interface IWeatherFunctions
{
[Description("Get the current weather in a given location")]
diff --git a/src/tests/Ollama.SnapshotTests/TestHelper.cs b/src/tests/Ollama.SnapshotTests/TestHelper.cs
index 3ec1b35..818c7f2 100755
--- a/src/tests/Ollama.SnapshotTests/TestHelper.cs
+++ b/src/tests/Ollama.SnapshotTests/TestHelper.cs
@@ -15,7 +15,7 @@ public static async Task CheckSourceAsync(
var referenceAssemblies = LatestReferenceAssemblies.Net80;
var references = await referenceAssemblies.ResolveAsync(null, cancellationToken);
references = references
- .Add(MetadataReference.CreateFromFile(typeof(OllamaFunctionsAttribute).Assembly.Location));
+ .Add(MetadataReference.CreateFromFile(typeof(OllamaToolsAttribute).Assembly.Location));
var compilation = (Compilation)CSharpCompilation.Create(
assemblyName: "Tests",
@@ -26,7 +26,7 @@ public static async Task CheckSourceAsync(
references: references,
options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
var driver = CSharpGeneratorDriver
- .Create(new OllamaFunctionsGenerator())
+ .Create(new OllamaToolsGenerator())
.RunGeneratorsAndUpdateCompilation(compilation, out compilation, out _, cancellationToken);
var diagnostics = compilation.GetDiagnostics(cancellationToken);