Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

major refactoring to add computer use and batches with proper polymorphic serialization #14

Merged
merged 9 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-branch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ jobs:
run: sudo apt-get install -y libcurl4-gnutls-dev

- name: Build
run: ./gradlew build
run: ./gradlew -PjvmOnlyBuild=false build
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
2 changes: 1 addition & 1 deletion .github/workflows/build-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: sudo apt-get install -y libcurl4-gnutls-dev

- name: Build
run: ./gradlew build sourcesJar dokkaHtml publish
run: ./gradlew -PjvmOnlyBuild=false build sourcesJar dokkaHtml publish
env:
ORG_GRADLE_PROJECT_githubActor: ${{ secrets.GITHUBACTOR }}
ORG_GRADLE_PROJECT_githubToken: ${{ secrets.GITHUBTOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
ORG_GRADLE_PROJECT_sonatypeUser: ${{ secrets.SONATYPE_USER }}
ORG_GRADLE_PROJECT_sonatypePassword: ${{ secrets.SONATYPE_PASSWORD }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: ./gradlew -Pversion=$VERSION build sourcesJar dokkaHtml publishToSonatype closeAndReleaseSonatypeStagingRepository
run: ./gradlew -Pversion=$VERSION -PjvmOnlyBuild=false build sourcesJar dokkaHtml publishToSonatype closeAndReleaseSonatypeStagingRepository

- name: Find branch from tag
id: find-branch
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ bin/
.DS_Store

/*.hprof

/kotlin-js-store
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ and many other environments.

## Usage

[!CAUTION]
> [!CAUTION]
> This SDK is in the early stage of development, so still a subject to API changes,
> however at the same time it is completely functional and passing all the
> [test cases](src/commonTest/kotlin).
Expand Down Expand Up @@ -78,7 +78,7 @@ dependencies {
}
```

, ff you are planning to use tools, you will also need:
, if you are planning to use tools, you will also need:

```kotlin
plugins {
Expand Down Expand Up @@ -132,7 +132,7 @@ If you want to write AI agents, you need tools, and this is where this library s
```kotlin
@AnthropicTool("get_weather")
@Description("Get the weather for a specific location")
data class WeatherTool(val location: String): UsableTool {
data class WeatherTool(val location: String): ToolInput {
override fun use(
toolUseId: String
) = ToolResult(
Expand All @@ -152,7 +152,7 @@ fun main() = runBlocking {

val initialResponse = client.messages.create {
messages = conversation
useTools()
allTools()
}
println("Initial response:")
println(initialResponse)
Expand Down Expand Up @@ -192,7 +192,7 @@ internet or DB connection pool to access the database.
```kotlin
@AnthropicTool("query_database")
@Description("Executes SQL on the database")
data class DatabaseQueryTool(val sql: String): UsableTool {
data class QueryDatabase(val sql: String): ToolInput {

@Transient
internal lateinit var connection: Connection
Expand All @@ -213,14 +213,14 @@ data class DatabaseQueryTool(val sql: String): UsableTool {
fun main() = runBlocking {

val client = Anthropic {
tool<DatabaseQueryTool> {
tool<QueryDatabase> {
connection = DriverManager.getConnection("jdbc:...")
}
}

val response = client.messages.create {
+Message { +"Select all the users who never logged in to the the system" }
useTools()
singleTool<QueryDatabase>()
}

val tool = response.content.filterIsInstance<ToolUse>().first()
Expand Down
52 changes: 39 additions & 13 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.gradle.api.tasks.testing.logging.TestLogEvent
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
import org.jetbrains.kotlin.gradle.dsl.KotlinVersion
import org.jetbrains.kotlin.gradle.targets.js.testing.KotlinJsTest
import org.jetbrains.kotlin.gradle.targets.native.tasks.KotlinNativeTest

plugins {
Expand All @@ -25,6 +26,8 @@ val javaTarget = libs.versions.javaTarget.get()
val kotlinTarget = KotlinVersion.fromVersion(libs.versions.kotlinTarget.get())

val isReleaseBuild = !project.version.toString().endsWith("-SNAPSHOT")
val jvmOnlyBuild: String? by project
val isJvmOnlyBuild: Boolean = (jvmOnlyBuild == null) || (jvmOnlyBuild!!.uppercase() == "true")
val githubActor: String? by project
val githubToken: String? by project
val signingKey: String? by project
Expand Down Expand Up @@ -64,6 +67,14 @@ kotlin {
}
}

if (!isJvmOnlyBuild) {

js {
browser()
nodejs()
binaries.library()
}

// linuxX64()
//
// mingwX64()
Expand All @@ -78,6 +89,8 @@ kotlin {
// else -> throw GradleException("Host OS is not supported in Kotlin/Native.")
// }

}

sourceSets {

commonMain {
Expand Down Expand Up @@ -109,21 +122,23 @@ kotlin {
}
}

linuxTest {
dependencies {
implementation(libs.ktor.client.curl)
if (!isJvmOnlyBuild) {
linuxTest {
dependencies {
implementation(libs.ktor.client.curl)
}
}
}

mingwTest {
dependencies {
implementation(libs.ktor.client.curl)
mingwTest {
dependencies {
implementation(libs.ktor.client.curl)
}
}
}

macosTest {
dependencies {
implementation(libs.ktor.client.darwin)
macosTest {
dependencies {
implementation(libs.ktor.client.darwin)
}
}
}

Expand Down Expand Up @@ -157,8 +172,19 @@ tasks.withType<Test> {
enabled = !skipTests
}

tasks.withType<KotlinNativeTest> {
enabled = !skipTests


if (!isJvmOnlyBuild) {

tasks.withType<KotlinNativeTest> {
enabled = !skipTests
}

tasks.withType<KotlinJsTest> {
// for now always skip JS tests, until we will find how to safely pass apiKey to them
enabled = false
}

}

powerAssert {
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ kotlin.code.style=official
kotlin.js.generate.executable.default=false
kotlin.native.ignoreDisabledTargets=true
group=com.xemantic.anthropic
version=0.5-SNAPSHOT
version=0.7-SNAPSHOT
6 changes: 2 additions & 4 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ javaTarget = "17"
kotlin = "2.0.21"
kotlinxCoroutines = "1.9.0"
kotlinxDatetime = "0.6.1"
ktor = "3.0.0"
ktor = "3.0.1"
kotest = "6.0.0.M1"

# logging is not used at the moment, might be enabled later
#kotlinLogging = "7.0.0"
log4j = "2.24.1"
jackson = "2.18.0"
jackson = "2.18.1"

versionsPlugin = "0.51.0"
dokkaPlugin = "1.9.20"
Expand All @@ -23,7 +22,6 @@ kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-t
kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinxDatetime" }

# logging libs
#kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "kotlinLogging" }
log4j-slf4j2 = { module = "org.apache.logging.log4j:log4j-slf4j2-impl", version.ref = "log4j" }
log4j-core = { module = "org.apache.logging.log4j:log4j-core", version.ref = "log4j" }
jackson-databind = { module = "com.fasterxml.jackson.core:jackson-databind", version.ref = "jackson" }
Expand Down
80 changes: 30 additions & 50 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package com.xemantic.anthropic

import com.xemantic.anthropic.error.AnthropicException
import com.xemantic.anthropic.error.ErrorResponse
import com.xemantic.anthropic.event.Event
import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.message.MessageRequest
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.ToolUse
import com.xemantic.anthropic.tool.UsableTool
import com.xemantic.anthropic.tool.toolOf
import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.tool.BuiltInTool
import com.xemantic.anthropic.tool.ToolUse
import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolInput
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.*
Expand All @@ -26,11 +30,6 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializer
import kotlin.reflect.KType
import kotlin.reflect.typeOf

/**
* The default Anthropic API base.
Expand All @@ -42,26 +41,10 @@ const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/"
*/
const val DEFAULT_ANTHROPIC_VERSION: String = "2023-06-01"

/**
* An exception thrown when API requests returns error.
*/
class AnthropicException(
error: Error,
httpStatusCode: HttpStatusCode
) : RuntimeException(error.toString())

expect val envApiKey: String?

expect val missingApiKeyMessage: String

/**
* A JSON format suitable for communication with Anthropic API.
*/
val anthropicJson: Json = Json {
allowSpecialFloatingPointValues = true
explicitNulls = false
encodeDefaults = true
}

/**
* The public constructor function which for the Anthropic API client.
Expand All @@ -82,10 +65,9 @@ fun Anthropic(
defaultModel = config.defaultModel.id,
defaultMaxTokens = config.defaultMaxTokens,
directBrowserAccess = config.directBrowserAccess,
logLevel = if (config.logHttp) LogLevel.ALL else LogLevel.NONE
).apply {
toolEntryMap = (config.usableTools as List<Anthropic.ToolEntry<UsableTool>>).associateBy { it.tool.name }
}
logLevel = if (config.logHttp) LogLevel.ALL else LogLevel.NONE,
toolMap = config.tools.associateBy { it.name }
)
} // TODO this can be a second constructor, then toolMap can be private

class Anthropic internal constructor(
Expand All @@ -96,7 +78,8 @@ class Anthropic internal constructor(
val defaultModel: String,
val defaultMaxTokens: Int,
val directBrowserAccess: Boolean,
val logLevel: LogLevel
val logLevel: LogLevel,
private val toolMap: Map<String, Tool>
) {

class Config {
Expand All @@ -110,27 +93,25 @@ class Anthropic internal constructor(
var directBrowserAccess: Boolean = false
var logHttp: Boolean = false

@PublishedApi
internal var usableTools: List<ToolEntry<out UsableTool>> = emptyList()
var tools: List<Tool> = emptyList()

inline fun <reified T : UsableTool> tool(
noinline block: T.() -> Unit = {}
inline fun <reified T : ToolInput> tool(
cacheControl: CacheControl? = null,
noinline inputInitializer: T.() -> Unit = {}
) {
val entry = ToolEntry(typeOf<T>(), toolOf<T>(), serializer<T>(), block)
usableTools += entry
tools += Tool<T>(cacheControl, initializer = inputInitializer)
}

}

@PublishedApi
internal class ToolEntry<T : UsableTool>(
val type: KType,
val tool: Tool, // TODO, no cache control
val serializer: KSerializer<T>,
val initialize: T.() -> Unit = {}
)
inline fun <reified T : BuiltInTool> builtInTool(
tool: T,
noinline inputInitializer: T.() -> Unit = {}
) {
@Suppress("UNCHECKED_CAST")
tool.inputInitializer = inputInitializer as ToolInput.() -> Unit
tools += tool
}

internal var toolEntryMap = mapOf<String, ToolEntry<UsableTool>>()
}

private val client = HttpClient {

Expand Down Expand Up @@ -179,7 +160,7 @@ class Anthropic internal constructor(
val request = MessageRequest.Builder(
defaultModel,
defaultMaxTokens,
toolEntryMap
toolMap
).apply(block).build()

val apiResponse = client.post("/v1/messages") {
Expand All @@ -191,8 +172,7 @@ class Anthropic internal constructor(
is MessageResponse -> response.apply {
content.filterIsInstance<ToolUse>()
.forEach { toolUse ->
val entry = toolEntryMap[toolUse.name]!!
toolUse.toolEntry = entry
toolUse.tool = toolMap[toolUse.name]!!
}
}
is ErrorResponse -> throw AnthropicException(
Expand All @@ -211,7 +191,7 @@ class Anthropic internal constructor(
val request = MessageRequest.Builder(
defaultModel,
defaultMaxTokens,
toolEntryMap
toolMap
).apply {
block(this)
stream = true
Expand Down
Loading