From fe6c46de01b38e0741c843f1d8131761c448383a Mon Sep 17 00:00:00 2001 From: franciscodr Date: Tue, 27 Aug 2024 10:59:59 +0200 Subject: [PATCH] Add a default instance for OpenAI config --- .../com/xebia/functional/xef/AIConfig.kt | 2 +- .../kotlin/com/xebia/functional/xef/Config.kt | 54 +++++++++++-------- .../kotlin/com/xebia/functional/xef/Tool.kt | 12 ++--- .../functional/xef/llm/ChatWithFunctions.kt | 2 +- .../xef/llm/assistants/Assistant.kt | 10 ++-- .../xef/llm/assistants/AssistantThread.kt | 12 ++--- .../functional/xef/llm/assistants/RunDelta.kt | 2 +- .../xef/assistants/AssistantStreaming.kt | 2 +- .../functional/xef/dsl/audio/SimpleSpeech.kt | 2 +- .../services/LocalVectorStoreService.kt | 10 +--- .../services/PostgresVectorStoreService.kt | 9 +--- 11 files changed, 57 insertions(+), 60 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt index a9d0f945f..4f5aa943e 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt @@ -8,7 +8,7 @@ import com.xebia.functional.xef.conversation.Conversation data class AIConfig( val tools: List> = emptyList(), val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o, - val config: Config = Config(), + val config: Config = Config.Default, val openAI: OpenAI = OpenAI(config, logRequests = false), val api: Chat = openAI.chat, val conversation: Conversation = Conversation() diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt index b4c2ea22e..d76ed4446 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt @@ -16,6 +16,7 @@ import kotlin.math.pow import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.Json sealed interface HttpClientRetryPolicy { @@ -85,25 +86,36 @@ data class HttpClientTimeoutPolicy( } data class Config( - val baseUrl: String = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/", - val httpClientRetryPolicy: HttpClientRetryPolicy = - HttpClientRetryPolicy.Incremental(250.milliseconds, 5.seconds, 5), - val httpClientTimeoutPolicy: HttpClientTimeoutPolicy = - HttpClientTimeoutPolicy(45.seconds, 45.seconds, 45.seconds), - val token: String? = null, - val org: String? = getenv(ORG_ENV_VAR), - val json: Json = Json { - ignoreUnknownKeys = true - prettyPrint = false - isLenient = true - explicitNulls = false - classDiscriminator = TYPE_DISCRIMINATOR - }, - val streamingPrefix: String = "data:", - val streamingDelimiter: String = "data: [DONE]" + val baseUrl: String, + val httpClientRetryPolicy: HttpClientRetryPolicy, + val httpClientTimeoutPolicy: HttpClientTimeoutPolicy, + val apiToken: String?, + val organization: String?, + val json: Json, + val streamingPrefix: String, + val streamingDelimiter: String ) { companion object { - val DEFAULT = Config() + @OptIn(ExperimentalSerializationApi::class) + val Default = + Config( + baseUrl = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/", + httpClientRetryPolicy = HttpClientRetryPolicy.Incremental(250.milliseconds, 5.seconds, 5), + httpClientTimeoutPolicy = HttpClientTimeoutPolicy(45.seconds, 45.seconds, 45.seconds), + json = + Json { + ignoreUnknownKeys = true + prettyPrint = false + isLenient = true + explicitNulls = false + classDiscriminator = TYPE_DISCRIMINATOR + }, + organization = getenv(ORG_ENV_VAR), + streamingDelimiter = "data: [DONE]", + streamingPrefix = "data:", + apiToken = null + ) + const val TYPE_DISCRIMINATOR = "_type_" } } @@ -117,13 +129,13 @@ private const val KEY_ENV_VAR = "OPENAI_TOKEN" * Just simple fun on top of generated API. */ fun OpenAI( - config: Config = Config(), + config: Config = Config.Default, httpClientEngine: HttpClientEngine? = null, httpClientConfig: ((HttpClientConfig<*>) -> Unit)? = null, logRequests: Boolean = false ): OpenAI { val token = - config.token + config.apiToken ?: getenv(KEY_ENV_VAR) ?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var")) val clientConfig: HttpClientConfig<*>.() -> Unit = { @@ -134,7 +146,7 @@ fun OpenAI( httpClientConfig?.invoke(this) defaultRequest { url(config.baseUrl) - config.org?.let { headers.append("OpenAI-Organization", it) } + config.organization?.let { headers.append("OpenAI-Organization", it) } bearerAuth(token) } } @@ -144,7 +156,7 @@ fun OpenAI( OpenAIConfig( baseUrl = config.baseUrl, token = token, - org = config.org, + org = config.organization, json = config.json, streamingPrefix = config.streamingPrefix, streamingDelimiter = config.streamingDelimiter diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt index a0b20a62e..4317b346e 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt @@ -128,7 +128,7 @@ sealed class Tool( val typeSerializer = targetClass.serializer() val functionObject = chatFunction(typeSerializer.descriptor) return Callable(functionObject) { - Config.DEFAULT.json.decodeFromString(typeSerializer, it.arguments) + Config.Default.json.decodeFromString(typeSerializer, it.arguments) } } @@ -137,7 +137,7 @@ sealed class Tool( val functionSerializer = Value.serializer(targetClass.serializer()) val functionObject = chatFunction(functionSerializer.descriptor) return Primitive(functionObject) { - Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value + Config.Default.json.decodeFromString(functionSerializer, it.arguments).value } } @@ -161,7 +161,7 @@ sealed class Tool( } val functionObject = chatFunction(functionSerializer.descriptor) return Callable(functionObject) { - Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value as A + Config.Default.json.decodeFromString(functionSerializer, it.arguments).value as A } } @@ -205,7 +205,7 @@ sealed class Tool( descriptor: SerialDescriptor ): Enumeration { val enumSerializer = { value: String -> - Config.DEFAULT.json.decodeFromString(targetClass.serializer(), value) as A + Config.Default.json.decodeFromString(targetClass.serializer(), value) as A } val functionObject = chatFunction(descriptor) val cases = @@ -251,7 +251,7 @@ sealed class Tool( sealedClassSerializer: SealedClassSerializer ): A { val newJson = descriptorChoice(it, functionObjectMap) - return Config.DEFAULT.json.decodeFromString( + return Config.Default.json.decodeFromString( sealedClassSerializer, Json.encodeToString(newJson) ) as A @@ -263,7 +263,7 @@ sealed class Tool( ): JsonObject { // adds a `type` field with the call.functionName serial name equivalent to the call arguments val jsonWithDiscriminator = - Config.DEFAULT.json.decodeFromString(JsonElement.serializer(), call.arguments) + Config.Default.json.decodeFromString(JsonElement.serializer(), call.arguments) val descriptor = descriptors.values.firstOrNull { it.name.endsWith(call.functionName) } ?: error("No descriptor found for ${call.functionName}") diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index bbbce296a..7b212c689 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -36,7 +36,7 @@ fun chatFunction(descriptor: SerialDescriptor): FunctionObject { @OptIn(ExperimentalSerializationApi::class) fun functionSchema(descriptor: SerialDescriptor): JsonObject = descriptor.annotations.filterIsInstance().firstOrNull()?.value?.let { - Config.DEFAULT.json.decodeFromString(JsonObject.serializer(), it) + Config.Default.json.decodeFromString(JsonObject.serializer(), it) } ?: buildJsonSchema(descriptor) @OptIn(ExperimentalSerializationApi::class) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt index 55e6ae86e..2d9d6df00 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt @@ -20,14 +20,14 @@ import net.mamoe.yamlkt.toYamlElement class Assistant( val assistantId: String, val toolsConfig: List> = emptyList(), - val config: Config = Config(), + val config: Config = Config.Default, private val assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ) { constructor( assistantObject: AssistantObject, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ) : this(assistantObject.id, toolsConfig, config, assistantsApi) @@ -85,7 +85,7 @@ class Assistant( toolResources: CreateAssistantRequestToolResources? = null, metadata: JsonObject? = null, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ): Assistant = Assistant( @@ -106,7 +106,7 @@ class Assistant( suspend operator fun invoke( request: CreateAssistantRequest, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ): Assistant { val response = assistantsApi.createAssistant(request, configure = ::defaultConfig) @@ -116,7 +116,7 @@ class Assistant( suspend fun fromConfig( request: String, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ): Assistant { val parsed = Yaml.Default.decodeYamlMapFromString(request) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt index f0c27edc4..64c6cb67e 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt @@ -20,7 +20,7 @@ import kotlinx.serialization.json.JsonPrimitive class AssistantThread( val threadId: String, val metric: Metric = Metric.EMPTY, - private val config: Config = Config(), + private val config: Config = Config.Default, private val api: Assistants = OpenAI(config).assistants ) { @@ -271,7 +271,7 @@ class AssistantThread( messages: List, metadata: JsonObject? = null, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -303,7 +303,7 @@ class AssistantThread( messages: List, metadata: JsonObject? = null, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -333,7 +333,7 @@ class AssistantThread( messages: List = emptyList(), metadata: JsonObject? = null, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -351,7 +351,7 @@ class AssistantThread( suspend operator fun invoke( request: CreateThreadRequest, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -364,7 +364,7 @@ class AssistantThread( suspend operator fun invoke( request: CreateThreadAndRunRequest, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt index fec396e02..61537ab26 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt @@ -144,7 +144,7 @@ sealed interface RunDelta { RunDeltaEvent.values().find { type.replace(".", "").replace("_", "").equals(it.name, ignoreCase = true) } - val json = Config.DEFAULT.json + val json = Config.Default.json return when (event) { RunDeltaEvent.ThreadCreated -> ThreadCreated(json.decodeFromJsonElement(ThreadObject.serializer(), data)) diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt index 8fecf412b..37bf3264e 100644 --- a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt +++ b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt @@ -18,7 +18,7 @@ suspend fun main() { assistantId = "asst_BwQvmWIbGUMDvCuXOtAFH8B6", toolsConfig = listOf(Tool.toolOf(SumTool())) ) - val config = Config(org = null) + val config = Config.Default.copy(organization = null) val api = OpenAI(config = config, logRequests = false).assistants val thread = AssistantThread(api = api, metric = metric) println("Welcome to the Math tutor, ask me anything about math:") diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt index 5a942d26b..fbe81e8ba 100644 --- a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt +++ b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt @@ -14,7 +14,7 @@ import java.io.File import javax.media.bean.playerbean.MediaPlayer suspend fun main() { - val config = Config() + val config = Config.Default val audio = OpenAI(config).audio println("ask me something!") while (true) { diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt index 0778dc5c4..068fe47d0 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt @@ -7,13 +7,5 @@ import com.xebia.functional.xef.store.VectorStore class LocalVectorStoreService : VectorStoreService() { override fun getVectorStore(token: String?, org: String?): VectorStore = - LocalVectorStore( - OpenAI( - Config( - token = token, - org = org, - ) - ) - .embeddings - ) + LocalVectorStore(OpenAI(Config.Default.copy(apiToken = token, organization = org)).embeddings) } diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt index 6f764ea4e..c547a6614 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt @@ -38,14 +38,7 @@ class PostgresVectorStoreService( } override fun getVectorStore(token: String?, org: String?): VectorStore { - val embeddingsApi = - OpenAI( - Config( - token = token, - org = org, - ) - ) - .embeddings + val embeddingsApi = OpenAI(Config.Default.copy(apiToken = token, organization = org)).embeddings return PGVectorStore( vectorSize = vectorSize,