Skip to content

Commit

Permalink
Merge msgpack serializers
Browse files Browse the repository at this point in the history
- There was very little performance difference between serializers so
  the `fast` serializer was entirely scrapped.
- The current serializer buffers the output in 4KiB segments before
  emitting it. This change brought a significant speedup.
  • Loading branch information
jarmuszz committed Sep 5, 2024
1 parent aa8658a commit cd9782e
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ class MsgPackItemSerializerBenchmarks {


@Benchmark
def compressed() =
def serialize() =
Stream
.emits(msgpackItems)
.through(fs2.data.msgpack.low.bytes[SyncIO](true, false))
.through(fs2.data.msgpack.low.bytes[SyncIO](false))
.compile
.drain
.unsafeRunSync()

@Benchmark
def fast() =
def withValidation() =
Stream
.emits(msgpackItems)
.through(fs2.data.msgpack.low.bytes[SyncIO](false, false))
.through(fs2.data.msgpack.low.bytes[SyncIO](true))
.compile
.drain
.unsafeRunSync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,181 +23,225 @@ package internal
import scodec.bits._

private[low] object ItemSerializer {
def compressed: MsgpackItem => ByteVector = {
class MalformedItemError extends Error("item exceeds the maximum size of it's format")
class MalformedStringError extends MalformedItemError
class MalformedBinError extends MalformedItemError
class MalformedIntError extends MalformedItemError
class MalformedUintError extends MalformedItemError

/** Checks whether integer `x` fits in `n` bytes. */
@inline
private def fitsIn(x: Int, n: Long): Boolean =
java.lang.Integer.compareUnsigned(x, (Math.pow(2, n.toDouble).toLong - 1).toInt) <= 0

private case class SerializationContext[F[_]](out: Out[F],
chunk: Chunk[MsgpackItem],
idx: Int,
rest: Stream[F, MsgpackItem])

/** Buffers [[Chunk]] into 4KiB segments before calling [[Pull.output]].
*
* @param contents buffered [[Chunk]]
*/
private class Out[F[_]](contents: Chunk[Byte]) {
private val limit = 4096

/** Pushes `bv` into the buffer and emits the buffer if it reaches the limit.
*/
@inline
def push(bv: ByteVector): Pull[F, Byte, Out[F]] =
if (contents.size >= limit)
Pull.output(contents).as(new Out(Chunk.byteVector(bv)))
else
Pull.done.as(new Out(contents ++ Chunk.byteVector(bv)))

/** Splices `bv` into segments and pushes them into the buffer while emitting the buffer at the same time so
* that it never exceeds the limit during the operation.
*
* Use this instead of [[Out.push]] when `bv` may significantly exceed 4KiB.
*/
def pushBuffered(bv: ByteVector): Pull[F, Byte, Out[F]] = {
@inline
def go(chunk: Chunk[Byte], rest: ByteVector): Pull[F, Byte, Out[F]] =
if (rest.isEmpty)
Pull.done.as(new Out(chunk))
else
Pull.output(chunk) >> go(Chunk.byteVector(rest.take(limit.toLong)), rest.drop(limit.toLong))

if (bv.isEmpty)
this.push(bv)
else if (contents.size >= limit)
Pull.output(contents) >> go(Chunk.byteVector(bv.take(limit.toLong)), bv.drop(limit.toLong))
else
go(contents ++ Chunk.byteVector(bv.take(limit.toLong - contents.size)), bv.drop(limit.toLong - contents.size))
}

/** Outputs the whole buffer. */
@inline
def flush = Pull.output(contents)
}

@inline
private def step[F[_]: RaiseThrowable](o: Out[F], item: MsgpackItem): Pull[F, Byte, Out[F]] = item match {
case MsgpackItem.UnsignedInt(bytes) =>
val bs = bytes.dropWhile(_ == 0)
if (bs.size <= 1)
ByteVector(Headers.Uint8).buffer ++ bs.padLeft(1)
o.push(ByteVector(Headers.Uint8) ++ bs.padLeft(1))
else if (bs.size <= 2)
ByteVector(Headers.Uint16).buffer ++ bs.padLeft(2)
o.push(ByteVector(Headers.Uint16) ++ bs.padLeft(2))
else if (bs.size <= 4)
ByteVector(Headers.Uint32).buffer ++ bs.padLeft(4)
o.push(ByteVector(Headers.Uint32) ++ bs.padLeft(4))
else if (bs.size <= 8)
o.push(ByteVector(Headers.Uint64) ++ bs.padLeft(8))
else
ByteVector(Headers.Uint64).buffer ++ bs.padLeft(8)
Pull.raiseError(new MalformedUintError)

case MsgpackItem.SignedInt(bytes) =>
val bs = bytes.dropWhile(_ == 0)
if (bs.size <= 1)
// positive fixint or negative fixint
if ((bs & hex"7f") == bs || (bs & hex"c0") == hex"c0")
bs.padLeft(1)
o.push(bs.padLeft(1))
else
ByteVector(Headers.Int8).buffer ++ bs.padLeft(1)
o.push(ByteVector(Headers.Int8) ++ bs.padLeft(1))
else if (bs.size <= 2)
ByteVector(Headers.Int16).buffer ++ bs.padLeft(2)
o.push(ByteVector(Headers.Int16) ++ bs.padLeft(2))
else if (bs.size <= 4)
ByteVector(Headers.Int32).buffer ++ bs.padLeft(4)
o.push(ByteVector(Headers.Int32) ++ bs.padLeft(4))
else if (bs.size <= 8)
o.push(ByteVector(Headers.Int64) ++ bs.padLeft(8))
else
ByteVector(Headers.Int64).buffer ++ bs.padLeft(8)
Pull.raiseError(new MalformedIntError)

case MsgpackItem.Float32(float) =>
ByteVector(Headers.Float32).buffer ++ ByteVector.fromInt(java.lang.Float.floatToIntBits(float))
o.push(ByteVector(Headers.Float32) ++ ByteVector.fromInt(java.lang.Float.floatToIntBits(float)))

case MsgpackItem.Float64(double) =>
ByteVector(Headers.Float64).buffer ++ ByteVector.fromLong(java.lang.Double.doubleToLongBits(double))
o.push(ByteVector(Headers.Float64) ++ ByteVector.fromLong(java.lang.Double.doubleToLongBits(double)))

case MsgpackItem.Str(bytes) =>
if (bytes.size <= 31) {
ByteVector.fromByte((0xa0 | bytes.size).toByte).buffer ++ bytes
o.push(ByteVector.fromByte((0xa0 | bytes.size).toByte) ++ bytes)
} else if (bytes.size <= Math.pow(2, 8) - 1) {
val size = ByteVector.fromByte(bytes.size.toByte)
ByteVector(Headers.Str8).buffer ++ size ++ bytes
o.push(ByteVector(Headers.Str8) ++ size ++ bytes)
} else if (bytes.size <= Math.pow(2, 16) - 1) {
val size = ByteVector.fromShort(bytes.size.toShort)
ByteVector(Headers.Str16).buffer ++ size ++ bytes
} else {
o.push(ByteVector(Headers.Str16) ++ size ++ bytes)
} else if (fitsIn(bytes.size.toInt, 32)) {
val size = ByteVector.fromInt(bytes.size.toInt)
ByteVector(Headers.Str32).buffer ++ size ++ bytes
/* Max length of str32 (incl. type and length info) is 2^32 + 4 bytes
* which is more than Chunk can handle at once
*/
o.pushBuffered(ByteVector(Headers.Str32) ++ size ++ bytes)
} else {
Pull.raiseError(new MalformedStringError)
}

case MsgpackItem.Bin(bytes) =>
if (bytes.size <= Math.pow(2, 8) - 1) {
val size = ByteVector.fromByte(bytes.size.toByte)
ByteVector(Headers.Bin8).buffer ++ size ++ bytes
o.push(ByteVector(Headers.Bin8) ++ size ++ bytes)
} else if (bytes.size <= Math.pow(2, 16) - 1) {
val size = ByteVector.fromShort(bytes.size.toShort)
ByteVector(Headers.Bin16).buffer ++ size ++ bytes
} else {
o.push(ByteVector(Headers.Bin16) ++ size ++ bytes)
} else if (fitsIn(bytes.size.toInt, 32)) {
val size = ByteVector.fromInt(bytes.size.toInt)
ByteVector(Headers.Bin32).buffer ++ size ++ bytes
/* Max length of str32 (incl. type and length info) is 2^32 + 4 bytes
* which is more than Chunk can handle at once
*/
o.pushBuffered(ByteVector(Headers.Bin32) ++ size ++ bytes)
} else {
Pull.raiseError(new MalformedBinError)
}

case MsgpackItem.Array(size) =>
if (size <= 15) {
ByteVector.fromByte((0x90 | size).toByte)
if (fitsIn(size, 4)) {
o.push(ByteVector.fromByte((0x90 | size).toByte))
} else if (size <= Math.pow(2, 16) - 1) {
val s = ByteVector.fromShort(size.toShort)
ByteVector(Headers.Array16).buffer ++ s
o.push(ByteVector(Headers.Array16) ++ s)
} else {
val s = ByteVector.fromInt(size)
ByteVector(Headers.Array32).buffer ++ s
o.push(ByteVector(Headers.Array32) ++ s)
}

case MsgpackItem.Map(size) =>
if (size <= 15) {
ByteVector.fromByte((0x80 | size).toByte)
o.push(ByteVector.fromByte((0x80 | size).toByte))
} else if (size <= Math.pow(2, 16) - 1) {
val s = ByteVector.fromShort(size.toShort)
ByteVector(Headers.Map16).buffer ++ s
o.push(ByteVector(Headers.Map16) ++ s)
} else {
val s = ByteVector.fromInt(size)
ByteVector(Headers.Map32).buffer ++ s
o.push(ByteVector(Headers.Map32) ++ s)
}

case MsgpackItem.Extension(tpe, bytes) =>
val bs = bytes.dropWhile(_ == 0)
if (bs.size <= 1) {
(ByteVector(Headers.FixExt1).buffer :+ tpe) ++ bs.padLeft(1)
o.push((ByteVector(Headers.FixExt1) :+ tpe) ++ bs.padLeft(1))
} else if (bs.size <= 2) {
(ByteVector(Headers.FixExt2).buffer :+ tpe) ++ bs.padLeft(2)
o.push((ByteVector(Headers.FixExt2) :+ tpe) ++ bs.padLeft(2))
} else if (bs.size <= 4) {
(ByteVector(Headers.FixExt4).buffer :+ tpe) ++ bs.padLeft(4)
o.push((ByteVector(Headers.FixExt4) :+ tpe) ++ bs.padLeft(4))
} else if (bs.size <= 8) {
(ByteVector(Headers.FixExt8).buffer :+ tpe) ++ bs.padLeft(8)
o.push((ByteVector(Headers.FixExt8) :+ tpe) ++ bs.padLeft(8))
} else if (bs.size <= 16) {
(ByteVector(Headers.FixExt16).buffer :+ tpe) ++ bs.padLeft(16)
o.push((ByteVector(Headers.FixExt16) :+ tpe) ++ bs.padLeft(16))
} else if (bs.size <= Math.pow(2, 8) - 1) {
val size = ByteVector.fromByte(bs.size.toByte)
(ByteVector(Headers.Ext8).buffer ++ size :+ tpe) ++ bs
o.push((ByteVector(Headers.Ext8) ++ size :+ tpe) ++ bs)
} else if (bs.size <= Math.pow(2, 16) - 1) {
val size = ByteVector.fromShort(bs.size.toShort)
(ByteVector(Headers.Ext16).buffer ++ size :+ tpe) ++ bs
o.push((ByteVector(Headers.Ext16) ++ size :+ tpe) ++ bs)
} else {
val size = ByteVector.fromInt(bs.size.toInt)
(ByteVector(Headers.Ext32).buffer ++ size :+ tpe) ++ bs
/* Max length of ext32 (incl. type and length info) is 2^32 + 5 bytes
* which is more than Chunk can handle at once.
*/
o.pushBuffered((ByteVector(Headers.Ext32) ++ size :+ tpe) ++ bs)
}

case MsgpackItem.Timestamp32(seconds) =>
(ByteVector(Headers.FixExt4).buffer :+ Headers.Timestamp.toByte) ++ ByteVector.fromInt(seconds)
o.push((ByteVector(Headers.FixExt4) :+ Headers.Timestamp.toByte) ++ ByteVector.fromInt(seconds))

case MsgpackItem.Timestamp64(combined) =>
(ByteVector(Headers.FixExt8).buffer :+ Headers.Timestamp.toByte) ++ ByteVector.fromLong(combined)
o.push((ByteVector(Headers.FixExt8) :+ Headers.Timestamp.toByte) ++ ByteVector.fromLong(combined))

case MsgpackItem.Timestamp96(nanoseconds, seconds) =>
val ns = ByteVector.fromInt(nanoseconds)
val s = ByteVector.fromLong(seconds)
(ByteVector(Headers.Ext8).buffer :+ 12 :+ Headers.Timestamp.toByte) ++ ns ++ s
o.push((ByteVector(Headers.Ext8) :+ 12 :+ Headers.Timestamp.toByte) ++ ns ++ s)

case MsgpackItem.Nil =>
ByteVector(Headers.Nil)
o.push(ByteVector(Headers.Nil))

case MsgpackItem.False =>
ByteVector(Headers.False)
o.push(ByteVector(Headers.False))

case MsgpackItem.True =>
ByteVector(Headers.True)
o.push(ByteVector(Headers.True))
}

def fast: MsgpackItem => ByteVector = {
case item: MsgpackItem.UnsignedInt =>
ByteVector(Headers.Uint64) ++ item.bytes.padLeft(8)

case item: MsgpackItem.SignedInt =>
ByteVector(Headers.Int64) ++ item.bytes.padLeft(8)

case item: MsgpackItem.Float32 =>
ByteVector(Headers.Float32) ++ ByteVector.fromInt(java.lang.Float.floatToIntBits(item.v))

case item: MsgpackItem.Float64 =>
ByteVector(Headers.Float64) ++ ByteVector.fromLong(java.lang.Double.doubleToLongBits(item.v))

case item: MsgpackItem.Str =>
val size = ByteVector.fromInt(item.bytes.size.toInt)
ByteVector(Headers.Str32) ++ size ++ item.bytes

case item: MsgpackItem.Bin =>
val size = ByteVector.fromInt(item.bytes.size.toInt)
ByteVector(Headers.Bin32) ++ size ++ item.bytes

case item: MsgpackItem.Array =>
ByteVector(Headers.Array32) ++ ByteVector.fromInt(item.size)

case item: MsgpackItem.Map =>
ByteVector(Headers.Map32) ++ ByteVector.fromInt(item.size)

case item: MsgpackItem.Extension =>
val size = ByteVector.fromInt(item.bytes.size.toInt)
val t = ByteVector(item.tpe)
ByteVector(Headers.Ext32) ++ size ++ t ++ item.bytes

case item: MsgpackItem.Timestamp32 =>
ByteVector(Headers.FixExt4) ++ hex"ff" ++ ByteVector.fromInt(item.seconds)

case item: MsgpackItem.Timestamp64 =>
ByteVector(Headers.FixExt8) ++ hex"ff" ++ ByteVector.fromLong(item.combined)

case item: MsgpackItem.Timestamp96 =>
val ns = ByteVector.fromInt(item.nanoseconds)
val s = ByteVector.fromLong(item.seconds)
ByteVector(Headers.Ext8) ++ hex"0c" ++ hex"ff" ++ ns ++ s

case MsgpackItem.Nil =>
ByteVector(Headers.Nil)
private def stepChunk[F[_]: RaiseThrowable](ctx: SerializationContext[F]): Pull[F, Byte, SerializationContext[F]] =
if (ctx.idx >= ctx.chunk.size)
Pull.done.as(ctx)
else
step(ctx.out, ctx.chunk(ctx.idx)).flatMap { out =>
stepChunk(SerializationContext(out, ctx.chunk, ctx.idx + 1, ctx.rest))
}

case MsgpackItem.False =>
ByteVector(Headers.False)
def pipe[F[_]: RaiseThrowable]: Pipe[F, MsgpackItem, Byte] = { stream =>
def go(out: Out[F], rest: Stream[F, MsgpackItem]): Pull[F, Byte, Unit] =
rest.pull.uncons.flatMap {
case None => out.flush
case Some((chunk, rest)) =>
stepChunk(SerializationContext(out, chunk, 0, rest)).flatMap { case SerializationContext(out, _, _, rest) =>
go(out, rest)
}
}

case MsgpackItem.True =>
ByteVector(Headers.True)
go(new Out(Chunk.empty), stream).stream
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ private[low] object ItemValidator {
Pull.pure(None)

case MsgpackItem.Array(size) =>
if (size < 0)
Pull.raiseError(new ValidationErrorAt(position, s"Array has a negative size ${size}"))
else if (size == 0)
if (size == 0)
Pull.pure(None)
else
Pull.pure(Some(Expect(size, position)))
Expand Down
31 changes: 7 additions & 24 deletions msgpack/src/main/scala/fs2/data/msgpack/low/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,16 @@ package object low {
def items[F[_]](implicit F: RaiseThrowable[F]): Pipe[F, Byte, MsgpackItem] =
ItemParser.pipe[F]

/** Alias for `bytes(compressed = true, validated = true)`
/** Alias for `bytes(validated = true)`
*/
def toBinary[F[_]: RaiseThrowable]: Pipe[F, MsgpackItem, Byte] =
bytes(true, true)
bytes(true)

def bytes[F[_]](compressed: Boolean, validated: Boolean)(implicit
F: RaiseThrowable[F]): Pipe[F, MsgpackItem, Byte] = { in =>
in
.through { if (validated) ItemValidator.simple else ItemValidator.none }
.flatMap { x =>
val bytes =
if (compressed)
ItemSerializer.compressed(x)
else
ItemSerializer.fast(x)

/* Maximum size of a `ByteVector` is bigger than the one of a `Chunk` (Long vs Int). The `Chunk.byteVector`
* function returns `Chunk.empty` if it encounters a `ByteVector` that won't fit in a `Chunk`. We have to work
* around this behaviour and explicitly check the `ByteVector` size.
*/
if (bytes.size <= Int.MaxValue) {
Stream.chunk(Chunk.byteVector(bytes))
} else {
val (lhs, rhs) = bytes.splitAt(Int.MaxValue)
Stream.chunk(Chunk.byteVector(lhs)) ++ Stream.chunk(Chunk.byteVector(rhs))
}
}
def bytes[F[_]: RaiseThrowable](validated: Boolean): Pipe[F, MsgpackItem, Byte] = {
if (validated)
ItemValidator.simple.andThen(ItemSerializer.pipe)
else
ItemSerializer.pipe
}

def validated[F[_]](implicit F: RaiseThrowable[F]): Pipe[F, MsgpackItem, MsgpackItem] =
Expand Down
Loading

0 comments on commit cd9782e

Please sign in to comment.