diff --git a/examples/kotlin/build.gradle.kts b/examples/kotlin/build.gradle.kts index 4b76004af..333c4cbf2 100644 --- a/examples/kotlin/build.gradle.kts +++ b/examples/kotlin/build.gradle.kts @@ -27,6 +27,7 @@ dependencies { implementation(projects.xefOpenai) implementation(projects.xefReasoning) implementation(projects.xefOpentelemetry) + implementation(projects.xefMlflow) implementation(libs.kotlinx.serialization.json) implementation(libs.logback) implementation(libs.klogging) diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/mlflow/Example.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/mlflow/Example.kt new file mode 100644 index 000000000..b84ceecb2 --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/mlflow/Example.kt @@ -0,0 +1,76 @@ +package com.xebia.functional.xef.conversation.mlflow + +import com.xebia.functional.xef.mlflow.* +import com.xebia.functional.xef.mlflow.MlflowClient +import io.ktor.client.* + +suspend fun main() { + + val gatewayUri = "http://localhost:5000" + + val httpClient = HttpClient() + + val client = MlflowClient(gatewayUri, httpClient) + + println("MLflow Gateway client created. Press any key to continue...") + readlnOrNull() + + println("Searching available models...") + println() + val routes = client.searchRoutes() + + println( + """ + |######### Routes found ######### + |${routes.joinToString(separator = "\n") { printRoute(it) }} + | + """ + .trimMargin() + ) + println() + + while (true) { + + println("Select the route you want to interact with") + val route = readlnOrNull() ?: "chat" + + val gptRoute = client.getRoute(route) + println("Route found: ${gptRoute?.name}. What do you want to ask?") + + val question = readlnOrNull() ?: "What's the best day of the week and why?" + + val response = + gptRoute?.name?.let { it -> + client.chat( + it, + listOf( + ChatMessage(ChatRole.SYSTEM, "You are a helpful assistant. Be concise"), + ChatMessage(ChatRole.USER, question), + ), + temperature = 0.7, + maxTokens = 200 + ) + } + + val chatResponse = response?.candidates?.get(0)?.message?.content + + println("Chat GPT response was: \n\n$chatResponse") + println() + println("Do you want to continue? (y/N)") + val userInput = readlnOrNull() ?: "" + if (!userInput.equals("y", true)) break + } + + httpClient.close() +} + +private fun printModel(model: RouteModel): String = + "(name = '${model.name}', provider = '${model.provider}')" + +private fun printRoute(r: RouteDefinition): String = + """ + |Name: ${r.name} + | * Route type: ${r.routeType} + | * Route url: ${r.routeUrl} + | * Model: ${printModel(r.model)}""" + .trimMargin() diff --git a/integrations/mlflow/build.gradle.kts b/integrations/mlflow/build.gradle.kts new file mode 100644 index 000000000..da8ed47ff --- /dev/null +++ b/integrations/mlflow/build.gradle.kts @@ -0,0 +1,123 @@ +plugins { + id(libs.plugins.kotlin.multiplatform.get().pluginId) + id(libs.plugins.kotlinx.serialization.get().pluginId) + alias(libs.plugins.spotless) + alias(libs.plugins.arrow.gradle.publish) + alias(libs.plugins.semver.gradle) + alias(libs.plugins.detekt) +} + + +dependencies { + detektPlugins(project(":detekt-rules")) +} + +detekt { + toolVersion = "1.23.1" + source = files("src/commonMain/kotlin", "src/jvmMain/kotlin") + config.setFrom("../../config/detekt/detekt.yml") + autoCorrect = true +} + + +repositories { + mavenCentral() +} + +java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + toolchain { + languageVersion = JavaLanguageVersion.of(11) + } +} + +kotlin { + jvm() + js(IR) { + browser() + nodejs() + } + + linuxX64() + macosX64() + macosArm64() + mingwX64() + + sourceSets { + val commonMain by getting { + dependencies { + api(projects.xefCore) + implementation(libs.bundles.ktor.client) + implementation(libs.uuid) + implementation(libs.kotlinx.datetime) + } + } + + val jvmMain by getting { + dependencies { + implementation(libs.logback) + api(libs.ktor.client.cio) + } + } + + val jsMain by getting { + dependencies { + api(libs.ktor.client.js) + } + } + + val linuxX64Main by getting { + dependencies { + api(libs.ktor.client.cio) + } + } + + val macosX64Main by getting { + dependencies { + api(libs.ktor.client.cio) + } + } + + val macosArm64Main by getting { + dependencies { + api(libs.ktor.client.cio) + } + } + + val mingwX64Main by getting { + dependencies { + api(libs.ktor.client.winhttp) + } + } + } +} + +spotless { + kotlin { + target("**/*.kt") + ktfmt().googleStyle().configure { + it.setRemoveUnusedImport(true) + } + } +} + +tasks{ + withType().configureEach { + dependsOn(":detekt-rules:assemble") + autoCorrect = true + } + named("detektJvmMain") { + dependsOn(":detekt-rules:assemble") + getByName("build").dependsOn(this) + } + named("detekt") { + dependsOn(":detekt-rules:assemble") + getByName("build").dependsOn(this) + } + withType { + dependsOn(withType()) + } + +} + diff --git a/integrations/mlflow/src/commonMain/kotlin/com/xebia/functional/xef/mlflow/MlflowClient.kt b/integrations/mlflow/src/commonMain/kotlin/com/xebia/functional/xef/mlflow/MlflowClient.kt new file mode 100644 index 000000000..4255b188b --- /dev/null +++ b/integrations/mlflow/src/commonMain/kotlin/com/xebia/functional/xef/mlflow/MlflowClient.kt @@ -0,0 +1,123 @@ +package com.xebia.functional.xef.mlflow + +import com.xebia.functional.xef.conversation.AutoClose +import com.xebia.functional.xef.conversation.autoClose +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.plugins.contentnegotiation.* +import io.ktor.client.request.* +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.http.isSuccess +import io.ktor.serialization.kotlinx.json.* +import kotlinx.serialization.json.Json + +class MlflowClient(private val gatewayUrl: String, client: HttpClient) : AutoClose by autoClose() { + + private val internal = + client.config { + install(ContentNegotiation) { + json( + Json { + encodeDefaults = false + isLenient = true + ignoreUnknownKeys = true + } + ) + } + } + + private val json = Json { ignoreUnknownKeys = true } + + private suspend fun routes(): List { + + val response = internal.get("$gatewayUrl/api/2.0/gateway/routes/") + if (response.status.isSuccess()) { + val textResponse = response.bodyAsText() + val data = json.decodeFromString(textResponse) + return data.routes + } else { + throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) + } + } + + suspend fun searchRoutes(): List = routes() + + suspend fun getRoute(name: String): RouteDefinition? = routes().find { it.name == name } + + suspend fun prompt( + route: String, + prompt: String, + candidateCount: Int? = null, + temperature: Double? = null, + maxTokens: Int? = null, + stop: List? = null + ): PromptResponse { + val body = Prompt(prompt, temperature, candidateCount, stop, maxTokens) + val response = + internal.post("$gatewayUrl/gateway/$route/invocations") { + accept(ContentType.Application.Json) + contentType(ContentType.Application.Json) + setBody(body) + } + + return if (response.status.isSuccess()) response.body() + else if (response.status.value == 422) + throw MLflowValidationError( + response.status, + response.body().detail?.firstOrNull()?.msg ?: "Unknown error" + ) + else throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) + } + + suspend fun chat( + route: String, + messages: List, + candidateCount: Int? = null, + temperature: Double? = null, + maxTokens: Int? = null, + stop: List? = null + ): ChatResponse { + val body = Chat(messages, temperature, candidateCount, stop, maxTokens) + val response = + internal.post("$gatewayUrl/gateway/$route/invocations") { + accept(ContentType.Application.Json) + contentType(ContentType.Application.Json) + setBody(body) + } + + return if (response.status.isSuccess()) response.body() + else if (response.status.value == 422) + throw MLflowValidationError( + response.status, + response.body().detail?.firstOrNull()?.msg ?: "Unknown error" + ) + else throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) + } + + suspend fun embeddings(route: String, text: List): EmbeddingsResponse { + val body = Embeddings(text) + val response = + internal.post("$gatewayUrl/gateway/$route/invocations") { + accept(ContentType.Application.Json) + contentType(ContentType.Application.Json) + setBody(body) + } + + return if (response.status.isSuccess()) response.body() + else if (response.status.value == 422) + throw MLflowValidationError( + response.status, + response.body().detail?.firstOrNull()?.msg ?: "Unknown error" + ) + else throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) + } + + class MLflowValidationError(httpStatusCode: HttpStatusCode, error: String) : + IllegalStateException("$httpStatusCode: $error") + + class MLflowClientUnexpectedError(httpStatusCode: HttpStatusCode, error: String) : + IllegalStateException("$httpStatusCode: $error") +} diff --git a/integrations/mlflow/src/commonMain/kotlin/com/xebia/functional/xef/mlflow/models.kt b/integrations/mlflow/src/commonMain/kotlin/com/xebia/functional/xef/mlflow/models.kt new file mode 100644 index 000000000..f4e55f9e6 --- /dev/null +++ b/integrations/mlflow/src/commonMain/kotlin/com/xebia/functional/xef/mlflow/models.kt @@ -0,0 +1,84 @@ +package com.xebia.functional.xef.mlflow + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable data class RoutesResponse(val routes: List) + +@Serializable +data class RouteDefinition( + val name: String, + @SerialName("route_type") val routeType: String, + val model: RouteModel, + @SerialName("route_url") val routeUrl: String, +) + +@Serializable +data class RouteModel( + val name: String, + val provider: String, +) + +@Serializable +data class Prompt( + val prompt: String, + val temperature: Double? = null, + @SerialName("candidate_count") val candidateCount: Int? = null, + val stop: List? = null, + @SerialName("max_tokens") val maxTokens: Int? = null +) + +@Serializable data class CandidateMetadata(@SerialName("finish_reason") val finishReason: String?) + +@Serializable data class PromptCandidate(val text: String, val metadata: CandidateMetadata?) + +@Serializable +enum class RouteType { + @SerialName("llm/v1/completions") COMPLETIONS, + @SerialName("llm/v1/chat") CHAT, + @SerialName("llm/v1/embeddings") EMBEDDINGS +} + +@Serializable +data class ResponseMetadata( + val model: String, + @SerialName("route_type") val routeType: RouteType, + @SerialName("input_tokens") val inputTokens: Int? = null, + @SerialName("output_tokens") val outputTokens: Int? = null, + @SerialName("total_tokens") val totalTokens: Int? = null +) + +@Serializable +data class PromptResponse(val candidates: List, val metadata: ResponseMetadata) + +@Serializable data class ValidationDetail(val msg: String, val type: String) + +@Serializable data class ValidationError(val detail: List?) + +@Serializable +enum class ChatRole { + @SerialName("system") SYSTEM, + @SerialName("user") USER, + @SerialName("assistant") ASSISTANT +} + +@Serializable data class ChatMessage(val role: ChatRole, val content: String) + +@Serializable +data class Chat( + val messages: List, + val temperature: Double? = null, + @SerialName("candidate_count") val candidateCount: Int? = null, + val stop: List? = null, + @SerialName("max_tokens") val maxTokens: Int? = null +) + +@Serializable data class ChatCandidate(val message: ChatMessage, val metadata: CandidateMetadata) + +@Serializable +data class ChatResponse(val candidates: List, val metadata: ResponseMetadata) + +@Serializable data class Embeddings(val text: List) + +@Serializable +data class EmbeddingsResponse(val embeddings: List>, val metadata: ResponseMetadata) diff --git a/settings.gradle.kts b/settings.gradle.kts index 43ac9bce2..6c8417956 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -51,6 +51,9 @@ project(":xef-gcp").projectDir = file("integrations/gcp") include("xef-opentelemetry") project(":xef-opentelemetry").projectDir = file("integrations/opentelemetry") + +include("xef-mlflow") +project(":xef-mlflow").projectDir = file("integrations/mlflow") // //