diff --git a/README.md b/README.md index 1ce1790..a85eb90 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ fun main() = runBlocking { println("Initial response:") println(initialResponse) - conversation += initialResponse.asMessage() + conversation += initialResponse val tool = initialResponse.content.filterIsInstance().first() val toolResult = tool.use() conversation += Message { +toolResult } diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index bc46cdd..7777897 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -230,7 +230,7 @@ data class System( @Serializable data class Tool( val name: String, - val description: String, + val description: String?, @SerialName("input_schema") val inputSchema: JsonSchema, @SerialName("cache_control") @@ -354,6 +354,21 @@ fun ToolResult( content = listOf(Text(text)) ) +inline fun ToolResult( + toolUseId: String, + value: T +): ToolResult = ToolResult( + toolUseId, + content = listOf( + Text( + anthropicJson.encodeToString( + serializer = serializer(), + value = value + ) + ) + ) +) + @Serializable data class CacheControl( val type: Type @@ -409,3 +424,9 @@ data class Usage( @SerialName("output_tokens") val outputTokens: Int ) + +operator fun MutableCollection.plusAssign( + response: MessageResponse +) { + this += response.asMessage() +} diff --git a/src/commonMain/kotlin/schema/JsonSchema.kt b/src/commonMain/kotlin/schema/JsonSchema.kt index 2c15f6b..d6ef32b 100644 --- a/src/commonMain/kotlin/schema/JsonSchema.kt +++ b/src/commonMain/kotlin/schema/JsonSchema.kt @@ -1,8 +1,17 @@ package com.xemantic.anthropic.schema +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.MetaSerializable import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +@OptIn(ExperimentalSerializationApi::class) +@Target(AnnotationTarget.PROPERTY) +@MetaSerializable +annotation class Description( + val value: String +) + @Serializable data class JsonSchema( val type: String = "object", @@ -16,17 +25,9 @@ data class JsonSchema( @Serializable data class JsonSchemaProperty( val type: String? = null, + val description: String? = null, val items: JsonSchemaProperty? = null, val enum: List? = null, @SerialName("\$ref") val ref: String? = null -) { - - companion object { - val STRING = JsonSchemaProperty("string") - val INTEGER = JsonSchemaProperty("integer") - val NUMBER = JsonSchemaProperty("number") - val BOOLEAN = JsonSchemaProperty("boolean") - } - -} \ No newline at end of file +) diff --git a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt index ad418cd..e44ed12 100644 --- a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt +++ b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt @@ -17,7 +17,15 @@ fun generateSchema(descriptor: SerialDescriptor): JsonSchema { for (i in 0 until descriptor.elementsCount) { val name = descriptor.getElementName(i) val elementDescriptor = descriptor.getElementDescriptor(i) - val property = generateSchemaProperty(elementDescriptor, definitions) + val elementAnnotations = descriptor.getElementAnnotations(i) + val property = generateSchemaProperty( + elementDescriptor, + description = elementAnnotations + .filterIsInstance() + .firstOrNull() + ?.value, + definitions + ) properties[name] = property if (!descriptor.isElementOptional(i)) { required.add(name) @@ -35,38 +43,45 @@ fun generateSchema(descriptor: SerialDescriptor): JsonSchema { @OptIn(ExperimentalSerializationApi::class) private fun generateSchemaProperty( descriptor: SerialDescriptor, + description: String?, definitions: MutableMap ): JsonSchemaProperty { return when (descriptor.kind) { - PrimitiveKind.STRING -> JsonSchemaProperty.STRING - PrimitiveKind.INT, PrimitiveKind.LONG -> JsonSchemaProperty.INTEGER - PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> JsonSchemaProperty.NUMBER - PrimitiveKind.BOOLEAN -> JsonSchemaProperty.BOOLEAN - SerialKind.ENUM -> enumProperty(descriptor) + PrimitiveKind.STRING -> JsonSchemaProperty("string", description) + PrimitiveKind.INT, PrimitiveKind.LONG -> JsonSchemaProperty("integer", description) + PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> JsonSchemaProperty("number", description) + PrimitiveKind.BOOLEAN -> JsonSchemaProperty("boolean", description) + SerialKind.ENUM -> enumProperty(descriptor, description) StructureKind.LIST -> JsonSchemaProperty( type = "array", items = generateSchemaProperty( descriptor.getElementDescriptor(0), + description, definitions ) ) - StructureKind.MAP -> JsonSchemaProperty("object") + StructureKind.MAP -> JsonSchemaProperty("object", description) StructureKind.CLASS -> { // dots are not allowed in JSON Schema name, if the @SerialName was not // specified, then fully qualified class name will be used, and we need // to translate it val refName = descriptor.serialName.replace('.', '_').trimEnd('?') definitions[refName] = generateSchema(descriptor) - JsonSchemaProperty(ref = "#/definitions/$refName") + JsonSchemaProperty( + ref = "#/definitions/$refName", + description = description + ) } - else -> JsonSchemaProperty("object") // Default case + else -> JsonSchemaProperty("object", description) // Default case } } private fun enumProperty( - descriptor: SerialDescriptor -) = JsonSchemaProperty( - enum = descriptor.elementNames() + descriptor: SerialDescriptor, + description: String? +) = JsonSchemaProperty( // TODO should it return type enum? + enum = descriptor.elementNames(), + description = description, ) @OptIn(ExperimentalSerializationApi::class) diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt index 649f0af..5fafc58 100644 --- a/src/commonMain/kotlin/tool/Tools.kt +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -23,7 +23,7 @@ import kotlinx.serialization.serializer @Target(AnnotationTarget.CLASS) annotation class AnthropicTool( val name: String, - val description: String + val description: String = "" ) /** @@ -80,7 +80,8 @@ inline fun toolOf( return Tool( name = anthropicTool.name, - description = anthropicTool.description, + // annotation description cannot be null, so we allow empty and detect it here + description = if (anthropicTool.description.isNotBlank()) anthropicTool.description else null, inputSchema = jsonSchemaOf(), cacheControl = cacheControl ) diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index b7f8f24..5dfb438 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -9,6 +9,7 @@ import com.xemantic.anthropic.message.Role import com.xemantic.anthropic.message.StopReason import com.xemantic.anthropic.message.Text import com.xemantic.anthropic.message.ToolUse +import com.xemantic.anthropic.message.plusAssign import com.xemantic.anthropic.test.Calculator import com.xemantic.anthropic.test.DatabaseQueryTool import com.xemantic.anthropic.test.FibonacciTool @@ -117,7 +118,7 @@ class AnthropicTest { messages = conversation useTools() } - conversation += initialResponse.asMessage() + conversation += initialResponse // then assertSoftly(initialResponse) { @@ -189,7 +190,7 @@ class AnthropicTest { messages = conversation useTools() } - conversation += fibonacciResponse.asMessage() + conversation += fibonacciResponse val fibonacciToolUse = fibonacciResponse.content.filterIsInstance().first() fibonacciToolUse.name shouldBe "FibonacciTool" @@ -200,7 +201,7 @@ class AnthropicTest { messages = conversation useTools() } - conversation += calculatorResponse.asMessage() + conversation += calculatorResponse val calculatorToolUse = calculatorResponse.content.filterIsInstance().first() calculatorToolUse.name shouldBe "Calculator" diff --git a/src/commonTest/kotlin/message/ToolResultTest.kt b/src/commonTest/kotlin/message/ToolResultTest.kt new file mode 100644 index 0000000..535e742 --- /dev/null +++ b/src/commonTest/kotlin/message/ToolResultTest.kt @@ -0,0 +1,34 @@ +package com.xemantic.anthropic.message + +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable +import kotlin.test.Test + +class ToolResultTest { + + @Test + fun shouldCreateToolResultForSingleString() { + ToolResult( + toolUseId = "42", + "foo" + ) shouldBe ToolResult( + toolUseId = "42", + content = listOf(Text(text = "foo")) + ) + } + + @Serializable + data class Foo(val bar: String) + + @Test + fun shouldCreateToolResultForSerializableInstance() { + ToolResult( + toolUseId = "42", + Foo("buzz") + ) shouldBe ToolResult( + toolUseId = "42", + content = listOf(Text(text = "{\"bar\":\"buzz\"}")) + ) + } + +} diff --git a/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt index 411e56e..be0071e 100644 --- a/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt +++ b/src/commonTest/kotlin/schema/JsonSchemaGeneratorTest.kt @@ -20,6 +20,7 @@ data class Address( @Serializable data class Person( + @Description("The official name") val name: String, val age: Int, val email: String?, @@ -102,7 +103,8 @@ class JsonSchemaGeneratorTest { }, "properties": { "name": { - "type": "string" + "type": "string", + "description": "The official name" }, "age": { "type": "integer" diff --git a/src/commonTest/kotlin/tool/UsableToolTest.kt b/src/commonTest/kotlin/tool/UsableToolTest.kt index cd1b6d4..4c32b0c 100644 --- a/src/commonTest/kotlin/tool/UsableToolTest.kt +++ b/src/commonTest/kotlin/tool/UsableToolTest.kt @@ -2,6 +2,7 @@ package com.xemantic.anthropic.tool import com.xemantic.anthropic.message.CacheControl import com.xemantic.anthropic.message.ToolResult +import com.xemantic.anthropic.schema.Description import com.xemantic.anthropic.schema.JsonSchema import com.xemantic.anthropic.schema.JsonSchemaProperty import io.kotest.assertions.assertSoftly @@ -18,6 +19,7 @@ class UsableToolTest { description = "Test tool receiving a message and outputting it back" ) class TestTool( + @Description("the message") val message: String ) : UsableTool { override suspend fun use( @@ -34,15 +36,19 @@ class UsableToolTest { name shouldBe "TestTool" description shouldBe "Test tool receiving a message and outputting it back" inputSchema shouldBe JsonSchema( - properties = mapOf("message" to JsonSchemaProperty.STRING), + properties = mapOf("message" to JsonSchemaProperty( + type = "string", + description = "the message" + )), required = listOf("message") ) cacheControl shouldBe null } } + // TODO maybe we need a builder here? @Test - fun shouldCreateToolWithCacheControlFromUsableTool() { + fun shouldCreateToolWithCacheControlFromUsableToolSuppliedWithCacheControl() { // when val tool = toolOf( cacheControl = CacheControl(type = CacheControl.Type.EPHEMERAL) @@ -52,7 +58,10 @@ class UsableToolTest { name shouldBe "TestTool" description shouldBe "Test tool receiving a message and outputting it back" inputSchema shouldBe JsonSchema( - properties = mapOf("message" to JsonSchemaProperty.STRING), + properties = mapOf("message" to JsonSchemaProperty( + type = "string", + description = "the message" + )), required = listOf("message") ) cacheControl shouldBe CacheControl(type = CacheControl.Type.EPHEMERAL)