Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the Client's performance for non-stream bodies #2919

Merged
merged 17 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,50 @@ package zio.http.netty

import java.io.IOException

import scala.collection.mutable

import zio.Chunk
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Chunk, ChunkBuilder}

import zio.http.netty.AsyncBodyReader.State
import zio.http.netty.NettyBody.UnsafeAsync

import io.netty.buffer.ByteBufUtil
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http.{HttpContent, LastHttpContent}

abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](true) {
import zio.http.netty.AsyncBodyReader._

private var state: State = State.Buffering
private val buffer = new mutable.ArrayBuilder.ofByte()
private var previousAutoRead: Boolean = false
private var readingDone: Boolean = false
private var ctx: ChannelHandlerContext = _

private var state: State = State.Buffering
private val buffer: ChunkBuilder[(Chunk[Byte], Boolean)] = ChunkBuilder.make[(Chunk[Byte], Boolean)]()
private var previousAutoRead: Boolean = false
private var ctx: ChannelHandlerContext = _
private def result(buffer: mutable.ArrayBuilder.ofByte): Chunk[Byte] = {
val arr = buffer.result()
Chunk.ByteArray(arr, 0, arr.length)
}

private[zio] def connect(callback: UnsafeAsync): Unit = {
val buffer0 = buffer // Avoid reading it from the heap in the synchronized block
this.synchronized {
state match {
case State.Buffering =>
val result: Chunk[(Chunk[Byte], Boolean)] = buffer.result()
val readingDone: Boolean = result.lastOption match {
case None => false
case Some((_, isLast)) => isLast
}
buffer.clear() // GC

if (ctx.channel.isOpen || readingDone) {
state = State.Direct(callback)
result.foreach { case (chunk, isLast) =>
callback(chunk, isLast)
state = State.Direct(callback)

if (readingDone) {
callback(result(buffer0), isLast = true)
} else if (ctx.channel().isOpen) {
callback match {
case UnsafeAsync.Aggregating(bufSize) => buffer.sizeHint(bufSize)
case cb => cb(result(buffer0), isLast = false)
}
ctx.read(): Unit
} else {
throw new IllegalStateException("Attempting to read from a closed channel, which will never finish")
}

case State.Direct(_) =>
case _ =>
throw new IllegalStateException("Cannot connect twice")
}
}
Expand All @@ -76,22 +81,36 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
ctx: ChannelHandlerContext,
msg: HttpContent,
): Unit = {
val isLast = msg.isInstanceOf[LastHttpContent]
val chunk = Chunk.fromArray(ByteBufUtil.getBytes(msg.content()))
val buffer0 = buffer // Avoid reading it from the heap in the synchronized block

this.synchronized {
val isLast = msg.isInstanceOf[LastHttpContent]
val content = ByteBufUtil.getBytes(msg.content())

state match {
case State.Buffering =>
buffer += ((chunk, isLast))
case State.Direct(callback) =>
callback(chunk, isLast)
ctx.read()
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
buffer0.addAll(content)
case State.Direct(callback) if isLast && buffer0.knownSize == 0 =>
// Buffer is empty, we can just use the array directly
callback(Chunk.fromArray(content), isLast = true)
case State.Direct(callback: UnsafeAsync.Aggregating) =>
// We're aggregating the full response, only call the callback on the last message
buffer0.addAll(content)
if (isLast) callback(result(buffer0), isLast = true)
case State.Direct(callback) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
}
}

if (isLast) {
ctx.channel().pipeline().remove(this)
}: Unit
if (isLast) {
readingDone = true
ctx.channel().pipeline().remove(this)
} else {
ctx.read()
}
()
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
Expand Down Expand Up @@ -125,4 +144,10 @@ object AsyncBodyReader {

final case class Direct(callback: UnsafeAsync) extends State
}

// For Scala 2.12. In Scala 2.13+, the methods directly implemented on ArrayBuilder[Byte] are selected over syntax.
private implicit class ByteArrayBuilderOps[A](private val self: mutable.ArrayBuilder[Byte]) extends AnyVal {
def addAll(as: Array[Byte]): Unit = self ++= as
def knownSize: Int = -1
}
}
69 changes: 46 additions & 23 deletions zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ package zio.http.netty

import java.nio.charset.Charset

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Chunk, Task, Trace, Unsafe, ZIO}

import zio.stream.ZStream

import zio.http.Body.{UnsafeBytes, UnsafeWriteable}
import zio.http.Body.UnsafeBytes
import zio.http.internal.BodyEncoding
import zio.http.{Body, Boundary, Header, Headers, MediaType}
import zio.http.{Body, Boundary, MediaType}

import io.netty.buffer.{ByteBuf, ByteBufUtil}
import io.netty.channel.{Channel => JChannel}
import io.netty.util.AsciiString

object NettyBody extends BodyEncoding {

/**
Expand Down Expand Up @@ -73,7 +73,6 @@ object NettyBody extends BodyEncoding {
override val mediaType: Option[MediaType] = None,
override val boundary: Option[Boundary] = None,
) extends Body
with UnsafeWriteable
with UnsafeBytes {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = ZIO.succeed(asciiString.array())
Expand Down Expand Up @@ -101,11 +100,10 @@ object NettyBody extends BodyEncoding {
}

private[zio] final case class ByteBufBody(
val byteBuf: ByteBuf,
byteBuf: ByteBuf,
override val mediaType: Option[MediaType] = None,
override val boundary: Option[Boundary] = None,
) extends Body
with UnsafeWriteable
with UnsafeBytes {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = ZIO.succeed(ByteBufUtil.getBytes(byteBuf))
Expand Down Expand Up @@ -137,38 +135,37 @@ object NettyBody extends BodyEncoding {
knownContentLength: Option[Long],
override val mediaType: Option[MediaType] = None,
override val boundary: Option[Boundary] = None,
) extends Body
with UnsafeWriteable {
) extends Body {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = asChunk.map(_.toArray)

override def asChunk(implicit trace: Trace): Task[Chunk[Byte]] = asStream.runCollect
override def asChunk(implicit trace: Trace): Task[Chunk[Byte]] =
ZIO.async { cb =>
try {
// Cap at 100kB as a precaution in case the server sends an invalid content length
unsafeAsync(UnsafeAsync.Aggregating(bufferSize(1024 * 100))(cb))
} catch {
case e: Throwable => cb(ZIO.fail(e))
}
}

override def asStream(implicit trace: Trace): ZStream[Any, Throwable, Byte] =
ZStream
.async[Any, Throwable, Byte](
emit =>
try {
unsafeAsync(new UnsafeAsync {
override def apply(message: Chunk[Byte], isLast: Boolean): Unit = {
emit(ZIO.succeed(message))
if (isLast) {
emit(ZIO.fail(None))
}
}
override def fail(cause: Throwable): Unit =
emit(ZIO.fail(Some(cause)))
})
unsafeAsync(new UnsafeAsync.Streaming(emit))
} catch {
case e: Throwable => emit(ZIO.fail(Option(e)))
},
streamBufferSize,
bufferSize(4096),
)

// No need to create a large buffer when we know the response is small
private[this] def streamBufferSize: Int = {
private[this] def bufferSize(maxSize: Int): Int = {
val cl = knownContentLength.getOrElse(4096L)
if (cl <= 16L) 16
else if (cl >= 4096) 4096
else if (cl >= maxSize) maxSize
else Integer.highestOneBit(cl.toInt - 1) << 1 // Round to next power of 2
}

Expand All @@ -188,4 +185,30 @@ object NettyBody extends BodyEncoding {
def apply(message: Chunk[Byte], isLast: Boolean): Unit
def fail(cause: Throwable): Unit
}

private[zio] object UnsafeAsync {
private val FailNone = Exit.fail(None)

final case class Aggregating(bufferInitialSize: Int)(callback: Task[Chunk[Byte]] => Unit)(implicit trace: Trace)
extends UnsafeAsync {

def apply(message: Chunk[Byte], isLast: Boolean): Unit = {
assert(isLast)
callback(ZIO.succeed(message))
}

def fail(cause: Throwable): Unit =
callback(ZIO.fail(cause))
}

final class Streaming(emit: ZStream.Emit[Any, Throwable, Byte, Unit])(implicit trace: Trace) extends UnsafeAsync {
def apply(message: Chunk[Byte], isLast: Boolean): Unit = {
if (message.nonEmpty) emit(ZIO.succeed(message))
if (isLast) emit(FailNone)
}

def fail(cause: Throwable): Unit =
emit(ZIO.fail(Some(cause)))
}
}
}
9 changes: 2 additions & 7 deletions zio-http/shared/src/main/scala/zio/http/Body.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ object Body {
def fromSocketApp(app: WebSocketApp[Any]): WebsocketBody =
WebsocketBody(app)

private[zio] trait UnsafeWriteable extends Body

private[zio] trait UnsafeBytes extends Body {
private[zio] def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte]
}
Expand All @@ -352,7 +350,7 @@ object Body {
* Helper to create empty Body
*/

private[zio] object EmptyBody extends Body with UnsafeWriteable with UnsafeBytes {
private[zio] object EmptyBody extends Body with UnsafeBytes {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = zioEmptyArray

Expand Down Expand Up @@ -383,7 +381,6 @@ object Body {
override val mediaType: Option[MediaType] = None,
override val boundary: Option[Boundary] = None,
) extends Body
with UnsafeWriteable
with UnsafeBytes { self =>

override def asArray(implicit trace: Trace): Task[Array[Byte]] = ZIO.succeed(data.toArray)
Expand Down Expand Up @@ -414,7 +411,6 @@ object Body {
override val mediaType: Option[MediaType] = None,
override val boundary: Option[Boundary] = None,
) extends Body
with UnsafeWriteable
with UnsafeBytes { self =>

override def asArray(implicit trace: Trace): Task[Array[Byte]] = ZIO.succeed(data)
Expand Down Expand Up @@ -446,8 +442,7 @@ object Body {
fileSize: Long,
override val mediaType: Option[MediaType] = None,
override val boundary: Option[Boundary] = None,
) extends Body
with UnsafeWriteable {
) extends Body {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = ZIO.attemptBlocking {
Files.readAllBytes(file.toPath)
Expand Down
Loading