Skip to content

Commit

Permalink
refactor and removing non-necessary dependency (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
Montagon authored Sep 26, 2023
1 parent 7c92ed5 commit 77213c8
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 127 deletions.
14 changes: 3 additions & 11 deletions server/src/main/kotlin/com/xebia/functional/xef/server/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,8 @@ import com.xebia.functional.xef.server.db.psql.XefDatabaseConfig
import com.xebia.functional.xef.server.db.psql.XefVectorStoreConfig
import com.xebia.functional.xef.server.db.psql.XefVectorStoreConfig.Companion.getVectorStoreService
import com.xebia.functional.xef.server.exceptions.exceptionsHandler
import com.xebia.functional.xef.server.http.routes.genAIRoutes
import com.xebia.functional.xef.server.http.routes.organizationRoutes
import com.xebia.functional.xef.server.http.routes.projectsRoutes
import com.xebia.functional.xef.server.http.routes.userRoutes
import com.xebia.functional.xef.server.services.OrganizationRepositoryService
import com.xebia.functional.xef.server.services.ProjectRepositoryService
import com.xebia.functional.xef.server.http.routes.*
import com.xebia.functional.xef.server.services.RepositoryService
import com.xebia.functional.xef.server.services.UserRepositoryService
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.auth.*
Expand Down Expand Up @@ -85,10 +79,8 @@ object Server {
}
exceptionsHandler()
routing {
genAIRoutes(ktorClient, vectorStoreService)
userRoutes(UserRepositoryService(logger))
organizationRoutes(OrganizationRepositoryService(logger))
projectsRoutes(ProjectRepositoryService(logger))
xefRoutes(logger)
aiRoutes(ktorClient)
}
}
awaitCancellation()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package com.xebia.functional.xef.server.http.routes

import com.aallam.openai.api.BetaOpenAI
import com.xebia.functional.xef.server.models.Token
import com.xebia.functional.xef.server.models.exceptions.XefExceptions
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.utils.io.jvm.javaio.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.boolean
import kotlinx.serialization.json.jsonPrimitive

enum class Provider {
OPENAI, GPT4ALL, GCP
}

fun String.toProvider(): Provider? = when (this) {
"openai" -> Provider.OPENAI
"gpt4all" -> Provider.GPT4ALL
"gcp" -> Provider.GCP
else -> Provider.OPENAI
}

@OptIn(BetaOpenAI::class)
fun Routing.aiRoutes(
client: HttpClient
) {
val openAiUrl = "https://api.openai.com/v1"

authenticate("auth-bearer") {
post("/chat/completions") {
val token = call.getToken()
val body = call.receive<String>()
val data = Json.decodeFromString<JsonObject>(body)

val isStream = data["stream"]?.jsonPrimitive?.boolean ?: false

if (!isStream) {
client.makeRequest(call, "$openAiUrl/chat/completions", body, token)
} else {
client.makeStreaming(call, "$openAiUrl/chat/completions", body, token)
}
}

post("/embeddings") {
val token = call.getToken()
val context = call.receive<String>()
client.makeRequest(call, "$openAiUrl/embeddings", context, token)
}
}
}

private suspend fun HttpClient.makeRequest(
call: ApplicationCall,
url: String,
body: String,
token: Token
) {
val response = this.request(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}
call.response.headers.copyFrom(response.headers)
call.respond(response.status, response.body<String>())
}

private suspend fun HttpClient.makeStreaming(
call: ApplicationCall,
url: String,
body: String,
token: Token
) {
this.preparePost(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}.execute { httpResponse ->
call.response.headers.copyFrom(httpResponse.headers)
call.respondOutputStream {
httpResponse
.bodyAsChannel()
.copyTo(this@respondOutputStream)
}
}
}

private fun ResponseHeaders.copyFrom(headers: Headers) = headers
.entries()
.filter { (key, _) -> !HttpHeaders.isUnsafe(key) } // setting unsafe headers results in exception
.forEach { (key, values) ->
values.forEach { value -> this.appendIfAbsent(key, value) }
}

private fun ApplicationCall.getProvider(): Provider =
request.headers["xef-provider"]?.toProvider()
?: Provider.OPENAI

fun ApplicationCall.getToken(): Token =
principal<UserIdPrincipal>()?.name?.let { Token(it) } ?: throw XefExceptions.AuthorizationException("No token found")

fun ApplicationCall.getId(): Int = getInt("id")

fun ApplicationCall.getInt(field: String): Int =
this.parameters[field]?.toInt() ?: throw XefExceptions.ValidationException("Invalid $field")

Original file line number Diff line number Diff line change
@@ -1,123 +1,16 @@
package com.xebia.functional.xef.server.http.routes

import com.aallam.openai.api.BetaOpenAI
import com.xebia.functional.xef.server.models.Token
import com.xebia.functional.xef.server.models.exceptions.XefExceptions
import com.xebia.functional.xef.server.services.VectorStoreService
import com.xebia.functional.xef.server.services.OrganizationRepositoryService
import com.xebia.functional.xef.server.services.ProjectRepositoryService
import com.xebia.functional.xef.server.services.UserRepositoryService
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.utils.io.jvm.javaio.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.boolean
import kotlinx.serialization.json.jsonPrimitive
import org.slf4j.Logger

enum class Provider {
OPENAI, GPT4ALL, GCP
}

fun String.toProvider(): Provider? = when (this) {
"openai" -> Provider.OPENAI
"gpt4all" -> Provider.GPT4ALL
"gcp" -> Provider.GCP
else -> Provider.OPENAI
}

@OptIn(BetaOpenAI::class)
fun Routing.genAIRoutes(
client: HttpClient,
vectorStoreService: VectorStoreService
) {
val openAiUrl = "https://api.openai.com/v1"

authenticate("auth-bearer") {
post("/chat/completions") {
val token = call.getToken()
val body = call.receive<String>()
val data = Json.decodeFromString<JsonObject>(body)

val isStream = data["stream"]?.jsonPrimitive?.boolean ?: false

if (!isStream) {
client.makeRequest(call, "$openAiUrl/chat/completions", body, token)
} else {
client.makeStreaming(call, "$openAiUrl/chat/completions", body, token)
}
}

post("/embeddings") {
val token = call.getToken()
val context = call.receive<String>()
client.makeRequest(call, "$openAiUrl/embeddings", context, token)
}
}
}

private suspend fun HttpClient.makeRequest(
call: ApplicationCall,
url: String,
body: String,
token: Token
fun Routing.xefRoutes(
logger: Logger
) {
val response = this.request(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}
call.response.headers.copyFrom(response.headers)
call.respond(response.status, response.body<String>())
userRoutes(UserRepositoryService(logger))
organizationRoutes(OrganizationRepositoryService(logger))
projectsRoutes(ProjectRepositoryService(logger))
}

private suspend fun HttpClient.makeStreaming(
call: ApplicationCall,
url: String,
body: String,
token: Token
) {
this.preparePost(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}.execute { httpResponse ->
call.response.headers.copyFrom(httpResponse.headers)
call.respondOutputStream {
httpResponse
.bodyAsChannel()
.copyTo(this@respondOutputStream)
}
}
}

private fun ResponseHeaders.copyFrom(headers: Headers) = headers
.entries()
.filter { (key, _) -> !HttpHeaders.isUnsafe(key) } // setting unsafe headers results in exception
.forEach { (key, values) ->
values.forEach { value -> this.appendIfAbsent(key, value) }
}

private fun ApplicationCall.getProvider(): Provider =
request.headers["xef-provider"]?.toProvider()
?: Provider.OPENAI

fun ApplicationCall.getToken(): Token =
principal<UserIdPrincipal>()?.name?.let { Token(it) } ?: throw XefExceptions.AuthorizationException("No token found")

fun ApplicationCall.getId(): Int = getInt("id")

fun ApplicationCall.getInt(field: String): Int =
this.parameters[field]?.toInt() ?: throw XefExceptions.ValidationException("Invalid $field")

0 comments on commit 77213c8

Please sign in to comment.