Skip to content

Commit

Permalink
ais: const field schema support (scala 3 only)
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Dec 17, 2023
1 parent 550be40 commit 6734ca9
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,37 @@ package kyo.llm.json
import kyo._
import kyo.ios._
import zio.schema.codec.JsonCodec
import scala.compiletime.constValue
import zio.schema._
import scala.compiletime._
import zio.schema.{Schema => ZSchema, _}
import zio.Chunk

trait JsonDerive {
inline implicit def deriveJson[T]: Json[T] =

inline implicit def deriveJson[T]: Json[T] = {
import JsonDerive._
Json.fromZio(DeriveSchema.gen)
}
}

object JsonDerive {
inline implicit def constStringZSchema[T <: String]: ZSchema[T] =
const(StandardType.StringType, compiletime.constValue[T])

inline implicit def constIntZSchema[T <: Int]: ZSchema[T] =
const(StandardType.IntType, compiletime.constValue[T])

inline implicit def constLongZSchema[T <: Long]: ZSchema[T] =
const(StandardType.LongType, compiletime.constValue[T])

inline implicit def constDoubleZSchema[T <: Double]: ZSchema[T] =
const(StandardType.DoubleType, compiletime.constValue[T])

inline implicit def constFloatZSchema[T <: Float]: ZSchema[T] =
const(StandardType.FloatType, compiletime.constValue[T])

inline implicit def constBoolZSchema[T <: Boolean]: ZSchema[T] =
const(StandardType.BoolType, compiletime.constValue[T])

private def const[T](t: StandardType[_], v: Any): ZSchema[T] =
ZSchema.Primitive(t, Chunk(Schema.Const(v))).asInstanceOf[ZSchema[T]]
}
16 changes: 8 additions & 8 deletions kyo-llm-macros/shared/src/main/scala/kyo/llm/json/Json.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ package kyo.llm.json
import kyo._
import kyo.ios._
import zio.schema.codec.JsonCodec
import zio.schema._
import zio.schema.{Schema => ZSchema, _}
import zio.Chunk

trait Json[T] {
def schema: JsonSchema
def schema: Schema
def encode(v: T): String > IOs
def decode(s: String): T > IOs
}

object Json extends JsonDerive {

def schema[T](implicit j: Json[T]): JsonSchema =
def schema[T](implicit j: Json[T]): Schema =
j.schema

def encode[T](v: T)(implicit j: Json[T]): String > IOs =
Expand All @@ -24,13 +24,13 @@ object Json extends JsonDerive {
j.decode(s)

implicit def primitive[T](implicit t: StandardType[T]): Json[T] =
fromZio(Schema.Primitive(t, Chunk.empty))
fromZio(ZSchema.Primitive(t, Chunk.empty))

def fromZio[T](z: Schema[T]) =
def fromZio[T](z: ZSchema[T]) =
new Json[T] {
lazy val schema: JsonSchema = JsonSchema(z)
private lazy val decoder = JsonCodec.jsonDecoder(z)
private lazy val encoder = JsonCodec.jsonEncoder(z)
lazy val schema: Schema = Schema(z)
private lazy val decoder = JsonCodec.jsonDecoder(z)
private lazy val encoder = JsonCodec.jsonEncoder(z)

def encode(v: T): String > IOs =
IOs(encoder.encodeJson(v).toString)
Expand Down
70 changes: 46 additions & 24 deletions kyo-llm-macros/shared/src/main/scala/kyo/llm/json/JsonSchema.scala
Original file line number Diff line number Diff line change
@@ -1,67 +1,89 @@
package kyo.llm.json

import zio.schema._
import zio.schema.{Schema => ZSchema, _}
import zio.json._
import zio.json.ast._
import zio.json.internal.Write
import scala.annotation.StaticAnnotation
import zio.Chunk

case class JsonSchema(data: List[(String, Json)])
case class Schema(data: List[(String, Json)])

object JsonSchema {
object Schema {

implicit val jsonSchemaEncoder: JsonEncoder[JsonSchema] = new JsonEncoder[JsonSchema] {
override def unsafeEncode(js: JsonSchema, indent: Option[Int], out: Write): Unit = {
case class Const[T](v: T)

implicit val jsonSchemaEncoder: JsonEncoder[Schema] = new JsonEncoder[Schema] {
override def unsafeEncode(js: Schema, indent: Option[Int], out: Write): Unit = {
implicitly[JsonEncoder[Json.Obj]].unsafeEncode(Json.Obj(js.data.toSeq: _*), indent, out)
}
}

def apply(schema: Schema[_]): JsonSchema =
new JsonSchema(convert(schema))
def apply(schema: ZSchema[_]): Schema =
new Schema(convert(schema))

def desc(c: Chunk[Any]): List[(String, Json)] =
c.collect {
case desc(v) =>
"description" -> Json.Str(v)
}.distinct.toList

def convert(schema: Schema[_]): List[(String, Json)] = {
def convert(schema: ZSchema[_]): List[(String, Json)] = {
def desc = this.desc(schema.annotations)
schema match {
case Schema.Primitive(StandardType.StringType, _) =>

case ZSchema.Primitive(StandardType.StringType, Chunk(Const(v))) =>
desc ++ List("const" -> Json.Str(v.asInstanceOf[String]))

case ZSchema.Primitive(StandardType.StringType, _) =>
desc ++ List("type" -> Json.Str("string"))

case Schema.Primitive(StandardType.IntType, _) =>
case ZSchema.Primitive(StandardType.IntType, Chunk(Const(v))) =>
desc ++ List("const" -> Json.Num(v.asInstanceOf[Int]))

case ZSchema.Primitive(StandardType.IntType, _) =>
desc ++ List("type" -> Json.Str("integer"), "format" -> Json.Str("int32"))

case Schema.Primitive(StandardType.LongType, _) =>
case ZSchema.Primitive(StandardType.LongType, Chunk(Const(v))) =>
desc ++ List("const" -> Json.Num(v.asInstanceOf[Long]))

case ZSchema.Primitive(StandardType.LongType, _) =>
desc ++ List("type" -> Json.Str("integer"), "format" -> Json.Str("int64"))

case Schema.Primitive(StandardType.DoubleType, _) =>
case ZSchema.Primitive(StandardType.DoubleType, Chunk(Const(v))) =>
desc ++ List("const" -> Json.Num(v.asInstanceOf[Double]))

case ZSchema.Primitive(StandardType.DoubleType, _) =>
desc ++ List("type" -> Json.Str("number"))

case Schema.Primitive(StandardType.FloatType, _) =>
case ZSchema.Primitive(StandardType.FloatType, Chunk(Const(v))) =>
desc ++ List("const" -> Json.Num(v.asInstanceOf[Float]))

case ZSchema.Primitive(StandardType.FloatType, _) =>
desc ++ List("type" -> Json.Str("number"), "format" -> Json.Str("float"))

case Schema.Primitive(StandardType.BoolType, _) =>
case ZSchema.Primitive(StandardType.BoolType, Chunk(Const(v))) =>
desc ++ List("const" -> Json.Bool(v.asInstanceOf[Boolean]))

case ZSchema.Primitive(StandardType.BoolType, _) =>
desc ++ List("type" -> Json.Str("boolean"))

case Schema.Optional(innerSchema, _) =>
case ZSchema.Optional(innerSchema, _) =>
convert(innerSchema)

case Schema.Sequence(innerSchema, _, _, _, _) =>
case ZSchema.Sequence(innerSchema, _, _, _, _) =>
List("type" -> Json.Str("array"), "items" -> Json.Obj(convert(innerSchema): _*))

case schema: Schema.Enum[_] =>
case schema: ZSchema.Enum[_] =>
val cases = schema.cases.map { c =>
val caseProperties = c.schema match {
case record: Schema.Record[_] =>
case record: ZSchema.Record[_] =>
val fields = record.fields.map { field =>
field.name -> Json.Obj(convert(field.schema): _*)
}
val requiredFields = record.fields.collect {
case field if !field.schema.isInstanceOf[Schema.Optional[_]] => Json.Str(field.name)
case field if !field.schema.isInstanceOf[ZSchema.Optional[_]] =>
Json.Str(field.name)
}
Json.Obj(
"type" -> Json.Str("object"),
Expand All @@ -78,24 +100,24 @@ object JsonSchema {
"properties" -> Json.Obj(cases: _*)
)

case schema: Schema.Record[_] =>
case schema: ZSchema.Record[_] =>
val properties = schema.fields.foldLeft(List.empty[(String, Json)]) { (acc, field) =>
acc :+ (field.name -> Json.Obj(
(this.desc(field.annotations) ++ convert(field.schema)): _*
))
}
val requiredFields = schema.fields.collect {
case field if !field.schema.isInstanceOf[Schema.Optional[_]] => Json.Str(field.name)
case field if !field.schema.isInstanceOf[ZSchema.Optional[_]] => Json.Str(field.name)
}
desc ++ List(
"type" -> Json.Str("object"),
"properties" -> Json.Obj(properties.toSeq: _*),
"required" -> Json.Arr(requiredFields: _*)
)

case Schema.Map(keySchema, valueSchema, _) =>
case ZSchema.Map(keySchema, valueSchema, _) =>
keySchema match {
case Schema.Primitive(tpe, _) if (tpe == StandardType.StringType) =>
case ZSchema.Primitive(tpe, _) if (tpe == StandardType.StringType) =>
List(
"type" -> Json.Str("object"),
"additionalProperties" -> Json.Obj(convert(valueSchema): _*)
Expand All @@ -104,7 +126,7 @@ object JsonSchema {
throw new UnsupportedOperationException("Non-string map keys are not supported")
}

case schema: Schema.Lazy[_] =>
case schema: ZSchema.Lazy[_] =>
convert(schema.schema)

case _ =>
Expand Down
2 changes: 1 addition & 1 deletion kyo-llm/shared/src/main/scala/kyo/llm/completions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object completions {
case class FunctionCall(arguments: String, name: String)
case class ToolCall(id: String, function: FunctionCall, `type`: String = "function")

case class FunctionDef(description: String, name: String, parameters: JsonSchema)
case class FunctionDef(description: String, name: String, parameters: Schema)
case class ToolDef(function: FunctionDef, `type`: String = "function")

sealed trait Entry
Expand Down
38 changes: 32 additions & 6 deletions kyo-llm/shared/src/main/scala/kyo/llm/thoughts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,42 @@ object thoughts {
)
}

case class NonEmptyString(
`I understand I can not generate an empty string in the next field`: Boolean,
nonEmptyString: String
) {
def string: String = nonEmptyString
}
type NonEmpty = "Non-empty"

case class Constrain[T, C <: String](
@desc("Constraints to consider when generating the value")
constraints: C,
value: T
)
}

// object tt extends KyoLLMApp {

// import thoughts._
// import kyo.llm.ais._

// case class CQA(
// @desc("Excerpt from the input text")
// excerpt: Constrain[String, NonEmpty],
// @desc("An elaborate question regarding the excerpt")
// question: Constrain[String, NonEmpty],
// @desc("A comprehensive answer")
// answer: Constrain[String, NonEmpty]
// )
// case class Req(
// reasoning: Reasoning,
// @desc("Comprehensive set of questions covering all information in the input text")
// questions: Collect[Constrain[String, NonEmpty]],
// @desc("Process each question")
// processed: Collect[CQA]
// )

// run {
// AIs.gen[Req](text)
// }

// def text =
// p"""
// General relativity is a theory of gravitation developed by Einstein in the years 1907–1915. The development of general relativity began with the equivalence principle, under which the states of accelerated motion and being at rest in a gravitational field (for example, when standing on the surface of the Earth) are physically identical. The upshot of this is that free fall is inertial motion: an object in free fall is falling because that is how objects move when there is no force being exerted on them, instead of this being due to the force of gravity as is the case in classical mechanics. This is incompatible with classical mechanics and special relativity because in those theories inertially moving objects cannot accelerate with respect to each other, but objects in free fall do so. To resolve this difficulty Einstein first proposed that spacetime is curved. Einstein discussed his idea with mathematician Marcel Grossmann and they concluded that general relativity could be formulated in the context of Riemannian geometry which had been developed in the 1800s.[10] In 1915, he devised the Einstein field equations which relate the curvature of spacetime with the mass, energy, and any momentum within it.
// """
// }

0 comments on commit 6734ca9

Please sign in to comment.