From 748957097827bd1f4c61c57b9d5e9d561812bcdb Mon Sep 17 00:00:00 2001 From: Adam Fraser Date: Sun, 24 Sep 2023 09:29:24 -0700 Subject: [PATCH] Do Not Continue Reading From Web Socket After Terminal Event (#2441) do not continue reading after terminal event --- .../src/main/scala/zio/http/TestChannel.scala | 12 ++++++ .../src/main/scala/zio/http/Channel.scala | 37 ++++++++++--------- .../scala/zio/http/WebSocketChannel.scala | 13 +++++++ 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala b/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala index 05489d2895..6e3e118fa1 100644 --- a/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala +++ b/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala @@ -14,6 +14,18 @@ case class TestChannel( promise.await def receive(implicit trace: Trace): Task[WebSocketChannelEvent] = in.take + def receiveAll[Env, Err](f: WebSocketChannelEvent => ZIO[Env, Err, Any])(implicit + trace: Trace, + ): ZIO[Env, Err, Unit] = { + lazy val loop: ZIO[Env, Err, Unit] = + in.take.flatMap { + case event @ ChannelEvent.ExceptionCaught(_) => f(event).unit + case event @ ChannelEvent.Unregistered => f(event).unit + case event => f(event) *> ZIO.yieldNow *> loop + } + + loop + } def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] = out.offer(in).unit def sendAll(in: Iterable[WebSocketChannelEvent])(implicit trace: Trace): Task[Unit] = diff --git a/zio-http/src/main/scala/zio/http/Channel.scala b/zio-http/src/main/scala/zio/http/Channel.scala index 348f630ecd..a18527bc37 100644 --- a/zio-http/src/main/scala/zio/http/Channel.scala +++ b/zio-http/src/main/scala/zio/http/Channel.scala @@ -36,6 +36,12 @@ trait Channel[-In, +Out] { self => */ def receive(implicit trace: Trace): Task[Out] + /** + * Reads all messages from the channel, handling them with the specified + * function. + */ + def receiveAll[Env, Err](f: Out => ZIO[Env, Err, Any])(implicit trace: Trace): ZIO[Env, Err, Unit] + /** * Send a message to the channel. */ @@ -57,15 +63,17 @@ trait Channel[-In, +Out] { self => */ final def contramap[In2](f: In2 => In): Channel[In2, Out] = new Channel[In2, Out] { - def awaitShutdown(implicit trace: Trace): UIO[Unit] = + def awaitShutdown(implicit trace: Trace): UIO[Unit] = self.awaitShutdown - def receive(implicit trace: Trace): Task[Out] = + def receive(implicit trace: Trace): Task[Out] = self.receive - def send(in: In2)(implicit trace: Trace): Task[Unit] = + def receiveAll[Env, Err](g: Out => ZIO[Env, Err, Any])(implicit trace: Trace): ZIO[Env, Err, Unit] = + self.receiveAll(g) + def send(in: In2)(implicit trace: Trace): Task[Unit] = self.send(f(in)) - def sendAll(in: Iterable[In2])(implicit trace: Trace): Task[Unit] = + def sendAll(in: Iterable[In2])(implicit trace: Trace): Task[Unit] = self.sendAll(in.map(f)) - def shutdown(implicit trace: Trace): UIO[Unit] = + def shutdown(implicit trace: Trace): UIO[Unit] = self.shutdown } @@ -75,22 +83,17 @@ trait Channel[-In, +Out] { self => */ final def map[Out2](f: Out => Out2)(implicit trace: Trace): Channel[In, Out2] = new Channel[In, Out2] { - def awaitShutdown(implicit trace: Trace): UIO[Unit] = + def awaitShutdown(implicit trace: Trace): UIO[Unit] = self.awaitShutdown - def receive(implicit trace: Trace): Task[Out2] = + def receive(implicit trace: Trace): Task[Out2] = self.receive.map(f) - def send(in: In)(implicit trace: Trace): Task[Unit] = + def receiveAll[Env, Err](g: Out2 => ZIO[Env, Err, Any])(implicit trace: Trace): ZIO[Env, Err, Unit] = + self.receiveAll(f andThen g) + def send(in: In)(implicit trace: Trace): Task[Unit] = self.send(in) - def sendAll(in: Iterable[In])(implicit trace: Trace): Task[Unit] = + def sendAll(in: Iterable[In])(implicit trace: Trace): Task[Unit] = self.sendAll(in) - def shutdown(implicit trace: Trace): UIO[Unit] = + def shutdown(implicit trace: Trace): UIO[Unit] = self.shutdown } - - /** - * Reads all messages from the channel, handling them with the specified - * function. - */ - final def receiveAll[Env](f: Out => ZIO[Env, Throwable, Any])(implicit trace: Trace): ZIO[Env, Throwable, Nothing] = - receive.flatMap(f).forever } diff --git a/zio-http/src/main/scala/zio/http/WebSocketChannel.scala b/zio-http/src/main/scala/zio/http/WebSocketChannel.scala index 434b6ae5d5..6bc4bd3e42 100644 --- a/zio-http/src/main/scala/zio/http/WebSocketChannel.scala +++ b/zio-http/src/main/scala/zio/http/WebSocketChannel.scala @@ -38,6 +38,19 @@ private[http] object WebSocketChannel { def receive(implicit trace: Trace): Task[WebSocketChannelEvent] = queue.take + def receiveAll[Env, Err](f: WebSocketChannelEvent => ZIO[Env, Err, Any])(implicit + trace: Trace, + ): ZIO[Env, Err, Unit] = { + lazy val loop: ZIO[Env, Err, Unit] = + queue.take.flatMap { + case event @ ChannelEvent.ExceptionCaught(_) => f(event).unit + case event @ ChannelEvent.Unregistered => f(event).unit + case event => f(event) *> ZIO.yieldNow *> loop + } + + loop + } + def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] = { in match { case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message))