From 6734ca93920862f7ad31cd51566145b24f026c78 Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Sun, 17 Dec 2023 02:28:01 -0800 Subject: [PATCH] ais: const field schema support (scala 3 only) --- .../scala-3/kyo/llm/json/JsonDerive.scala | 32 ++++++++- .../src/main/scala/kyo/llm/json/Json.scala | 16 ++--- .../main/scala/kyo/llm/json/JsonSchema.scala | 70 ++++++++++++------- .../src/main/scala/kyo/llm/completions.scala | 2 +- .../src/main/scala/kyo/llm/thoughts.scala | 38 ++++++++-- 5 files changed, 116 insertions(+), 42 deletions(-) diff --git a/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala b/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala index 5d03bbf60..2e45d0121 100644 --- a/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala +++ b/kyo-llm-macros/shared/src/main/scala-3/kyo/llm/json/JsonDerive.scala @@ -3,11 +3,37 @@ package kyo.llm.json import kyo._ import kyo.ios._ import zio.schema.codec.JsonCodec -import scala.compiletime.constValue -import zio.schema._ +import scala.compiletime._ +import zio.schema.{Schema => ZSchema, _} import zio.Chunk trait JsonDerive { - inline implicit def deriveJson[T]: Json[T] = + + inline implicit def deriveJson[T]: Json[T] = { + import JsonDerive._ Json.fromZio(DeriveSchema.gen) + } +} + +object JsonDerive { + inline implicit def constStringZSchema[T <: String]: ZSchema[T] = + const(StandardType.StringType, compiletime.constValue[T]) + + inline implicit def constIntZSchema[T <: Int]: ZSchema[T] = + const(StandardType.IntType, compiletime.constValue[T]) + + inline implicit def constLongZSchema[T <: Long]: ZSchema[T] = + const(StandardType.LongType, compiletime.constValue[T]) + + inline implicit def constDoubleZSchema[T <: Double]: ZSchema[T] = + const(StandardType.DoubleType, compiletime.constValue[T]) + + inline implicit def constFloatZSchema[T <: Float]: ZSchema[T] = + const(StandardType.FloatType, compiletime.constValue[T]) + + inline implicit def constBoolZSchema[T <: Boolean]: ZSchema[T] = + const(StandardType.BoolType, compiletime.constValue[T]) + + private def const[T](t: StandardType[_], v: Any): ZSchema[T] = + ZSchema.Primitive(t, Chunk(Schema.Const(v))).asInstanceOf[ZSchema[T]] } diff --git a/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala index 58a59bebe..d2a68c0e0 100644 --- a/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala +++ b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala @@ -3,18 +3,18 @@ package kyo.llm.json import kyo._ import kyo.ios._ import zio.schema.codec.JsonCodec -import zio.schema._ +import zio.schema.{Schema => ZSchema, _} import zio.Chunk trait Json[T] { - def schema: JsonSchema + def schema: Schema def encode(v: T): String > IOs def decode(s: String): T > IOs } object Json extends JsonDerive { - def schema[T](implicit j: Json[T]): JsonSchema = + def schema[T](implicit j: Json[T]): Schema = j.schema def encode[T](v: T)(implicit j: Json[T]): String > IOs = @@ -24,13 +24,13 @@ object Json extends JsonDerive { j.decode(s) implicit def primitive[T](implicit t: StandardType[T]): Json[T] = - fromZio(Schema.Primitive(t, Chunk.empty)) + fromZio(ZSchema.Primitive(t, Chunk.empty)) - def fromZio[T](z: Schema[T]) = + def fromZio[T](z: ZSchema[T]) = new Json[T] { - lazy val schema: JsonSchema = JsonSchema(z) - private lazy val decoder = JsonCodec.jsonDecoder(z) - private lazy val encoder = JsonCodec.jsonEncoder(z) + lazy val schema: Schema = Schema(z) + private lazy val decoder = JsonCodec.jsonDecoder(z) + private lazy val encoder = JsonCodec.jsonEncoder(z) def encode(v: T): String > IOs = IOs(encoder.encodeJson(v).toString) diff --git a/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala index 71667cf12..cc0eb7bf7 100644 --- a/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala +++ b/kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala @@ -1,24 +1,26 @@ package kyo.llm.json -import zio.schema._ +import zio.schema.{Schema => ZSchema, _} import zio.json._ import zio.json.ast._ import zio.json.internal.Write import scala.annotation.StaticAnnotation import zio.Chunk -case class JsonSchema(data: List[(String, Json)]) +case class Schema(data: List[(String, Json)]) -object JsonSchema { +object Schema { - implicit val jsonSchemaEncoder: JsonEncoder[JsonSchema] = new JsonEncoder[JsonSchema] { - override def unsafeEncode(js: JsonSchema, indent: Option[Int], out: Write): Unit = { + case class Const[T](v: T) + + implicit val jsonSchemaEncoder: JsonEncoder[Schema] = new JsonEncoder[Schema] { + override def unsafeEncode(js: Schema, indent: Option[Int], out: Write): Unit = { implicitly[JsonEncoder[Json.Obj]].unsafeEncode(Json.Obj(js.data.toSeq: _*), indent, out) } } - def apply(schema: Schema[_]): JsonSchema = - new JsonSchema(convert(schema)) + def apply(schema: ZSchema[_]): Schema = + new Schema(convert(schema)) def desc(c: Chunk[Any]): List[(String, Json)] = c.collect { @@ -26,42 +28,62 @@ object JsonSchema { "description" -> Json.Str(v) }.distinct.toList - def convert(schema: Schema[_]): List[(String, Json)] = { + def convert(schema: ZSchema[_]): List[(String, Json)] = { def desc = this.desc(schema.annotations) schema match { - case Schema.Primitive(StandardType.StringType, _) => + + case ZSchema.Primitive(StandardType.StringType, Chunk(Const(v))) => + desc ++ List("const" -> Json.Str(v.asInstanceOf[String])) + + case ZSchema.Primitive(StandardType.StringType, _) => desc ++ List("type" -> Json.Str("string")) - case Schema.Primitive(StandardType.IntType, _) => + case ZSchema.Primitive(StandardType.IntType, Chunk(Const(v))) => + desc ++ List("const" -> Json.Num(v.asInstanceOf[Int])) + + case ZSchema.Primitive(StandardType.IntType, _) => desc ++ List("type" -> Json.Str("integer"), "format" -> Json.Str("int32")) - case Schema.Primitive(StandardType.LongType, _) => + case ZSchema.Primitive(StandardType.LongType, Chunk(Const(v))) => + desc ++ List("const" -> Json.Num(v.asInstanceOf[Long])) + + case ZSchema.Primitive(StandardType.LongType, _) => desc ++ List("type" -> Json.Str("integer"), "format" -> Json.Str("int64")) - case Schema.Primitive(StandardType.DoubleType, _) => + case ZSchema.Primitive(StandardType.DoubleType, Chunk(Const(v))) => + desc ++ List("const" -> Json.Num(v.asInstanceOf[Double])) + + case ZSchema.Primitive(StandardType.DoubleType, _) => desc ++ List("type" -> Json.Str("number")) - case Schema.Primitive(StandardType.FloatType, _) => + case ZSchema.Primitive(StandardType.FloatType, Chunk(Const(v))) => + desc ++ List("const" -> Json.Num(v.asInstanceOf[Float])) + + case ZSchema.Primitive(StandardType.FloatType, _) => desc ++ List("type" -> Json.Str("number"), "format" -> Json.Str("float")) - case Schema.Primitive(StandardType.BoolType, _) => + case ZSchema.Primitive(StandardType.BoolType, Chunk(Const(v))) => + desc ++ List("const" -> Json.Bool(v.asInstanceOf[Boolean])) + + case ZSchema.Primitive(StandardType.BoolType, _) => desc ++ List("type" -> Json.Str("boolean")) - case Schema.Optional(innerSchema, _) => + case ZSchema.Optional(innerSchema, _) => convert(innerSchema) - case Schema.Sequence(innerSchema, _, _, _, _) => + case ZSchema.Sequence(innerSchema, _, _, _, _) => List("type" -> Json.Str("array"), "items" -> Json.Obj(convert(innerSchema): _*)) - case schema: Schema.Enum[_] => + case schema: ZSchema.Enum[_] => val cases = schema.cases.map { c => val caseProperties = c.schema match { - case record: Schema.Record[_] => + case record: ZSchema.Record[_] => val fields = record.fields.map { field => field.name -> Json.Obj(convert(field.schema): _*) } val requiredFields = record.fields.collect { - case field if !field.schema.isInstanceOf[Schema.Optional[_]] => Json.Str(field.name) + case field if !field.schema.isInstanceOf[ZSchema.Optional[_]] => + Json.Str(field.name) } Json.Obj( "type" -> Json.Str("object"), @@ -78,14 +100,14 @@ object JsonSchema { "properties" -> Json.Obj(cases: _*) ) - case schema: Schema.Record[_] => + case schema: ZSchema.Record[_] => val properties = schema.fields.foldLeft(List.empty[(String, Json)]) { (acc, field) => acc :+ (field.name -> Json.Obj( (this.desc(field.annotations) ++ convert(field.schema)): _* )) } val requiredFields = schema.fields.collect { - case field if !field.schema.isInstanceOf[Schema.Optional[_]] => Json.Str(field.name) + case field if !field.schema.isInstanceOf[ZSchema.Optional[_]] => Json.Str(field.name) } desc ++ List( "type" -> Json.Str("object"), @@ -93,9 +115,9 @@ object JsonSchema { "required" -> Json.Arr(requiredFields: _*) ) - case Schema.Map(keySchema, valueSchema, _) => + case ZSchema.Map(keySchema, valueSchema, _) => keySchema match { - case Schema.Primitive(tpe, _) if (tpe == StandardType.StringType) => + case ZSchema.Primitive(tpe, _) if (tpe == StandardType.StringType) => List( "type" -> Json.Str("object"), "additionalProperties" -> Json.Obj(convert(valueSchema): _*) @@ -104,7 +126,7 @@ object JsonSchema { throw new UnsupportedOperationException("Non-string map keys are not supported") } - case schema: Schema.Lazy[_] => + case schema: ZSchema.Lazy[_] => convert(schema.schema) case _ => diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala b/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala index 4033ad7bb..b9b8762e3 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala +++ b/kyo-llm/shared/src/main/scala/kyo/llm/completions.scala @@ -80,7 +80,7 @@ object completions { case class FunctionCall(arguments: String, name: String) case class ToolCall(id: String, function: FunctionCall, `type`: String = "function") - case class FunctionDef(description: String, name: String, parameters: JsonSchema) + case class FunctionDef(description: String, name: String, parameters: Schema) case class ToolDef(function: FunctionDef, `type`: String = "function") sealed trait Entry diff --git a/kyo-llm/shared/src/main/scala/kyo/llm/thoughts.scala b/kyo-llm/shared/src/main/scala/kyo/llm/thoughts.scala index e61dbcf13..e6acd138b 100644 --- a/kyo-llm/shared/src/main/scala/kyo/llm/thoughts.scala +++ b/kyo-llm/shared/src/main/scala/kyo/llm/thoughts.scala @@ -56,12 +56,7 @@ object thoughts { ) } - case class NonEmptyString( - `I understand I can not generate an empty string in the next field`: Boolean, - nonEmptyString: String - ) { - def string: String = nonEmptyString - } + type NonEmpty = "Non-empty" case class Constrain[T, C <: String]( @desc("Constraints to consider when generating the value") @@ -69,3 +64,34 @@ object thoughts { value: T ) } + +// object tt extends KyoLLMApp { + +// import thoughts._ +// import kyo.llm.ais._ + +// case class CQA( +// @desc("Excerpt from the input text") +// excerpt: Constrain[String, NonEmpty], +// @desc("An elaborate question regarding the excerpt") +// question: Constrain[String, NonEmpty], +// @desc("A comprehensive answer") +// answer: Constrain[String, NonEmpty] +// ) +// case class Req( +// reasoning: Reasoning, +// @desc("Comprehensive set of questions covering all information in the input text") +// questions: Collect[Constrain[String, NonEmpty]], +// @desc("Process each question") +// processed: Collect[CQA] +// ) + +// run { +// AIs.gen[Req](text) +// } + +// def text = +// p""" +// General relativity is a theory of gravitation developed by Einstein in the years 1907–1915. The development of general relativity began with the equivalence principle, under which the states of accelerated motion and being at rest in a gravitational field (for example, when standing on the surface of the Earth) are physically identical. The upshot of this is that free fall is inertial motion: an object in free fall is falling because that is how objects move when there is no force being exerted on them, instead of this being due to the force of gravity as is the case in classical mechanics. This is incompatible with classical mechanics and special relativity because in those theories inertially moving objects cannot accelerate with respect to each other, but objects in free fall do so. To resolve this difficulty Einstein first proposed that spacetime is curved. Einstein discussed his idea with mathematician Marcel Grossmann and they concluded that general relativity could be formulated in the context of Riemannian geometry which had been developed in the 1800s.[10] In 1915, he devised the Einstein field equations which relate the curvature of spacetime with the mass, energy, and any momentum within it. +// """ +// }