diff --git a/zio-http/src/main/scala/zio/http/netty/NettyBodyWriter.scala b/zio-http/src/main/scala/zio/http/netty/NettyBodyWriter.scala index 3b23123c1a..1fd1b23501 100644 --- a/zio-http/src/main/scala/zio/http/netty/NettyBodyWriter.scala +++ b/zio-http/src/main/scala/zio/http/netty/NettyBodyWriter.scala @@ -29,7 +29,9 @@ import io.netty.channel._ import io.netty.handler.codec.http.{DefaultHttpContent, LastHttpContent} object NettyBodyWriter { - def writeAndFlush(body: Body, ctx: ChannelHandlerContext)(implicit trace: Trace): Option[Task[Unit]] = + def writeAndFlush(body: Body, contentLength: Option[Long], ctx: ChannelHandlerContext)(implicit + trace: Trace, + ): Option[Task[Unit]] = body match { case body: ByteBufBody => ctx.write(body.byteBuf) @@ -66,14 +68,44 @@ object NettyBodyWriter { None case StreamBody(stream, _, _) => Some( - stream.chunks.mapZIO { bytes => - NettyFutureExecutor.executed { - ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.toArray))) - } - }.runDrain.zipRight { - NettyFutureExecutor.executed { - ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) - } + contentLength match { + case Some(length) => + stream.chunks + .runFoldZIO(length) { (remaining, bytes) => + remaining - bytes.size match { + case 0L => + NettyFutureExecutor.executed { + // Flushes the last body content and LastHttpContent together to avoid race conditions. + ctx.write(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.toArray))) + ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + }.as(0L) + + case n => + NettyFutureExecutor.executed { + ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.toArray))) + }.as(n) + } + } + .flatMap { + case 0L => ZIO.unit + case remaining => + val actualLength = length - remaining + ZIO.logWarning(s"Expected Content-Length of $length, but sent $actualLength bytes") *> + NettyFutureExecutor.executed { + ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + } + } + + case None => + stream.chunks.mapZIO { bytes => + NettyFutureExecutor.executed { + ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.toArray))) + } + }.runDrain.zipRight { + NettyFutureExecutor.executed { + ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + } + } }, ) case ChunkBody(data, _, _) => diff --git a/zio-http/src/main/scala/zio/http/netty/NettyResponseEncoder.scala b/zio-http/src/main/scala/zio/http/netty/NettyResponseEncoder.scala index f5e1cc14cf..0e6fcc828f 100644 --- a/zio-http/src/main/scala/zio/http/netty/NettyResponseEncoder.scala +++ b/zio-http/src/main/scala/zio/http/netty/NettyResponseEncoder.scala @@ -39,10 +39,6 @@ private[zio] object NettyResponseEncoder { fastEncode(response, bytes) } else { val jHeaders = Conversions.headersToNetty(response.headers) - // Prevent client from closing connection before server writes EMPTY_LAST_CONTENT. - if (response.body.isInstanceOf[Body.StreamBody]) { - jHeaders.remove(HttpHeaderNames.CONTENT_LENGTH) - } val jStatus = Conversions.statusToNetty(response.status) val hasContentLength = jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) if (!hasContentLength) jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) diff --git a/zio-http/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala b/zio-http/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala index 4f57b773c1..a4ee256aa9 100644 --- a/zio-http/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala +++ b/zio-http/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala @@ -55,7 +55,7 @@ final class ClientInboundHandler( ctx.writeAndFlush(fullRequest) case _: HttpRequest => ctx.write(jReq) - NettyBodyWriter.writeAndFlush(req.body, ctx).foreach { effect => + NettyBodyWriter.writeAndFlush(req.body, None, ctx).foreach { effect => rtm.run(ctx, NettyRuntime.noopEnsuring)(effect)(Unsafe.unsafe, trace) } } diff --git a/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 5df9e87a94..f44c549d6a 100644 --- a/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -172,9 +172,13 @@ private[zio] final case class ServerInboundHandler( val jResponse = NettyResponseEncoder.encode(ctx, response, runtime) // setServerTime(time, response, jResponse) ctx.writeAndFlush(jResponse) - if (!jResponse.isInstanceOf[FullHttpResponse]) - NettyBodyWriter.writeAndFlush(response.body, ctx) - else + if (!jResponse.isInstanceOf[FullHttpResponse]) { + val contentLength = jResponse.headers.get(HttpHeaderNames.CONTENT_LENGTH) match { + case null => None + case value => Some(value.toLong) + } + NettyBodyWriter.writeAndFlush(response.body, contentLength, ctx) + } else None } } diff --git a/zio-http/src/test/scala/zio/http/StaticFileServerSpec.scala b/zio-http/src/test/scala/zio/http/StaticFileServerSpec.scala index 445b680413..7499ab3c3f 100644 --- a/zio-http/src/test/scala/zio/http/StaticFileServerSpec.scala +++ b/zio-http/src/test/scala/zio/http/StaticFileServerSpec.scala @@ -100,6 +100,10 @@ object StaticFileServerSpec extends HttpRunnableSpec { val res = resourceOk.run().map(_.status) assertZIO(res)(equalTo(Status.Ok)) }, + test("should have content-length") { + val res = resourceOk.run().map(_.header(Header.ContentLength)) + assertZIO(res)(isSome(equalTo(Header.ContentLength(7L)))) + }, test("should have content") { val res = resourceOk.run().flatMap(_.body.asString) assertZIO(res)(equalTo("foo\nbar")) diff --git a/zio-http/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala b/zio-http/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala index be95cb392d..11a93802d2 100644 --- a/zio-http/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala +++ b/zio-http/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala @@ -1,9 +1,8 @@ package zio.http.netty import zio._ -import zio.test.Assertion._ import zio.test.TestAspect.withLiveClock -import zio.test.{Spec, TestEnvironment, assert} +import zio.test.{Spec, TestEnvironment, assertTrue} import zio.stream.{ZStream, ZStreamAspect} @@ -20,7 +19,8 @@ object NettyStreamBodySpec extends HttpRunnableSpec { handler( http.Response( status = Status.Ok, - // Content-Length header will be removed when the body is a stream + // content length header is important, + // in this case the server will not use chunked transfer encoding even if response is a stream headers = Headers(Header.ContentLength(len)), body = Body.fromStream(streams.next()), ), @@ -77,7 +77,7 @@ object NettyStreamBodySpec extends HttpRunnableSpec { client <- ZIO.service[Client] firstResponse <- makeRequest(client, port) firstResponseBodyReceive <- firstResponse.body.asStream.chunks.mapZIO { chunk => - atLeastOneChunkReceived.succeed(()) *> ZIO.succeed(chunk.asString) + atLeastOneChunkReceived.succeed(()).as(chunk.asString) }.runCollect.fork _ <- firstResponseQueue.offerAll(message.getBytes.toList) _ <- atLeastOneChunkReceived.await @@ -91,23 +91,18 @@ object NettyStreamBodySpec extends HttpRunnableSpec { secondResponse <- makeRequest(client, port) secondResponseBody <- secondResponse.body.asStream.chunks.map(_.asString).runCollect firstResponseBody <- firstResponseBodyReceive.join - - assertFirst = - assert(firstResponse.status)(equalTo(Status.Ok)) && - assert(firstResponse.headers.get(Header.ContentLength))(isNone) && - assert(firstResponse.headers.get(Header.TransferEncoding))( - isSome(equalTo(Header.TransferEncoding.Chunked)), - ) && - assert(firstResponseBody.reduce(_ + _))(equalTo(message)) - - assertSecond = - assert(secondResponse.status)(equalTo(Status.Ok)) && - assert(secondResponse.headers.get(Header.ContentLength))(isNone) && - assert(secondResponse.headers.get(Header.TransferEncoding))( - isSome(equalTo(Header.TransferEncoding.Chunked)), - ) && - assert(secondResponseBody)(equalTo(Chunk(message, ""))) - } yield assertFirst && assertSecond + value = + firstResponse.status == Status.Ok && + // since response has not chunked transfer encoding header we can't guarantee that + // received chunks will be the same as it was transferred. So we need to check the whole body + firstResponseBody.reduce(_ + _) == message && + secondResponse.status == Status.Ok && + secondResponseBody == Chunk(message) + } yield { + assertTrue( + value, + ) + } }, ).provide( singleConnectionClient,