Skip to content

Commit

Permalink
Improving assistant tools (#645)
Browse files Browse the repository at this point in the history
* Improving assistant tools

* try catch on tool invocation
  • Loading branch information
javipacheco authored Jan 25, 2024
1 parent f417c9c commit 7f0afe0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,63 @@ package com.xebia.functional.xef.llm.assistants

import com.xebia.functional.openai.apis.AssistantApi
import com.xebia.functional.openai.apis.AssistantsApi
import com.xebia.functional.openai.infrastructure.ApiClient
import com.xebia.functional.openai.models.AssistantObject
import com.xebia.functional.openai.models.CreateAssistantRequest
import com.xebia.functional.openai.models.ModifyAssistantRequest
import com.xebia.functional.openai.models.ext.assistant.AssistantTools
import com.xebia.functional.xef.llm.fromEnvironment
import io.ktor.client.statement.*
import io.ktor.util.logging.*
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive

class Assistant(
val assistantId: String,
val toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
private val assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi),
private val api: AssistantApi = fromEnvironment(::AssistantApi)
) {

constructor(
assistantObject: AssistantObject,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi),
api: AssistantApi = fromEnvironment(::AssistantApi)
) : this(assistantObject.id, assistantsApi, api)
) : this(assistantObject.id, toolsConfig, assistantsApi, api)

suspend fun get(): AssistantObject = assistantsApi.getAssistant(assistantId).body()

suspend fun modify(modifyAssistantRequest: ModifyAssistantRequest): Assistant =
Assistant(api.modifyAssistant(assistantId, modifyAssistantRequest).body(), assistantsApi, api)
Assistant(
api.modifyAssistant(assistantId, modifyAssistantRequest).body(),
toolsConfig,
assistantsApi,
api
)

suspend inline fun getToolRegistered(name: String, args: String): JsonElement =
try {
val toolConfig = toolsConfig.firstOrNull { it.functionObject.name == name }

val toolSerializer = toolConfig?.serializers ?: error("Function $name not registered")
val input = ApiClient.JSON_DEFAULT.decodeFromString(toolSerializer.inputSerializer, args)

val tool: Tool<Any?, Any?> = toolConfig.tool as Tool<Any?, Any?>

val output: Any? = tool(input)
ApiClient.JSON_DEFAULT.encodeToJsonElement(
toolSerializer.outputSerializer as KSerializer<Any?>,
output
)
} catch (e: Exception) {
val message = "Error calling to tool registered $name: ${e.message}"
val logger = KtorSimpleLogger("Functions")
logger.error(message)
JsonObject(mapOf("error" to JsonPrimitive(message)))
}

companion object {

Expand All @@ -37,6 +70,7 @@ class Assistant(
tools: List<AssistantTools> = arrayListOf(),
fileIds: List<String> = arrayListOf(),
metadata: JsonObject? = null,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi),
api: AssistantApi = fromEnvironment(::AssistantApi)
): Assistant =
Expand All @@ -50,17 +84,19 @@ class Assistant(
fileIds = fileIds,
metadata = metadata
),
toolsConfig,
assistantsApi,
api
)

suspend operator fun invoke(
request: CreateAssistantRequest,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi),
api: AssistantApi = fromEnvironment(::AssistantApi)
): Assistant {
val response = assistantsApi.createAssistant(request)
return Assistant(response.body(), assistantsApi, api)
return Assistant(response.body(), toolsConfig, assistantsApi, api)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsMessageCre
import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsToolCallsObject
import com.xebia.functional.openai.models.ext.assistant.RunStepObjectStepDetails
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.prompt.templates.assistant
import kotlin.jvm.JvmName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
Expand Down Expand Up @@ -63,7 +64,7 @@ class AssistantThread(

suspend fun run(assistant: Assistant): Flow<RunDelta> {
val run = createRun(assistant)
return awaitRun(run.id)
return awaitRun(assistant, run.id)
}

suspend fun cancelRun(runId: String): RunObject = api.cancelRun(threadId, runId).body()
Expand All @@ -79,13 +80,13 @@ class AssistantThread(
data class Step(val runStep: RunStepObject) : RunDelta()
}

fun awaitRun(runId: String): Flow<RunDelta> = flow {
fun awaitRun(assistant: Assistant, runId: String): Flow<RunDelta> = flow {
val stepCache = mutableSetOf<RunStepObject>()
val messagesCache = mutableSetOf<MessageObject>()
val runCache = mutableSetOf<RunObject>()
var run = checkRun(runId = runId, cache = runCache)
while (run.status != RunObject.Status.completed) {
checkSteps(runId = runId, cache = stepCache)
checkSteps(assistant = assistant, runId = runId, cache = stepCache)
checkMessages(cache = messagesCache)
run = checkRun(runId = runId, cache = runCache)
}
Expand Down Expand Up @@ -124,6 +125,7 @@ class AssistantThread(
}

private suspend fun FlowCollector<RunDelta>.checkSteps(
assistant: Assistant,
runId: String,
cache: MutableSet<RunStepObject>
) {
Expand All @@ -135,7 +137,7 @@ class AssistantThread(
step.stepDetails.toolCalls().forEach { toolCall ->
val function = toolCall.function
if (function != null && function.arguments.isNotBlank()) {
val result: JsonElement = Tool(function.name, function.arguments)
val result: JsonElement = assistant.getToolRegistered(function.name, function.arguments)
api.submitToolOuputsToRun(
threadId = threadId,
runId = runId,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,44 +1,35 @@
package com.xebia.functional.xef.llm.assistants

import com.xebia.functional.openai.infrastructure.ApiClient
import com.xebia.functional.openai.models.FunctionObject
import com.xebia.functional.xef.llm.chatFunction
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.serializer

fun interface Tool<out Output> {
suspend operator fun invoke(): Output
fun interface Tool<Input, out Output> {
suspend operator fun invoke(input: Input): Output

companion object {

data class ToolConfig<Input, out Output>(
val functionObject: FunctionObject,
val serializers: ToolSerializer,
val tool: Tool<Input, Output>
)

data class ToolSerializer(
val inputSerializer: KSerializer<*>,
val outputSerializer: KSerializer<*>
)

@PublishedApi internal val toolRegistry = mutableMapOf<String, ToolSerializer>()

inline operator fun <reified T : Tool<O>, reified O> invoke(): FunctionObject {
val serializer = serializer<T>()
inline fun <reified I, reified O> toolOf(tool: Tool<I, O>): ToolConfig<I, O> {
val serializer = serializer<I>()
val outputSerializer = serializer<O>()
val toolSerializer = ToolSerializer(serializer, outputSerializer)
val fn = chatFunction(serializer.descriptor)
if (toolRegistry.containsKey(fn.name)) {
error("Function ${fn.name} already registered")
}
toolRegistry[fn.name] = toolSerializer
return fn
}

suspend inline operator fun invoke(name: String, args: String): JsonElement {
val toolSerializer = toolRegistry[name] ?: error("Function $name not registered")
val input =
ApiClient.JSON_DEFAULT.decodeFromString(toolSerializer.inputSerializer, args) as Tool<Any?>
val output: Any? = input.invoke()
return ApiClient.JSON_DEFAULT.encodeToJsonElement(
toolSerializer.outputSerializer as KSerializer<Any?>,
output
return ToolConfig(
fn.copy(name = tool::class.simpleName ?: error("unnamed class")),
toolSerializer,
tool
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@ import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsToolCallsO
import com.xebia.functional.xef.llm.assistants.Assistant
import com.xebia.functional.xef.llm.assistants.AssistantThread
import com.xebia.functional.xef.llm.assistants.Tool
import io.ktor.client.*
import kotlinx.serialization.Serializable

@Serializable
data class SumTool(
val left: Int,
val right: Int,
) : Tool<Int> {
override suspend fun invoke(): Int = left + right
@Serializable data class SumInput(val left: Int, val right: Int)

class SumTool : Tool<SumInput, Int> {
override suspend fun invoke(input: SumInput): Int {
return input.left + input.right
}
}

suspend fun main() {

Tool<SumTool, Int>() // register the tool, Int is the output type

// val assistant2 = Assistant(
// name = "Math Tutor",
// instructions = "You help the user with all kinds of math problems.",
Expand All @@ -38,7 +37,11 @@ suspend fun main() {
// model = "gpt-4-1106-preview"
// )
// println("generated assistant: ${assistant2.assistantId}")
val assistant = Assistant(assistantId = "asst_mYw6e4wddJvRcjdQQ2qcWFsn")
val assistant =
Assistant(
assistantId = "asst_UxczzpJkysC0l424ood87DAk",
toolsConfig = listOf(Tool.toolOf(SumTool())),
)
val thread = AssistantThread()
println("Welcome to the Math tutor, ask me anything about math:")
while (true) {
Expand Down
2 changes: 1 addition & 1 deletion examples/src/main/resources/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
</encoder>
</appender>

<root level="off">
<root level="error">
<appender-ref ref="STDOUT"/>
</root>

Expand Down

0 comments on commit 7f0afe0

Please sign in to comment.