diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala index e56d5f9dcd..61d8a5dc8b 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala @@ -36,38 +36,36 @@ import io.netty.handler.codec.http.websocketx.{WebSocketFrame => JWebSocketFrame private[zio] final class WebSocketAppHandler( zExec: NettyRuntime, queue: Queue[WebSocketChannelEvent], + handshakeCompleted: Promise[Nothing, Boolean], onComplete: Option[Promise[Throwable, ChannelState]], )(implicit trace: Trace) extends SimpleChannelInboundHandler[JWebSocketFrame] { implicit private val unsafeClass: Unsafe = Unsafe.unsafe - private def dispatch( - ctx: ChannelHandlerContext, - event: ChannelEvent[JWebSocketFrame], - close: Boolean = false, - ): Unit = { + private def dispatch(event: ChannelEvent[JWebSocketFrame]): Unit = { // IMPORTANT: Offering to the queue must be run synchronously to avoid messages being added in the wrong order // Since the queue is unbounded, this will not block the event loop // TODO: We need to come up with a design that doesn't involve running an effect to offer to the queue - zExec.unsafeRunSync(queue.offer(event.map(frameFromNetty))) - onComplete match { - case Some(promise) if close => promise.unsafe.done(Exit.succeed(ChannelState.Invalid)) - case _ => () - } + val _ = zExec.unsafeRunSync(queue.offer(event.map(frameFromNetty))) } override def channelRead0(ctx: ChannelHandlerContext, msg: JWebSocketFrame): Unit = - dispatch(ctx, ChannelEvent.read(msg)) + dispatch(ChannelEvent.read(msg)) override def channelRegistered(ctx: ChannelHandlerContext): Unit = - dispatch(ctx, ChannelEvent.registered) + dispatch(ChannelEvent.registered) - override def channelUnregistered(ctx: ChannelHandlerContext): Unit = - dispatch(ctx, ChannelEvent.unregistered, close = true) + override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { + dispatch(ChannelEvent.unregistered) + onComplete match { + case Some(promise) => promise.unsafe.done(Exit.succeed(ChannelState.Invalid)) + case None => () + } + } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - dispatch(ctx, ChannelEvent.exceptionCaught(cause)) + dispatch(ChannelEvent.exceptionCaught(cause)) onComplete match { case Some(promise) => promise.unsafe.done(Exit.fail(cause)) case None => () @@ -77,9 +75,11 @@ private[zio] final class WebSocketAppHandler( override def userEventTriggered(ctx: ChannelHandlerContext, msg: AnyRef): Unit = { msg match { case _: WebSocketServerProtocolHandler.HandshakeComplete | ClientHandshakeStateEvent.HANDSHAKE_COMPLETE => - dispatch(ctx, ChannelEvent.userEventTriggered(UserEvent.HandshakeComplete)) + handshakeCompleted.unsafe.succeed(true) + dispatch(ChannelEvent.userEventTriggered(UserEvent.HandshakeComplete)) case ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT | ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT => - dispatch(ctx, ChannelEvent.userEventTriggered(UserEvent.HandshakeTimeout)) + handshakeCompleted.unsafe.succeed(false) + dispatch(ChannelEvent.userEventTriggered(UserEvent.HandshakeTimeout)) case _ => super.userEventTriggered(ctx, msg) } } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketChannel.scala b/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketChannel.scala index af80fade0f..eba6dee221 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketChannel.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketChannel.scala @@ -29,6 +29,7 @@ private[http] object WebSocketChannel { def make( nettyChannel: NettyChannel[JWebSocketFrame], queue: Queue[WebSocketChannelEvent], + handshakeCompleted: Promise[Nothing, Boolean], ): WebSocketChannel = new WebSocketChannel { def awaitShutdown(implicit trace: Trace): UIO[Unit] = @@ -51,14 +52,14 @@ private[http] object WebSocketChannel { } def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] = { - in match { + sendAwaitHandshakeCompleted *> (in match { case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message)) case _ => ZIO.unit - } + }) } def sendAll(in: Iterable[WebSocketChannelEvent])(implicit trace: Trace): Task[Unit] = - ZIO.suspendSucceed { + sendAwaitHandshakeCompleted *> ZIO.suspendSucceed { val iterator = in.iterator.collect { case Read(message) => message } ZIO.whileLoop(iterator.hasNext) { @@ -67,7 +68,18 @@ private[http] object WebSocketChannel { else nettyChannel.writeAndFlush(frameToNetty(message)) }(_ => ()) } - def shutdown(implicit trace: Trace): UIO[Unit] = + + private def sendAwaitHandshakeCompleted: UIO[Unit] = for { + _ <- ZIO + .logWarning( + "WebSocket send before handshake completed, waiting for it to complete", + ) + .unlessZIO(handshakeCompleted.isDone) + successful <- handshakeCompleted.await + _ <- ZIO.interrupt.when(!successful) + } yield () + + def shutdown(implicit trace: Trace): UIO[Unit] = nettyChannel.close(false).orDie } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala index 6ca16548a9..e12606b164 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala @@ -125,9 +125,10 @@ final case class NettyClientDriver private[netty] ( webSocketConfig: WebSocketConfig, )(implicit trace: Trace): RIO[Scope, ChannelInterface] = { for { - queue <- Queue.unbounded[WebSocketChannelEvent] + queue <- Queue.unbounded[WebSocketChannelEvent] + handshakeCompleted <- Promise.make[Nothing, Boolean] nettyChannel = NettyChannel.make[JWebSocketFrame](channel) - webSocketChannel = WebSocketChannel.make(nettyChannel, queue) + webSocketChannel = WebSocketChannel.make(nettyChannel, queue, handshakeCompleted) app = createSocketApp() _ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped } yield { @@ -149,7 +150,7 @@ final case class NettyClientDriver private[netty] ( // Handles the heavy lifting required to upgrade the connection to a WebSocket connection val webSocketClientProtocol = new WebSocketClientProtocolHandler(config) - val webSocket = new WebSocketAppHandler(nettyRuntime, queue, Some(onComplete)) + val webSocket = new WebSocketAppHandler(nettyRuntime, queue, handshakeCompleted, Some(onComplete)) pipeline.addLast(Names.WebSocketClientProtocolHandler, webSocketClientProtocol) pipeline.addLast(Names.WebSocketHandler, webSocket) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index a91f638a0d..07a3a84448 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -265,40 +265,39 @@ private[zio] final case class ServerInboundHandler( request: Request, webSocketApp: WebSocketApp[Any], runtime: NettyRuntime, - ): Task[Unit] = { - Queue + ): Task[Unit] = for { + handshakeCompleted <- Promise.make[Nothing, Boolean] + queue <- Queue .unbounded[WebSocketChannelEvent] .tap { queue => ZIO.suspend { val nettyChannel = NettyChannel.make[JWebSocketFrame](ctx.channel()) - val webSocketChannel = WebSocketChannel.make(nettyChannel, queue) + val webSocketChannel = WebSocketChannel.make(nettyChannel, queue, handshakeCompleted) webSocketApp.handler.runZIO(webSocketChannel).ignoreLogged.forkDaemon } } - .flatMap { queue => - ZIO.attempt { - ctx - .channel() - .pipeline() - .addLast( - new WebSocketServerProtocolHandler( - NettySocketProtocol - .serverBuilder(webSocketApp.customConfig.getOrElse(config.webSocketConfig)) - .build(), - ), - ) - .addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, None)) - - val jReq = new DefaultFullHttpRequest( - Conversions.versionToNetty(request.version), - Conversions.methodToNetty(request.method), - Conversions.urlToNetty(request.url), - ) - jReq.headers().setAll(Conversions.headersToNetty(request.allHeaders)) - ctx.channel().eventLoop().submit { () => ctx.fireChannelRead(jReq) }: Unit - } - } - } + _ <- ZIO.attempt { + ctx + .channel() + .pipeline() + .addLast( + new WebSocketServerProtocolHandler( + NettySocketProtocol + .serverBuilder(webSocketApp.customConfig.getOrElse(config.webSocketConfig)) + .build(), + ), + ) + .addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, handshakeCompleted, None)) + + val jReq = new DefaultFullHttpRequest( + Conversions.versionToNetty(request.version), + Conversions.methodToNetty(request.method), + Conversions.urlToNetty(request.url), + ) + jReq.headers().setAll(Conversions.headersToNetty(request.allHeaders)) + ctx.channel().eventLoop().submit { () => ctx.fireChannelRead(jReq) } + } + } yield () private def writeResponse( ctx: ChannelHandlerContext, diff --git a/zio-http/jvm/src/test/scala/zio/http/ServerSentEventSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ServerSentEventSpec.scala index 109a5a028e..41ac74045f 100644 --- a/zio-http/jvm/src/test/scala/zio/http/ServerSentEventSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/ServerSentEventSpec.scala @@ -10,7 +10,7 @@ import zio.test._ import zio.stream.ZStream -object ServerSentEventSpec extends ZIOSpecDefault { +object ServerSentEventSpec extends ZIOHttpSpec { val stream: ZStream[Any, Nothing, ServerSentEvent[String]] = ZStream.repeatWithSchedule(ServerSentEvent(ISO_LOCAL_TIME.format(LocalDateTime.now)), Schedule.spaced(1.second)) diff --git a/zio-http/jvm/src/test/scala/zio/http/WebSocketConfig.scala b/zio-http/jvm/src/test/scala/zio/http/WebSocketConfigSpec.scala similarity index 100% rename from zio-http/jvm/src/test/scala/zio/http/WebSocketConfig.scala rename to zio-http/jvm/src/test/scala/zio/http/WebSocketConfigSpec.scala diff --git a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala index 3ea000d991..526db4629a 100644 --- a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala @@ -196,6 +196,31 @@ object WebSocketSpec extends HttpRunnableSpec { result <- queue2.takeAll } yield assertTrue(result == Chunk("1", "2", "3", "4", "5")) }, + test("send waits for handshake to complete") { + for { + url <- DynamicServer.httpURL + id <- DynamicServer.deploy { + Handler.webSocket { channel => + channel.send(Read(WebSocketFrame.text("immediate"))) + }.toRoutes + } + + queue <- Queue.unbounded[String] + result <- ZIO.scoped { + Handler.webSocket { channel => + channel.receiveAll { + case Read(WebSocketFrame.Text(s)) => + queue.offer(s) + case _ => + ZIO.unit + } + }.connect(url, Headers(DynamicServer.APP_ID, id)) *> + queue.take + } + } yield { + assertTrue(result == "immediate") + } + }, ) private val withStreamingEnabled = diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/ServerSentEventEndpointSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/ServerSentEventEndpointSpec.scala index 94389ffdc2..77a14a4f4e 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/ServerSentEventEndpointSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/ServerSentEventEndpointSpec.scala @@ -15,7 +15,7 @@ import zio.schema.{DeriveSchema, Schema} import zio.http._ import zio.http.codec.HttpCodec -object ServerSentEventEndpointSpec extends ZIOSpecDefault { +object ServerSentEventEndpointSpec extends ZIOHttpSpec { object StringPayload { val sseEndpoint: Endpoint[Unit, Unit, ZNothing, ZStream[Any, Nothing, ServerSentEvent[String]], AuthType.None] =