Skip to content

Commit

Permalink
Update parameters type in CFunction to JsonObject (#503)
Browse files Browse the repository at this point in the history
* parameters in CFunction is a Json object

* spotless

* removed encodeJsonSchema
  • Loading branch information
Montagon authored Oct 24, 2023
1 parent cdbc923 commit 8cad10f
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.xebia.functional.xef.llm.models.chat.ChatCompletionChunk
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.FunChatCompletionRequest
import com.xebia.functional.xef.llm.models.functions.encodeJsonSchema
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.flow.*
Expand All @@ -32,10 +32,10 @@ interface ChatWithFunctions : LLM {
@OptIn(ExperimentalSerializationApi::class)
fun chatFunction(descriptor: SerialDescriptor): CFunction {
val fnName = descriptor.serialName.substringAfterLast(".")
return chatFunction(fnName, encodeJsonSchema(descriptor))
return chatFunction(fnName, buildJsonSchema(descriptor))
}

fun chatFunction(fnName: String, schema: String): CFunction =
fun chatFunction(fnName: String, schema: JsonObject): CFunction =
CFunction(fnName, "Generated function for $fnName", schema)

@AiDsl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ sealed class StreamedFunction<out A> {
// the path to this potential nested property
var path: List<String> = emptyList()
// we extract the expected JSON schema before the LLM replies
val schema = Json.parseToJsonElement(function.parameters)
val schema = function.parameters
// we create an example from the schema from which we can expect and infer the paths
// as the LLM is sending us chunks with malformed JSON
val example = createExampleFromSchema(schema)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.xebia.functional.xef.llm.models.functions

import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

@Serializable
data class CFunction(val name: String, val description: String, val parameters: String)
data class CFunction(val name: String, val description: String, val parameters: JsonObject)
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,8 @@ annotation class JsonSchema {
annotation class NoDefinition
}

fun encodeJsonSchema(descriptor: SerialDescriptor): String =
Json.encodeToString(JsonObject.serializer(), buildJsonSchema(descriptor))

/** Creates a Json Schema using the provided [descriptor] */
private fun buildJsonSchema(descriptor: SerialDescriptor): JsonObject {
fun buildJsonSchema(descriptor: SerialDescriptor): JsonObject {
val autoDefinitions = false
val prepend = mapOf("\$schema" to JsonPrimitive("http://json-schema.org/draft-07/schema"))
val definitions = JsonSchemaDefinitions(autoDefinitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import kotlinx.coroutines.async
import kotlinx.coroutines.cancel
import kotlinx.coroutines.future.asCompletableFuture
import kotlinx.coroutines.reactive.asPublisher
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonObject
import org.reactivestreams.Publisher

actual abstract class PlatformConversation
Expand Down Expand Up @@ -58,12 +60,14 @@ actual constructor(
}
.asCompletableFuture()

fun chatFunction(target: Class<*>): CFunction =
CFunction(
fun chatFunction(target: Class<*>): CFunction {
val targetString = JacksonSerialization.schemaGenerator.generateSchema(target).toString()
return CFunction(
name = target.simpleName,
description = "Generated function for ${target.simpleName}",
parameters = JacksonSerialization.schemaGenerator.generateSchema(target).toString()
parameters = Json.parseToJsonElement(targetString).jsonObject
)
}

fun promptMessage(chat: Chat, prompt: Prompt): CompletableFuture<String> =
coroutineScope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import com.xebia.functional.xef.llm.models.functions.FunChatCompletionRequest
import com.xebia.functional.xef.llm.models.functions.FunctionCall
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map
import kotlinx.serialization.json.Json

class OpenAIFunChat(
private val provider: OpenAI, // TODO: use context receiver
Expand Down Expand Up @@ -102,7 +101,7 @@ private fun CFunction.toOpenAI() =
ChatCompletionFunction(
name = name,
description = description,
parameters = Parameters(Json.parseToJsonElement(parameters)),
parameters = Parameters(parameters)
)

private fun Message.toOpenAI() =
Expand Down

0 comments on commit 8cad10f

Please sign in to comment.