Skip to content

Commit

Permalink
Recent changes in cost calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil committed Dec 13, 2024
1 parent 0ce0976 commit 4b527ce
Show file tree
Hide file tree
Showing 10 changed files with 444 additions and 182 deletions.
11 changes: 6 additions & 5 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,15 @@ kotlin {
implementation(libs.ktor.client.logging)
implementation(libs.ktor.serialization.kotlinx.json)
implementation(libs.xemantic.ai.tool.schema)
api(libs.xemantic.ai.money)
}
}

commonTest {
dependencies {
implementation(libs.kotlin.test)
implementation(libs.kotlinx.coroutines.test)
implementation(libs.kotest.assertions.core)
implementation(libs.xemantic.kotlin.test)
implementation(libs.kotest.assertions.json)
}
}
Expand Down Expand Up @@ -242,10 +243,10 @@ tasks.withType<Test> {
}

powerAssert {
// functions = listOf(
// "io.kotest.matchers.shouldBe"
// )
// includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest")
functions = listOf(
"com.xemantic.kotlin.test.assert",
"com.xemantic.kotlin.test.have"
)
}

// maybe this one is not necessary?
Expand Down
32 changes: 19 additions & 13 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolInput
import com.xemantic.anthropic.usage.Cost
import com.xemantic.anthropic.usage.Usage
import com.xemantic.anthropic.usage.UsageCollector
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.*
Expand All @@ -28,8 +29,6 @@ import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.flow
Expand Down Expand Up @@ -93,7 +92,7 @@ class Anthropic internal constructor(
var anthropicVersion: String = DEFAULT_ANTHROPIC_VERSION
var anthropicBeta: String? = null
var apiBase: String = ANTHROPIC_API_BASE
var defaultModel: Model = Model.DEFAULT
var defaultModel: AnthropicModel = Model.DEFAULT
var defaultMaxTokens: Int = defaultModel.maxOutput

var directBrowserAccess: Boolean = false
Expand Down Expand Up @@ -184,6 +183,7 @@ class Anthropic internal constructor(
val response = apiResponse.body<Response>()
when (response) {
is MessageResponse -> response.apply {
updateUsage(response)
content.filterIsInstance<ToolUse>()
.forEach { toolUse ->
val tool = toolMap[toolUse.name]
Expand All @@ -195,13 +195,14 @@ class Anthropic internal constructor(
println("Error!!! Unexpected tool use: ${toolUse.name}")
}
}
updateTotals()
}
is ErrorResponse -> throw AnthropicException(
error = response.error,
httpStatusCode = apiResponse.status
)
else -> throw RuntimeException("Unsupported response: $response") // should never happen
else -> throw RuntimeException(
"Unsupported response: $response"
) // should never happen
}
return response
}
Expand Down Expand Up @@ -234,7 +235,8 @@ class Anthropic internal constructor(
.collect { event ->
// TODO we need better way of handling subsequent deltas with usage
if (event is Event.MessageStart) {
event.message.updateTotals()
// TODO more rules are needed here
updateUsage(event.message)
}
emit(event)
}
Expand All @@ -245,21 +247,25 @@ class Anthropic internal constructor(

val messages = Messages()

private val _totalUsage = atomic(Usage.ZERO)
val totalUsage: Usage get() = _totalUsage.value
private val usageCollector = UsageCollector()

private val _totalCost = atomic(Cost.ZERO)
val totalCost: Cost get() = _totalCost.value
val usage: Usage get() = usageCollector.usage

val cost: Cost get() = usageCollector.cost

override fun toString(): String = "Anthropic($usage, $cost)"

private val MessageResponse.anthropicModel: AnthropicModel get() = requireNotNull(
modelMap[model]
) {
"The model returned in the response is not known to Anthropic API client: $id"
}

private fun MessageResponse.updateTotals() {
_totalUsage.update { it + usage }
_totalCost.update { it + (usage.cost(anthropicModel) / Model.PRICE_UNIT) }
private fun updateUsage(response: MessageResponse) {
usageCollector.update(
modelCost = response.anthropicModel.cost,
usage = response.usage
)
}

}
10 changes: 10 additions & 0 deletions src/commonMain/kotlin/AnthropicJson.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.buildSerialDescriptor
import kotlinx.serialization.encodeToString
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -65,6 +66,15 @@ val anthropicJson: Json = Json {
encodeDefaults = true
}

@OptIn(ExperimentalSerializationApi::class)
@PublishedApi
internal val prettyAnthropicJson: Json = Json(from = anthropicJson) {
prettyPrint = true
prettyPrintIndent = " "
}

inline fun <reified T> T.toPrettyJson(): String = prettyAnthropicJson.encodeToString<T>(this)

private object ResponseSerializer : JsonContentPolymorphicSerializer<Response>(
baseClass = Response::class
) {
Expand Down
59 changes: 30 additions & 29 deletions src/commonMain/kotlin/usage/Usage.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package com.xemantic.anthropic.usage

import com.xemantic.anthropic.AnthropicModel
import com.xemantic.anthropic.Model
import com.xemantic.ai.money.Money
import com.xemantic.ai.money.ONE
import com.xemantic.ai.money.Ratio
import com.xemantic.ai.money.times
import com.xemantic.ai.money.ZERO
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

Expand Down Expand Up @@ -31,28 +34,29 @@ data class Usage(
operator fun plus(usage: Usage): Usage = Usage(
inputTokens = inputTokens + usage.inputTokens,
outputTokens = outputTokens + usage.outputTokens,
cacheReadInputTokens = (cacheReadInputTokens ?: 0) + (usage.cacheReadInputTokens ?: 0),
cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) + (usage.cacheCreationInputTokens ?: 0),
cacheReadInputTokens = (cacheReadInputTokens ?: 0) + (usage.cacheReadInputTokens ?: 0)
)

fun cost(
model: AnthropicModel,
isBatch: Boolean = false
modelCost: Cost,
costRatio: Money.Ratio = Money.Ratio.ONE
): Cost = Cost(
inputTokens = inputTokens * model.cost.inputTokens / Model.PRICE_UNIT,
outputTokens = outputTokens * model.cost.outputTokens / Model.PRICE_UNIT,
cacheReadInputTokens = (cacheReadInputTokens ?: 0) / Model.PRICE_UNIT,
cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) / Model.PRICE_UNIT
).let { if (isBatch) it * .5 else it }
inputTokens = inputTokens * modelCost.inputTokens * costRatio,
outputTokens = outputTokens * modelCost.outputTokens * costRatio,
// how cacheCreation and batch are playing together?
cacheCreationInputTokens = (cacheCreationInputTokens ?: 0) * modelCost.cacheCreationInputTokens * costRatio,
cacheReadInputTokens = (cacheReadInputTokens ?: 0) * modelCost.cacheReadInputTokens * costRatio
)

}

@Serializable
data class Cost(
val inputTokens: Double,
val outputTokens: Double,
val cacheCreationInputTokens: Double = inputTokens * .25,
val cacheReadInputTokens: Double = inputTokens * .25
val inputTokens: Money,
val outputTokens: Money,
val cacheCreationInputTokens: Money = inputTokens * Money.Ratio("1.25"),
val cacheReadInputTokens: Money = inputTokens * Money.Ratio("0.1"),
) {

operator fun plus(cost: Cost): Cost = Cost(
Expand All @@ -62,26 +66,23 @@ data class Cost(
cacheReadInputTokens = cacheReadInputTokens + cost.cacheReadInputTokens
)

operator fun times(value: Double): Cost = Cost(
inputTokens = inputTokens * value,
outputTokens = outputTokens * value,
cacheCreationInputTokens = cacheCreationInputTokens * value,
cacheReadInputTokens = cacheReadInputTokens * value
)

operator fun div(value: Double): Cost = Cost(
inputTokens = inputTokens / value,
outputTokens = outputTokens / value,
cacheCreationInputTokens = cacheCreationInputTokens / value,
cacheReadInputTokens = cacheReadInputTokens / value
operator fun times(amount: Money): Cost = Cost(
inputTokens = inputTokens * amount,
outputTokens = outputTokens * amount,
cacheCreationInputTokens = cacheCreationInputTokens * amount,
cacheReadInputTokens = cacheReadInputTokens * amount
)

val total: Double get() = inputTokens + outputTokens + cacheCreationInputTokens + cacheReadInputTokens
val total: Money get() =
inputTokens +
outputTokens +
cacheCreationInputTokens +
cacheReadInputTokens

companion object {
val ZERO = Cost(
inputTokens = 0.0,
outputTokens = 0.0
inputTokens = Money.ZERO,
outputTokens = Money.ZERO
)
}

Expand Down
54 changes: 54 additions & 0 deletions src/commonMain/kotlin/usage/UsageCollector.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.xemantic.anthropic.usage

import com.xemantic.ai.money.Money
import com.xemantic.ai.money.ONE
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update

/**
* Collects overall [Usage] and calculates [Cost] information
* based on [com.xemantic.anthropic.message.MessageResponse]s returned
* by API calls.
*/
class UsageCollector {

// Atomic in case of several threads updating this data concurrently
private val _usage = atomic(Usage.ZERO)

/**
* The current accumulated usage.
*/
val usage: Usage get() = _usage.value

// Atomic in case of several threads updating this data concurrently
private val _cost = atomic(Cost.ZERO)

/**
* The current accumulated cost.
*/
val cost: Cost get() = _cost.value

/**
* Updates the usage and cost based on the provided parameters.
*
* @param usage The usage to add.
* @param modelCost The cost of the used model.
* @param costRatio The cost ratio to apply, defaults to 1, but might be different for batch requests, etc.
*/
fun update(
usage: Usage,
modelCost: Cost,
costRatio: Money.Ratio = Money.Ratio.ONE,
) {
_usage.update { it + usage }
_cost.update { it + usage.cost(modelCost, costRatio) }
}

/**
* Returns a string representation of the UsageCollector.
*
* @return A string containing the current usage and cost.
*/
override fun toString(): String = "UsageCollector(usage=$usage, cost=$cost)"

}
Loading

0 comments on commit 4b527ce

Please sign in to comment.