Skip to content

Commit

Permalink
MLflow Gateway models (#507)
Browse files Browse the repository at this point in the history
* Adds new integration

* Adds some comments

* Completes the API and adds an example

* Improves the example for the demo

* Adapt integration to the models

* spotless apply

* Allow sending the http client
  • Loading branch information
fedefernandez authored Oct 25, 2023
1 parent 57afd49 commit 3986df1
Show file tree
Hide file tree
Showing 6 changed files with 410 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies {
implementation(projects.xefOpenai)
implementation(projects.xefReasoning)
implementation(projects.xefOpentelemetry)
implementation(projects.xefMlflow)
implementation(libs.kotlinx.serialization.json)
implementation(libs.logback)
implementation(libs.klogging)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.xebia.functional.xef.conversation.mlflow

import com.xebia.functional.xef.mlflow.*
import com.xebia.functional.xef.mlflow.MlflowClient
import io.ktor.client.*

suspend fun main() {

val gatewayUri = "http://localhost:5000"

val httpClient = HttpClient()

val client = MlflowClient(gatewayUri, httpClient)

println("MLflow Gateway client created. Press any key to continue...")
readlnOrNull()

println("Searching available models...")
println()
val routes = client.searchRoutes()

println(
"""
|######### Routes found #########
|${routes.joinToString(separator = "\n") { printRoute(it) }}
|
"""
.trimMargin()
)
println()

while (true) {

println("Select the route you want to interact with")
val route = readlnOrNull() ?: "chat"

val gptRoute = client.getRoute(route)
println("Route found: ${gptRoute?.name}. What do you want to ask?")

val question = readlnOrNull() ?: "What's the best day of the week and why?"

val response =
gptRoute?.name?.let { it ->
client.chat(
it,
listOf(
ChatMessage(ChatRole.SYSTEM, "You are a helpful assistant. Be concise"),
ChatMessage(ChatRole.USER, question),
),
temperature = 0.7,
maxTokens = 200
)
}

val chatResponse = response?.candidates?.get(0)?.message?.content

println("Chat GPT response was: \n\n$chatResponse")
println()
println("Do you want to continue? (y/N)")
val userInput = readlnOrNull() ?: ""
if (!userInput.equals("y", true)) break
}

httpClient.close()
}

private fun printModel(model: RouteModel): String =
"(name = '${model.name}', provider = '${model.provider}')"

private fun printRoute(r: RouteDefinition): String =
"""
|Name: ${r.name}
| * Route type: ${r.routeType}
| * Route url: ${r.routeUrl}
| * Model: ${printModel(r.model)}"""
.trimMargin()
123 changes: 123 additions & 0 deletions integrations/mlflow/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
plugins {
id(libs.plugins.kotlin.multiplatform.get().pluginId)
id(libs.plugins.kotlinx.serialization.get().pluginId)
alias(libs.plugins.spotless)
alias(libs.plugins.arrow.gradle.publish)
alias(libs.plugins.semver.gradle)
alias(libs.plugins.detekt)
}


dependencies {
detektPlugins(project(":detekt-rules"))
}

detekt {
toolVersion = "1.23.1"
source = files("src/commonMain/kotlin", "src/jvmMain/kotlin")
config.setFrom("../../config/detekt/detekt.yml")
autoCorrect = true
}


repositories {
mavenCentral()
}

java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
toolchain {
languageVersion = JavaLanguageVersion.of(11)
}
}

kotlin {
jvm()
js(IR) {
browser()
nodejs()
}

linuxX64()
macosX64()
macosArm64()
mingwX64()

sourceSets {
val commonMain by getting {
dependencies {
api(projects.xefCore)
implementation(libs.bundles.ktor.client)
implementation(libs.uuid)
implementation(libs.kotlinx.datetime)
}
}

val jvmMain by getting {
dependencies {
implementation(libs.logback)
api(libs.ktor.client.cio)
}
}

val jsMain by getting {
dependencies {
api(libs.ktor.client.js)
}
}

val linuxX64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val macosX64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val macosArm64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val mingwX64Main by getting {
dependencies {
api(libs.ktor.client.winhttp)
}
}
}
}

spotless {
kotlin {
target("**/*.kt")
ktfmt().googleStyle().configure {
it.setRemoveUnusedImport(true)
}
}
}

tasks{
withType<io.gitlab.arturbosch.detekt.Detekt>().configureEach {
dependsOn(":detekt-rules:assemble")
autoCorrect = true
}
named("detektJvmMain") {
dependsOn(":detekt-rules:assemble")
getByName("build").dependsOn(this)
}
named("detekt") {
dependsOn(":detekt-rules:assemble")
getByName("build").dependsOn(this)
}
withType<AbstractPublishToMaven> {
dependsOn(withType<Sign>())
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package com.xebia.functional.xef.mlflow

import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.*
import kotlinx.serialization.json.Json

class MlflowClient(private val gatewayUrl: String, client: HttpClient) : AutoClose by autoClose() {

private val internal =
client.config {
install(ContentNegotiation) {
json(
Json {
encodeDefaults = false
isLenient = true
ignoreUnknownKeys = true
}
)
}
}

private val json = Json { ignoreUnknownKeys = true }

private suspend fun routes(): List<RouteDefinition> {

val response = internal.get("$gatewayUrl/api/2.0/gateway/routes/")
if (response.status.isSuccess()) {
val textResponse = response.bodyAsText()
val data = json.decodeFromString<RoutesResponse>(textResponse)
return data.routes
} else {
throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}
}

suspend fun searchRoutes(): List<RouteDefinition> = routes()

suspend fun getRoute(name: String): RouteDefinition? = routes().find { it.name == name }

suspend fun prompt(
route: String,
prompt: String,
candidateCount: Int? = null,
temperature: Double? = null,
maxTokens: Int? = null,
stop: List<String>? = null
): PromptResponse {
val body = Prompt(prompt, temperature, candidateCount, stop, maxTokens)
val response =
internal.post("$gatewayUrl/gateway/$route/invocations") {
accept(ContentType.Application.Json)
contentType(ContentType.Application.Json)
setBody(body)
}

return if (response.status.isSuccess()) response.body<PromptResponse>()
else if (response.status.value == 422)
throw MLflowValidationError(
response.status,
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error"
)
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}

suspend fun chat(
route: String,
messages: List<ChatMessage>,
candidateCount: Int? = null,
temperature: Double? = null,
maxTokens: Int? = null,
stop: List<String>? = null
): ChatResponse {
val body = Chat(messages, temperature, candidateCount, stop, maxTokens)
val response =
internal.post("$gatewayUrl/gateway/$route/invocations") {
accept(ContentType.Application.Json)
contentType(ContentType.Application.Json)
setBody(body)
}

return if (response.status.isSuccess()) response.body<ChatResponse>()
else if (response.status.value == 422)
throw MLflowValidationError(
response.status,
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error"
)
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}

suspend fun embeddings(route: String, text: List<String>): EmbeddingsResponse {
val body = Embeddings(text)
val response =
internal.post("$gatewayUrl/gateway/$route/invocations") {
accept(ContentType.Application.Json)
contentType(ContentType.Application.Json)
setBody(body)
}

return if (response.status.isSuccess()) response.body<EmbeddingsResponse>()
else if (response.status.value == 422)
throw MLflowValidationError(
response.status,
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error"
)
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}

class MLflowValidationError(httpStatusCode: HttpStatusCode, error: String) :
IllegalStateException("$httpStatusCode: $error")

class MLflowClientUnexpectedError(httpStatusCode: HttpStatusCode, error: String) :
IllegalStateException("$httpStatusCode: $error")
}
Loading

0 comments on commit 3986df1

Please sign in to comment.