Skip to content

Commit

Permalink
MessagePolicy in Conversations (#504)
Browse files Browse the repository at this point in the history
* First approach for improving nested conversations

* Removing first system message from conversation block

* Spotless apply

* suspend functions removed

* Adding tests
  • Loading branch information
javipacheco authored Oct 25, 2023
1 parent 8cad10f commit 2081223
Show file tree
Hide file tree
Showing 16 changed files with 342 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@ import kotlinx.serialization.Serializable
*/
@Serializable
data class MessagePolicy(
val historyPercent: Int = 50,
val historyPaddingTokens: Int = 100,
val contextPercent: Int = 50,
var historyPercent: Int = 50,
var historyPaddingTokens: Int = 100,
var contextPercent: Int = 50,
var addMessagesFromConversation: MessagesFromHistory = MessagesFromHistory.ALL,
var addMessagesToConversation: MessagesToHistory = MessagesToHistory.ALL,
)

enum class MessagesFromHistory {
ALL,
NONE,
}

enum class MessagesToHistory {
ALL,
ONLY_SYSTEM_MESSAGES,
NOT_SYSTEM_MESSAGES,
NONE,
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ expect abstract class PlatformConversation(
fun create(
store: VectorStore,
metric: Metric,
conversationId: ConversationId?,
conversationId: ConversationId?
): PlatformConversation
}
}
10 changes: 8 additions & 2 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ interface Chat : LLM {
.also { finalText ->
val aiResponseMessage = assistant(finalText)
val newMessages = prompt.messages + listOf(aiResponseMessage)
newMessages.addToMemory(scope)
newMessages.addToMemory(scope, prompt.configuration.messagePolicy.addMessagesToConversation)
}
}

Expand All @@ -51,6 +51,8 @@ interface Chat : LLM {
val adaptedPrompt =
PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

adaptedPrompt.addMetrics(scope)

val request =
ChatCompletionRequest(
user = adaptedPrompt.configuration.user,
Expand All @@ -63,7 +65,11 @@ interface Chat : LLM {
createChatCompletion(request)
.addMetrics(scope)
.choices
.addChoiceToMemory(scope, promptMemories)
.addChoiceToMemory(
scope,
promptMemories,
prompt.configuration.messagePolicy.addMessagesToConversation
)
.mapNotNull { it.message?.content }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ interface ChatWithFunctions : LLM {
this@ChatWithFunctions
)

adaptedPrompt.addMetrics(scope)

val request =
FunChatCompletionRequest(
user = adaptedPrompt.configuration.user,
Expand All @@ -90,7 +92,11 @@ interface ChatWithFunctions : LLM {
createChatCompletionWithFunctions(request)
.addMetrics(scope)
.choices
.addChoiceWithFunctionsToMemory(scope, requestedMemories)
.addChoiceWithFunctionsToMemory(
scope,
requestedMemories,
prompt.configuration.messagePolicy.addMessagesToConversation
)
.mapNotNull { it.message?.functionCall?.arguments }
}
}
Expand Down Expand Up @@ -128,7 +134,7 @@ interface ChatWithFunctions : LLM {
) {
streamFunctionCall(
chat = this@ChatWithFunctions,
promptMessages = prompt.messages,
prompt = prompt,
request = request,
scope = scope,
serializer = serializer,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.MessagesToHistory
import com.xebia.functional.xef.llm.models.chat.Choice
import com.xebia.functional.xef.llm.models.chat.ChoiceWithFunctions
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.Memory
import com.xebia.functional.xef.store.VectorStore

internal suspend fun List<Message>.addToMemory(scope: Conversation) {
internal suspend fun List<Message>.addToMemory(scope: Conversation, history: MessagesToHistory) {
val cid = scope.conversationId
if (cid != null) {
if (history != MessagesToHistory.NONE && cid != null) {
val memories = toMemory(scope)
if (memories.isNotEmpty()) {
scope.store.addMemories(memories)
}
scope.store.addMemoriesByHistory(history, memories)
}
}

Expand All @@ -29,27 +30,44 @@ internal fun List<Message>.toMemory(scope: Conversation): List<Memory> {

internal suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
scope: Conversation,
previousMemories: List<Memory>
previousMemories: List<Memory>,
history: MessagesToHistory
): List<ChoiceWithFunctions> = also {
val cid = scope.conversationId
if (isNotEmpty() && cid != null) {
if (history != MessagesToHistory.NONE && isNotEmpty() && cid != null) {
val aiMemory =
this.mapNotNull { it.message }
.map { it.toMessage().toMemory(cid, scope.store.incrementIndexAndGet()) }
val newMessages = previousMemories + aiMemory
scope.store.addMemories(newMessages)
scope.store.addMemoriesByHistory(history, newMessages)
}
}

internal suspend fun List<Choice>.addChoiceToMemory(
scope: Conversation,
previousMemories: List<Memory>
previousMemories: List<Memory>,
history: MessagesToHistory
): List<Choice> = also {
val cid = scope.conversationId
if (isNotEmpty() && cid != null) {
if (history != MessagesToHistory.NONE && isNotEmpty() && cid != null) {
val aiMemory =
this.mapNotNull { it.message }.map { it.toMemory(cid, scope.store.incrementIndexAndGet()) }
val newMessages = previousMemories + aiMemory
scope.store.addMemories(newMessages)
scope.store.addMemoriesByHistory(history, newMessages)
}
}

suspend fun VectorStore.addMemoriesByHistory(history: MessagesToHistory, memories: List<Memory>) {
when (history) {
MessagesToHistory.ALL -> {
addMemories(memories)
}
MessagesToHistory.ONLY_SYSTEM_MESSAGES -> {
addMemories(memories.filter { it.content.role == Role.SYSTEM })
}
MessagesToHistory.NOT_SYSTEM_MESSAGES -> {
addMemories(memories.filter { it.content.role != Role.SYSTEM })
}
MessagesToHistory.NONE -> {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.xebia.functional.xef.llm
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponse
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
import com.xebia.functional.xef.prompt.Prompt

fun ChatCompletionResponseWithFunctions.addMetrics(
conversation: Conversation
Expand All @@ -23,3 +24,12 @@ fun ChatCompletionResponse.addMetrics(conversation: Conversation): ChatCompletio
)
return this
}

fun Prompt.addMetrics(conversation: Conversation) {
conversation.metric.log(
conversation,
"Number of messages: ${messages.size} (${messages.map { it.role.toString().firstOrNull() ?: "" }.joinToString("-")})"
)
conversation.metric.log(conversation, "Functions: ${function?.let { "yes" } ?: "no"}")
conversation.metric.log(conversation, "Temperature: ${configuration.temperature}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.xebia.functional.xef.llm
import com.xebia.functional.tokenizer.truncateText
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.MessagesFromHistory
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
Expand All @@ -14,6 +15,16 @@ internal object PromptCalculator {
prompt: Prompt,
scope: Conversation,
llm: LLM
): Prompt =
when (prompt.configuration.messagePolicy.addMessagesFromConversation) {
MessagesFromHistory.ALL -> adaptPromptFromConversation(prompt, scope, llm)
MessagesFromHistory.NONE -> prompt
}

private suspend fun adaptPromptFromConversation(
prompt: Prompt,
scope: Conversation,
llm: LLM
): Prompt {

// calculate tokens for history and context
Expand Down Expand Up @@ -110,12 +121,6 @@ internal object PromptCalculator {
return remainingTokensForContexts
}

private suspend fun Conversation.memories(llm: LLM, limitTokens: Int): List<Memory> {
val cid = conversationId
return if (cid != null) {
store.memories(llm, cid, limitTokens)
} else {
emptyList()
}
}
private suspend fun Conversation.memories(llm: LLM, limitTokens: Int): List<Memory> =
conversationId?.let { store.memories(llm, it, limitTokens) } ?: emptyList()
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.FunChatCompletionRequest
import com.xebia.functional.xef.llm.models.functions.FunctionCall
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import kotlin.jvm.JvmSynthetic
import kotlinx.coroutines.flow.FlowCollector
Expand Down Expand Up @@ -41,7 +42,7 @@ sealed class StreamedFunction<out A> {
@JvmSynthetic
internal suspend fun <A> FlowCollector<StreamedFunction<A>>.streamFunctionCall(
chat: ChatWithFunctions,
promptMessages: List<Message>,
prompt: Prompt,
request: FunChatCompletionRequest,
scope: Conversation,
serializer: (json: String) -> A,
Expand All @@ -64,8 +65,11 @@ sealed class StreamedFunction<out A> {
chat
.createChatCompletionsWithFunctions(request)
.onCompletion {
val newMessages = promptMessages + messages
newMessages.addToMemory(scope)
val newMessages = prompt.messages + messages
newMessages.addToMemory(
scope,
prompt.configuration.messagePolicy.addMessagesToConversation
)
}
.collect { responseChunk ->
// Each chunk is emitted from the LLM and it will include a delta.parameters with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ class LogsMetric : Metric {
val milis = getTimeMillis()
val name = prompt.messages.lastOrNull()?.content ?: "empty"
println("Prompt-Span: $name")
println("${writeIdent()}|-- Conversation Id: ${conversation.conversationId?.value ?: "empty"}")
val output = block()
println("${writeIdent()}|-- Finished in ${getTimeMillis()-milis} ms")
return output
}

override fun log(conversation: Conversation, message: String) {
println("${writeIdent()}|-- $message".padStart(identSize, ' '))
println("${writeIdent()}|-- $message")
}

private fun writeIdent(times: Int = 1) = (1..identSize * times).fold("") { a, b -> "$a " }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@ import kotlinx.serialization.Serializable
data class PromptConfiguration
@JvmOverloads
constructor(
val maxDeserializationAttempts: Int = 3,
val user: String = Role.USER.name,
val temperature: Double = 0.4,
val numberOfPredictions: Int = 1,
val docsInContext: Int = 5,
val minResponseTokens: Int = 500,
val messagePolicy: MessagePolicy = MessagePolicy(),
var maxDeserializationAttempts: Int = 3,
var user: String = Role.USER.name,
var temperature: Double = 0.4,
var numberOfPredictions: Int = 1,
var docsInContext: Int = 5,
var minResponseTokens: Int = 500,
var messagePolicy: MessagePolicy = MessagePolicy(),
) {

fun messagePolicy(block: MessagePolicy.() -> Unit) = messagePolicy.apply { block() }

companion object {
@JvmField val DEFAULTS = PromptConfiguration()

operator fun invoke(block: PromptConfiguration.() -> Unit) =
PromptConfiguration().apply { block() }
}
}
Loading

0 comments on commit 2081223

Please sign in to comment.