-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adds new integration * Adds some comments * Completes the API and adds an example * Improves the example for the demo * Adapt integration to the models * spotless apply * Allow sending the http client
- Loading branch information
1 parent
57afd49
commit 3986df1
Showing
6 changed files
with
410 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/mlflow/Example.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package com.xebia.functional.xef.conversation.mlflow | ||
|
||
import com.xebia.functional.xef.mlflow.* | ||
import com.xebia.functional.xef.mlflow.MlflowClient | ||
import io.ktor.client.* | ||
|
||
suspend fun main() { | ||
|
||
val gatewayUri = "http://localhost:5000" | ||
|
||
val httpClient = HttpClient() | ||
|
||
val client = MlflowClient(gatewayUri, httpClient) | ||
|
||
println("MLflow Gateway client created. Press any key to continue...") | ||
readlnOrNull() | ||
|
||
println("Searching available models...") | ||
println() | ||
val routes = client.searchRoutes() | ||
|
||
println( | ||
""" | ||
|######### Routes found ######### | ||
|${routes.joinToString(separator = "\n") { printRoute(it) }} | ||
| | ||
""" | ||
.trimMargin() | ||
) | ||
println() | ||
|
||
while (true) { | ||
|
||
println("Select the route you want to interact with") | ||
val route = readlnOrNull() ?: "chat" | ||
|
||
val gptRoute = client.getRoute(route) | ||
println("Route found: ${gptRoute?.name}. What do you want to ask?") | ||
|
||
val question = readlnOrNull() ?: "What's the best day of the week and why?" | ||
|
||
val response = | ||
gptRoute?.name?.let { it -> | ||
client.chat( | ||
it, | ||
listOf( | ||
ChatMessage(ChatRole.SYSTEM, "You are a helpful assistant. Be concise"), | ||
ChatMessage(ChatRole.USER, question), | ||
), | ||
temperature = 0.7, | ||
maxTokens = 200 | ||
) | ||
} | ||
|
||
val chatResponse = response?.candidates?.get(0)?.message?.content | ||
|
||
println("Chat GPT response was: \n\n$chatResponse") | ||
println() | ||
println("Do you want to continue? (y/N)") | ||
val userInput = readlnOrNull() ?: "" | ||
if (!userInput.equals("y", true)) break | ||
} | ||
|
||
httpClient.close() | ||
} | ||
|
||
private fun printModel(model: RouteModel): String = | ||
"(name = '${model.name}', provider = '${model.provider}')" | ||
|
||
private fun printRoute(r: RouteDefinition): String = | ||
""" | ||
|Name: ${r.name} | ||
| * Route type: ${r.routeType} | ||
| * Route url: ${r.routeUrl} | ||
| * Model: ${printModel(r.model)}""" | ||
.trimMargin() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
plugins { | ||
id(libs.plugins.kotlin.multiplatform.get().pluginId) | ||
id(libs.plugins.kotlinx.serialization.get().pluginId) | ||
alias(libs.plugins.spotless) | ||
alias(libs.plugins.arrow.gradle.publish) | ||
alias(libs.plugins.semver.gradle) | ||
alias(libs.plugins.detekt) | ||
} | ||
|
||
|
||
dependencies { | ||
detektPlugins(project(":detekt-rules")) | ||
} | ||
|
||
detekt { | ||
toolVersion = "1.23.1" | ||
source = files("src/commonMain/kotlin", "src/jvmMain/kotlin") | ||
config.setFrom("../../config/detekt/detekt.yml") | ||
autoCorrect = true | ||
} | ||
|
||
|
||
repositories { | ||
mavenCentral() | ||
} | ||
|
||
java { | ||
sourceCompatibility = JavaVersion.VERSION_11 | ||
targetCompatibility = JavaVersion.VERSION_11 | ||
toolchain { | ||
languageVersion = JavaLanguageVersion.of(11) | ||
} | ||
} | ||
|
||
kotlin { | ||
jvm() | ||
js(IR) { | ||
browser() | ||
nodejs() | ||
} | ||
|
||
linuxX64() | ||
macosX64() | ||
macosArm64() | ||
mingwX64() | ||
|
||
sourceSets { | ||
val commonMain by getting { | ||
dependencies { | ||
api(projects.xefCore) | ||
implementation(libs.bundles.ktor.client) | ||
implementation(libs.uuid) | ||
implementation(libs.kotlinx.datetime) | ||
} | ||
} | ||
|
||
val jvmMain by getting { | ||
dependencies { | ||
implementation(libs.logback) | ||
api(libs.ktor.client.cio) | ||
} | ||
} | ||
|
||
val jsMain by getting { | ||
dependencies { | ||
api(libs.ktor.client.js) | ||
} | ||
} | ||
|
||
val linuxX64Main by getting { | ||
dependencies { | ||
api(libs.ktor.client.cio) | ||
} | ||
} | ||
|
||
val macosX64Main by getting { | ||
dependencies { | ||
api(libs.ktor.client.cio) | ||
} | ||
} | ||
|
||
val macosArm64Main by getting { | ||
dependencies { | ||
api(libs.ktor.client.cio) | ||
} | ||
} | ||
|
||
val mingwX64Main by getting { | ||
dependencies { | ||
api(libs.ktor.client.winhttp) | ||
} | ||
} | ||
} | ||
} | ||
|
||
spotless { | ||
kotlin { | ||
target("**/*.kt") | ||
ktfmt().googleStyle().configure { | ||
it.setRemoveUnusedImport(true) | ||
} | ||
} | ||
} | ||
|
||
tasks{ | ||
withType<io.gitlab.arturbosch.detekt.Detekt>().configureEach { | ||
dependsOn(":detekt-rules:assemble") | ||
autoCorrect = true | ||
} | ||
named("detektJvmMain") { | ||
dependsOn(":detekt-rules:assemble") | ||
getByName("build").dependsOn(this) | ||
} | ||
named("detekt") { | ||
dependsOn(":detekt-rules:assemble") | ||
getByName("build").dependsOn(this) | ||
} | ||
withType<AbstractPublishToMaven> { | ||
dependsOn(withType<Sign>()) | ||
} | ||
|
||
} | ||
|
123 changes: 123 additions & 0 deletions
123
integrations/mlflow/src/commonMain/kotlin/com/xebia/functional/xef/mlflow/MlflowClient.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package com.xebia.functional.xef.mlflow | ||
|
||
import com.xebia.functional.xef.conversation.AutoClose | ||
import com.xebia.functional.xef.conversation.autoClose | ||
import io.ktor.client.HttpClient | ||
import io.ktor.client.call.body | ||
import io.ktor.client.plugins.contentnegotiation.* | ||
import io.ktor.client.request.* | ||
import io.ktor.client.statement.bodyAsText | ||
import io.ktor.http.ContentType | ||
import io.ktor.http.HttpStatusCode | ||
import io.ktor.http.contentType | ||
import io.ktor.http.isSuccess | ||
import io.ktor.serialization.kotlinx.json.* | ||
import kotlinx.serialization.json.Json | ||
|
||
class MlflowClient(private val gatewayUrl: String, client: HttpClient) : AutoClose by autoClose() { | ||
|
||
private val internal = | ||
client.config { | ||
install(ContentNegotiation) { | ||
json( | ||
Json { | ||
encodeDefaults = false | ||
isLenient = true | ||
ignoreUnknownKeys = true | ||
} | ||
) | ||
} | ||
} | ||
|
||
private val json = Json { ignoreUnknownKeys = true } | ||
|
||
private suspend fun routes(): List<RouteDefinition> { | ||
|
||
val response = internal.get("$gatewayUrl/api/2.0/gateway/routes/") | ||
if (response.status.isSuccess()) { | ||
val textResponse = response.bodyAsText() | ||
val data = json.decodeFromString<RoutesResponse>(textResponse) | ||
return data.routes | ||
} else { | ||
throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) | ||
} | ||
} | ||
|
||
suspend fun searchRoutes(): List<RouteDefinition> = routes() | ||
|
||
suspend fun getRoute(name: String): RouteDefinition? = routes().find { it.name == name } | ||
|
||
suspend fun prompt( | ||
route: String, | ||
prompt: String, | ||
candidateCount: Int? = null, | ||
temperature: Double? = null, | ||
maxTokens: Int? = null, | ||
stop: List<String>? = null | ||
): PromptResponse { | ||
val body = Prompt(prompt, temperature, candidateCount, stop, maxTokens) | ||
val response = | ||
internal.post("$gatewayUrl/gateway/$route/invocations") { | ||
accept(ContentType.Application.Json) | ||
contentType(ContentType.Application.Json) | ||
setBody(body) | ||
} | ||
|
||
return if (response.status.isSuccess()) response.body<PromptResponse>() | ||
else if (response.status.value == 422) | ||
throw MLflowValidationError( | ||
response.status, | ||
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error" | ||
) | ||
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) | ||
} | ||
|
||
suspend fun chat( | ||
route: String, | ||
messages: List<ChatMessage>, | ||
candidateCount: Int? = null, | ||
temperature: Double? = null, | ||
maxTokens: Int? = null, | ||
stop: List<String>? = null | ||
): ChatResponse { | ||
val body = Chat(messages, temperature, candidateCount, stop, maxTokens) | ||
val response = | ||
internal.post("$gatewayUrl/gateway/$route/invocations") { | ||
accept(ContentType.Application.Json) | ||
contentType(ContentType.Application.Json) | ||
setBody(body) | ||
} | ||
|
||
return if (response.status.isSuccess()) response.body<ChatResponse>() | ||
else if (response.status.value == 422) | ||
throw MLflowValidationError( | ||
response.status, | ||
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error" | ||
) | ||
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) | ||
} | ||
|
||
suspend fun embeddings(route: String, text: List<String>): EmbeddingsResponse { | ||
val body = Embeddings(text) | ||
val response = | ||
internal.post("$gatewayUrl/gateway/$route/invocations") { | ||
accept(ContentType.Application.Json) | ||
contentType(ContentType.Application.Json) | ||
setBody(body) | ||
} | ||
|
||
return if (response.status.isSuccess()) response.body<EmbeddingsResponse>() | ||
else if (response.status.value == 422) | ||
throw MLflowValidationError( | ||
response.status, | ||
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error" | ||
) | ||
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText()) | ||
} | ||
|
||
class MLflowValidationError(httpStatusCode: HttpStatusCode, error: String) : | ||
IllegalStateException("$httpStatusCode: $error") | ||
|
||
class MLflowClientUnexpectedError(httpStatusCode: HttpStatusCode, error: String) : | ||
IllegalStateException("$httpStatusCode: $error") | ||
} |
Oops, something went wrong.