Skip to content

Commit

Permalink
AI DSL : Generic serialization support for Sealed classes and other p…
Browse files Browse the repository at this point in the history
…rimitives (#602)

* AI DSL : Generic serialization support for Sealed classes and other primitives

* AI DSL : Support for streaming text and functions

* remove println

* Provide Prompt operators

* Uses of Prompt and configuration

* added seed in prompt configuration (#609)

---------

Co-authored-by: José Carlos Montañez <[email protected]>
  • Loading branch information
raulraja and Montagon authored Dec 26, 2023
1 parent 7711366 commit bdc12b6
Show file tree
Hide file tree
Showing 27 changed files with 488 additions and 229 deletions.
236 changes: 154 additions & 82 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
@@ -1,134 +1,206 @@
package com.xebia.functional.xef

import ai.xef.openai.CustomModel
import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import com.xebia.functional.openai.apis.ChatApi
import com.xebia.functional.openai.models.CreateChatCompletionRequest
import com.xebia.functional.openai.models.CreateChatCompletionRequestModel
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessage
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestUserMessageContentText
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.elementDescriptors
import kotlinx.serialization.serializer
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*

interface AI<A> {
val model: CreateChatCompletionRequestModel
interface AI<A : Any> {
val target: KType
val model: OpenAIModel<CreateChatCompletionRequestModel>
val api: ChatApi
val serializer: () -> KSerializer<A>
val conversation: Conversation
val enumSerializer: ((case: String) -> A)?
val caseSerializers: List<KSerializer<A>>

@Serializable data class Value<A>(val value: A)

private suspend fun <B> runWithSerializer(prompt: String, serializer: KSerializer<B>): B =
api.prompt(Prompt(StandardModel(model), prompt), conversation, serializer)
private suspend fun <B> runWithSerializer(
prompt: Prompt<CreateChatCompletionRequestModel>,
serializer: KSerializer<B>
): B = api.prompt(prompt, conversation, serializer)

@OptIn(ExperimentalSerializationApi::class)
suspend operator fun invoke(prompt: String): A {
private fun runStreamingWithStringSerializer(
prompt: Prompt<CreateChatCompletionRequestModel>
): Flow<String> = api.promptStreaming(prompt, conversation)

private fun <B> runStreamingWithFunctionSerializer(
prompt: Prompt<CreateChatCompletionRequestModel>,
serializer: KSerializer<B>
): Flow<StreamedFunction<B>> = api.promptStreaming(prompt, conversation, serializer)

private suspend fun <B> runWithDescriptors(
prompt: Prompt<CreateChatCompletionRequestModel>,
serializer: KSerializer<B>,
descriptors: List<SerialDescriptor>
): B = api.prompt(prompt, conversation, serializer, descriptors)

@OptIn(ExperimentalSerializationApi::class, InternalSerializationApi::class)
suspend operator fun invoke(prompt: Prompt<CreateChatCompletionRequestModel>): A {
val serializer = serializer()
return when (serializer.descriptor.kind) {
PrimitiveKind.BOOLEAN ->
runWithSerializer(prompt, Value.serializer(Boolean.serializer())).value as A
PrimitiveKind.BYTE ->
runWithSerializer(prompt, Value.serializer(Byte.serializer())).value as A
PrimitiveKind.CHAR ->
runWithSerializer(prompt, Value.serializer(Char.serializer())).value as A
PrimitiveKind.DOUBLE ->
runWithSerializer(prompt, Value.serializer(Double.serializer())).value as A
PrimitiveKind.FLOAT ->
runWithSerializer(prompt, Value.serializer(Float.serializer())).value as A
PrimitiveKind.INT -> runWithSerializer(prompt, Value.serializer(Int.serializer())).value as A
PrimitiveKind.LONG ->
runWithSerializer(prompt, Value.serializer(Long.serializer())).value as A
PrimitiveKind.SHORT ->
runWithSerializer(prompt, Value.serializer(Short.serializer())).value as A
PrimitiveKind.STRING ->
runWithSerializer(prompt, Value.serializer(String.serializer())).value as A
SerialKind.ENUM -> {
val encoding = StandardModel(model).modelType(forFunctions = false).encoding
val cases =
serializer.descriptor.elementDescriptors.map { it.serialName.substringAfterLast(".") }
val logitBias =
cases
.flatMap {
val result = encoding.encode(it)
if (result.size > 1) {
error("Cannot encode enum case $it into one token")
}
result
}
.associate { "$it" to 100 }
val result =
api.createChatCompletion(
CreateChatCompletionRequest(
messages =
listOf(
ChatCompletionRequestUserMessage(
content = listOf(ChatCompletionRequestUserMessageContentText(prompt)),
)
),
model = StandardModel(model),
logitBias = logitBias,
maxTokens = 1,
temperature = 0.0
)
)
val choice = result.body().choices[0].message.content
val enumSerializer = enumSerializer
if (choice != null && enumSerializer != null) {
enumSerializer(choice)
} else {
error("Cannot decode enum case from $choice")
runWithEnumSingleTokenSerializer(serializer, prompt)
}
// else -> runWithSerializer(prompt, serializer)
PolymorphicKind.OPEN ->
when {
target == typeOf<Flow<String>>() -> {
runStreamingWithStringSerializer(prompt) as A
}
(target.classifier == Flow::class &&
target.arguments.firstOrNull()?.type?.classifier == StreamedFunction::class) -> {
val functionClass =
target.arguments.first().type?.arguments?.firstOrNull()?.type?.classifier
as? KClass<*>
val functionSerializer =
functionClass?.serializer() ?: error("Cannot find serializer for $functionClass")
runStreamingWithFunctionSerializer(prompt, functionSerializer) as A
}
else -> {
runWithSerializer(prompt, Value.serializer(serializer)) as A
}
}
PolymorphicKind.SEALED -> {
val s = serializer as SealedClassSerializer<A>
val cases = s.descriptor.elementDescriptors.toList()[1].elementDescriptors.toList()
runWithDescriptors(prompt, s, cases)
}
else -> runWithSerializer(prompt, serializer)
SerialKind.CONTEXTUAL -> runWithSerializer(prompt, serializer)
StructureKind.CLASS -> runWithSerializer(prompt, serializer)
else -> runWithSerializer(prompt, Value.serializer(serializer)).value
}
}

@OptIn(ExperimentalSerializationApi::class)
suspend fun runWithEnumSingleTokenSerializer(
serializer: KSerializer<A>,
prompt: Prompt<CreateChatCompletionRequestModel>
): A {
val encoding = StandardModel(model).modelType(forFunctions = false).encoding
val cases =
serializer.descriptor.elementDescriptors.map { it.serialName.substringAfterLast(".") }
val logitBias =
cases
.flatMap {
val result = encoding.encode(it)
if (result.size > 1) {
error("Cannot encode enum case $it into one token")
}
result
}
.associate { "$it" to 100 }
val result =
api.createChatCompletion(
CreateChatCompletionRequest(
messages = prompt.messages,
model = model,
logitBias = logitBias,
maxTokens = 1,
temperature = 0.0
)
)
val choice = result.body().choices[0].message.content
val enumSerializer = enumSerializer
return if (choice != null && enumSerializer != null) {
enumSerializer(choice)
} else {
error("Cannot decode enum case from $choice")
}
}

companion object {
@PublishedApi
internal operator fun <A : Any> invoke(
model: CreateChatCompletionRequestModel,
operator fun <A : Any> invoke(
target: KType,
model: OpenAIModel<CreateChatCompletionRequestModel>,
api: ChatApi,
conversation: Conversation,
enumSerializer: ((case: String) -> A)?,
caseSerializers: List<KSerializer<A>>,
serializer: () -> KSerializer<A>,
): AI<A> =
object : AI<A> {
override val model: CreateChatCompletionRequestModel = model
override val target: KType = target
override val model: OpenAIModel<CreateChatCompletionRequestModel> = model
override val api: ChatApi = api
override val serializer: () -> KSerializer<A> = serializer
override val conversation: Conversation = conversation
override val enumSerializer: ((case: String) -> A)? = enumSerializer
override val caseSerializers: List<KSerializer<A>> = caseSerializers
}

@AiDsl
inline fun <reified A : Enum<A>> enum(
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
api: ChatApi = fromEnvironment(::ChatApi)
): AI<A> = invoke(model, api, Conversation(), { name -> enumValueOf<A>(name) }) { serializer() }

inline operator fun <reified A : Any> invoke(
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
@PublishedApi
internal suspend inline fun <reified A : Any> invokeEnum(
prompt: Prompt<CreateChatCompletionRequestModel>,
target: KType = typeOf<A>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): AI<A> = invoke(model, api, conversation, null) { serializer() }
): A =
invoke(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = { @Suppress("UPPER_BOUND_VIOLATED") enumValueOf<A>(it) },
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
target: KType = typeOf<A>(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): A = invoke(model, api, conversation, null) { serializer<A>() }.invoke(prompt)
): A = invoke(Prompt(CustomModel(model.value), prompt), target, api, conversation)

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: Prompt<CreateChatCompletionRequestModel>,
target: KType = typeOf<A>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): A {
val kind =
(target.classifier as? KClass<*>)?.serializer()?.descriptor?.kind
?: error("Cannot find SerialKind for $target")
return when (kind) {
SerialKind.ENUM -> invokeEnum<A>(prompt, target, api, conversation)
else -> {
invoke(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = null,
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ constructor(
suspend fun <A> ChatApi.prompt(
prompt: Prompt<CreateChatCompletionRequestModel>,
function: FunctionObject,
serializer: (String) -> A
): A = prompt(prompt, this@Conversation, function, serializer)
serializer: (FunctionCall) -> A
): A = prompt(prompt, this@Conversation, listOf(function), serializer)

@AiDsl
@JvmSynthetic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ fun ChatApi.promptStreaming(
n = prompt.configuration.numberOfPredictions,
temperature = prompt.configuration.temperature,
maxTokens = prompt.configuration.maxTokens,
model = prompt.model
model = prompt.model,
seed = prompt.configuration.seed,
)

val buffer = StringBuilder()
Expand Down Expand Up @@ -73,7 +74,8 @@ suspend fun ChatApi.promptMessages(
n = adaptedPrompt.configuration.numberOfPredictions,
temperature = adaptedPrompt.configuration.temperature,
maxTokens = adaptedPrompt.configuration.maxTokens,
model = prompt.model
model = prompt.model,
seed = adaptedPrompt.configuration.seed,
)

createChatCompletion(request)
Expand Down
Loading

0 comments on commit bdc12b6

Please sign in to comment.