From 4afcf45ab274922a0f3ba243cab0824d16e6a8e7 Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Sun, 17 Dec 2023 01:23:13 -0800 Subject: [PATCH] ais: simplify json handling --- build.sbt | 16 +++---- .../JsonDerive.scala} | 13 +++--- .../main/scala-3/kyo/llm/ValueSchema.scala | 11 ----- .../scala-3/kyo/llm/json/JsonDerive.scala | 13 ++++++ .../shared/src/main/scala/kyo/llm/Value.scala | 17 ------- .../src/main/scala/kyo/llm/json/Json.scala | 46 +++++++++++++++++++ .../main/scala/kyo/llm/json}/JsonSchema.scala | 5 +- .../src/main/scala/kyo/llm/json/desc.scala | 5 ++ .../src/main/scala/kyo/llm/agents.scala | 28 ++++------- .../shared/src/main/scala/kyo/llm/ais.scala | 29 ++++++------ .../src/main/scala/kyo/llm/completions.scala | 4 +- .../src/main/scala/kyo/llm/index/tokens.scala | 4 +- 12 files changed, 107 insertions(+), 84 deletions(-) rename kyo-llm-macros/shared/src/main/scala-2/kyo/llm/{ValueSchema.scala => json/JsonDerive.scala} (50%) delete mode 100644 kyo-llm-macros/shared/src/main/scala-3/kyo/llm/ValueSchema.scala create mode 100644 kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala delete mode 100644 kyo-llm-macros/shared/src/main/scala/kyo/llm/Value.scala create mode 100644 kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala rename {kyo-llm/shared/src/main/scala/kyo/llm/util => kyo-llm-macros/shared/src/main/scala/kyo/llm/json}/JsonSchema.scala (98%) create mode 100644 kyo-llm-macros/shared/src/main/scala/kyo/llm/json/desc.scala diff --git a/build.sbt b/build.sbt index 66ed253f5..a906dfa3b 100644 --- a/build.sbt +++ b/build.sbt @@ -234,9 +234,13 @@ lazy val `kyo-llm-macros` = .dependsOn(`kyo-core` % "test->test;compile->compile") .settings( `kyo-settings`, - scalaVersion := scala3Version, - crossScalaVersions := List(scala2Version, scala3Version), - libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16", + scalaVersion := scala3Version, + crossScalaVersions := List(scala2Version, scala3Version), + libraryDependencies += "com.softwaremill.sttp.client3" %% "zio-json" % "3.9.1", + libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16", + libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16", + libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.16", + libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.16", libraryDependencies += "dev.zio" %% "zio-schema-derivation" % "0.4.16", libraryDependencies ++= (CrossVersion.partialVersion(scalaVersion.value) match { case Some((2, _)) => Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value) @@ -257,11 +261,7 @@ lazy val `kyo-llm` = .settings( `kyo-settings`, `with-cross-scala`, - libraryDependencies += "com.softwaremill.sttp.client3" %% "zio-json" % "3.9.1", - libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.16", - libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.16", - libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.16", - libraryDependencies += "com.knuddels" % "jtokkit" % "0.6.1" + libraryDependencies += "com.knuddels" % "jtokkit" % "0.6.1" ) .jsSettings(`js-settings`) diff --git a/kyo-llm-macros/shared/src/main/scala-2/kyo/llm/ValueSchema.scala b/kyo-llm-macros/shared/src/main/scala-2/kyo/llm/json/JsonDerive.scala similarity index 50% rename from kyo-llm-macros/shared/src/main/scala-2/kyo/llm/ValueSchema.scala rename to kyo-llm-macros/shared/src/main/scala-2/kyo/llm/json/JsonDerive.scala index afe88635c..6fff0dc81 100644 --- a/kyo-llm-macros/shared/src/main/scala-2/kyo/llm/ValueSchema.scala +++ b/kyo-llm-macros/shared/src/main/scala-2/kyo/llm/json/JsonDerive.scala @@ -1,20 +1,19 @@ -package kyo.llm +package kyo.llm.json import zio.schema._ import scala.language.experimental.macros import scala.reflect.macros.blackbox -case class ValueSchema[T](get: Schema[Value[T]]) - -object ValueSchema { +trait JsonDerive { + implicit def deriveJson[T]: Json[T] = macro JsonDerive.genMacro[T] +} - implicit def gen[T]: ValueSchema[T] = macro genMacro[T] +object JsonDerive { def genMacro[T](c: blackbox.Context)(implicit t: c.WeakTypeTag[T]): c.Tree = { import c.universe._ q""" - import kyo.llm.Value - kyo.llm.ValueSchema[$t](zio.schema.DeriveSchema.gen) + kyo.llm.json.Json.fromZio[$t](zio.schema.DeriveSchema.gen) """ } } diff --git a/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/ValueSchema.scala b/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/ValueSchema.scala deleted file mode 100644 index 8801b3206..000000000 --- a/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/ValueSchema.scala +++ /dev/null @@ -1,11 +0,0 @@ -package kyo.llm - -import zio.schema._ - -case class ValueSchema[T](get: Schema[Value[T]]) - -object ValueSchema { - - inline implicit def gen[T]: ValueSchema[T] = - ValueSchema[T](DeriveSchema.gen) -} diff --git a/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala b/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala new file mode 100644 index 000000000..5d03bbf60 --- /dev/null +++ b/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala @@ -0,0 +1,13 @@ +package kyo.llm.json + +import kyo._ +import kyo.ios._ +import zio.schema.codec.JsonCodec +import scala.compiletime.constValue +import zio.schema._ +import zio.Chunk + +trait JsonDerive { + inline implicit def deriveJson[T]: Json[T] = + Json.fromZio(DeriveSchema.gen) +} diff --git a/kyo-llm-macros/shared/src/main/scala/kyo/llm/Value.scala b/kyo-llm-macros/shared/src/main/scala/kyo/llm/Value.scala deleted file mode 100644 index 3cd55db73..000000000 --- a/kyo-llm-macros/shared/src/main/scala/kyo/llm/Value.scala +++ /dev/null @@ -1,17 +0,0 @@ -package kyo.llm - -import scala.annotation.StaticAnnotation -import zio.schema._ - -final case class desc(value: String) extends StaticAnnotation - -case class Value[T]( - @desc("Please **generate compact json**.") - willIGenerateCompactJson: Boolean, - @desc("Result is wrapped into a `value` field.") - value: T -) - -object Value { - def apply[T](v: T): Value[T] = new Value(true, v) -} diff --git a/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala new file mode 100644 index 000000000..58a59bebe --- /dev/null +++ b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala @@ -0,0 +1,46 @@ +package kyo.llm.json + +import kyo._ +import kyo.ios._ +import zio.schema.codec.JsonCodec +import zio.schema._ +import zio.Chunk + +trait Json[T] { + def schema: JsonSchema + def encode(v: T): String > IOs + def decode(s: String): T > IOs +} + +object Json extends JsonDerive { + + def schema[T](implicit j: Json[T]): JsonSchema = + j.schema + + def encode[T](v: T)(implicit j: Json[T]): String > IOs = + j.encode(v) + + def decode[T](s: String)(implicit j: Json[T]): T > IOs = + j.decode(s) + + implicit def primitive[T](implicit t: StandardType[T]): Json[T] = + fromZio(Schema.Primitive(t, Chunk.empty)) + + def fromZio[T](z: Schema[T]) = + new Json[T] { + lazy val schema: JsonSchema = JsonSchema(z) + private lazy val decoder = JsonCodec.jsonDecoder(z) + private lazy val encoder = JsonCodec.jsonEncoder(z) + + def encode(v: T): String > IOs = + IOs(encoder.encodeJson(v).toString) + + def decode(s: String): T > IOs = + IOs { + decoder.decodeJson(s) match { + case Left(fail) => IOs.fail(fail) + case Right(v) => v + } + } + } +} diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/util/JsonSchema.scala b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala similarity index 98% rename from kyo-llm/shared/src/main/scala/kyo/llm/util/JsonSchema.scala rename to kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala index 931b9236f..71667cf12 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/util/JsonSchema.scala +++ b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala @@ -1,6 +1,5 @@ -package kyo.llm.util +package kyo.llm.json -import kyo.llm.ais._ import zio.schema._ import zio.json._ import zio.json.ast._ @@ -24,7 +23,7 @@ object JsonSchema { def desc(c: Chunk[Any]): List[(String, Json)] = c.collect { case desc(v) => - "description" -> Json.Str(p"$v") + "description" -> Json.Str(v) }.distinct.toList def convert(schema: Schema[_]): List[(String, Json)] = { diff --git a/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/desc.scala b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/desc.scala new file mode 100644 index 000000000..c4964eeba --- /dev/null +++ b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/desc.scala @@ -0,0 +1,5 @@ +package kyo.llm.json + +import scala.annotation.StaticAnnotation + +final case class desc(value: String) extends StaticAnnotation diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/agents.scala b/kyo-llm/shared/src/main/scala/kyo/llm/agents.scala index 27e49174e..6b821dc84 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/agents.scala +++ b/kyo-llm/shared/src/main/scala/kyo/llm/agents.scala @@ -11,7 +11,6 @@ import kyo.llm.contexts._ import kyo.concurrent.atomics._ import zio.schema.Schema import zio.schema.codec.JsonCodec -import kyo.llm.util.JsonSchema import scala.annotation.implicitNotFound package object agents { @@ -25,13 +24,9 @@ package object agents { name: String, description: String )(implicit - val input: ValueSchema[Input], - val output: ValueSchema[Output] - ) { - val schema = JsonSchema(input.get) - val decoder = JsonCodec.jsonDecoder(input.get) - val encoder = JsonCodec.jsonEncoder(output.get) - } + val input: Json[Input], + val output: Json[Output] + ) val info: Info @@ -51,16 +46,9 @@ package object agents { } private[kyo] def handle(ai: AI, v: String): String > AIs = - info.decoder.decodeJson(v) match { - case Left(error) => - AIs.fail( - "Invalid json input. **Correct any mistakes before retrying**. " + error - ) - case Right(value) => - run(ai, value.value).map { v => - info.encoder.encodeJson(Value(v)).toString() - } - } + info.input.decode(v) + .map(run(ai, _)) + .map(info.output.encode) } object Agents { @@ -79,8 +67,8 @@ package object agents { def disable[T, S](f: T > S): T > (AIs with S) = local.let(Set.empty)(f) - private[kyo] def resultAgent[T](implicit - t: ValueSchema[T] + private[kyo] def resultAgent[T]( + implicit t: Json[T] ): (Agent, Option[T] > AIs) > AIs = Atomics.initRef(Option.empty[T]).map { ref => val agent = diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala b/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala index 04af45af8..146baef06 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala +++ b/kyo-llm/shared/src/main/scala/kyo/llm/ais.scala @@ -27,8 +27,11 @@ object ais { type AIs >: AIs.Effects <: AIs.Effects - type desc = kyo.llm.desc - val desc = kyo.llm.desc + type desc = kyo.llm.json.desc + val desc = kyo.llm.json.desc + + type Json[T] = kyo.llm.json.Json[T] + val Json = kyo.llm.json.Json implicit class PromptInterpolator(val sc: StringContext) extends AnyVal { def p(args: Any*): String = @@ -44,7 +47,7 @@ object ais { def save: Context > AIs = State.get.map(_.getOrElse(ref, Context.empty)) - def dump: Unit > (AIs with Consoles) = + def dump: Unit > AIs = save.map(_.dump).map(Consoles.println(_)) def restore(ctx: Context): Unit > AIs = @@ -99,10 +102,10 @@ object ais { Agents.get.map(eval) } - def gen[T](msg: String)(implicit t: ValueSchema[T]): T > AIs = + def gen[T](msg: String)(implicit t: Json[T]): T > AIs = userMessage(msg).andThen(gen[T]) - def gen[T](implicit t: ValueSchema[T]): T > AIs = { + def gen[T](implicit t: Json[T]): T > AIs = { Agents.resultAgent[T].map { case (resultAgent, result) => def eval(): T > AIs = fetch(Set(resultAgent), Some(resultAgent)).map { r => @@ -119,10 +122,10 @@ object ais { } } - def infer[T](msg: String)(implicit t: ValueSchema[T]): T > AIs = + def infer[T](msg: String)(implicit t: Json[T]): T > AIs = userMessage(msg).andThen(infer[T]) - def infer[T](implicit t: ValueSchema[T]): T > AIs = { + def infer[T](implicit t: Json[T]): T > AIs = { Agents.resultAgent[T].map { case (resultAgent, result) => def eval(agents: Set[Agent], constrain: Option[Agent] = None): T > AIs = fetch(agents, constrain).map { r => @@ -184,28 +187,28 @@ object ais { def ask(msg: String): String > AIs = init.map(_.ask(msg)) - def gen[T](msg: String)(implicit t: ValueSchema[T]): T > AIs = + def gen[T](msg: String)(implicit t: Json[T]): T > AIs = init.map(_.gen[T](msg)) - def infer[T](msg: String)(implicit t: ValueSchema[T]): T > AIs = + def infer[T](msg: String)(implicit t: Json[T]): T > AIs = init.map(_.infer[T](msg)) def ask(seed: String, msg: String): String > AIs = init(seed).map(_.ask(msg)) - def gen[T](seed: String, msg: String)(implicit t: ValueSchema[T]): T > AIs = + def gen[T](seed: String, msg: String)(implicit t: Json[T]): T > AIs = init(seed).map(_.gen[T](msg)) - def infer[T](seed: String, msg: String)(implicit t: ValueSchema[T]): T > AIs = + def infer[T](seed: String, msg: String)(implicit t: Json[T]): T > AIs = init(seed).map(_.infer[T](msg)) def ask(seed: String, reminder: String, msg: String): String > AIs = init(seed, reminder).map(_.ask(msg)) - def gen[T](seed: String, reminder: String, msg: String)(implicit t: ValueSchema[T]): T > AIs = + def gen[T](seed: String, reminder: String, msg: String)(implicit t: Json[T]): T > AIs = init(seed, reminder).map(_.gen[T](msg)) - def infer[T](seed: String, reminder: String, msg: String)(implicit t: ValueSchema[T]): T > AIs = + def infer[T](seed: String, reminder: String, msg: String)(implicit t: Json[T]): T > AIs = init(seed, reminder).map(_.infer[T](msg)) def restore(ctx: Context): AI > AIs = diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala b/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala index c992503d2..4033ad7bb 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala +++ b/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala @@ -5,7 +5,7 @@ import kyo.llm.configs._ import kyo.llm.contexts._ import kyo.llm.agents._ import kyo.llm.ais._ -import kyo.llm.util.JsonSchema +import kyo.llm.json._ import kyo.ios._ import kyo.requests._ import kyo.tries._ @@ -173,7 +173,7 @@ object completions { ToolDef(FunctionDef( p.info.description, p.info.name, - p.info.schema + p.info.input.schema )) ).toList) Request( diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/index/tokens.scala b/kyo-llm/shared/src/main/scala/kyo/llm/index/tokens.scala index 93cf67216..ec26c860a 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/index/tokens.scala +++ b/kyo-llm/shared/src/main/scala/kyo/llm/index/tokens.scala @@ -10,17 +10,15 @@ object tokens { private val encoding = Encodings.newLazyEncodingRegistry.getEncoding(EncodingType.CL100K_BASE) - private case class Concat(a: Tokens, b: Tokens) - type Tokens // = Unit | Array[Int] | Concat implicit class TokensOps(a: Tokens) { def append(s: String): Tokens = if (!s.isEmpty()) { - append(encoding.encode(s).toArray[Int].asInstanceOf[Tokens]) + append(encoding.encode(s).toArray.asInstanceOf[Tokens]) } else { a }