From a33196351e2523d2330255de6017152c61c54f6e Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Tue, 18 Jun 2024 22:27:28 +1000 Subject: [PATCH] Fix memory leak in netty connection pool (#2907) * Fix memory leak in netty connection pool * Use a forked effect to monitor the close future for client requests * Use Netty's future listener again * fmt * Remove suspendSucceed --- .../zio/http/netty/AsyncBodyReader.scala | 1 + .../main/scala/zio/http/netty/NettyBody.scala | 10 +- .../http/netty/client/NettyClientDriver.scala | 231 +++++++++--------- .../client/NettyConnectionPoolSpec.scala | 33 ++- 4 files changed, 156 insertions(+), 119 deletions(-) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala b/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala index db11dfa77b..fc97550cd7 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala @@ -44,6 +44,7 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent]( case None => false case Some((_, isLast)) => isLast } + buffer.clear() // GC if (ctx.channel.isOpen || readingDone) { state = State.Direct(callback) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala b/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala index ee58652262..86afacdfa1 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala @@ -161,9 +161,17 @@ object NettyBody extends BodyEncoding { } catch { case e: Throwable => emit(ZIO.fail(Option(e))) }, - 4096, + streamBufferSize, ) + // No need to create a large buffer when we know the response is small + private[this] def streamBufferSize: Int = { + val cl = knownContentLength.getOrElse(4096L) + if (cl <= 16L) 16 + else if (cl >= 4096) 4096 + else Integer.highestOneBit(cl.toInt - 1) << 1 // Round to next power of 2 + } + override def isComplete: Boolean = false override def isEmpty: Boolean = false diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala index 0a0b4e44a5..f5d815961d 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala @@ -16,8 +16,6 @@ package zio.http.netty.client -import scala.collection.mutable - import zio._ import zio.stacktracer.TracingImplicits.disableAutoTrace @@ -28,10 +26,11 @@ import zio.http.netty._ import zio.http.netty.model.Conversions import zio.http.netty.socket.NettySocketProtocol -import io.netty.channel.{Channel, ChannelFactory, ChannelFuture, ChannelHandler, EventLoopGroup} +import io.netty.channel.{Channel, ChannelFactory, ChannelFuture, EventLoopGroup} import io.netty.handler.codec.PrematureChannelClosureException import io.netty.handler.codec.http.websocketx.{WebSocketClientProtocolHandler, WebSocketFrame => JWebSocketFrame} -import io.netty.handler.codec.http.{FullHttpRequest, HttpObjectAggregator, HttpRequest} +import io.netty.handler.codec.http.{FullHttpRequest, HttpObjectAggregator} +import io.netty.util.concurrent.GenericFutureListener final case class NettyClientDriver private[netty] ( channelFactory: ChannelFactory[Channel], @@ -50,129 +49,123 @@ final case class NettyClientDriver private[netty] ( enableKeepAlive: Boolean, createSocketApp: () => WebSocketApp[Any], webSocketConfig: WebSocketConfig, - )(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] = { - val f = NettyRequestEncoder.encode(req).flatMap { jReq => - for { - _ <- Scope.addFinalizer { - ZIO.attempt { - jReq match { - case fullRequest: FullHttpRequest => - if (fullRequest.refCnt() > 0) - fullRequest.release(fullRequest.refCnt()) - case _ => - } - }.ignore - } - queue <- Queue.unbounded[WebSocketChannelEvent] - nettyChannel = NettyChannel.make[JWebSocketFrame](channel) - webSocketChannel = WebSocketChannel.make(nettyChannel, queue) - app = createSocketApp() - _ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped - } yield { - val pipeline = channel.pipeline() - val toRemove: mutable.Set[ChannelHandler] = new mutable.HashSet[ChannelHandler]() - - if (location.scheme.isWebSocket) { - val httpObjectAggregator = new HttpObjectAggregator(Int.MaxValue) - val inboundHandler = new WebSocketClientInboundHandler(onResponse, onComplete) - - pipeline.addLast(Names.HttpObjectAggregator, httpObjectAggregator) - pipeline.addLast(Names.ClientInboundHandler, inboundHandler) - - toRemove.add(httpObjectAggregator) - toRemove.add(inboundHandler) - - val headers = Conversions.headersToNetty(req.headers) - val config = NettySocketProtocol - .clientBuilder(app.customConfig.getOrElse(webSocketConfig)) - .customHeaders(headers) - .webSocketUri(req.url.encode) - .build() - - // Handles the heavy lifting required to upgrade the connection to a WebSocket connection - - val webSocketClientProtocol = new WebSocketClientProtocolHandler(config) - val webSocket = new WebSocketAppHandler(nettyRuntime, queue, Some(onComplete)) + )(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] = + if (location.scheme.isWebSocket) + requestWebsocket(channel, req, onResponse, onComplete, createSocketApp, webSocketConfig) + else + requestHttp(channel, req, onResponse, onComplete, enableKeepAlive) - pipeline.addLast(Names.WebSocketClientProtocolHandler, webSocketClientProtocol) - pipeline.addLast(Names.WebSocketHandler, webSocket) - - toRemove.add(webSocketClientProtocol) - toRemove.add(webSocket) - - pipeline.fireChannelRegistered() - pipeline.fireChannelActive() - - new ChannelInterface { - override def resetChannel: ZIO[Any, Throwable, ChannelState] = - ZIO.succeed( - ChannelState.Invalid, - ) // channel becomes invalid - reuse of websocket channels not supported currently - - override def interrupt: ZIO[Any, Throwable, Unit] = - NettyFutureExecutor.executed(channel.disconnect()) - } - } else { - val clientInbound = - new ClientInboundHandler( - nettyRuntime, - req, - jReq, - onResponse, - onComplete, - enableKeepAlive, - ) - - pipeline.addLast(Names.ClientInboundHandler, clientInbound) - toRemove.add(clientInbound) - - val clientFailureHandler = new ClientFailureHandler(onResponse, onComplete) - pipeline.addLast(Names.ClientFailureHandler, clientFailureHandler) - toRemove.add(clientFailureHandler) - - pipeline.fireChannelRegistered() - pipeline.fireUserEventTriggered(ClientInboundHandler.SendRequest) - - val frozenToRemove = toRemove.toSet - - new ChannelInterface { - override def resetChannel: ZIO[Any, Throwable, ChannelState] = - ZIO.attempt { - frozenToRemove.foreach(pipeline.remove) - ChannelState.Reusable // channel can be reused - } - - override def interrupt: ZIO[Any, Throwable, Unit] = - NettyFutureExecutor.executed(channel.disconnect()) + private def requestHttp( + channel: Channel, + req: Request, + onResponse: Promise[Throwable, Response], + onComplete: Promise[Throwable, ChannelState], + enableKeepAlive: Boolean, + )(implicit trace: Trace): RIO[Scope, ChannelInterface] = + NettyRequestEncoder + .encode(req) + .tapSome { case fullReq: FullHttpRequest => + Scope.addFinalizer { + ZIO.succeed { + val refCount = fullReq.refCnt() + if (refCount > 0) fullReq.release(refCount) else () } } } - } + .map { jReq => + val closeListener: GenericFutureListener[ChannelFuture] = { (_: ChannelFuture) => + // If onComplete was already set, it means another fiber is already in the process of fulfilling the promises + // so we don't need to fulfill `onResponse` + nettyRuntime.unsafeRunSync { + onComplete.interrupt && onResponse.fail(NettyClientDriver.PrematureChannelClosure) + }(Unsafe.unsafe, trace): Unit + } - f.ensuring { - // If the channel was closed and the promises were not completed, this will lead to the request hanging so we need - // to listen to the close future and complete the promises - ZIO.unless(location.scheme.isWebSocket) { - ZIO.succeedUnsafe { implicit u => - channel.closeFuture().addListener { (_: ChannelFuture) => - // If onComplete was already set, it means another fiber is already in the process of fulfilling the promises - // so we don't need to fulfill `onResponse` - nettyRuntime.unsafeRunSync { - ZIO - .whenZIO(onComplete.interrupt)( - onResponse.fail( - new PrematureChannelClosureException( - "Channel closed while executing the request. This is likely caused due to a client connection misconfiguration", - ), - ), - ) - .unit + val pipeline = channel.pipeline() + + pipeline.addLast( + Names.ClientInboundHandler, + new ClientInboundHandler(nettyRuntime, req, jReq, onResponse, onComplete, enableKeepAlive), + ) + + pipeline.addLast( + Names.ClientFailureHandler, + new ClientFailureHandler(onResponse, onComplete), + ) + + pipeline + .fireChannelRegistered() + .fireUserEventTriggered(ClientInboundHandler.SendRequest) + + channel.closeFuture().addListener(closeListener) + new ChannelInterface { + override def resetChannel: ZIO[Any, Throwable, ChannelState] = { + ZIO.attempt { + channel.closeFuture().removeListener(closeListener) + pipeline.remove(Names.ClientInboundHandler) + pipeline.remove(Names.ClientFailureHandler) + ChannelState.Reusable // channel can be reused } } + + override def interrupt: ZIO[Any, Throwable, Unit] = + ZIO.suspendSucceed { + channel.closeFuture().removeListener(closeListener) + NettyFutureExecutor.executed(channel.disconnect()) + } } } - } + private def requestWebsocket( + channel: Channel, + req: Request, + onResponse: Promise[Throwable, Response], + onComplete: Promise[Throwable, ChannelState], + createSocketApp: () => WebSocketApp[Any], + webSocketConfig: WebSocketConfig, + )(implicit trace: Trace): RIO[Scope, ChannelInterface] = { + for { + queue <- Queue.unbounded[WebSocketChannelEvent] + nettyChannel = NettyChannel.make[JWebSocketFrame](channel) + webSocketChannel = WebSocketChannel.make(nettyChannel, queue) + app = createSocketApp() + _ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped + } yield { + val pipeline = channel.pipeline() + + val httpObjectAggregator = new HttpObjectAggregator(Int.MaxValue) + val inboundHandler = new WebSocketClientInboundHandler(onResponse, onComplete) + + pipeline.addLast(Names.HttpObjectAggregator, httpObjectAggregator) + pipeline.addLast(Names.ClientInboundHandler, inboundHandler) + + val headers = Conversions.headersToNetty(req.headers) + val config = NettySocketProtocol + .clientBuilder(app.customConfig.getOrElse(webSocketConfig)) + .customHeaders(headers) + .webSocketUri(req.url.encode) + .build() + + // Handles the heavy lifting required to upgrade the connection to a WebSocket connection + + val webSocketClientProtocol = new WebSocketClientProtocolHandler(config) + val webSocket = new WebSocketAppHandler(nettyRuntime, queue, Some(onComplete)) + + pipeline.addLast(Names.WebSocketClientProtocolHandler, webSocketClientProtocol) + pipeline.addLast(Names.WebSocketHandler, webSocket) + + pipeline.fireChannelRegistered() + pipeline.fireChannelActive() + + new ChannelInterface { + override def resetChannel: ZIO[Any, Throwable, ChannelState] = + // channel becomes invalid - reuse of websocket channels not supported currently + Exit.succeed(ChannelState.Invalid) + + override def interrupt: ZIO[Any, Throwable, Unit] = + NettyFutureExecutor.executed(channel.disconnect()) + } + } } override def createConnectionPool(dnsResolver: DnsResolver, config: ConnectionPoolConfig)(implicit @@ -196,4 +189,8 @@ object NettyClientDriver { } yield NettyClientDriver(channelFactory, eventLoopGroup, nettyRuntime) } + private val PrematureChannelClosure = new PrematureChannelClosureException( + "Channel closed while executing the request. This is likely caused due to a client connection misconfiguration", + ) + } diff --git a/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala b/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala index b7fe1b575a..e8f841aa7d 100644 --- a/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala @@ -249,6 +249,37 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec { serverTestLayer, ) @@ withLiveClock @@ nonFlaky(10) + private def connectionPoolIssuesSpec = { + suite("ConnectionPoolIssuesSpec")( + test("Reusing connections doesn't cause memory leaks") { + Random.nextString(1024 * 1024).flatMap { text => + val resp = Response.text(text) + Handler + .succeed(resp) + .toRoutes + .deployAndRequest { client => + ZIO.foreachParDiscard(0 to 10) { _ => + ZIO + .scoped[Any](client.request(Request()).flatMap(_.body.asArray)) + .repeatN(200) + } + }(Request()) + .as(assertCompletes) + } + }, + ) + }.provide( + ZLayer(appKeepAliveEnabled.unit), + DynamicServer.live, + serverTestLayer, + Client.customized, + ZLayer.succeed(ZClient.Config.default.dynamicConnectionPool(1, 512, 60.seconds)), + NettyClientDriver.live, + DnsResolver.default, + ZLayer.succeed(NettyConfig.defaultWithFastShutdown), + Scope.default, + ) + def connectionPoolSpec: Spec[Any, Throwable] = suite("ConnectionPool")( suite("fixed")( @@ -310,7 +341,7 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec { ) override def spec: Spec[Scope, Throwable] = { - connectionPoolSpec @@ sequential @@ withLiveClock + (connectionPoolSpec + connectionPoolIssuesSpec) @@ sequential @@ withLiveClock } }