Skip to content

Commit

Permalink
Sending to websocket before handshake completed (zio#3028)
Browse files Browse the repository at this point in the history
* Sending to websocket before handshake completed

Ensure any attempts to send to a websocket wait until the handshake has
been completed otherwise they will not be received.

* ensure test has a timeout

* log warning when send before handshake complete
  • Loading branch information
guymers authored Aug 29, 2024
1 parent e3930d8 commit b4dde6a
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,36 @@ import io.netty.handler.codec.http.websocketx.{WebSocketFrame => JWebSocketFrame
private[zio] final class WebSocketAppHandler(
zExec: NettyRuntime,
queue: Queue[WebSocketChannelEvent],
handshakeCompleted: Promise[Nothing, Boolean],
onComplete: Option[Promise[Throwable, ChannelState]],
)(implicit trace: Trace)
extends SimpleChannelInboundHandler[JWebSocketFrame] {

implicit private val unsafeClass: Unsafe = Unsafe.unsafe

private def dispatch(
ctx: ChannelHandlerContext,
event: ChannelEvent[JWebSocketFrame],
close: Boolean = false,
): Unit = {
private def dispatch(event: ChannelEvent[JWebSocketFrame]): Unit = {
// IMPORTANT: Offering to the queue must be run synchronously to avoid messages being added in the wrong order
// Since the queue is unbounded, this will not block the event loop
// TODO: We need to come up with a design that doesn't involve running an effect to offer to the queue
zExec.unsafeRunSync(queue.offer(event.map(frameFromNetty)))
onComplete match {
case Some(promise) if close => promise.unsafe.done(Exit.succeed(ChannelState.Invalid))
case _ => ()
}
val _ = zExec.unsafeRunSync(queue.offer(event.map(frameFromNetty)))
}

override def channelRead0(ctx: ChannelHandlerContext, msg: JWebSocketFrame): Unit =
dispatch(ctx, ChannelEvent.read(msg))
dispatch(ChannelEvent.read(msg))

override def channelRegistered(ctx: ChannelHandlerContext): Unit =
dispatch(ctx, ChannelEvent.registered)
dispatch(ChannelEvent.registered)

override def channelUnregistered(ctx: ChannelHandlerContext): Unit =
dispatch(ctx, ChannelEvent.unregistered, close = true)
override def channelUnregistered(ctx: ChannelHandlerContext): Unit = {
dispatch(ChannelEvent.unregistered)
onComplete match {
case Some(promise) => promise.unsafe.done(Exit.succeed(ChannelState.Invalid))
case None => ()
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
dispatch(ctx, ChannelEvent.exceptionCaught(cause))
dispatch(ChannelEvent.exceptionCaught(cause))
onComplete match {
case Some(promise) => promise.unsafe.done(Exit.fail(cause))
case None => ()
Expand All @@ -77,9 +75,11 @@ private[zio] final class WebSocketAppHandler(
override def userEventTriggered(ctx: ChannelHandlerContext, msg: AnyRef): Unit = {
msg match {
case _: WebSocketServerProtocolHandler.HandshakeComplete | ClientHandshakeStateEvent.HANDSHAKE_COMPLETE =>
dispatch(ctx, ChannelEvent.userEventTriggered(UserEvent.HandshakeComplete))
handshakeCompleted.unsafe.succeed(true)
dispatch(ChannelEvent.userEventTriggered(UserEvent.HandshakeComplete))
case ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT | ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT =>
dispatch(ctx, ChannelEvent.userEventTriggered(UserEvent.HandshakeTimeout))
handshakeCompleted.unsafe.succeed(false)
dispatch(ChannelEvent.userEventTriggered(UserEvent.HandshakeTimeout))
case _ => super.userEventTriggered(ctx, msg)
}
}
Expand Down
20 changes: 16 additions & 4 deletions zio-http/jvm/src/main/scala/zio/http/netty/WebSocketChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ private[http] object WebSocketChannel {
def make(
nettyChannel: NettyChannel[JWebSocketFrame],
queue: Queue[WebSocketChannelEvent],
handshakeCompleted: Promise[Nothing, Boolean],
): WebSocketChannel =
new WebSocketChannel {
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
Expand All @@ -51,14 +52,14 @@ private[http] object WebSocketChannel {
}

def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] = {
in match {
sendAwaitHandshakeCompleted *> (in match {
case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message))
case _ => ZIO.unit
}
})
}

def sendAll(in: Iterable[WebSocketChannelEvent])(implicit trace: Trace): Task[Unit] =
ZIO.suspendSucceed {
sendAwaitHandshakeCompleted *> ZIO.suspendSucceed {
val iterator = in.iterator.collect { case Read(message) => message }

ZIO.whileLoop(iterator.hasNext) {
Expand All @@ -67,7 +68,18 @@ private[http] object WebSocketChannel {
else nettyChannel.writeAndFlush(frameToNetty(message))
}(_ => ())
}
def shutdown(implicit trace: Trace): UIO[Unit] =

private def sendAwaitHandshakeCompleted: UIO[Unit] = for {
_ <- ZIO
.logWarning(
"WebSocket send before handshake completed, waiting for it to complete",
)
.unlessZIO(handshakeCompleted.isDone)
successful <- handshakeCompleted.await
_ <- ZIO.interrupt.when(!successful)
} yield ()

def shutdown(implicit trace: Trace): UIO[Unit] =
nettyChannel.close(false).orDie
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ final case class NettyClientDriver private[netty] (
webSocketConfig: WebSocketConfig,
)(implicit trace: Trace): RIO[Scope, ChannelInterface] = {
for {
queue <- Queue.unbounded[WebSocketChannelEvent]
queue <- Queue.unbounded[WebSocketChannelEvent]
handshakeCompleted <- Promise.make[Nothing, Boolean]
nettyChannel = NettyChannel.make[JWebSocketFrame](channel)
webSocketChannel = WebSocketChannel.make(nettyChannel, queue)
webSocketChannel = WebSocketChannel.make(nettyChannel, queue, handshakeCompleted)
app = createSocketApp()
_ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped
} yield {
Expand All @@ -149,7 +150,7 @@ final case class NettyClientDriver private[netty] (
// Handles the heavy lifting required to upgrade the connection to a WebSocket connection

val webSocketClientProtocol = new WebSocketClientProtocolHandler(config)
val webSocket = new WebSocketAppHandler(nettyRuntime, queue, Some(onComplete))
val webSocket = new WebSocketAppHandler(nettyRuntime, queue, handshakeCompleted, Some(onComplete))

pipeline.addLast(Names.WebSocketClientProtocolHandler, webSocketClientProtocol)
pipeline.addLast(Names.WebSocketHandler, webSocket)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,40 +265,39 @@ private[zio] final case class ServerInboundHandler(
request: Request,
webSocketApp: WebSocketApp[Any],
runtime: NettyRuntime,
): Task[Unit] = {
Queue
): Task[Unit] = for {
handshakeCompleted <- Promise.make[Nothing, Boolean]
queue <- Queue
.unbounded[WebSocketChannelEvent]
.tap { queue =>
ZIO.suspend {
val nettyChannel = NettyChannel.make[JWebSocketFrame](ctx.channel())
val webSocketChannel = WebSocketChannel.make(nettyChannel, queue)
val webSocketChannel = WebSocketChannel.make(nettyChannel, queue, handshakeCompleted)
webSocketApp.handler.runZIO(webSocketChannel).ignoreLogged.forkDaemon
}
}
.flatMap { queue =>
ZIO.attempt {
ctx
.channel()
.pipeline()
.addLast(
new WebSocketServerProtocolHandler(
NettySocketProtocol
.serverBuilder(webSocketApp.customConfig.getOrElse(config.webSocketConfig))
.build(),
),
)
.addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, None))

val jReq = new DefaultFullHttpRequest(
Conversions.versionToNetty(request.version),
Conversions.methodToNetty(request.method),
Conversions.urlToNetty(request.url),
)
jReq.headers().setAll(Conversions.headersToNetty(request.allHeaders))
ctx.channel().eventLoop().submit { () => ctx.fireChannelRead(jReq) }: Unit
}
}
}
_ <- ZIO.attempt {
ctx
.channel()
.pipeline()
.addLast(
new WebSocketServerProtocolHandler(
NettySocketProtocol
.serverBuilder(webSocketApp.customConfig.getOrElse(config.webSocketConfig))
.build(),
),
)
.addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, handshakeCompleted, None))

val jReq = new DefaultFullHttpRequest(
Conversions.versionToNetty(request.version),
Conversions.methodToNetty(request.method),
Conversions.urlToNetty(request.url),
)
jReq.headers().setAll(Conversions.headersToNetty(request.allHeaders))
ctx.channel().eventLoop().submit { () => ctx.fireChannelRead(jReq) }
}
} yield ()

private def writeResponse(
ctx: ChannelHandlerContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import zio.test._

import zio.stream.ZStream

object ServerSentEventSpec extends ZIOSpecDefault {
object ServerSentEventSpec extends ZIOHttpSpec {

val stream: ZStream[Any, Nothing, ServerSentEvent[String]] =
ZStream.repeatWithSchedule(ServerSentEvent(ISO_LOCAL_TIME.format(LocalDateTime.now)), Schedule.spaced(1.second))
Expand Down
25 changes: 25 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,31 @@ object WebSocketSpec extends HttpRunnableSpec {
result <- queue2.takeAll
} yield assertTrue(result == Chunk("1", "2", "3", "4", "5"))
},
test("send waits for handshake to complete") {
for {
url <- DynamicServer.httpURL
id <- DynamicServer.deploy {
Handler.webSocket { channel =>
channel.send(Read(WebSocketFrame.text("immediate")))
}.toRoutes
}

queue <- Queue.unbounded[String]
result <- ZIO.scoped {
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text(s)) =>
queue.offer(s)
case _ =>
ZIO.unit
}
}.connect(url, Headers(DynamicServer.APP_ID, id)) *>
queue.take
}
} yield {
assertTrue(result == "immediate")
}
},
)

private val withStreamingEnabled =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import zio.schema.{DeriveSchema, Schema}
import zio.http._
import zio.http.codec.HttpCodec

object ServerSentEventEndpointSpec extends ZIOSpecDefault {
object ServerSentEventEndpointSpec extends ZIOHttpSpec {

object StringPayload {
val sseEndpoint: Endpoint[Unit, Unit, ZNothing, ZStream[Any, Nothing, ServerSentEvent[String]], AuthType.None] =
Expand Down

0 comments on commit b4dde6a

Please sign in to comment.