Skip to content

Commit

Permalink
Add a default instance for OpenAI config
Browse files Browse the repository at this point in the history
  • Loading branch information
franciscodr committed Aug 27, 2024
1 parent 0c9e69c commit fe6c46d
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.xebia.functional.xef.conversation.Conversation
data class AIConfig(
val tools: List<Tool<*>> = emptyList(),
val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
val config: Config = Config(),
val config: Config = Config.Default,
val openAI: OpenAI = OpenAI(config, logRequests = false),
val api: Chat = openAI.chat,
val conversation: Conversation = Conversation()
Expand Down
54 changes: 33 additions & 21 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import kotlin.math.pow
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json

sealed interface HttpClientRetryPolicy {
Expand Down Expand Up @@ -85,25 +86,36 @@ data class HttpClientTimeoutPolicy(
}

data class Config(
val baseUrl: String = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/",
val httpClientRetryPolicy: HttpClientRetryPolicy =
HttpClientRetryPolicy.Incremental(250.milliseconds, 5.seconds, 5),
val httpClientTimeoutPolicy: HttpClientTimeoutPolicy =
HttpClientTimeoutPolicy(45.seconds, 45.seconds, 45.seconds),
val token: String? = null,
val org: String? = getenv(ORG_ENV_VAR),
val json: Json = Json {
ignoreUnknownKeys = true
prettyPrint = false
isLenient = true
explicitNulls = false
classDiscriminator = TYPE_DISCRIMINATOR
},
val streamingPrefix: String = "data:",
val streamingDelimiter: String = "data: [DONE]"
val baseUrl: String,
val httpClientRetryPolicy: HttpClientRetryPolicy,
val httpClientTimeoutPolicy: HttpClientTimeoutPolicy,
val apiToken: String?,
val organization: String?,
val json: Json,
val streamingPrefix: String,
val streamingDelimiter: String
) {
companion object {
val DEFAULT = Config()
@OptIn(ExperimentalSerializationApi::class)
val Default =
Config(
baseUrl = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/",
httpClientRetryPolicy = HttpClientRetryPolicy.Incremental(250.milliseconds, 5.seconds, 5),
httpClientTimeoutPolicy = HttpClientTimeoutPolicy(45.seconds, 45.seconds, 45.seconds),
json =
Json {
ignoreUnknownKeys = true
prettyPrint = false
isLenient = true
explicitNulls = false
classDiscriminator = TYPE_DISCRIMINATOR
},
organization = getenv(ORG_ENV_VAR),
streamingDelimiter = "data: [DONE]",
streamingPrefix = "data:",
apiToken = null
)

const val TYPE_DISCRIMINATOR = "_type_"
}
}
Expand All @@ -117,13 +129,13 @@ private const val KEY_ENV_VAR = "OPENAI_TOKEN"
* Just simple fun on top of generated API.
*/
fun OpenAI(
config: Config = Config(),
config: Config = Config.Default,
httpClientEngine: HttpClientEngine? = null,
httpClientConfig: ((HttpClientConfig<*>) -> Unit)? = null,
logRequests: Boolean = false
): OpenAI {
val token =
config.token
config.apiToken
?: getenv(KEY_ENV_VAR)
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var"))
val clientConfig: HttpClientConfig<*>.() -> Unit = {
Expand All @@ -134,7 +146,7 @@ fun OpenAI(
httpClientConfig?.invoke(this)
defaultRequest {
url(config.baseUrl)
config.org?.let { headers.append("OpenAI-Organization", it) }
config.organization?.let { headers.append("OpenAI-Organization", it) }
bearerAuth(token)
}
}
Expand All @@ -144,7 +156,7 @@ fun OpenAI(
OpenAIConfig(
baseUrl = config.baseUrl,
token = token,
org = config.org,
org = config.organization,
json = config.json,
streamingPrefix = config.streamingPrefix,
streamingDelimiter = config.streamingDelimiter
Expand Down
12 changes: 6 additions & 6 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ sealed class Tool<out A>(
val typeSerializer = targetClass.serializer()
val functionObject = chatFunction(typeSerializer.descriptor)
return Callable(functionObject) {
Config.DEFAULT.json.decodeFromString(typeSerializer, it.arguments)
Config.Default.json.decodeFromString(typeSerializer, it.arguments)
}
}

Expand All @@ -137,7 +137,7 @@ sealed class Tool<out A>(
val functionSerializer = Value.serializer(targetClass.serializer())
val functionObject = chatFunction(functionSerializer.descriptor)
return Primitive(functionObject) {
Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value
Config.Default.json.decodeFromString(functionSerializer, it.arguments).value
}
}

Expand All @@ -161,7 +161,7 @@ sealed class Tool<out A>(
}
val functionObject = chatFunction(functionSerializer.descriptor)
return Callable(functionObject) {
Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value as A
Config.Default.json.decodeFromString(functionSerializer, it.arguments).value as A
}
}

Expand Down Expand Up @@ -205,7 +205,7 @@ sealed class Tool<out A>(
descriptor: SerialDescriptor
): Enumeration<A> {
val enumSerializer = { value: String ->
Config.DEFAULT.json.decodeFromString(targetClass.serializer(), value) as A
Config.Default.json.decodeFromString(targetClass.serializer(), value) as A
}
val functionObject = chatFunction(descriptor)
val cases =
Expand Down Expand Up @@ -251,7 +251,7 @@ sealed class Tool<out A>(
sealedClassSerializer: SealedClassSerializer<out Any>
): A {
val newJson = descriptorChoice(it, functionObjectMap)
return Config.DEFAULT.json.decodeFromString(
return Config.Default.json.decodeFromString(
sealedClassSerializer,
Json.encodeToString(newJson)
) as A
Expand All @@ -263,7 +263,7 @@ sealed class Tool<out A>(
): JsonObject {
// adds a `type` field with the call.functionName serial name equivalent to the call arguments
val jsonWithDiscriminator =
Config.DEFAULT.json.decodeFromString(JsonElement.serializer(), call.arguments)
Config.Default.json.decodeFromString(JsonElement.serializer(), call.arguments)
val descriptor =
descriptors.values.firstOrNull { it.name.endsWith(call.functionName) }
?: error("No descriptor found for ${call.functionName}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fun chatFunction(descriptor: SerialDescriptor): FunctionObject {
@OptIn(ExperimentalSerializationApi::class)
fun functionSchema(descriptor: SerialDescriptor): JsonObject =
descriptor.annotations.filterIsInstance<Schema>().firstOrNull()?.value?.let {
Config.DEFAULT.json.decodeFromString(JsonObject.serializer(), it)
Config.Default.json.decodeFromString(JsonObject.serializer(), it)
} ?: buildJsonSchema(descriptor)

@OptIn(ExperimentalSerializationApi::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ import net.mamoe.yamlkt.toYamlElement
class Assistant(
val assistantId: String,
val toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
val config: Config = Config(),
val config: Config = Config.Default,
private val assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
) {

constructor(
assistantObject: AssistantObject,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
) : this(assistantObject.id, toolsConfig, config, assistantsApi)

Expand Down Expand Up @@ -85,7 +85,7 @@ class Assistant(
toolResources: CreateAssistantRequestToolResources? = null,
metadata: JsonObject? = null,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
): Assistant =
Assistant(
Expand All @@ -106,7 +106,7 @@ class Assistant(
suspend operator fun invoke(
request: CreateAssistantRequest,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
): Assistant {
val response = assistantsApi.createAssistant(request, configure = ::defaultConfig)
Expand All @@ -116,7 +116,7 @@ class Assistant(
suspend fun fromConfig(
request: String,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
): Assistant {
val parsed = Yaml.Default.decodeYamlMapFromString(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import kotlinx.serialization.json.JsonPrimitive
class AssistantThread(
val threadId: String,
val metric: Metric = Metric.EMPTY,
private val config: Config = Config(),
private val config: Config = Config.Default,
private val api: Assistants = OpenAI(config).assistants
) {

Expand Down Expand Up @@ -271,7 +271,7 @@ class AssistantThread(
messages: List<MessageWithFiles>,
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand Down Expand Up @@ -303,7 +303,7 @@ class AssistantThread(
messages: List<String>,
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand Down Expand Up @@ -333,7 +333,7 @@ class AssistantThread(
messages: List<CreateMessageRequest> = emptyList(),
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand All @@ -351,7 +351,7 @@ class AssistantThread(
suspend operator fun invoke(
request: CreateThreadRequest,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand All @@ -364,7 +364,7 @@ class AssistantThread(
suspend operator fun invoke(
request: CreateThreadAndRunRequest,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ sealed interface RunDelta {
RunDeltaEvent.values().find {
type.replace(".", "").replace("_", "").equals(it.name, ignoreCase = true)
}
val json = Config.DEFAULT.json
val json = Config.Default.json
return when (event) {
RunDeltaEvent.ThreadCreated ->
ThreadCreated(json.decodeFromJsonElement(ThreadObject.serializer(), data))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ suspend fun main() {
assistantId = "asst_BwQvmWIbGUMDvCuXOtAFH8B6",
toolsConfig = listOf(Tool.toolOf(SumTool()))
)
val config = Config(org = null)
val config = Config.Default.copy(organization = null)
val api = OpenAI(config = config, logRequests = false).assistants
val thread = AssistantThread(api = api, metric = metric)
println("Welcome to the Math tutor, ask me anything about math:")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import java.io.File
import javax.media.bean.playerbean.MediaPlayer

suspend fun main() {
val config = Config()
val config = Config.Default
val audio = OpenAI(config).audio
println("ask me something!")
while (true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,5 @@ import com.xebia.functional.xef.store.VectorStore

class LocalVectorStoreService : VectorStoreService() {
override fun getVectorStore(token: String?, org: String?): VectorStore =
LocalVectorStore(
OpenAI(
Config(
token = token,
org = org,
)
)
.embeddings
)
LocalVectorStore(OpenAI(Config.Default.copy(apiToken = token, organization = org)).embeddings)
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,7 @@ class PostgresVectorStoreService(
}

override fun getVectorStore(token: String?, org: String?): VectorStore {
val embeddingsApi =
OpenAI(
Config(
token = token,
org = org,
)
)
.embeddings
val embeddingsApi = OpenAI(Config.Default.copy(apiToken = token, organization = org)).embeddings

return PGVectorStore(
vectorSize = vectorSize,
Expand Down

0 comments on commit fe6c46d

Please sign in to comment.