Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the Client's performance for non-stream bodies #2919

Merged
merged 17 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,40 @@ jobs:
name: Jmh_Main_CachedDateHeaderBenchmark
path: Main_CachedDateHeaderBenchmark.txt

Jmh_ClientBenchmark:
name: Jmh ClientBenchmark
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
strategy:
matrix:
os: [ubuntu-latest]
scala: [2.13.14]
java: [temurin@8]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
with:
path: zio-http

- uses: actions/setup-java@v2
with:
distribution: temurin
java-version: 11

- name: Benchmark_Main
id: Benchmark_Main
env:
GITHUB_TOKEN: ${{secrets.ACTIONS_PAT}}
run: |
cd zio-http
sed -i -e '$aaddSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.7")' project/plugins.sbt
cat > Main_ClientBenchmark.txt
sbt -no-colors -v "zioHttpBenchmarks/jmh:run -i 3 -wi 3 -f1 -t1 ClientBenchmark" | grep -e "thrpt" -e "avgt" >> ../Main_ClientBenchmark.txt

- uses: actions/upload-artifact@v3
with:
name: Jmh_Main_ClientBenchmark
path: Main_ClientBenchmark.txt

Jmh_CookieDecodeBenchmark:
name: Jmh CookieDecodeBenchmark
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
Expand Down Expand Up @@ -627,7 +661,7 @@ jobs:

Jmh_cache:
name: Cache Jmh benchmarks
needs: [Jmh_CachedDateHeaderBenchmark, Jmh_CookieDecodeBenchmark, Jmh_EndpointBenchmark, Jmh_HttpCollectEval, Jmh_HttpCombineEval, Jmh_HttpNestedFlatMapEval, Jmh_HttpRouteTextPerf, Jmh_ProbeContentTypeBenchmark, Jmh_SchemeDecodeBenchmark, Jmh_ServerInboundHandlerBenchmark, Jmh_UtilBenchmark]
needs: [Jmh_CachedDateHeaderBenchmark, Jmh_ClientBenchmark, Jmh_CookieDecodeBenchmark, Jmh_EndpointBenchmark, Jmh_HttpCollectEval, Jmh_HttpCombineEval, Jmh_HttpNestedFlatMapEval, Jmh_HttpRouteTextPerf, Jmh_ProbeContentTypeBenchmark, Jmh_SchemeDecodeBenchmark, Jmh_ServerInboundHandlerBenchmark, Jmh_UtilBenchmark]
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
strategy:
matrix:
Expand All @@ -643,6 +677,13 @@ jobs:
- name: Format_Main_CachedDateHeaderBenchmark
run: cat Main_CachedDateHeaderBenchmark.txt >> Main_benchmarks.txt

- uses: actions/download-artifact@v3
with:
name: Jmh_Main_ClientBenchmark

- name: Format_Main_ClientBenchmark
run: cat Main_ClientBenchmark.txt >> Main_benchmarks.txt

- uses: actions/download-artifact@v3
with:
name: Jmh_Main_CookieDecodeBenchmark
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package zhttp.benchmarks

import java.util.concurrent.TimeUnit

import scala.annotation.nowarn

import zio._

import zio.http._

import org.openjdk.jmh.annotations._

@nowarn
@State(org.openjdk.jmh.annotations.Scope.Benchmark)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
@Warmup(iterations = 3, time = 3)
@Measurement(iterations = 3, time = 3)
@Fork(1)
class ClientBenchmark {
private val random = scala.util.Random
random.setSeed(42)

private implicit val unsafe: Unsafe = Unsafe.unsafe(identity)

@Param(Array("small", "large"))
var path: String = _

private val smallString = "Hello!".getBytes
private val largeString = random.alphanumeric.take(10000).mkString.getBytes

private val smallRequest = Request(url = url"http://0.0.0.0:8080/small")
private val largeRequest = Request(url = url"http://0.0.0.0:8080/large")

private val smallResponse = Response(status = Status.Ok, body = Body.fromArray(smallString))
private val largeResponse = Response(status = Status.Ok, body = Body.fromArray(largeString))

private val smallRoute = Route.route(Method.GET / "small")(handler(smallResponse))
private val largeRoute = Route.route(Method.GET / "large")(handler(largeResponse))

private val shutdownResponse = Response.text("shutting down")

private def shutdownRoute(shutdownSignal: Promise[Nothing, Unit]) =
Route.route(Method.GET / "shutdown")(handler(shutdownSignal.succeed(()).as(shutdownResponse)))

private def http(shutdownSignal: Promise[Nothing, Unit]) =
Routes(smallRoute, largeRoute, shutdownRoute(shutdownSignal))

private val rtm = Runtime.unsafe.fromLayer(ZClient.default)
private val runtime = rtm.unsafe

private def run(f: RIO[Client, Any]): Any = runtime.run(f).getOrThrow()

@Setup(Level.Trial)
def setup(): Unit = {
val startServer: Task[Unit] = (for {
shutdownSignal <- Promise.make[Nothing, Unit]
fiber <- Server.serve(http(shutdownSignal)).fork
_ <- shutdownSignal.await *> fiber.interrupt
} yield ()).provideLayer(Server.default)

val waitForServerStarted: Task[Unit] = (for {
client <- ZIO.service[Client]
_ <- client.request(smallRequest)
} yield ()).provide(ZClient.default, zio.Scope.default)

run(startServer.forkDaemon *> waitForServerStarted.retry(Schedule.fixed(1.second)))
}

@TearDown(Level.Trial)
def tearDown(): Unit = {
val stopServer = (for {
client <- ZIO.service[Client]
_ <- client.request(Request(url = url"http://localhost:8080/shutdown"))
} yield ()).provide(ZClient.default, zio.Scope.default)
run(stopServer)
rtm.shutdown0()
}

@Benchmark
@OperationsPerInvocation(100)
def zhttpChunkBenchmark(): Any = run {
val req = if (path == "small") smallRequest else largeRequest
ZIO.serviceWithZIO[Client] { client =>
ZIO.scoped(client.request(req).flatMap(_.body.asChunk)).repeatN(100)
}
}

@Benchmark
@OperationsPerInvocation(100)
def zhttpStreamToChunkBenchmark(): Any = run {
val req = if (path == "small") smallRequest else largeRequest
ZIO.serviceWithZIO[Client] { client =>
ZIO.scoped(client.request(req).flatMap(_.body.asStream.runCollect)).repeatN(100)
}
}
}
85 changes: 55 additions & 30 deletions zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,50 @@ package zio.http.netty

import java.io.IOException

import scala.collection.mutable

import zio.Chunk
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Chunk, ChunkBuilder}

import zio.http.netty.AsyncBodyReader.State
import zio.http.netty.NettyBody.UnsafeAsync

import io.netty.buffer.ByteBufUtil
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http.{HttpContent, LastHttpContent}

abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](true) {
import zio.http.netty.AsyncBodyReader._

private var state: State = State.Buffering
private val buffer = new mutable.ArrayBuilder.ofByte()
private var previousAutoRead: Boolean = false
private var readingDone: Boolean = false
private var ctx: ChannelHandlerContext = _

private var state: State = State.Buffering
private val buffer: ChunkBuilder[(Chunk[Byte], Boolean)] = ChunkBuilder.make[(Chunk[Byte], Boolean)]()
private var previousAutoRead: Boolean = false
private var ctx: ChannelHandlerContext = _
private def result(buffer: mutable.ArrayBuilder.ofByte): Chunk[Byte] = {
val arr = buffer.result()
Chunk.ByteArray(arr, 0, arr.length)
}

private[zio] def connect(callback: UnsafeAsync): Unit = {
val buffer0 = buffer // Avoid reading it from the heap in the synchronized block
this.synchronized {
state match {
case State.Buffering =>
val result: Chunk[(Chunk[Byte], Boolean)] = buffer.result()
val readingDone: Boolean = result.lastOption match {
case None => false
case Some((_, isLast)) => isLast
}
buffer.clear() // GC

if (ctx.channel.isOpen || readingDone) {
state = State.Direct(callback)
result.foreach { case (chunk, isLast) =>
callback(chunk, isLast)
state = State.Direct(callback)

if (readingDone) {
callback(result(buffer0), isLast = true)
} else if (ctx.channel().isOpen) {
callback match {
case UnsafeAsync.Aggregating(bufSize) => buffer.sizeHint(bufSize)
case cb => cb(result(buffer0), isLast = false)
}
ctx.read(): Unit
} else {
throw new IllegalStateException("Attempting to read from a closed channel, which will never finish")
}

case State.Direct(_) =>
case _ =>
throw new IllegalStateException("Cannot connect twice")
}
}
Expand All @@ -76,22 +81,36 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
ctx: ChannelHandlerContext,
msg: HttpContent,
): Unit = {
val isLast = msg.isInstanceOf[LastHttpContent]
val chunk = Chunk.fromArray(ByteBufUtil.getBytes(msg.content()))
val buffer0 = buffer // Avoid reading it from the heap in the synchronized block

this.synchronized {
val isLast = msg.isInstanceOf[LastHttpContent]
val content = ByteBufUtil.getBytes(msg.content())

state match {
case State.Buffering =>
buffer += ((chunk, isLast))
case State.Direct(callback) =>
callback(chunk, isLast)
ctx.read()
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
buffer0.addAll(content)
case State.Direct(callback) if isLast && buffer0.knownSize == 0 =>
// Buffer is empty, we can just use the array directly
callback(Chunk.fromArray(content), isLast = true)
case State.Direct(callback: UnsafeAsync.Aggregating) =>
// We're aggregating the full response, only call the callback on the last message
buffer0.addAll(content)
if (isLast) callback(result(buffer0), isLast = true)
case State.Direct(callback) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
}
}

if (isLast) {
ctx.channel().pipeline().remove(this)
}: Unit
if (isLast) {
readingDone = true
ctx.channel().pipeline().remove(this)
} else {
ctx.read()
}
()
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
Expand Down Expand Up @@ -125,4 +144,10 @@ object AsyncBodyReader {

final case class Direct(callback: UnsafeAsync) extends State
}

// For Scala 2.12. In Scala 2.13+, the methods directly implemented on ArrayBuilder[Byte] are selected over syntax.
private implicit class ByteArrayBuilderOps[A](private val self: mutable.ArrayBuilder[Byte]) extends AnyVal {
def addAll(as: Array[Byte]): Unit = self ++= as
def knownSize: Int = -1
}
}
Loading
Loading