Skip to content

Commit

Permalink
ais: simplify json handling
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Dec 17, 2023
1 parent c6cde9e commit 4afcf45
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 84 deletions.
16 changes: 8 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,13 @@ lazy val `kyo-llm-macros` =
.dependsOn(`kyo-core` % "test->test;compile->compile")
.settings(
`kyo-settings`,
scalaVersion := scala3Version,
crossScalaVersions := List(scala2Version, scala3Version),
libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16",
scalaVersion := scala3Version,
crossScalaVersions := List(scala2Version, scala3Version),
libraryDependencies += "com.softwaremill.sttp.client3" %% "zio-json" % "3.9.1",
libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16",
libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16",
libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.16",
libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.16",
libraryDependencies += "dev.zio" %% "zio-schema-derivation" % "0.4.16",
libraryDependencies ++= (CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, _)) => Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value)
Expand All @@ -257,11 +261,7 @@ lazy val `kyo-llm` =
.settings(
`kyo-settings`,
`with-cross-scala`,
libraryDependencies += "com.softwaremill.sttp.client3" %% "zio-json" % "3.9.1",
libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16",
libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.16",
libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.16",
libraryDependencies += "com.knuddels" % "jtokkit" % "0.6.1"
libraryDependencies += "com.knuddels" % "jtokkit" % "0.6.1"
)
.jsSettings(`js-settings`)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
package kyo.llm
package kyo.llm.json

import zio.schema._
import scala.language.experimental.macros
import scala.reflect.macros.blackbox

case class ValueSchema[T](get: Schema[Value[T]])

object ValueSchema {
trait JsonDerive {
implicit def deriveJson[T]: Json[T] = macro JsonDerive.genMacro[T]
}

implicit def gen[T]: ValueSchema[T] = macro genMacro[T]
object JsonDerive {

def genMacro[T](c: blackbox.Context)(implicit t: c.WeakTypeTag[T]): c.Tree = {
import c.universe._
q"""
import kyo.llm.Value
kyo.llm.ValueSchema[$t](zio.schema.DeriveSchema.gen)
kyo.llm.json.Json.fromZio[$t](zio.schema.DeriveSchema.gen)
"""
}
}
11 changes: 0 additions & 11 deletions kyo-llm-macros/shared/src/main/scala-3/kyo/llm/ValueSchema.scala

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package kyo.llm.json

import kyo._
import kyo.ios._
import zio.schema.codec.JsonCodec
import scala.compiletime.constValue
import zio.schema._
import zio.Chunk

trait JsonDerive {
inline implicit def deriveJson[T]: Json[T] =
Json.fromZio(DeriveSchema.gen)
}
17 changes: 0 additions & 17 deletions kyo-llm-macros/shared/src/main/scala/kyo/llm/Value.scala

This file was deleted.

46 changes: 46 additions & 0 deletions kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package kyo.llm.json

import kyo._
import kyo.ios._
import zio.schema.codec.JsonCodec
import zio.schema._
import zio.Chunk

trait Json[T] {
def schema: JsonSchema
def encode(v: T): String > IOs
def decode(s: String): T > IOs
}

object Json extends JsonDerive {

def schema[T](implicit j: Json[T]): JsonSchema =
j.schema

def encode[T](v: T)(implicit j: Json[T]): String > IOs =
j.encode(v)

def decode[T](s: String)(implicit j: Json[T]): T > IOs =
j.decode(s)

implicit def primitive[T](implicit t: StandardType[T]): Json[T] =
fromZio(Schema.Primitive(t, Chunk.empty))

def fromZio[T](z: Schema[T]) =
new Json[T] {
lazy val schema: JsonSchema = JsonSchema(z)
private lazy val decoder = JsonCodec.jsonDecoder(z)
private lazy val encoder = JsonCodec.jsonEncoder(z)

def encode(v: T): String > IOs =
IOs(encoder.encodeJson(v).toString)

def decode(s: String): T > IOs =
IOs {
decoder.decodeJson(s) match {
case Left(fail) => IOs.fail(fail)
case Right(v) => v
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package kyo.llm.util
package kyo.llm.json

import kyo.llm.ais._
import zio.schema._
import zio.json._
import zio.json.ast._
Expand All @@ -24,7 +23,7 @@ object JsonSchema {
def desc(c: Chunk[Any]): List[(String, Json)] =
c.collect {
case desc(v) =>
"description" -> Json.Str(p"$v")
"description" -> Json.Str(v)
}.distinct.toList

def convert(schema: Schema[_]): List[(String, Json)] = {
Expand Down
5 changes: 5 additions & 0 deletions kyo-llm-macros/shared/src/main/scala/kyo/llm/json/desc.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package kyo.llm.json

import scala.annotation.StaticAnnotation

final case class desc(value: String) extends StaticAnnotation
28 changes: 8 additions & 20 deletions kyo-llm/shared/src/main/scala/kyo/llm/agents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import kyo.llm.contexts._
import kyo.concurrent.atomics._
import zio.schema.Schema
import zio.schema.codec.JsonCodec
import kyo.llm.util.JsonSchema
import scala.annotation.implicitNotFound

package object agents {
Expand All @@ -25,13 +24,9 @@ package object agents {
name: String,
description: String
)(implicit
val input: ValueSchema[Input],
val output: ValueSchema[Output]
) {
val schema = JsonSchema(input.get)
val decoder = JsonCodec.jsonDecoder(input.get)
val encoder = JsonCodec.jsonEncoder(output.get)
}
val input: Json[Input],
val output: Json[Output]
)

val info: Info

Expand All @@ -51,16 +46,9 @@ package object agents {
}

private[kyo] def handle(ai: AI, v: String): String > AIs =
info.decoder.decodeJson(v) match {
case Left(error) =>
AIs.fail(
"Invalid json input. **Correct any mistakes before retrying**. " + error
)
case Right(value) =>
run(ai, value.value).map { v =>
info.encoder.encodeJson(Value(v)).toString()
}
}
info.input.decode(v)
.map(run(ai, _))
.map(info.output.encode)
}

object Agents {
Expand All @@ -79,8 +67,8 @@ package object agents {
def disable[T, S](f: T > S): T > (AIs with S) =
local.let(Set.empty)(f)

private[kyo] def resultAgent[T](implicit
t: ValueSchema[T]
private[kyo] def resultAgent[T](
implicit t: Json[T]
): (Agent, Option[T] > AIs) > AIs =
Atomics.initRef(Option.empty[T]).map { ref =>
val agent =
Expand Down
29 changes: 16 additions & 13 deletions kyo-llm/shared/src/main/scala/kyo/llm/ais.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ object ais {

type AIs >: AIs.Effects <: AIs.Effects

type desc = kyo.llm.desc
val desc = kyo.llm.desc
type desc = kyo.llm.json.desc
val desc = kyo.llm.json.desc

type Json[T] = kyo.llm.json.Json[T]
val Json = kyo.llm.json.Json

implicit class PromptInterpolator(val sc: StringContext) extends AnyVal {
def p(args: Any*): String =
Expand All @@ -44,7 +47,7 @@ object ais {
def save: Context > AIs =
State.get.map(_.getOrElse(ref, Context.empty))

def dump: Unit > (AIs with Consoles) =
def dump: Unit > AIs =
save.map(_.dump).map(Consoles.println(_))

def restore(ctx: Context): Unit > AIs =
Expand Down Expand Up @@ -99,10 +102,10 @@ object ais {
Agents.get.map(eval)
}

def gen[T](msg: String)(implicit t: ValueSchema[T]): T > AIs =
def gen[T](msg: String)(implicit t: Json[T]): T > AIs =
userMessage(msg).andThen(gen[T])

def gen[T](implicit t: ValueSchema[T]): T > AIs = {
def gen[T](implicit t: Json[T]): T > AIs = {
Agents.resultAgent[T].map { case (resultAgent, result) =>
def eval(): T > AIs =
fetch(Set(resultAgent), Some(resultAgent)).map { r =>
Expand All @@ -119,10 +122,10 @@ object ais {
}
}

def infer[T](msg: String)(implicit t: ValueSchema[T]): T > AIs =
def infer[T](msg: String)(implicit t: Json[T]): T > AIs =
userMessage(msg).andThen(infer[T])

def infer[T](implicit t: ValueSchema[T]): T > AIs = {
def infer[T](implicit t: Json[T]): T > AIs = {
Agents.resultAgent[T].map { case (resultAgent, result) =>
def eval(agents: Set[Agent], constrain: Option[Agent] = None): T > AIs =
fetch(agents, constrain).map { r =>
Expand Down Expand Up @@ -184,28 +187,28 @@ object ais {
def ask(msg: String): String > AIs =
init.map(_.ask(msg))

def gen[T](msg: String)(implicit t: ValueSchema[T]): T > AIs =
def gen[T](msg: String)(implicit t: Json[T]): T > AIs =
init.map(_.gen[T](msg))

def infer[T](msg: String)(implicit t: ValueSchema[T]): T > AIs =
def infer[T](msg: String)(implicit t: Json[T]): T > AIs =
init.map(_.infer[T](msg))

def ask(seed: String, msg: String): String > AIs =
init(seed).map(_.ask(msg))

def gen[T](seed: String, msg: String)(implicit t: ValueSchema[T]): T > AIs =
def gen[T](seed: String, msg: String)(implicit t: Json[T]): T > AIs =
init(seed).map(_.gen[T](msg))

def infer[T](seed: String, msg: String)(implicit t: ValueSchema[T]): T > AIs =
def infer[T](seed: String, msg: String)(implicit t: Json[T]): T > AIs =
init(seed).map(_.infer[T](msg))

def ask(seed: String, reminder: String, msg: String): String > AIs =
init(seed, reminder).map(_.ask(msg))

def gen[T](seed: String, reminder: String, msg: String)(implicit t: ValueSchema[T]): T > AIs =
def gen[T](seed: String, reminder: String, msg: String)(implicit t: Json[T]): T > AIs =
init(seed, reminder).map(_.gen[T](msg))

def infer[T](seed: String, reminder: String, msg: String)(implicit t: ValueSchema[T]): T > AIs =
def infer[T](seed: String, reminder: String, msg: String)(implicit t: Json[T]): T > AIs =
init(seed, reminder).map(_.infer[T](msg))

def restore(ctx: Context): AI > AIs =
Expand Down
4 changes: 2 additions & 2 deletions kyo-llm/shared/src/main/scala/kyo/llm/completions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import kyo.llm.configs._
import kyo.llm.contexts._
import kyo.llm.agents._
import kyo.llm.ais._
import kyo.llm.util.JsonSchema
import kyo.llm.json._
import kyo.ios._
import kyo.requests._
import kyo.tries._
Expand Down Expand Up @@ -173,7 +173,7 @@ object completions {
ToolDef(FunctionDef(
p.info.description,
p.info.name,
p.info.schema
p.info.input.schema
))
).toList)
Request(
Expand Down
4 changes: 1 addition & 3 deletions kyo-llm/shared/src/main/scala/kyo/llm/index/tokens.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@ object tokens {

private val encoding = Encodings.newLazyEncodingRegistry.getEncoding(EncodingType.CL100K_BASE)


private case class Concat(a: Tokens, b: Tokens)


type Tokens // = Unit | Array[Int] | Concat

implicit class TokensOps(a: Tokens) {

def append(s: String): Tokens =
if (!s.isEmpty()) {
append(encoding.encode(s).toArray[Int].asInstanceOf[Tokens])
append(encoding.encode(s).toArray.asInstanceOf[Tokens])
} else {
a
}
Expand Down

0 comments on commit 4afcf45

Please sign in to comment.