diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt index b1a747660..72d51f3d4 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt @@ -2,30 +2,63 @@ package com.xebia.functional.xef.llm.assistants import com.xebia.functional.openai.apis.AssistantApi import com.xebia.functional.openai.apis.AssistantsApi +import com.xebia.functional.openai.infrastructure.ApiClient import com.xebia.functional.openai.models.AssistantObject import com.xebia.functional.openai.models.CreateAssistantRequest import com.xebia.functional.openai.models.ModifyAssistantRequest import com.xebia.functional.openai.models.ext.assistant.AssistantTools import com.xebia.functional.xef.llm.fromEnvironment import io.ktor.client.statement.* +import io.ktor.util.logging.* +import kotlinx.serialization.KSerializer +import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive class Assistant( val assistantId: String, + val toolsConfig: List> = emptyList(), private val assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi), private val api: AssistantApi = fromEnvironment(::AssistantApi) ) { constructor( assistantObject: AssistantObject, + toolsConfig: List> = emptyList(), assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi), api: AssistantApi = fromEnvironment(::AssistantApi) - ) : this(assistantObject.id, assistantsApi, api) + ) : this(assistantObject.id, toolsConfig, assistantsApi, api) suspend fun get(): AssistantObject = assistantsApi.getAssistant(assistantId).body() suspend fun modify(modifyAssistantRequest: ModifyAssistantRequest): Assistant = - Assistant(api.modifyAssistant(assistantId, modifyAssistantRequest).body(), assistantsApi, api) + Assistant( + api.modifyAssistant(assistantId, modifyAssistantRequest).body(), + toolsConfig, + assistantsApi, + api + ) + + suspend inline fun getToolRegistered(name: String, args: String): JsonElement = + try { + val toolConfig = toolsConfig.firstOrNull { it.functionObject.name == name } + + val toolSerializer = toolConfig?.serializers ?: error("Function $name not registered") + val input = ApiClient.JSON_DEFAULT.decodeFromString(toolSerializer.inputSerializer, args) + + val tool: Tool = toolConfig.tool as Tool + + val output: Any? = tool(input) + ApiClient.JSON_DEFAULT.encodeToJsonElement( + toolSerializer.outputSerializer as KSerializer, + output + ) + } catch (e: Exception) { + val message = "Error calling to tool registered $name: ${e.message}" + val logger = KtorSimpleLogger("Functions") + logger.error(message) + JsonObject(mapOf("error" to JsonPrimitive(message))) + } companion object { @@ -37,6 +70,7 @@ class Assistant( tools: List = arrayListOf(), fileIds: List = arrayListOf(), metadata: JsonObject? = null, + toolsConfig: List> = emptyList(), assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi), api: AssistantApi = fromEnvironment(::AssistantApi) ): Assistant = @@ -50,17 +84,19 @@ class Assistant( fileIds = fileIds, metadata = metadata ), + toolsConfig, assistantsApi, api ) suspend operator fun invoke( request: CreateAssistantRequest, + toolsConfig: List> = emptyList(), assistantsApi: AssistantsApi = fromEnvironment(::AssistantsApi), api: AssistantApi = fromEnvironment(::AssistantApi) ): Assistant { val response = assistantsApi.createAssistant(request) - return Assistant(response.body(), assistantsApi, api) + return Assistant(response.body(), toolsConfig, assistantsApi, api) } } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt index b60fec7de..2b5b9ec07 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt @@ -7,6 +7,7 @@ import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsMessageCre import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsToolCallsObject import com.xebia.functional.openai.models.ext.assistant.RunStepObjectStepDetails import com.xebia.functional.xef.llm.fromEnvironment +import com.xebia.functional.xef.prompt.templates.assistant import kotlin.jvm.JvmName import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector @@ -63,7 +64,7 @@ class AssistantThread( suspend fun run(assistant: Assistant): Flow { val run = createRun(assistant) - return awaitRun(run.id) + return awaitRun(assistant, run.id) } suspend fun cancelRun(runId: String): RunObject = api.cancelRun(threadId, runId).body() @@ -79,13 +80,13 @@ class AssistantThread( data class Step(val runStep: RunStepObject) : RunDelta() } - fun awaitRun(runId: String): Flow = flow { + fun awaitRun(assistant: Assistant, runId: String): Flow = flow { val stepCache = mutableSetOf() val messagesCache = mutableSetOf() val runCache = mutableSetOf() var run = checkRun(runId = runId, cache = runCache) while (run.status != RunObject.Status.completed) { - checkSteps(runId = runId, cache = stepCache) + checkSteps(assistant = assistant, runId = runId, cache = stepCache) checkMessages(cache = messagesCache) run = checkRun(runId = runId, cache = runCache) } @@ -124,6 +125,7 @@ class AssistantThread( } private suspend fun FlowCollector.checkSteps( + assistant: Assistant, runId: String, cache: MutableSet ) { @@ -135,7 +137,7 @@ class AssistantThread( step.stepDetails.toolCalls().forEach { toolCall -> val function = toolCall.function if (function != null && function.arguments.isNotBlank()) { - val result: JsonElement = Tool(function.name, function.arguments) + val result: JsonElement = assistant.getToolRegistered(function.name, function.arguments) api.submitToolOuputsToRun( threadId = threadId, runId = runId, diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Tool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Tool.kt index 62db4a0ef..38de06e5c 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Tool.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Tool.kt @@ -1,44 +1,35 @@ package com.xebia.functional.xef.llm.assistants -import com.xebia.functional.openai.infrastructure.ApiClient import com.xebia.functional.openai.models.FunctionObject import com.xebia.functional.xef.llm.chatFunction import kotlinx.serialization.KSerializer -import kotlinx.serialization.json.JsonElement import kotlinx.serialization.serializer -fun interface Tool { - suspend operator fun invoke(): Output +fun interface Tool { + suspend operator fun invoke(input: Input): Output companion object { + data class ToolConfig( + val functionObject: FunctionObject, + val serializers: ToolSerializer, + val tool: Tool + ) + data class ToolSerializer( val inputSerializer: KSerializer<*>, val outputSerializer: KSerializer<*> ) - @PublishedApi internal val toolRegistry = mutableMapOf() - - inline operator fun , reified O> invoke(): FunctionObject { - val serializer = serializer() + inline fun toolOf(tool: Tool): ToolConfig { + val serializer = serializer() val outputSerializer = serializer() val toolSerializer = ToolSerializer(serializer, outputSerializer) val fn = chatFunction(serializer.descriptor) - if (toolRegistry.containsKey(fn.name)) { - error("Function ${fn.name} already registered") - } - toolRegistry[fn.name] = toolSerializer - return fn - } - - suspend inline operator fun invoke(name: String, args: String): JsonElement { - val toolSerializer = toolRegistry[name] ?: error("Function $name not registered") - val input = - ApiClient.JSON_DEFAULT.decodeFromString(toolSerializer.inputSerializer, args) as Tool - val output: Any? = input.invoke() - return ApiClient.JSON_DEFAULT.encodeToJsonElement( - toolSerializer.outputSerializer as KSerializer, - output + return ToolConfig( + fn.copy(name = tool::class.simpleName ?: error("unnamed class")), + toolSerializer, + tool ) } } diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/DSL.kt b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/DSL.kt index eb8c3b101..262c49b81 100644 --- a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/DSL.kt +++ b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/DSL.kt @@ -6,20 +6,19 @@ import com.xebia.functional.openai.models.ext.assistant.RunStepDetailsToolCallsO import com.xebia.functional.xef.llm.assistants.Assistant import com.xebia.functional.xef.llm.assistants.AssistantThread import com.xebia.functional.xef.llm.assistants.Tool +import io.ktor.client.* import kotlinx.serialization.Serializable -@Serializable -data class SumTool( - val left: Int, - val right: Int, -) : Tool { - override suspend fun invoke(): Int = left + right +@Serializable data class SumInput(val left: Int, val right: Int) + +class SumTool : Tool { + override suspend fun invoke(input: SumInput): Int { + return input.left + input.right + } } suspend fun main() { - Tool() // register the tool, Int is the output type - // val assistant2 = Assistant( // name = "Math Tutor", // instructions = "You help the user with all kinds of math problems.", @@ -38,7 +37,11 @@ suspend fun main() { // model = "gpt-4-1106-preview" // ) // println("generated assistant: ${assistant2.assistantId}") - val assistant = Assistant(assistantId = "asst_mYw6e4wddJvRcjdQQ2qcWFsn") + val assistant = + Assistant( + assistantId = "asst_UxczzpJkysC0l424ood87DAk", + toolsConfig = listOf(Tool.toolOf(SumTool())), + ) val thread = AssistantThread() println("Welcome to the Math tutor, ask me anything about math:") while (true) { diff --git a/examples/src/main/resources/logback.xml b/examples/src/main/resources/logback.xml index 0ef56869d..7b6f77f01 100644 --- a/examples/src/main/resources/logback.xml +++ b/examples/src/main/resources/logback.xml @@ -12,7 +12,7 @@ - +