Skip to content

Commit

Permalink
Fix compilation of Recursive GADT schema derivation (#561)
Browse files Browse the repository at this point in the history
* Fix compilation of GADT schema derivation on scala 2

* Fix deriving schema for generic types on scala 3

* Add tests for generically deriving schemas for enums

* Update readme

* Fix

---------

Co-authored-by: Daniel Vigovszky <[email protected]>
  • Loading branch information
Jesse-Bakker and vigoo authored Nov 18, 2023
1 parent b0b7062 commit 7610e60
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ object DeriveSchema {

val JavaAnnotationTpe = typeOf[java.lang.annotation.Annotation]

lazy val optionType = typeOf[Option[_]]
lazy val listType = typeOf[List[_]]
lazy val setType = typeOf[Set[_]]
lazy val vectorType = typeOf[Vector[_]]
lazy val chunkType = typeOf[Chunk[_]]
lazy val eitherType = typeOf[Either[_, _]]
lazy val tuple2Type = typeOf[(_, _)]
lazy val tuple3Type = typeOf[(_, _, _)]
lazy val tuple4Type = typeOf[(_, _, _, _)]

val tpe = weakTypeOf[T]

def concreteType(seenFrom: Type, tpe: Type): Type =
Expand Down Expand Up @@ -62,7 +72,7 @@ object DeriveSchema {
s"Failed to derive schema for $tpe. Can only derive Schema for case class or sealed trait"
)

def directInferSchema(parentType: Type, schemaType: Type, stack: List[Frame[c.type]]): Tree =
def directInferSchema(parentType: Type, schemaType: Type, stack: List[Frame[c.type]]): Tree = {
stack
.find(_.tpe =:= schemaType)
.map {
Expand All @@ -83,20 +93,20 @@ object DeriveSchema {
case Nil =>
recurse(schemaType, stack)
case typeArg1 :: Nil =>
if (schemaType <:< c.typeOf[Option[_]])
if (schemaType <:< optionType)
q"_root_.zio.schema.Schema.option(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))"
else if (schemaType <:< typeOf[List[_]])
else if (schemaType <:< listType)
q"_root_.zio.schema.Schema.list(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))"
else if (schemaType <:< typeOf[Set[_]])
else if (schemaType <:< setType)
q"_root_.zio.schema.Schema.set(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))"
else if (schemaType <:< typeOf[Vector[_]])
else if (schemaType <:< vectorType)
q"_root_.zio.schema.Schema.vector(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))"
else if (schemaType <:< typeOf[Chunk[_]])
else if (schemaType <:< chunkType)
q"_root_.zio.schema.Schema.chunk(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))"
else
recurse(schemaType, stack)
case typeArg1 :: typeArg2 :: Nil =>
if (schemaType <:< typeOf[Either[_, _]])
if (schemaType <:< eitherType)
q"""_root_.zio.schema.Schema.either(
_root_.zio.schema.Schema.defer(${directInferSchema(
parentType,
Expand All @@ -110,7 +120,7 @@ object DeriveSchema {
)})
)
"""
else if (schemaType <:< typeOf[(_, _)])
else if (schemaType <:< tuple2Type)
q"""_root_.zio.schema.Schema.tuple2(
_root_.zio.schema.Schema.defer(${directInferSchema(
parentType,
Expand All @@ -127,7 +137,7 @@ object DeriveSchema {
else
recurse(schemaType, stack)
case typeArg1 :: typeArg2 :: typeArg3 :: Nil =>
if (schemaType <:< typeOf[(_, _, _)])
if (schemaType <:< tuple3Type)
q"""_root_.zio.schema.Schema.tuple3(
_root_.zio.schema.Schema.defer(${directInferSchema(
parentType,
Expand All @@ -150,7 +160,7 @@ object DeriveSchema {
else
recurse(schemaType, stack)
case typeArg1 :: typeArg2 :: typeArg3 :: typeArg4 :: Nil =>
if (schemaType <:< typeOf[(_, _, _)])
if (schemaType <:< tuple4Type)
q"""_root_.zio.schema.Schema.tuple4(
_root_.zio.schema.Schema.defer(${directInferSchema(
parentType,
Expand Down Expand Up @@ -183,6 +193,7 @@ object DeriveSchema {
}
}
}
}

def getFieldName(annotations: List[Tree]): Option[String] =
annotations.collectFirst {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,62 +52,57 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val result = stack.find(typeRepr) match {
case Some(ref) =>
'{ Schema.defer(${ref.asExprOf[Schema[T]]}) }
case None =>
val summoned = Expr.summon[Schema[T]]
case None =>
val summoned = if (!top) Expr.summon[Schema[T]] else None
if (!top && summoned.isDefined) {
'{ Schema.defer(${summoned.get}) }.asExprOf[Schema[T]]
'{
Schema.defer(${
summoned.get
})
}.asExprOf[Schema[T]]
} else {
typeRepr.asType match {
case '[List[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.list(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case '[scala.util.Either[a, b]] =>
val schemaA = deriveSchema[a](stack)
val schemaB = deriveSchema[b](stack)
'{ Schema.either(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]]
case '[Option[a]] =>
val schema = deriveSchema[a](stack)
// throw new Error(s"OPITOS ${schema.show}")
'{ Schema.option(Schema.defer($schema)) }.asExprOf[Schema[T]]
case '[scala.collection.Set[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.set(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case '[Vector[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.vector(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case '[scala.collection.Map[a, b]] =>
val schemaA = deriveSchema[a](stack)
val schemaB = deriveSchema[b](stack)
'{ Schema.map(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]]
case '[zio.Chunk[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.chunk(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case _ =>
val summoned = if (!top) Expr.summon[Schema[T]] else None
summoned match {
case Some(schema) =>
// println(s"FOR TYPE ${typeRepr.show}")
// println(s"STACK ${stack.find(typeRepr)}")
// println(s"Found schema ${schema.show}")
schema
case _ =>
Mirror(typeRepr) match {
case Some(mirror) =>
mirror.mirrorType match {
case MirrorType.Sum =>
deriveEnum[T](mirror, stack)
case MirrorType.Product =>
deriveCaseClass[T](mirror, stack, top)
}
case None =>
val sym = typeRepr.typeSymbol
if (sym.isClassDef && sym.flags.is(Flags.Module)) {
deriveCaseObject[T](stack, top)
}
else {
report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported")
}
}
typeRepr.asType match {
case '[List[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.list(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case '[scala.util.Either[a, b]] =>
val schemaA = deriveSchema[a](stack)
val schemaB = deriveSchema[b](stack)
'{ Schema.either(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]]
case '[Option[a]] =>
val schema = deriveSchema[a](stack)
// throw new Error(s"OPITOS ${schema.show}")
'{ Schema.option(Schema.defer($schema)) }.asExprOf[Schema[T]]
case '[scala.collection.Set[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.set(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case '[Vector[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.vector(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case '[scala.collection.Map[a, b]] =>
val schemaA = deriveSchema[a](stack)
val schemaB = deriveSchema[b](stack)
'{ Schema.map(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]]
case '[zio.Chunk[a]] =>
val schema = deriveSchema[a](stack)
'{ Schema.chunk(Schema.defer(${schema})) }.asExprOf[Schema[T]]
case _ =>
Mirror(typeRepr) match {
case Some(mirror) =>
mirror.mirrorType match {
case MirrorType.Sum =>
deriveEnum[T](mirror, stack)
case MirrorType.Product =>
deriveCaseClass[T](mirror, stack, top)
}
case None =>
val sym = typeRepr.typeSymbol
if (sym.isClassDef && sym.flags.is(Flags.Module)) {
deriveCaseObject[T](stack, top)
}
else {
report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
case class Leaf[A](value: A) extends Tree[A]
case object Root extends Tree[Nothing]

implicit def schema[A: Schema]: Schema[Tree[A]] = DeriveSchema.gen[Tree[A]]
}

sealed trait RBTree[+A, +B]
Expand All @@ -173,6 +175,8 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
case class Branch[A, B](left: RBTree[A, B], right: RBTree[A, B]) extends RBTree[A, B]
case class RLeaf[A](value: A) extends RBTree[A, Nothing]
case class BLeaf[B](value: B) extends RBTree[Nothing, B]

implicit def schema[A: Schema, B: Schema]: Schema[RBTree[A, B]] = DeriveSchema.gen[RBTree[A, B]]
}

sealed trait AdtWithTypeParameters[+Param1, +Param2]
Expand Down Expand Up @@ -400,10 +404,16 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
val derived: Schema[Tree[Recursive]] = DeriveSchema.gen[Tree[Recursive]]
assert(derived)(anything)
},
test("correctly derives generic recursive Enum") {
assert(Schema[Tree[Recursive]].toString)(not(containsString("null")) && not(equalTo("$Lazy$")))
},
test("correctly derives recursive Enum with multiple type parameters") {
val derived: Schema[RBTree[String, Int]] = DeriveSchema.gen[RBTree[String, Int]]
assert(derived)(anything)
},
test("correctly derives generic recursive Enum with multiple type parameters") {
assert(Schema[RBTree[String, Int]].toString)(not(containsString("null")) && not(equalTo("$Lazy$")))
},
test("correctly derives schema with unused type parameters") {
val derived: Schema[AdtWithTypeParameters[Int, Int]] = DeriveSchema.gen[AdtWithTypeParameters[Int, Int]]
assert(derived)(anything)
Expand Down

0 comments on commit 7610e60

Please sign in to comment.