Skip to content

Commit

Permalink
Safely spawn fine tuned model (#494)
Browse files Browse the repository at this point in the history
* safer way to spawn a fine tuned model by its job id

* fixes example

* spotless
  • Loading branch information
Intex32 authored Oct 21, 2023
1 parent a2db786 commit cdbc923
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
package com.xebia.functional.xef.conversation.finetuning

import arrow.core.getOrElse
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.prompt.Prompt

suspend fun main() {
val spawnModelId =
getenv("OPENAI_FINE_TUNED_MODEL_ID")
?: error("Please set the OPENAI_FINE_TUNED_MODEL_ID environment variable.")

val OAI = OpenAI()
val model = OAI.spawnModel(spawnModelId, OAI.GPT_3_5_TURBO)
val baseModel = OAI.GPT_3_5_TURBO

val fineTunedModelId = getenv("OPENAI_FINE_TUNED_MODEL_ID")
val fineTuneJobId = getenv("OPENAI_FINE_TUNE_JOB_ID")

val model =
when {
fineTunedModelId != null -> OAI.spawnModel(fineTunedModelId, baseModel)
fineTuneJobId != null -> OAI.spawnFineTunedModel(fineTuneJobId, baseModel)
else ->
error(
"Please set the OPENAI_FINE_TUNED_MODEL_ID or OPENAI_FINE_TUNE_JOB_ID environment variable."
)
}.getOrElse { error(it) }

OpenAI.conversation {
while (true) {
print("> ")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.xebia.functional.xef.conversation.llm.openai

import arrow.core.nonEmptyListOf
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.ensureNotNull
import com.aallam.openai.api.exception.InvalidRequestException
import com.aallam.openai.api.finetuning.FineTuningId
import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.LoggingConfig
Expand Down Expand Up @@ -137,6 +141,7 @@ class OpenAI(

@JvmField val DEFAULT_IMAGES = DALLE_2

/** Returns a list of all publicly available, supported models. */
fun supportedModels(): List<LLM> = // TODO: impl of abstract provider function
listOf(
GPT_4,
Expand All @@ -155,26 +160,53 @@ class OpenAI(
DALLE_2,
)

suspend fun <T : LLM> spawnModel(
modelId: String,
baseModel: T
): T { // TODO: impl of abstract provider function
if (findModel(modelId) == null) error("model not found")
return baseModel.copy(ModelType.FineTunedModel(modelId, baseModel = baseModel.modelType)) as? T
?: error("${baseModel::class} does not follow contract to return the most specific type")
/**
* Spawns a model by its [modelId]. It should have the same capabilities as [baseModel]. The model
* to spawn can i.e. be a fine-tuned model which is not known to the public.
*
* Warning: Throws an error at runtime during querying if the model does not provide the same
* capabilities as [baseModel].
*/
suspend fun <T : LLM> spawnModel(modelId: String, baseModel: T) =
either { // TODO: impl of abstract provider function
ensure(modelExists(modelId)) { "model $modelId not found" }
@Suppress("UNCHECKED_CAST")
baseModel.copy(ModelType.FineTunedModel(modelId, baseModel = baseModel.modelType)) as? T
?: error("${baseModel::class} does not follow contract to return the most specific type")
}

/**
* Spawns a model based off a [fineTuningJobId]. It should have the same capabilities as
* [baseModel].
*
* This function is safer than [spawnModel] because it checks if the base model the fine-tuned
* model was derived from matches [baseModel].
*/
suspend fun <T : LLM> spawnFineTunedModel(fineTuningJobId: String, baseModel: T) = either {
val job = defaultClient.fineTuningJob(FineTuningId(fineTuningJobId))
ensureNotNull(job) { "job $fineTuningJobId not found" }
val fineTunedModel = job.fineTunedModel
ensureNotNull(fineTunedModel) { "fine tuned model not available, status ${job.status}" }
ensure(baseModel.modelType.name == job.model.id) {
"base model instance does not match the job's base model"
}
spawnModel(fineTunedModel.id, baseModel).bind()
}

private suspend fun findModel(modelId: String): Any? { // TODO: impl of abstract provider function
/** Checks if the model exists. */
private suspend fun modelExists(
modelId: String
): Boolean { // TODO: impl of abstract provider function
val model =
try {
defaultClient.model(ModelId(modelId))
} catch (e: InvalidRequestException) {
when (e.error.detail?.code) {
"model_not_found" -> return null
"model_not_found" -> return false
else -> throw e
}
}
return ModelType.TODO(model.id.id)
return true
}

companion object {
Expand Down

0 comments on commit cdbc923

Please sign in to comment.