Skip to content

Commit

Permalink
Vision DSL (#620)
Browse files Browse the repository at this point in the history
* Images Vision API

* Updating timeout to 60 sec in ApiClient

---------

Co-authored-by: Javi Pacheco <[email protected]>
  • Loading branch information
raulraja and javipacheco authored Jan 2, 2024
1 parent f338c03 commit 248771d
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 9 deletions.
11 changes: 10 additions & 1 deletion core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ sealed interface AI {

fun images(
api: ImagesApi = fromEnvironment(::ImagesApi),
): Images = Images(api)
chatApi: ChatApi = fromEnvironment(::ChatApi)
): Images = Images(api, chatApi)

@PublishedApi
internal suspend inline fun <reified A : Any> invokeEnum(
Expand Down Expand Up @@ -73,6 +74,14 @@ sealed interface AI {
conversation: Conversation = Conversation()
): A = chat(Prompt(CustomModel(model.value), prompt), target, api, conversation)

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: Prompt<CreateChatCompletionRequestModel>,
target: KType = typeOf<A>(),
api: ChatApi = fromEnvironment(::ChatApi),
conversation: Conversation = Conversation()
): A = chat(prompt, target, api, conversation)

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
@AiDsl
suspend inline fun <reified A : Any> chat(
Expand Down
39 changes: 34 additions & 5 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/Images.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@ package com.xebia.functional.xef

import ai.xef.openai.OpenAIModel
import ai.xef.openai.StandardModel
import com.xebia.functional.openai.apis.ChatApi
import com.xebia.functional.openai.apis.ImagesApi
import com.xebia.functional.openai.apis.UploadFile
import com.xebia.functional.openai.infrastructure.HttpResponse
import com.xebia.functional.openai.models.CreateImageEditRequestModel
import com.xebia.functional.openai.models.CreateImageRequest
import com.xebia.functional.openai.models.CreateImageRequestModel
import com.xebia.functional.openai.models.ImagesResponse
import com.xebia.functional.openai.models.*
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.user
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.request.forms.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import kotlinx.serialization.serializer

data class Images(val api: ImagesApi) {
data class Images(val api: ImagesApi, val chatApi: ChatApi) {

sealed class Image {
data class Url(
Expand All @@ -28,6 +33,30 @@ data class Images(val api: ImagesApi) {
data class B64Json(val content: String, val revisedPrompt: String) : Image()
}

suspend inline fun <reified A> visionStructured(
prompt: String,
url: String,
conversation: Conversation = Conversation(),
model: OpenAIModel<CreateChatCompletionRequestModel> =
StandardModel(CreateChatCompletionRequestModel.gpt_4_0613)
): A {
val response = vision(prompt, url, conversation).toList().joinToString("") { it }
return chatApi.prompt(Prompt(model) { +user(response) }, conversation, serializer())
}

fun vision(
prompt: String,
url: String,
conversation: Conversation = Conversation()
): Flow<String> =
chatApi.promptStreaming(
prompt =
Prompt(StandardModel(CreateChatCompletionRequestModel.gpt_4_vision_preview)) {
+com.xebia.functional.xef.prompt.templates.image(prompt, url)
},
scope = conversation
)

suspend fun image(
prompt: String,
amount: Int = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.xebia.functional.xef.prompt.templates

import com.xebia.functional.openai.models.ChatCompletionRole
import com.xebia.functional.openai.models.ChatCompletionRole.*
import com.xebia.functional.openai.models.ext.chat.ChatCompletionRequestMessage
import com.xebia.functional.openai.models.ext.chat.*
import com.xebia.functional.xef.prompt.message

fun system(context: String): ChatCompletionRequestMessage = context.message(system)
Expand All @@ -11,6 +11,16 @@ fun assistant(context: String): ChatCompletionRequestMessage = context.message(a

fun user(context: String): ChatCompletionRequestMessage = context.message(user)

fun image(prompt: String, url: String): ChatCompletionRequestMessage =
ChatCompletionRequestUserMessage(
listOf(
ChatCompletionRequestUserMessageContentText(prompt),
ChatCompletionRequestUserMessageContentImage(
imageUrl = ChatCompletionRequestUserMessageContentImageUrl(url)
)
)
)

inline fun <reified A> system(data: A): ChatCompletionRequestMessage = data.message(system)

inline fun <reified A> assistant(data: A): ChatCompletionRequestMessage = data.message(assistant)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.xebia.functional.xef.dsl.vision

import com.xebia.functional.xef.AI

suspend fun main() {
val images = AI.images()
val stream =
images.vision(
prompt = "Describe the image in detail",
url = "https://apod.nasa.gov/apod/image/2401/ngc1232b_vlt_960.jpg"
)
stream.collect(::print)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.xebia.functional.xef.dsl.vision

import com.xebia.functional.xef.AI
import kotlinx.serialization.Serializable

@Serializable
data class ImageAnalysisResult(
val topic: String,
val description: String,
)

suspend fun main() {
val images = AI.images()
val result: ImageAnalysisResult =
images.visionStructured(
prompt = "Describe the image in detail",
url = "https://apod.nasa.gov/apod/image/2401/ngc1232b_vlt_960.jpg"
)
println(result)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ open class ApiClient(val baseUrl: String) {
val clientConfig: (HttpClientConfig<*>) -> Unit by lazy {
{
it.install(ContentNegotiation) { json(jsonBlock) }
it.install(HttpTimeout)
it.install(HttpTimeout) {
requestTimeoutMillis = 60 * 1000
connectTimeoutMillis = 60 * 1000
socketTimeoutMillis = 60 * 1000
}
it.install(Logging) { level = LogLevel.NONE }
httpClientConfig?.invoke(it)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ import {{packageName}}.auth.*
val clientConfig: (HttpClientConfig<*>) -> Unit by lazy {
{
it.install(ContentNegotiation) { json(jsonBlock) }
it.install(HttpTimeout)
it.install(HttpTimeout) {
requestTimeoutMillis = 60 * 1000
connectTimeoutMillis = 60 * 1000
socketTimeoutMillis = 60 * 1000
}
it.install(Logging) { level = LogLevel.NONE }
httpClientConfig?.invoke(it)
}
Expand Down

0 comments on commit 248771d

Please sign in to comment.