Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: WebSocketConfig.withForwardCloseFrames does not forward Close frames #2375 #2395

Merged
merged 10 commits into from
Oct 8, 2023
9 changes: 5 additions & 4 deletions zio-http/src/main/scala/zio/http/WebSocketConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ final case class WebSocketConfig(
) { self =>

/**
* Close frame to send, when close frame was not send manually.
* Close frame to send, when close frame was not sent manually.
*/
def closeFrame(code: Int, reason: String): WebSocketConfig =
self.copy(sendCloseFrame = WebSocketConfig.CloseStatus.Custom(code, reason))

/**
* Close frame to send, when close frame was not send manually.
* Close frame to send, when close frame was not sent manually.
*/
def closeStatus(status: WebSocketConfig.CloseStatus): WebSocketConfig = self.copy(sendCloseFrame = status)

Expand All @@ -53,9 +53,10 @@ final case class WebSocketConfig(
self.copy(forceCloseTimeoutMillis = duration.toMillis)

/**
* Close frames should be forwarded
* Close frames should be forwarded instead of handled solely by Netty,
* invisibly to the Websocket Client
*/
def forwardCloseFrames(forward: Boolean): WebSocketConfig = self.copy(handleCloseFrames = forward)
def forwardCloseFrames(forward: Boolean): WebSocketConfig = self.copy(handleCloseFrames = !forward)

/**
* Pong frames should be forwarded
Expand Down
111 changes: 111 additions & 0 deletions zio-http/src/test/scala/zio/http/WebSocketConfig.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package zio.http

import zio._
// import zio.test.Assertion.equalTo
import zio.test.TestAspect._
import zio.test.{TestClock, assertCompletes, assertTrue, assertZIO, testClock}

// import zio.http.ChannelEvent.UserEvent.HandshakeComplete
import zio.http.ChannelEvent.{Read, Unregistered, UserEvent, UserEventTriggered}
import zio.http.Client
import zio.http.DnsResolver
import zio.http.internal.{DynamicServer, HttpRunnableSpec, severTestLayer}
import zio.http.Client
import zio.http.netty.NettyConfig
import zio.http.Client
import zio.http.Client
// import zio.http.netty.NettyConfig

object WebSocketConfigSpec extends HttpRunnableSpec {

val closeFrame = Read(WebSocketFrame.Close(1000, Some("goodbye")))

private val webSocketConfigSpec = suite("WebSocketConfigSpec")(
test("Close frames are received when WebSocketConfig.forwardCloseFrames is true") {
for {
msg <- MessageCollector.make[WebSocketChannelEvent]
url <- DynamicServer.wsURL
id <- DynamicServer.deploy {
Handler.webSocket { channel =>
channel.receiveAll {
case UserEventTriggered(UserEvent.HandshakeComplete) =>
channel.send(closeFrame)
case _ => ZIO.unit
}
}.toHttpAppWS
}

res <- ZIO.scoped {
Handler.webSocket { channel =>
channel.receiveAll {
case event @ Read(WebSocketFrame.Close(_, _)) =>
msg.add(event, true)
case _ => ZIO.unit
}
}.connect(url, Headers(DynamicServer.APP_ID, id)) *> {
for {
events <- msg.await
expected = List(closeFrame)
} yield assertTrue(events == expected)
}
}
} yield res
},
)

def clientWithCloseFrames =
ZLayer.succeed(
ZClient.Config.default.webSocketConfig(
WebSocketConfig.default
.forwardCloseFrames(true),
),
) ++
ZLayer.succeed(NettyConfig.default) ++
DnsResolver.default >>>
Client.live

override def spec = suite("Server") {
ZIO.scoped {
serve.as(List(webSocketConfigSpec))
}
}
.provideShared(
DynamicServer.live,
severTestLayer,
clientWithCloseFrames,
Scope.default,
) @@
timeout(30 seconds) @@
diagnose(30.seconds) @@
withLiveClock @@
sequential

final class MessageCollector[A](ref: Ref[List[A]], promise: Promise[Nothing, Unit]) {
def add(a: A, isDone: Boolean = false): UIO[Unit] = ref.update(_ :+ a) <* promise.succeed(()).when(isDone)
def await: UIO[List[A]] = promise.await *> ref.get
def done: UIO[Boolean] = promise.succeed(())
}

object MessageCollector {
def make[A]: ZIO[Any, Nothing, MessageCollector[A]] = for {
ref <- Ref.make(List.empty[A])
prm <- Promise.make[Nothing, Unit]
} yield new MessageCollector(ref, prm)
}
}
6 changes: 4 additions & 2 deletions zio-http/src/test/scala/zio/http/WebSocketSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import zio.test.{TestClock, assertCompletes, assertTrue, assertZIO, testClock}

import zio.http.ChannelEvent.UserEvent.HandshakeComplete
import zio.http.ChannelEvent.{Read, Unregistered, UserEvent, UserEventTriggered}
import zio.http.internal.{DynamicServer, HttpRunnableSpec, serverTestLayer}
import zio.http.DnsResolver
import zio.http.internal.{DynamicServer, HttpRunnableSpec, severTestLayer}
import zio.http.netty.NettyConfig

object WebSocketSpec extends HttpRunnableSpec {

Expand Down Expand Up @@ -53,7 +55,7 @@ object WebSocketSpec extends HttpRunnableSpec {
channel.shutdown
case _ =>
ZIO.unit
}
}DnsResolver
}.connect(url, Headers(DynamicServer.APP_ID, id)) *> {
for {
events <- msg.await
Expand Down