Skip to content

Commit

Permalink
Fixes #5: Switch to SortedMap for generic record schemas (#55)
Browse files Browse the repository at this point in the history
* Fixes #5: Use SortedMap for generic Record and Enumeration schemas

Add UnorderedRecord for case classes with arity > 22 and CaseObject for objects

Use ListMap for case class encoding to guarantee preserved insert order

Use 2.12 compatible SortedMap and ListMap

* Resolve conflicts after rebase

* Use ListMap for generic structure encoding

* Merge upstream/main

* work around yet another jdk8 bug

Co-authored-by: thinkharder <[email protected]>
  • Loading branch information
thinkharderdev and thinkharderdev authored May 8, 2021
1 parent fb643ba commit fb5e766
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 201 deletions.
165 changes: 96 additions & 69 deletions zio-schema/shared/src/main/scala/zio/schema/Schema.scala

Large diffs are not rendered by default.

41 changes: 30 additions & 11 deletions zio-schema/shared/src/main/scala/zio/schema/codec/JsonCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package zio.schema.codec
import java.nio.CharBuffer
import java.nio.charset.StandardCharsets

import scala.collection.immutable.ListMap

import zio.json.JsonCodec._
import zio.json.JsonDecoder.{ JsonError, UnsafeJson }
import zio.json.internal.{ Lexer, RetractReader, StringMatrix, Write }
Expand Down Expand Up @@ -45,6 +47,10 @@ object JsonCodec extends Codec {

protected[codec] val unitCodec: ZJsonCodec[Unit] = ZJsonCodec(unitEncoder, unitDecoder)

protected[codec] def objectEncoder[Z]: JsonEncoder[Z] = { (_: Z, _: Option[Int], out: Write) =>
out.write("{}")
}

protected[codec] def failDecoder[A](message: String): JsonDecoder[A] =
(trace: List[JsonDecoder.JsonError], _: RetractReader) => throw UnsafeJson(JsonError.Message(message) :: trace)

Expand Down Expand Up @@ -103,9 +109,10 @@ object JsonCodec extends Codec {
case Schema.Tuple(l, r) => schemaEncoder(l).both(schemaEncoder(r))
case Schema.Optional(schema) => JsonEncoder.option(schemaEncoder(schema))
case Schema.Fail(_) => unitEncoder.contramap(_ => ())
case Schema.Record(structure) => recordEncoder(structure)
case Schema.GenericRecord(structure) => recordEncoder(structure)
case Schema.Enumeration(structure) => enumerationEncoder(structure)
case EitherSchema(left, right) => JsonEncoder.either(schemaEncoder(left), schemaEncoder(right))
case Schema.CaseObject(_) => objectEncoder[A]
case Schema.CaseClass1(f, _, ext) => caseClassEncoder(f -> ext)
case Schema.CaseClass2(f1, f2, _, ext1, ext2) => caseClassEncoder(f1 -> ext1, f2 -> ext2)
case Schema.CaseClass3(f1, f2, f3, _, ext1, ext2, ext3) => caseClassEncoder(f1 -> ext1, f2 -> ext2, f3 -> ext3)
Expand Down Expand Up @@ -879,8 +886,8 @@ object JsonCodec extends Codec {
}
}

private def recordEncoder(structure: Map[String, Schema[_]]): JsonEncoder[Map[String, _]] = {
(value: Map[String, _], indent: Option[Int], out: Write) =>
private def recordEncoder(structure: ListMap[String, Schema[_]]): JsonEncoder[ListMap[String, _]] = {
(value: ListMap[String, _], indent: Option[Int], out: Write) =>
{
if (structure.isEmpty) {
out.write("{}")
Expand Down Expand Up @@ -910,8 +917,8 @@ object JsonCodec extends Codec {
}
}

private def enumerationEncoder(structure: Map[String, Schema[_]]): JsonEncoder[Map[String, _]] = {
(a: Map[String, _], indent: Option[Int], out: Write) =>
private def enumerationEncoder(structure: ListMap[String, Schema[_]]): JsonEncoder[ListMap[String, _]] = {
(a: ListMap[String, _], indent: Option[Int], out: Write) =>
{
if (structure.isEmpty) {
out.write("{}")
Expand Down Expand Up @@ -956,9 +963,10 @@ object JsonCodec extends Codec {
case Schema.Transform(codec, f, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _) => JsonDecoder.chunk(schemaDecoder(codec)).map(f)
case Schema.Fail(message) => failDecoder(message)
case Schema.Record(structure) => recordDecoder(structure)
case Schema.GenericRecord(structure) => recordDecoder(structure)
case Schema.Enumeration(structure) => enumerationDecoder(structure)
case EitherSchema(left, right) => JsonDecoder.either(schemaDecoder(left), schemaDecoder(right))
case Schema.CaseObject(instance) => caseObjectDecoder(instance)
case s @ Schema.CaseClass1(_, _, _) => caseClass1Decoder(s)
case s @ Schema.CaseClass2(_, _, _, _, _) => caseClass2Decoder(s)
case s @ Schema.CaseClass3(_, _, _, _, _, _, _) => caseClass3Decoder(s)
Expand Down Expand Up @@ -1302,6 +1310,16 @@ object JsonCodec extends Codec {
}
}

private def caseObjectDecoder[Z](instance: Z): JsonDecoder[Z] = { (trace: List[JsonError], in: RetractReader) =>
{
Lexer.char(trace, in, '{')
if (!Lexer.firstField(trace, in))
instance
else
throw UnsafeJson(JsonError.Message("invalid field") :: trace)
}
}

private def caseClass1Decoder[A, Z](schema: Schema.CaseClass1[A, Z]): JsonDecoder[Z] = {
(trace: List[JsonError], in: RetractReader) =>
{
Expand Down Expand Up @@ -2267,24 +2285,25 @@ object JsonCodec extends Codec {
buffer
}

private def recordDecoder(structure: Map[String, Schema[_]]): JsonDecoder[Map[String, Any]] = {
private def recordDecoder(structure: ListMap[String, Schema[_]]): JsonDecoder[ListMap[String, Any]] = {
(trace: List[JsonError], in: RetractReader) =>
{
val builder: ChunkBuilder[(String, Any)] = zio.ChunkBuilder.make[(String, Any)](structure.size)
Lexer.char(trace, in, '{')
if (Lexer.firstField(trace, in))
if (Lexer.firstField(trace, in)) {
do {
val field = Lexer.string(trace, in).toString
val trace_ = JsonError.ObjectAccess(field) :: trace
Lexer.char(trace_, in, ':')
val value = schemaDecoder(structure(field)).unsafeDecode(trace_, in)
builder += ((JsonFieldDecoder.string.unsafeDecodeField(trace_, field), value))
} while (Lexer.nextField(trace, in))
builder.result().toMap
}
(ListMap.newBuilder[String, Any] ++= builder.result()).result()
}
}

private def enumerationDecoder(structure: Map[String, Schema[_]]): JsonDecoder[Map[String, Any]] = {
private def enumerationDecoder(structure: ListMap[String, Schema[_]]): JsonDecoder[ListMap[String, Any]] = {
(trace: List[JsonError], in: RetractReader) =>
{
val builder: ChunkBuilder[(String, Any)] = zio.ChunkBuilder.make[(String, Any)](structure.size)
Expand All @@ -2296,7 +2315,7 @@ object JsonCodec extends Codec {
val value = schemaDecoder(structure(field)).unsafeDecode(trace_, in)
builder += ((JsonFieldDecoder.string.unsafeDecodeField(trace_, field), value))
}
builder.result().toMap
(ListMap.newBuilder[String, Any] ++= builder.result()).result()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import java.nio.{ ByteBuffer, ByteOrder }
import java.time._

import scala.annotation.tailrec
import scala.collection.immutable.ListMap

import zio.schema._
import zio.schema.codec.ProtobufCodec.Protobuf.WireType.LengthDelimited
Expand Down Expand Up @@ -47,11 +48,11 @@ object ProtobufCodec extends Codec {
}

def flatFields(
structure: Map[String, Schema[_]],
structure: ListMap[String, Schema[_]],
nextFieldNumber: Int = 1
): Map[Int, (String, Schema[_])] =
): ListMap[Int, (String, Schema[_])] =
structure.toSeq
.foldLeft((nextFieldNumber, Map[Int, (String, Schema[_])]())) { (numAndMap, fieldAndSchema) =>
.foldLeft((nextFieldNumber, ListMap[Int, (String, Schema[_])]())) { (numAndMap, fieldAndSchema) =>
nestedFields(fieldAndSchema._1, fieldAndSchema._2, nextFieldNumber) match {
case Some(fields) => (numAndMap._1 + fields.size, numAndMap._2 ++ fields)
case None => (numAndMap._1 + 1, numAndMap._2 + (numAndMap._1 -> fieldAndSchema))
Expand All @@ -63,7 +64,7 @@ object ProtobufCodec extends Codec {
baseField: String,
schema: Schema[_],
nextFieldNumber: Int
): Option[Map[Int, (String, Schema[_])]] =
): Option[ListMap[Int, (String, Schema[_])]] =
schema match {
case Schema.Transform(codec, f, g) =>
nestedFields(baseField, codec, nextFieldNumber).map(_.map {
Expand All @@ -73,32 +74,31 @@ object ProtobufCodec extends Codec {
case _ => None
}

def tupleSchema[A, B](first: Schema[A], second: Schema[B]): Schema[Map[String, _]] =
Schema.record(Map("first" -> first, "second" -> second))
def tupleSchema[A, B](first: Schema[A], second: Schema[B]): Schema[ListMap[String, _]] =
Schema.record(ListMap("first" -> first, "second" -> second))

def singleSchema[A](codec: Schema[A]): Schema[Map[String, _]] = Schema.record(Map("value" -> codec))
def singleSchema[A](codec: Schema[A]): Schema[ListMap[String, _]] = Schema.record(ListMap("value" -> codec))

def monthDayStructure(): Map[String, Schema[Int]] =
Map("month" -> Schema.Primitive(StandardType.IntType), "day" -> Schema.Primitive(StandardType.IntType))
def monthDayStructure(): ListMap[String, Schema[Int]] =
ListMap("month" -> Schema.Primitive(StandardType.IntType), "day" -> Schema.Primitive(StandardType.IntType))

def periodStructure(): Map[String, Schema[Int]] = Map(
def periodStructure(): ListMap[String, Schema[Int]] = ListMap(
"years" -> Schema.Primitive(StandardType.IntType),
"months" -> Schema.Primitive(StandardType.IntType),
"days" -> Schema.Primitive(StandardType.IntType)
)

def yearMonthStructure(): Map[String, Schema[Int]] =
Map("year" -> Schema.Primitive(StandardType.IntType), "month" -> Schema.Primitive(StandardType.IntType))
def yearMonthStructure(): ListMap[String, Schema[Int]] =
ListMap("year" -> Schema.Primitive(StandardType.IntType), "month" -> Schema.Primitive(StandardType.IntType))

def durationStructure(): Map[String, Schema[_]] =
Map("seconds" -> Schema.Primitive(StandardType.LongType), "nanos" -> Schema.Primitive(StandardType.IntType))
def durationStructure(): ListMap[String, Schema[_]] =
ListMap("seconds" -> Schema.Primitive(StandardType.LongType), "nanos" -> Schema.Primitive(StandardType.IntType))

/**
* Used when encoding sequence of values to decide whether each value need its own key or values can be packed together without keys (for example numbers).
*/
@scala.annotation.tailrec
def canBePacked(schema: Schema[_]): Boolean = schema match {
case _: Schema.Record => false
case Schema.Sequence(element, _, _) => canBePacked(element)
case _: Schema.Enumeration => false
case Schema.Transform(codec, _, _) => canBePacked(codec)
Expand Down Expand Up @@ -149,16 +149,17 @@ object ProtobufCodec extends Codec {

def encode[A](fieldNumber: Option[Int], schema: Schema[A], value: A): Chunk[Byte] =
(schema, value) match {
case (Schema.Record(structure), v: Map[String, _]) => encodeRecord(fieldNumber, structure, v)
case (Schema.Sequence(element, _, g), v) => encodeSequence(fieldNumber, element, g(v))
case (Schema.Enumeration(structure), v: Map[String, _]) => encodeEnumeration(fieldNumber, structure, v)
case (Schema.Transform(codec, _, g), _) => g(value).map(encode(fieldNumber, codec, _)).getOrElse(Chunk.empty)
case (Schema.Primitive(standardType), v) => encodePrimitive(fieldNumber, standardType, v)
case (Schema.Tuple(left, right), v @ (_, _)) => encodeTuple(fieldNumber, left, right, v)
case (Schema.Optional(codec), v: Option[_]) => encodeOptional(fieldNumber, codec, v)
case (Schema.EitherSchema(left, right), v: Either[_, _]) => encodeEither(fieldNumber, left, right, v)
case (Schema.CaseClass1(f, _, ext), v) => encodeCaseClass(fieldNumber, v, f -> ext)
case (Schema.CaseClass2(f1, f2, _, ext1, ext2), v) => encodeCaseClass(fieldNumber, v, f1 -> ext1, f2 -> ext2)
case (Schema.GenericRecord(structure), v: Map[String, _]) => encodeRecord(fieldNumber, structure, v)
case (Schema.Sequence(element, _, g), v) => encodeSequence(fieldNumber, element, g(v))
case (Schema.Enumeration(structure), v: Map[String, _]) => encodeEnumeration(fieldNumber, structure, v)
case (Schema.Transform(codec, _, g), _) => g(value).map(encode(fieldNumber, codec, _)).getOrElse(Chunk.empty)
case (Schema.Primitive(standardType), v) => encodePrimitive(fieldNumber, standardType, v)
case (Schema.Tuple(left, right), v @ (_, _)) => encodeTuple(fieldNumber, left, right, v)
case (Schema.Optional(codec), v: Option[_]) => encodeOptional(fieldNumber, codec, v)
case (Schema.EitherSchema(left, right), v: Either[_, _]) => encodeEither(fieldNumber, left, right, v)
case (Schema.CaseObject(_), _) => encodeCaseObject(fieldNumber)
case (Schema.CaseClass1(f, _, ext), v) => encodeCaseClass(fieldNumber, v, f -> ext)
case (Schema.CaseClass2(f1, f2, _, ext1, ext2), v) => encodeCaseClass(fieldNumber, v, f1 -> ext1, f2 -> ext2)
case (Schema.CaseClass3(f1, f2, f3, _, ext1, ext2, ext3), v) =>
encodeCaseClass(fieldNumber, v, f1 -> ext1, f2 -> ext2, f3 -> ext3)
case (Schema.CaseClass4(f1, f2, f3, f4, _, ext1, ext2, ext3, ext4), v) =>
Expand Down Expand Up @@ -1010,6 +1011,9 @@ object ProtobufCodec extends Codec {
encodeKey(WireType.LengthDelimited(encoded.size), fieldNumber) ++ encoded
}

private def encodeCaseObject[Z](fieldNumber: Option[Int]): Chunk[Byte] =
encodeKey(WireType.LengthDelimited(0), fieldNumber)

private def encodeCaseClass[Z](
fieldNumber: Option[Int],
value: Z,
Expand All @@ -1028,8 +1032,8 @@ object ProtobufCodec extends Codec {

private def encodeRecord(
fieldNumber: Option[Int],
structure: Map[String, Schema[_]],
data: Map[String, _]
structure: ListMap[String, Schema[_]],
data: ListMap[String, _]
): Chunk[Byte] = {
val encodedRecord = Chunk
.fromIterable(flatFields(structure).toSeq.map {
Expand Down Expand Up @@ -1115,23 +1119,23 @@ object ProtobufCodec extends Codec {
case (StandardType.Month, v: Month) =>
encodePrimitive(fieldNumber, StandardType.IntType, v.getValue)
case (StandardType.MonthDay, v: MonthDay) =>
encodeRecord(fieldNumber, monthDayStructure(), Map("month" -> v.getMonthValue, "day" -> v.getDayOfMonth))
encodeRecord(fieldNumber, monthDayStructure(), ListMap("month" -> v.getMonthValue, "day" -> v.getDayOfMonth))
case (StandardType.Period, v: Period) =>
encodeRecord(
fieldNumber,
periodStructure(),
Map("years" -> v.getYears, "months" -> v.getMonths, "days" -> v.getDays)
ListMap("years" -> v.getYears, "months" -> v.getMonths, "days" -> v.getDays)
)
case (StandardType.Year, v: Year) =>
encodePrimitive(fieldNumber, StandardType.IntType, v.getValue)
case (StandardType.YearMonth, v: YearMonth) =>
encodeRecord(fieldNumber, yearMonthStructure(), Map("year" -> v.getYear, "month" -> v.getMonthValue))
encodeRecord(fieldNumber, yearMonthStructure(), ListMap("year" -> v.getYear, "month" -> v.getMonthValue))
case (StandardType.ZoneId, v: ZoneId) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.getId)
case (StandardType.ZoneOffset, v: ZoneOffset) =>
encodePrimitive(fieldNumber, StandardType.IntType, v.getTotalSeconds)
case (StandardType.Duration(_), v: Duration) =>
encodeRecord(fieldNumber, durationStructure(), Map("seconds" -> v.getSeconds, "nanos" -> v.getNano))
encodeRecord(fieldNumber, durationStructure(), ListMap("seconds" -> v.getSeconds, "nanos" -> v.getNano))
case (StandardType.Instant(formatter), v: Instant) =>
encodePrimitive(fieldNumber, StandardType.StringType, formatter.format(v))
case (StandardType.LocalDate(formatter), v: LocalDate) =>
Expand Down Expand Up @@ -1159,7 +1163,7 @@ object ProtobufCodec extends Codec {
encode(
fieldNumber,
tupleSchema(left, right),
Map[String, Any]("first" -> tuple._1, "second" -> tuple._2)
ListMap[String, Any]("first" -> tuple._1, "second" -> tuple._2)
)

private def encodeEither[A, B](
Expand All @@ -1182,7 +1186,7 @@ object ProtobufCodec extends Codec {
encode(
fieldNumber,
singleSchema(schema),
Map("value" -> v)
ListMap("value" -> v)
)
case None => Chunk.empty
}
Expand Down Expand Up @@ -1292,7 +1296,7 @@ object ProtobufCodec extends Codec {

private def decoder[A](schema: Schema[A]): Decoder[A] =
schema match {
case Schema.Record(structure) => recordDecoder(flatFields(structure))
case Schema.GenericRecord(structure) => recordDecoder(flatFields(structure))
case Schema.Sequence(element, f, _) =>
if (canBePacked(element)) packedSequenceDecoder(element).map(f) else nonPackedSequenceDecoder(element).map(f)
case Schema.Enumeration(structure) => enumerationDecoder(flatFields(structure))
Expand All @@ -1302,6 +1306,7 @@ object ProtobufCodec extends Codec {
case Schema.Optional(codec) => optionalDecoder(codec)
case Schema.Fail(message) => fail(message)
case Schema.EitherSchema(left, right) => eitherDecoder(left, right)
case Schema.CaseObject(instance) => caseObjectDecoder(instance)
case s: Schema.CaseClass1[_, A] => caseClass1Decoder(s)
case s: Schema.CaseClass2[_, _, A] => caseClass2Decoder(s)
case s: Schema.CaseClass3[_, _, _, A] => caseClass3Decoder(s)
Expand Down Expand Up @@ -1350,6 +1355,11 @@ object ProtobufCodec extends Codec {
fail(s"Schema doesn't contain field number $fieldNumber.")
}

private def caseObjectDecoder[Z](instance: Z): Decoder[Z] = keyDecoder.flatMap {
case (LengthDelimited(0), _) => succeed(instance)
case _ => fail(s"Expected length-delimited field with length 0")
}

private def caseClass1Decoder[A, Z](schema: Schema.CaseClass1[A, Z]): Decoder[Z] =
unsafeDecodeFields(Array.ofDim[Any](1), schema.field).flatMap { buffer =>
if (buffer(0) == null)
Expand Down Expand Up @@ -2283,21 +2293,21 @@ object ProtobufCodec extends Codec {
}
}

private def enumerationDecoder(fields: Map[Int, (String, Schema[_])]): Decoder[Map[String, _]] =
private def enumerationDecoder(fields: ListMap[Int, (String, Schema[_])]): Decoder[ListMap[String, _]] =
keyDecoder.flatMap {
case (_, fieldNumber) =>
if (fields.contains(fieldNumber)) {
val (fieldName, schema) = fields(fieldNumber)

decoder(schema).map(fieldValue => Map(fieldName -> fieldValue))
decoder(schema).map(fieldValue => ListMap(fieldName -> fieldValue))
} else {
fail(s"Schema doesn't contain field number $fieldNumber.")
}
}

private def recordDecoder(fields: Map[Int, (String, Schema[_])]): Decoder[Map[String, _]] =
private def recordDecoder(fields: ListMap[Int, (String, Schema[_])]): Decoder[ListMap[String, _]] =
if (fields.isEmpty)
Decoder.succeed(Map())
Decoder.succeed(ListMap.empty)
else
keyDecoder.flatMap {
case (wt, fieldNumber) =>
Expand Down
Loading

0 comments on commit fb5e766

Please sign in to comment.