Skip to content

Commit

Permalink
Fix memory leak in netty connection pool (#2907)
Browse files Browse the repository at this point in the history
* Fix memory leak in netty connection pool

* Use a forked effect to monitor the close future for client requests

* Use Netty's future listener again

* fmt

* Remove suspendSucceed
  • Loading branch information
kyri-petrou authored Jun 18, 2024
1 parent 36df30f commit a331963
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
case None => false
case Some((_, isLast)) => isLast
}
buffer.clear() // GC

if (ctx.channel.isOpen || readingDone) {
state = State.Direct(callback)
Expand Down
10 changes: 9 additions & 1 deletion zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,17 @@ object NettyBody extends BodyEncoding {
} catch {
case e: Throwable => emit(ZIO.fail(Option(e)))
},
4096,
streamBufferSize,
)

// No need to create a large buffer when we know the response is small
private[this] def streamBufferSize: Int = {
val cl = knownContentLength.getOrElse(4096L)
if (cl <= 16L) 16
else if (cl >= 4096) 4096
else Integer.highestOneBit(cl.toInt - 1) << 1 // Round to next power of 2
}

override def isComplete: Boolean = false

override def isEmpty: Boolean = false
Expand Down
231 changes: 114 additions & 117 deletions zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package zio.http.netty.client

import scala.collection.mutable

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

Expand All @@ -28,10 +26,11 @@ import zio.http.netty._
import zio.http.netty.model.Conversions
import zio.http.netty.socket.NettySocketProtocol

import io.netty.channel.{Channel, ChannelFactory, ChannelFuture, ChannelHandler, EventLoopGroup}
import io.netty.channel.{Channel, ChannelFactory, ChannelFuture, EventLoopGroup}
import io.netty.handler.codec.PrematureChannelClosureException
import io.netty.handler.codec.http.websocketx.{WebSocketClientProtocolHandler, WebSocketFrame => JWebSocketFrame}
import io.netty.handler.codec.http.{FullHttpRequest, HttpObjectAggregator, HttpRequest}
import io.netty.handler.codec.http.{FullHttpRequest, HttpObjectAggregator}
import io.netty.util.concurrent.GenericFutureListener

final case class NettyClientDriver private[netty] (
channelFactory: ChannelFactory[Channel],
Expand All @@ -50,129 +49,123 @@ final case class NettyClientDriver private[netty] (
enableKeepAlive: Boolean,
createSocketApp: () => WebSocketApp[Any],
webSocketConfig: WebSocketConfig,
)(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] = {
val f = NettyRequestEncoder.encode(req).flatMap { jReq =>
for {
_ <- Scope.addFinalizer {
ZIO.attempt {
jReq match {
case fullRequest: FullHttpRequest =>
if (fullRequest.refCnt() > 0)
fullRequest.release(fullRequest.refCnt())
case _ =>
}
}.ignore
}
queue <- Queue.unbounded[WebSocketChannelEvent]
nettyChannel = NettyChannel.make[JWebSocketFrame](channel)
webSocketChannel = WebSocketChannel.make(nettyChannel, queue)
app = createSocketApp()
_ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped
} yield {
val pipeline = channel.pipeline()
val toRemove: mutable.Set[ChannelHandler] = new mutable.HashSet[ChannelHandler]()

if (location.scheme.isWebSocket) {
val httpObjectAggregator = new HttpObjectAggregator(Int.MaxValue)
val inboundHandler = new WebSocketClientInboundHandler(onResponse, onComplete)

pipeline.addLast(Names.HttpObjectAggregator, httpObjectAggregator)
pipeline.addLast(Names.ClientInboundHandler, inboundHandler)

toRemove.add(httpObjectAggregator)
toRemove.add(inboundHandler)

val headers = Conversions.headersToNetty(req.headers)
val config = NettySocketProtocol
.clientBuilder(app.customConfig.getOrElse(webSocketConfig))
.customHeaders(headers)
.webSocketUri(req.url.encode)
.build()

// Handles the heavy lifting required to upgrade the connection to a WebSocket connection

val webSocketClientProtocol = new WebSocketClientProtocolHandler(config)
val webSocket = new WebSocketAppHandler(nettyRuntime, queue, Some(onComplete))
)(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] =
if (location.scheme.isWebSocket)
requestWebsocket(channel, req, onResponse, onComplete, createSocketApp, webSocketConfig)
else
requestHttp(channel, req, onResponse, onComplete, enableKeepAlive)

pipeline.addLast(Names.WebSocketClientProtocolHandler, webSocketClientProtocol)
pipeline.addLast(Names.WebSocketHandler, webSocket)

toRemove.add(webSocketClientProtocol)
toRemove.add(webSocket)

pipeline.fireChannelRegistered()
pipeline.fireChannelActive()

new ChannelInterface {
override def resetChannel: ZIO[Any, Throwable, ChannelState] =
ZIO.succeed(
ChannelState.Invalid,
) // channel becomes invalid - reuse of websocket channels not supported currently

override def interrupt: ZIO[Any, Throwable, Unit] =
NettyFutureExecutor.executed(channel.disconnect())
}
} else {
val clientInbound =
new ClientInboundHandler(
nettyRuntime,
req,
jReq,
onResponse,
onComplete,
enableKeepAlive,
)

pipeline.addLast(Names.ClientInboundHandler, clientInbound)
toRemove.add(clientInbound)

val clientFailureHandler = new ClientFailureHandler(onResponse, onComplete)
pipeline.addLast(Names.ClientFailureHandler, clientFailureHandler)
toRemove.add(clientFailureHandler)

pipeline.fireChannelRegistered()
pipeline.fireUserEventTriggered(ClientInboundHandler.SendRequest)

val frozenToRemove = toRemove.toSet

new ChannelInterface {
override def resetChannel: ZIO[Any, Throwable, ChannelState] =
ZIO.attempt {
frozenToRemove.foreach(pipeline.remove)
ChannelState.Reusable // channel can be reused
}

override def interrupt: ZIO[Any, Throwable, Unit] =
NettyFutureExecutor.executed(channel.disconnect())
private def requestHttp(
channel: Channel,
req: Request,
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
enableKeepAlive: Boolean,
)(implicit trace: Trace): RIO[Scope, ChannelInterface] =
NettyRequestEncoder
.encode(req)
.tapSome { case fullReq: FullHttpRequest =>
Scope.addFinalizer {
ZIO.succeed {
val refCount = fullReq.refCnt()
if (refCount > 0) fullReq.release(refCount) else ()
}
}
}
}
.map { jReq =>
val closeListener: GenericFutureListener[ChannelFuture] = { (_: ChannelFuture) =>
// If onComplete was already set, it means another fiber is already in the process of fulfilling the promises
// so we don't need to fulfill `onResponse`
nettyRuntime.unsafeRunSync {
onComplete.interrupt && onResponse.fail(NettyClientDriver.PrematureChannelClosure)
}(Unsafe.unsafe, trace): Unit
}

f.ensuring {
// If the channel was closed and the promises were not completed, this will lead to the request hanging so we need
// to listen to the close future and complete the promises
ZIO.unless(location.scheme.isWebSocket) {
ZIO.succeedUnsafe { implicit u =>
channel.closeFuture().addListener { (_: ChannelFuture) =>
// If onComplete was already set, it means another fiber is already in the process of fulfilling the promises
// so we don't need to fulfill `onResponse`
nettyRuntime.unsafeRunSync {
ZIO
.whenZIO(onComplete.interrupt)(
onResponse.fail(
new PrematureChannelClosureException(
"Channel closed while executing the request. This is likely caused due to a client connection misconfiguration",
),
),
)
.unit
val pipeline = channel.pipeline()

pipeline.addLast(
Names.ClientInboundHandler,
new ClientInboundHandler(nettyRuntime, req, jReq, onResponse, onComplete, enableKeepAlive),
)

pipeline.addLast(
Names.ClientFailureHandler,
new ClientFailureHandler(onResponse, onComplete),
)

pipeline
.fireChannelRegistered()
.fireUserEventTriggered(ClientInboundHandler.SendRequest)

channel.closeFuture().addListener(closeListener)
new ChannelInterface {
override def resetChannel: ZIO[Any, Throwable, ChannelState] = {
ZIO.attempt {
channel.closeFuture().removeListener(closeListener)
pipeline.remove(Names.ClientInboundHandler)
pipeline.remove(Names.ClientFailureHandler)
ChannelState.Reusable // channel can be reused
}
}

override def interrupt: ZIO[Any, Throwable, Unit] =
ZIO.suspendSucceed {
channel.closeFuture().removeListener(closeListener)
NettyFutureExecutor.executed(channel.disconnect())
}
}
}
}

private def requestWebsocket(
channel: Channel,
req: Request,
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
createSocketApp: () => WebSocketApp[Any],
webSocketConfig: WebSocketConfig,
)(implicit trace: Trace): RIO[Scope, ChannelInterface] = {
for {
queue <- Queue.unbounded[WebSocketChannelEvent]
nettyChannel = NettyChannel.make[JWebSocketFrame](channel)
webSocketChannel = WebSocketChannel.make(nettyChannel, queue)
app = createSocketApp()
_ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped
} yield {
val pipeline = channel.pipeline()

val httpObjectAggregator = new HttpObjectAggregator(Int.MaxValue)
val inboundHandler = new WebSocketClientInboundHandler(onResponse, onComplete)

pipeline.addLast(Names.HttpObjectAggregator, httpObjectAggregator)
pipeline.addLast(Names.ClientInboundHandler, inboundHandler)

val headers = Conversions.headersToNetty(req.headers)
val config = NettySocketProtocol
.clientBuilder(app.customConfig.getOrElse(webSocketConfig))
.customHeaders(headers)
.webSocketUri(req.url.encode)
.build()

// Handles the heavy lifting required to upgrade the connection to a WebSocket connection

val webSocketClientProtocol = new WebSocketClientProtocolHandler(config)
val webSocket = new WebSocketAppHandler(nettyRuntime, queue, Some(onComplete))

pipeline.addLast(Names.WebSocketClientProtocolHandler, webSocketClientProtocol)
pipeline.addLast(Names.WebSocketHandler, webSocket)

pipeline.fireChannelRegistered()
pipeline.fireChannelActive()

new ChannelInterface {
override def resetChannel: ZIO[Any, Throwable, ChannelState] =
// channel becomes invalid - reuse of websocket channels not supported currently
Exit.succeed(ChannelState.Invalid)

override def interrupt: ZIO[Any, Throwable, Unit] =
NettyFutureExecutor.executed(channel.disconnect())
}
}
}

override def createConnectionPool(dnsResolver: DnsResolver, config: ConnectionPoolConfig)(implicit
Expand All @@ -196,4 +189,8 @@ object NettyClientDriver {
} yield NettyClientDriver(channelFactory, eventLoopGroup, nettyRuntime)
}

private val PrematureChannelClosure = new PrematureChannelClosureException(
"Channel closed while executing the request. This is likely caused due to a client connection misconfiguration",
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,37 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec {
serverTestLayer,
) @@ withLiveClock @@ nonFlaky(10)

private def connectionPoolIssuesSpec = {
suite("ConnectionPoolIssuesSpec")(
test("Reusing connections doesn't cause memory leaks") {
Random.nextString(1024 * 1024).flatMap { text =>
val resp = Response.text(text)
Handler
.succeed(resp)
.toRoutes
.deployAndRequest { client =>
ZIO.foreachParDiscard(0 to 10) { _ =>
ZIO
.scoped[Any](client.request(Request()).flatMap(_.body.asArray))
.repeatN(200)
}
}(Request())
.as(assertCompletes)
}
},
)
}.provide(
ZLayer(appKeepAliveEnabled.unit),
DynamicServer.live,
serverTestLayer,
Client.customized,
ZLayer.succeed(ZClient.Config.default.dynamicConnectionPool(1, 512, 60.seconds)),
NettyClientDriver.live,
DnsResolver.default,
ZLayer.succeed(NettyConfig.defaultWithFastShutdown),
Scope.default,
)

def connectionPoolSpec: Spec[Any, Throwable] =
suite("ConnectionPool")(
suite("fixed")(
Expand Down Expand Up @@ -310,7 +341,7 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec {
)

override def spec: Spec[Scope, Throwable] = {
connectionPoolSpec @@ sequential @@ withLiveClock
(connectionPoolSpec + connectionPoolIssuesSpec) @@ sequential @@ withLiveClock
}

}

0 comments on commit a331963

Please sign in to comment.