diff --git a/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala b/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala index 3dedf98413..f4e2777cb7 100644 --- a/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala +++ b/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala @@ -6,15 +6,22 @@ import zio.http.socket.{WebSocketChannel, WebSocketFrame} case class TestChannel(counterpartEvents: Queue[ChannelEvent.Event[WebSocketFrame]]) extends WebSocketChannel { override def autoRead(flag: Boolean)(implicit trace: Trace): UIO[Unit] = ??? - override def awaitClose(implicit trace: Trace): UIO[Unit] = ??? + override def awaitClose(implicit trace: Trace): UIO[Unit] = + close(true).orDie override def close(await: Boolean)(implicit trace: Trace): Task[Unit] = counterpartEvents.offer(ChannelEvent.ChannelUnregistered).unit override def contramap[A1](f: A1 => WebSocketFrame): Channel[A1] = ??? - override def flush(implicit trace: Trace): Task[Unit] = ??? + override def flush(implicit trace: Trace): Task[Unit] = + // There's not queuing as would happen in a real Netty server, so this will always be a NoOp + ZIO.unit + // TODO Is this ID meaningful in a test? + // We can either: + // - Give it a random ID in `make` + // - Hardcode it to "TestChannel" override def id(implicit trace: Trace): String = ??? override def isAutoRead(implicit trace: Trace): UIO[Boolean] = ??? diff --git a/zio-http-testkit/src/main/scala/zio/http/TestClient.scala b/zio-http-testkit/src/main/scala/zio/http/TestClient.scala index f2c46f6ee8..92c52e9c61 100644 --- a/zio-http-testkit/src/main/scala/zio/http/TestClient.scala +++ b/zio-http-testkit/src/main/scala/zio/http/TestClient.scala @@ -3,7 +3,7 @@ package zio.http import zio._ import zio.http.ChannelEvent.{ChannelUnregistered, UserEvent} import zio.http.model.{Headers, Method, Scheme, Status, Version} -import zio.http.socket.{SocketApp, WebSocketFrame} +import zio.http.socket.{SocketApp, WebSocketChannelEvent, WebSocketFrame} /** * Enables tests that use a client without needing a live Server @@ -65,7 +65,7 @@ final case class TestClient(behavior: Ref[HttpApp[Any, Throwable]], serverSocket previousBehavior <- behavior.get newBehavior = handler.andThen(_.provideEnvironment(r)) app: HttpApp[Any, Throwable] = Http.collectZIO(newBehavior) - _ <- behavior.set(previousBehavior ++ app) + _ <- behavior.set(previousBehavior.defaultWith(app)) } yield () val headers: Headers = Headers.empty @@ -124,16 +124,24 @@ final case class TestClient(behavior: Ref[HttpApp[Any, Throwable]], serverSocket } yield Response.status(Status.SwitchingProtocols) } + private val warnLongRunning = + ZIO + .log("Socket Application is taking a long time to run. You might have logic that does not terminate.") + .delay(15.seconds) + .withClock(Clock.ClockLive) *> ZIO.never + private def eventLoop(name: String, channel: TestChannel, app: SocketApp[Any], otherChannel: TestChannel) = (for { - pendEvent <- channel.pending - _ <- app.message.get.apply(ChannelEvent(otherChannel, pendEvent)) + pendEvent <- channel.pending race warnLongRunning + _ <- app.message.get + .apply(ChannelEvent(otherChannel, pendEvent)) + .tapError(e => ZIO.debug(s"Unexpected WebSocket $name error: " + e) *> otherChannel.close) _ <- ZIO.when(pendEvent == ChannelUnregistered) { otherChannel.close } - } yield pendEvent).repeatWhileZIO(event => ZIO.succeed(shouldContinue(event))) + } yield pendEvent).repeatWhile(event => shouldContinue(event)) - def shouldContinue(event: ChannelEvent.Event[WebSocketFrame]) = + private def shouldContinue(event: ChannelEvent.Event[WebSocketFrame]) = event match { case ChannelEvent.ExceptionCaught(_) => false case ChannelEvent.ChannelRead(message) => @@ -150,12 +158,14 @@ final case class TestClient(behavior: Ref[HttpApp[Any, Throwable]], serverSocket case ChannelEvent.ChannelUnregistered => false } - def addSocketApp[Env1]( - app: SocketApp[Env1], + def installSocketApp[Env1]( + app: Http[Any, Throwable, WebSocketChannelEvent, Unit], ): ZIO[Env1, Nothing, Unit] = for { env <- ZIO.environment[Env1] - _ <- serverSocketBehavior.set(app.provideEnvironment(env)) + _ <- serverSocketBehavior.set( + app.defaultWith(TestClient.warnOnUnrecognizedEvent).toSocketApp.provideEnvironment(env), + ) } yield () } @@ -196,10 +206,10 @@ object TestClient { ): ZIO[R with TestClient, Nothing, Unit] = ZIO.serviceWithZIO[TestClient](_.addHandler(handler)) - def addSocketApp[Env1]( - app: SocketApp[Env1], - ): ZIO[TestClient with Env1, Nothing, Unit] = - ZIO.serviceWithZIO[TestClient](_.addSocketApp(app)) + def installSocketApp( + app: Http[Any, Throwable, WebSocketChannelEvent, Unit], + ): ZIO[TestClient, Nothing, Unit] = + ZIO.serviceWithZIO[TestClient](_.installSocketApp(app)) val layer: ZLayer[Any, Nothing, TestClient] = ZLayer.scoped { @@ -208,4 +218,9 @@ object TestClient { socketBehavior <- Ref.make[SocketApp[Any]](SocketApp.apply(_ => ZIO.unit)) } yield TestClient(behavior, socketBehavior) } + + private val warnOnUnrecognizedEvent = Http.collectZIO[WebSocketChannelEvent] { case other => + ZIO.fail(new Exception("Test Server received Unexpected event: " + other)) + } + } diff --git a/zio-http-testkit/src/main/scala/zio/http/TestServer.scala b/zio-http-testkit/src/main/scala/zio/http/TestServer.scala index f13918d214..e88dbbb1d5 100644 --- a/zio-http-testkit/src/main/scala/zio/http/TestServer.scala +++ b/zio-http-testkit/src/main/scala/zio/http/TestServer.scala @@ -79,7 +79,14 @@ final case class TestServer(driver: Driver, bindPort: Int) extends Server { override def install[R](httpApp: HttpApp[R, Throwable], errorCallback: Option[ErrorCallback])(implicit trace: zio.Trace, ): URIO[R, Unit] = - ZIO.environment[R].flatMap(driver.addApp(httpApp, _)) *> setErrorCallback(errorCallback) + ZIO + .environment[R] + .flatMap( + driver.addApp( + httpApp, + _, + ), + ) *> setErrorCallback(errorCallback) private def setErrorCallback(errorCallback: Option[ErrorCallback]): UIO[Unit] = driver diff --git a/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala b/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala index 5a879606b4..ff374fd857 100644 --- a/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala +++ b/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala @@ -1,114 +1,139 @@ package zio.http +import zio.Console.printLine import zio._ import zio.http.ChannelEvent.{ChannelRead, ChannelUnregistered, UserEvent, UserEventTriggered} -import zio.http.ServerConfig.LeakDetectionLevel import zio.http.model.Status import zio.http.netty.server.NettyDriver import zio.http.socket._ import zio.test._ object SocketContractSpec extends ZIOSpecDefault { - val testServerConfig: ZLayer[Any, Nothing, ServerConfig] = - ZLayer.succeed(ServerConfig.default.port(0).leakDetection(LeakDetectionLevel.PARANOID)) - val severTestLayer = testServerConfig >+> Server.live - - val messageFilter: Http[Any, Nothing, WebSocketChannelEvent, (Channel[WebSocketFrame], String)] = + private val messageFilter: Http[Any, Nothing, WebSocketChannelEvent, (Channel[WebSocketFrame], String)] = Http.collect[WebSocketChannelEvent] { case ChannelEvent(channel, ChannelRead(WebSocketFrame.Text(message))) => (channel, message) } - val messageSocketServer: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>> - Http.collectZIO[(WebSocketChannel, String)] { - case (ch, text) if text.contains("Hi Server") => - ZIO.debug("Server got message: " + text) *> ch.close() - case (_, text) => // TODO remove? - ZIO.debug("Unrecognized message sent to server: " + text) - } - - def channelSocketServer(p: Promise[Throwable, Unit]): Http[Any, Throwable, WebSocketChannelEvent, Unit] = - Http.collectZIO[WebSocketChannelEvent] { - case ChannelEvent(ch, UserEventTriggered(UserEvent.HandshakeComplete)) => - ch.writeAndFlush(WebSocketFrame.text("Hi Client")) - - case ChannelEvent(_, ChannelRead(WebSocketFrame.Close(status, reason))) => - p.succeed(()) *> - Console.printLine("Closing channel with status: " + status + " and reason: " + reason) - case ChannelEvent(_, ChannelUnregistered) => - p.succeed(()) *> - Console.printLine("Server Channel unregistered") - case ChannelEvent(ch, ChannelRead(WebSocketFrame.Text("Hi Server"))) => - ch.write(WebSocketFrame.text("Hi Client")) - - case ChannelEvent(_, other) => - Console.printLine("Server Other: " + other) - } - - val protocol = SocketProtocol.default.withSubProtocol(Some("json")) - - val decoder = SocketDecoder.default.withExtensions(allowed = true) - - def socketAppServer(p: Promise[Throwable, Unit]): SocketApp[Any] = - (messageSocketServer ++ channelSocketServer(p)).toSocketApp - .withDecoder(decoder) - .withProtocol(protocol) + private val warnOnUnrecognizedEvent = Http.collectZIO[WebSocketChannelEvent] { case other => + ZIO.fail(new Exception("Unexpected event: " + other)) + } - sys.props.put("ZIOHttpLogLevel", "DEBUG") def spec = suite("SocketOps")( - contract( - "Live", - ZIO.serviceWithZIO[Server](server => - for { - p <- Promise.make[Throwable, Unit] - _ <- server.install(socketAppServer(p).toHttp) - - } yield (server.port, p), - ), - ).provide(Client.default, Scope.default, TestServer.layer, NettyDriver.default, ServerConfig.liveOnOpenPort), - contract( - "Test", { - for { - p <- Promise.make[Throwable, Unit] - _ <- TestClient.addSocketApp(socketAppServer(p)) - - } yield (0, p) + contract("Successful Multi-message application") { p => + def channelSocketServer: Http[Any, Throwable, WebSocketChannelEvent, Unit] = + Http + .collectZIO[WebSocketChannelEvent] { + case ChannelEvent(ch, UserEventTriggered(UserEvent.HandshakeComplete)) => + ch.writeAndFlush(WebSocketFrame.text("Hi Client")) + case ChannelEvent(_, ChannelUnregistered) => + p.succeed(()) *> + printLine("Server Channel unregistered") + case ChannelEvent(ch, ChannelRead(WebSocketFrame.Text("Hi Server"))) => + ch.close() + case ChannelEvent(_, other) => + printLine("Server Unexpected: " + other) + } + .defaultWith(warnOnUnrecognizedEvent) + + val messageSocketServer: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>> + Http.collectZIO[(WebSocketChannel, String)] { + case (ch, text) if text.contains("Hi Server") => + printLine("Server got message: " + text) *> ch.close() + } + + messageSocketServer + .defaultWith(channelSocketServer) + } { _ => + val messageSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>> + Http.collectZIO[(WebSocketChannel, String)] { + case (ch, text) if text.contains("Hi Client") => + ch.writeAndFlush(WebSocketFrame.text("Hi Server"), await = true) + } + + val channelSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] = + Http.collectZIO[WebSocketChannelEvent] { + case ChannelEvent(_, ChannelUnregistered) => + printLine("Client Channel unregistered") + + case ChannelEvent(_, other) => + printLine("Client received Unexpected event: " + other) + } + + messageSocketClient.defaultWith(channelSocketClient) + }, + contract("Application where server app fails")(_ => + Http.collectZIO[WebSocketChannelEvent] { + case ChannelEvent(_, UserEventTriggered(UserEvent.HandshakeComplete)) => + ZIO.fail(new Exception("Broken server")) }, - ) - .provide(TestClient.layer, Scope.default), - ) - - def contract[R](name: String, serverSetup: ZIO[R, Nothing, (Int, Promise[Throwable, Unit])]) = - test(name) { - val messageSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>> - Http.collectZIO[(WebSocketChannel, String)] { - case (ch, text) if text.contains("Hi Client") => - ch.writeAndFlush(WebSocketFrame.text("Hi Server"), await = true).debug("Client got message: " + text) + ) { p => + Http.collectZIO[WebSocketChannelEvent] { case ChannelEvent(ch, ChannelUnregistered) => + printLine("Server failed and killed socket. Should complete promise.") *> + p.succeed(()).unit } - - val channelSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] = + }, + contract("Application where client app fails")(p => Http.collectZIO[WebSocketChannelEvent] { - case ChannelEvent(_, ChannelUnregistered) => - Console.printLine("Client Channel unregistered") - - case ChannelEvent(_, other) => - Console.printLine("Client received other event: " + other) + case ChannelEvent(_, UserEventTriggered(UserEvent.HandshakeComplete)) => ZIO.unit + case ChannelEvent(_, ChannelUnregistered) => + printLine("Client failed and killed socket. Should complete promise.") *> + p.succeed(()).unit + }, + ) { _ => + Http.collectZIO[WebSocketChannelEvent] { + case ChannelEvent(_, UserEventTriggered(UserEvent.HandshakeComplete)) => + ZIO.fail(new Exception("Broken client")) } + }, + ) - val httpSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] = - messageSocketClient ++ channelSocketClient - - val socketAppClient: SocketApp[Any] = - httpSocketClient.toSocketApp - .withDecoder(decoder) - .withProtocol(protocol) + private def contract( + name: String, + )( + serverApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit], + )(clientApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit]) = { + suite(name)( + test("Live") { + for { + portAndPromise <- liveServerSetup(serverApp) + (port, promise) = portAndPromise + response <- ZIO.serviceWithZIO[Client]( + _.socket(s"ws://localhost:$port/", clientApp(promise).toSocketApp), + ) + _ <- promise.await.timeout(10.seconds) + } yield assertTrue(response.status == Status.SwitchingProtocols) + }.provide(Client.default, Scope.default, TestServer.layer, NettyDriver.default, ServerConfig.liveOnOpenPort), + test("Test") { + for { + portAndPromise <- testServerSetup(serverApp) + (port, promise) = portAndPromise + response <- ZIO.serviceWithZIO[Client]( + _.socket(s"ws://localhost:$port/", clientApp(promise).toSocketApp), + ) + _ <- promise.await.timeout(10.seconds) + } yield assertTrue(response.status == Status.SwitchingProtocols) + }.provide(TestClient.layer, Scope.default), + ) + } + private def liveServerSetup( + serverApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit], + ) = + ZIO.serviceWithZIO[Server](server => for { - portAndPromise <- serverSetup - response <- ZIO.serviceWithZIO[Client](_.socket(s"ws://localhost:${portAndPromise._1}/", socketAppClient)) - _ <- portAndPromise._2.await - } yield assertTrue(response.status == Status.SwitchingProtocols) - } + p <- Promise.make[Throwable, Unit] + _ <- server.install(serverApp(p).toSocketApp.toHttp) + } yield (server.port, p), + ) + + private def testServerSetup( + serverApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit], + ) = + for { + p <- Promise.make[Throwable, Unit] + _ <- TestClient.installSocketApp(serverApp(p)) + } yield (0, p) } diff --git a/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala b/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala index d38c5949c7..f05d51cf02 100644 --- a/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala +++ b/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala @@ -55,7 +55,7 @@ object TestClientSpec extends ZIOSpecDefault { ), suite("socket ops")( test("happy path") { - val messageUnwrapper: Http[Any, Nothing, WebSocketChannelEvent, (Channel[WebSocketFrame], String)] = + val messageUnwrapper: Http[Any, Nothing, WebSocketChannelEvent, (WebSocketChannel, String)] = Http.collect[WebSocketChannelEvent] { case ChannelEvent(channel, ChannelRead(WebSocketFrame.Text(message))) => (channel, message) @@ -88,7 +88,7 @@ object TestClientSpec extends ZIOSpecDefault { messageSocketServer ++ channelSocketServer for { - _ <- TestClient.addSocketApp(httpSocketServer.toSocketApp) + _ <- TestClient.installSocketApp(httpSocketServer) response <- ZIO.serviceWithZIO[Client](_.socket(pathSuffix = "")(httpSocketClient.toSocketApp)) } yield assertTrue(response.status == Status.SwitchingProtocols) }, diff --git a/zio-http/src/main/scala/zio/http/package.scala b/zio-http/src/main/scala/zio/http/package.scala index b22bf55cd0..8b1426db0e 100644 --- a/zio-http/src/main/scala/zio/http/package.scala +++ b/zio-http/src/main/scala/zio/http/package.scala @@ -1,13 +1,15 @@ package zio +import zio.http.socket.WebSocketChannelEvent import zio.stacktracer.TracingImplicits.disableAutoTrace // scalafix:ok; package object http extends PathSyntax with RequestSyntax with RouteDecoderModule { - type HttpApp[-R, +E] = Http[R, E, Request, Response] - type UHttpApp = HttpApp[Any, Nothing] - type RHttpApp[-R] = HttpApp[R, Throwable] - type EHttpApp = HttpApp[Any, Throwable] - type UHttp[-A, +B] = Http[Any, Nothing, A, B] + type HttpApp[-R, +E] = Http[R, E, Request, Response] + type UHttpApp = HttpApp[Any, Nothing] + type RHttpApp[-R] = HttpApp[R, Throwable] + type EHttpApp = HttpApp[Any, Throwable] + type UHttp[-A, +B] = Http[Any, Nothing, A, B] + type ResponseZIO[-R, +E] = ZIO[R, E, Response] type UMiddleware[+AIn, -BIn, -AOut, +BOut] = Middleware[Any, Nothing, AIn, BIn, AOut, BOut]