Skip to content

Commit

Permalink
Test client socket improvements and tests (#1753)
Browse files Browse the repository at this point in the history
* Various cleanup. Add HttpSocket alias.

* Work on more user-facing/SocketApp oriented Channel

* More removals and notes

* More API auditing. Warn about long-running socket apps in TestClient.

* Start on Server failure contract app

* Complete Server failure test for Live Client

* Get rid of SocketApp Channel experimentation
I need to grok the backpressure that Tushar discussed more before taking another run at this.

* Fix broken refs to experiments

* Create/use general contract.
Temporarily quiet Netty exception output.

* Cleanup

* Restructure contract function

* Conver other test to new contract function

* Warn about unexpected events sent to server in TestClient

* TestClient test for broken Client app

* Cleanup for PR

* More cleanup

* Restore noisy Netty output

* rm HTTPSocket type alias
  • Loading branch information
swoogles authored Nov 11, 2022
1 parent ac2bf3c commit 0d0f2c0
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 110 deletions.
11 changes: 9 additions & 2 deletions zio-http-testkit/src/main/scala/zio/http/TestChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@ import zio.http.socket.{WebSocketChannel, WebSocketFrame}
case class TestChannel(counterpartEvents: Queue[ChannelEvent.Event[WebSocketFrame]]) extends WebSocketChannel {
override def autoRead(flag: Boolean)(implicit trace: Trace): UIO[Unit] = ???

override def awaitClose(implicit trace: Trace): UIO[Unit] = ???
override def awaitClose(implicit trace: Trace): UIO[Unit] =
close(true).orDie

override def close(await: Boolean)(implicit trace: Trace): Task[Unit] =
counterpartEvents.offer(ChannelEvent.ChannelUnregistered).unit

override def contramap[A1](f: A1 => WebSocketFrame): Channel[A1] = ???

override def flush(implicit trace: Trace): Task[Unit] = ???
override def flush(implicit trace: Trace): Task[Unit] =
// There's not queuing as would happen in a real Netty server, so this will always be a NoOp
ZIO.unit

// TODO Is this ID meaningful in a test?
// We can either:
// - Give it a random ID in `make`
// - Hardcode it to "TestChannel"
override def id(implicit trace: Trace): String = ???

override def isAutoRead(implicit trace: Trace): UIO[Boolean] = ???
Expand Down
41 changes: 28 additions & 13 deletions zio-http-testkit/src/main/scala/zio/http/TestClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package zio.http
import zio._
import zio.http.ChannelEvent.{ChannelUnregistered, UserEvent}
import zio.http.model.{Headers, Method, Scheme, Status, Version}
import zio.http.socket.{SocketApp, WebSocketFrame}
import zio.http.socket.{SocketApp, WebSocketChannelEvent, WebSocketFrame}

/**
* Enables tests that use a client without needing a live Server
Expand Down Expand Up @@ -65,7 +65,7 @@ final case class TestClient(behavior: Ref[HttpApp[Any, Throwable]], serverSocket
previousBehavior <- behavior.get
newBehavior = handler.andThen(_.provideEnvironment(r))
app: HttpApp[Any, Throwable] = Http.collectZIO(newBehavior)
_ <- behavior.set(previousBehavior ++ app)
_ <- behavior.set(previousBehavior.defaultWith(app))
} yield ()

val headers: Headers = Headers.empty
Expand Down Expand Up @@ -124,16 +124,24 @@ final case class TestClient(behavior: Ref[HttpApp[Any, Throwable]], serverSocket
} yield Response.status(Status.SwitchingProtocols)
}

private val warnLongRunning =
ZIO
.log("Socket Application is taking a long time to run. You might have logic that does not terminate.")
.delay(15.seconds)
.withClock(Clock.ClockLive) *> ZIO.never

private def eventLoop(name: String, channel: TestChannel, app: SocketApp[Any], otherChannel: TestChannel) =
(for {
pendEvent <- channel.pending
_ <- app.message.get.apply(ChannelEvent(otherChannel, pendEvent))
pendEvent <- channel.pending race warnLongRunning
_ <- app.message.get
.apply(ChannelEvent(otherChannel, pendEvent))
.tapError(e => ZIO.debug(s"Unexpected WebSocket $name error: " + e) *> otherChannel.close)
_ <- ZIO.when(pendEvent == ChannelUnregistered) {
otherChannel.close
}
} yield pendEvent).repeatWhileZIO(event => ZIO.succeed(shouldContinue(event)))
} yield pendEvent).repeatWhile(event => shouldContinue(event))

def shouldContinue(event: ChannelEvent.Event[WebSocketFrame]) =
private def shouldContinue(event: ChannelEvent.Event[WebSocketFrame]) =
event match {
case ChannelEvent.ExceptionCaught(_) => false
case ChannelEvent.ChannelRead(message) =>
Expand All @@ -150,12 +158,14 @@ final case class TestClient(behavior: Ref[HttpApp[Any, Throwable]], serverSocket
case ChannelEvent.ChannelUnregistered => false
}

def addSocketApp[Env1](
app: SocketApp[Env1],
def installSocketApp[Env1](
app: Http[Any, Throwable, WebSocketChannelEvent, Unit],
): ZIO[Env1, Nothing, Unit] =
for {
env <- ZIO.environment[Env1]
_ <- serverSocketBehavior.set(app.provideEnvironment(env))
_ <- serverSocketBehavior.set(
app.defaultWith(TestClient.warnOnUnrecognizedEvent).toSocketApp.provideEnvironment(env),
)
} yield ()
}

Expand Down Expand Up @@ -196,10 +206,10 @@ object TestClient {
): ZIO[R with TestClient, Nothing, Unit] =
ZIO.serviceWithZIO[TestClient](_.addHandler(handler))

def addSocketApp[Env1](
app: SocketApp[Env1],
): ZIO[TestClient with Env1, Nothing, Unit] =
ZIO.serviceWithZIO[TestClient](_.addSocketApp(app))
def installSocketApp(
app: Http[Any, Throwable, WebSocketChannelEvent, Unit],
): ZIO[TestClient, Nothing, Unit] =
ZIO.serviceWithZIO[TestClient](_.installSocketApp(app))

val layer: ZLayer[Any, Nothing, TestClient] =
ZLayer.scoped {
Expand All @@ -208,4 +218,9 @@ object TestClient {
socketBehavior <- Ref.make[SocketApp[Any]](SocketApp.apply(_ => ZIO.unit))
} yield TestClient(behavior, socketBehavior)
}

private val warnOnUnrecognizedEvent = Http.collectZIO[WebSocketChannelEvent] { case other =>
ZIO.fail(new Exception("Test Server received Unexpected event: " + other))
}

}
9 changes: 8 additions & 1 deletion zio-http-testkit/src/main/scala/zio/http/TestServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ final case class TestServer(driver: Driver, bindPort: Int) extends Server {
override def install[R](httpApp: HttpApp[R, Throwable], errorCallback: Option[ErrorCallback])(implicit
trace: zio.Trace,
): URIO[R, Unit] =
ZIO.environment[R].flatMap(driver.addApp(httpApp, _)) *> setErrorCallback(errorCallback)
ZIO
.environment[R]
.flatMap(
driver.addApp(
httpApp,
_,
),
) *> setErrorCallback(errorCallback)

private def setErrorCallback(errorCallback: Option[ErrorCallback]): UIO[Unit] =
driver
Expand Down
199 changes: 112 additions & 87 deletions zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala
Original file line number Diff line number Diff line change
@@ -1,114 +1,139 @@
package zio.http

import zio.Console.printLine
import zio._
import zio.http.ChannelEvent.{ChannelRead, ChannelUnregistered, UserEvent, UserEventTriggered}
import zio.http.ServerConfig.LeakDetectionLevel
import zio.http.model.Status
import zio.http.netty.server.NettyDriver
import zio.http.socket._
import zio.test._

object SocketContractSpec extends ZIOSpecDefault {
val testServerConfig: ZLayer[Any, Nothing, ServerConfig] =
ZLayer.succeed(ServerConfig.default.port(0).leakDetection(LeakDetectionLevel.PARANOID))

val severTestLayer = testServerConfig >+> Server.live

val messageFilter: Http[Any, Nothing, WebSocketChannelEvent, (Channel[WebSocketFrame], String)] =
private val messageFilter: Http[Any, Nothing, WebSocketChannelEvent, (Channel[WebSocketFrame], String)] =
Http.collect[WebSocketChannelEvent] { case ChannelEvent(channel, ChannelRead(WebSocketFrame.Text(message))) =>
(channel, message)
}

val messageSocketServer: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>>
Http.collectZIO[(WebSocketChannel, String)] {
case (ch, text) if text.contains("Hi Server") =>
ZIO.debug("Server got message: " + text) *> ch.close()
case (_, text) => // TODO remove?
ZIO.debug("Unrecognized message sent to server: " + text)
}

def channelSocketServer(p: Promise[Throwable, Unit]): Http[Any, Throwable, WebSocketChannelEvent, Unit] =
Http.collectZIO[WebSocketChannelEvent] {
case ChannelEvent(ch, UserEventTriggered(UserEvent.HandshakeComplete)) =>
ch.writeAndFlush(WebSocketFrame.text("Hi Client"))

case ChannelEvent(_, ChannelRead(WebSocketFrame.Close(status, reason))) =>
p.succeed(()) *>
Console.printLine("Closing channel with status: " + status + " and reason: " + reason)
case ChannelEvent(_, ChannelUnregistered) =>
p.succeed(()) *>
Console.printLine("Server Channel unregistered")
case ChannelEvent(ch, ChannelRead(WebSocketFrame.Text("Hi Server"))) =>
ch.write(WebSocketFrame.text("Hi Client"))

case ChannelEvent(_, other) =>
Console.printLine("Server Other: " + other)
}

val protocol = SocketProtocol.default.withSubProtocol(Some("json"))

val decoder = SocketDecoder.default.withExtensions(allowed = true)

def socketAppServer(p: Promise[Throwable, Unit]): SocketApp[Any] =
(messageSocketServer ++ channelSocketServer(p)).toSocketApp
.withDecoder(decoder)
.withProtocol(protocol)
private val warnOnUnrecognizedEvent = Http.collectZIO[WebSocketChannelEvent] { case other =>
ZIO.fail(new Exception("Unexpected event: " + other))
}

sys.props.put("ZIOHttpLogLevel", "DEBUG")
def spec =
suite("SocketOps")(
contract(
"Live",
ZIO.serviceWithZIO[Server](server =>
for {
p <- Promise.make[Throwable, Unit]
_ <- server.install(socketAppServer(p).toHttp)

} yield (server.port, p),
),
).provide(Client.default, Scope.default, TestServer.layer, NettyDriver.default, ServerConfig.liveOnOpenPort),
contract(
"Test", {
for {
p <- Promise.make[Throwable, Unit]
_ <- TestClient.addSocketApp(socketAppServer(p))

} yield (0, p)
contract("Successful Multi-message application") { p =>
def channelSocketServer: Http[Any, Throwable, WebSocketChannelEvent, Unit] =
Http
.collectZIO[WebSocketChannelEvent] {
case ChannelEvent(ch, UserEventTriggered(UserEvent.HandshakeComplete)) =>
ch.writeAndFlush(WebSocketFrame.text("Hi Client"))
case ChannelEvent(_, ChannelUnregistered) =>
p.succeed(()) *>
printLine("Server Channel unregistered")
case ChannelEvent(ch, ChannelRead(WebSocketFrame.Text("Hi Server"))) =>
ch.close()
case ChannelEvent(_, other) =>
printLine("Server Unexpected: " + other)
}
.defaultWith(warnOnUnrecognizedEvent)

val messageSocketServer: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>>
Http.collectZIO[(WebSocketChannel, String)] {
case (ch, text) if text.contains("Hi Server") =>
printLine("Server got message: " + text) *> ch.close()
}

messageSocketServer
.defaultWith(channelSocketServer)
} { _ =>
val messageSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>>
Http.collectZIO[(WebSocketChannel, String)] {
case (ch, text) if text.contains("Hi Client") =>
ch.writeAndFlush(WebSocketFrame.text("Hi Server"), await = true)
}

val channelSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] =
Http.collectZIO[WebSocketChannelEvent] {
case ChannelEvent(_, ChannelUnregistered) =>
printLine("Client Channel unregistered")

case ChannelEvent(_, other) =>
printLine("Client received Unexpected event: " + other)
}

messageSocketClient.defaultWith(channelSocketClient)
},
contract("Application where server app fails")(_ =>
Http.collectZIO[WebSocketChannelEvent] {
case ChannelEvent(_, UserEventTriggered(UserEvent.HandshakeComplete)) =>
ZIO.fail(new Exception("Broken server"))
},
)
.provide(TestClient.layer, Scope.default),
)

def contract[R](name: String, serverSetup: ZIO[R, Nothing, (Int, Promise[Throwable, Unit])]) =
test(name) {
val messageSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] = messageFilter >>>
Http.collectZIO[(WebSocketChannel, String)] {
case (ch, text) if text.contains("Hi Client") =>
ch.writeAndFlush(WebSocketFrame.text("Hi Server"), await = true).debug("Client got message: " + text)
) { p =>
Http.collectZIO[WebSocketChannelEvent] { case ChannelEvent(ch, ChannelUnregistered) =>
printLine("Server failed and killed socket. Should complete promise.") *>
p.succeed(()).unit
}

val channelSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] =
},
contract("Application where client app fails")(p =>
Http.collectZIO[WebSocketChannelEvent] {
case ChannelEvent(_, ChannelUnregistered) =>
Console.printLine("Client Channel unregistered")

case ChannelEvent(_, other) =>
Console.printLine("Client received other event: " + other)
case ChannelEvent(_, UserEventTriggered(UserEvent.HandshakeComplete)) => ZIO.unit
case ChannelEvent(_, ChannelUnregistered) =>
printLine("Client failed and killed socket. Should complete promise.") *>
p.succeed(()).unit
},
) { _ =>
Http.collectZIO[WebSocketChannelEvent] {
case ChannelEvent(_, UserEventTriggered(UserEvent.HandshakeComplete)) =>
ZIO.fail(new Exception("Broken client"))
}
},
)

val httpSocketClient: Http[Any, Throwable, WebSocketChannelEvent, Unit] =
messageSocketClient ++ channelSocketClient

val socketAppClient: SocketApp[Any] =
httpSocketClient.toSocketApp
.withDecoder(decoder)
.withProtocol(protocol)
private def contract(
name: String,
)(
serverApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit],
)(clientApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit]) = {
suite(name)(
test("Live") {
for {
portAndPromise <- liveServerSetup(serverApp)
(port, promise) = portAndPromise
response <- ZIO.serviceWithZIO[Client](
_.socket(s"ws://localhost:$port/", clientApp(promise).toSocketApp),
)
_ <- promise.await.timeout(10.seconds)
} yield assertTrue(response.status == Status.SwitchingProtocols)
}.provide(Client.default, Scope.default, TestServer.layer, NettyDriver.default, ServerConfig.liveOnOpenPort),
test("Test") {
for {
portAndPromise <- testServerSetup(serverApp)
(port, promise) = portAndPromise
response <- ZIO.serviceWithZIO[Client](
_.socket(s"ws://localhost:$port/", clientApp(promise).toSocketApp),
)
_ <- promise.await.timeout(10.seconds)
} yield assertTrue(response.status == Status.SwitchingProtocols)
}.provide(TestClient.layer, Scope.default),
)
}

private def liveServerSetup(
serverApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit],
) =
ZIO.serviceWithZIO[Server](server =>
for {
portAndPromise <- serverSetup
response <- ZIO.serviceWithZIO[Client](_.socket(s"ws://localhost:${portAndPromise._1}/", socketAppClient))
_ <- portAndPromise._2.await
} yield assertTrue(response.status == Status.SwitchingProtocols)
}
p <- Promise.make[Throwable, Unit]
_ <- server.install(serverApp(p).toSocketApp.toHttp)
} yield (server.port, p),
)

private def testServerSetup(
serverApp: Promise[Throwable, Unit] => Http[Any, Throwable, WebSocketChannelEvent, Unit],
) =
for {
p <- Promise.make[Throwable, Unit]
_ <- TestClient.installSocketApp(serverApp(p))
} yield (0, p)

}
4 changes: 2 additions & 2 deletions zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object TestClientSpec extends ZIOSpecDefault {
),
suite("socket ops")(
test("happy path") {
val messageUnwrapper: Http[Any, Nothing, WebSocketChannelEvent, (Channel[WebSocketFrame], String)] =
val messageUnwrapper: Http[Any, Nothing, WebSocketChannelEvent, (WebSocketChannel, String)] =
Http.collect[WebSocketChannelEvent] {
case ChannelEvent(channel, ChannelRead(WebSocketFrame.Text(message))) =>
(channel, message)
Expand Down Expand Up @@ -88,7 +88,7 @@ object TestClientSpec extends ZIOSpecDefault {
messageSocketServer ++ channelSocketServer

for {
_ <- TestClient.addSocketApp(httpSocketServer.toSocketApp)
_ <- TestClient.installSocketApp(httpSocketServer)
response <- ZIO.serviceWithZIO[Client](_.socket(pathSuffix = "")(httpSocketClient.toSocketApp))
} yield assertTrue(response.status == Status.SwitchingProtocols)
},
Expand Down
12 changes: 7 additions & 5 deletions zio-http/src/main/scala/zio/http/package.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package zio

import zio.http.socket.WebSocketChannelEvent
import zio.stacktracer.TracingImplicits.disableAutoTrace // scalafix:ok;

package object http extends PathSyntax with RequestSyntax with RouteDecoderModule {
type HttpApp[-R, +E] = Http[R, E, Request, Response]
type UHttpApp = HttpApp[Any, Nothing]
type RHttpApp[-R] = HttpApp[R, Throwable]
type EHttpApp = HttpApp[Any, Throwable]
type UHttp[-A, +B] = Http[Any, Nothing, A, B]
type HttpApp[-R, +E] = Http[R, E, Request, Response]
type UHttpApp = HttpApp[Any, Nothing]
type RHttpApp[-R] = HttpApp[R, Throwable]
type EHttpApp = HttpApp[Any, Throwable]
type UHttp[-A, +B] = Http[Any, Nothing, A, B]

type ResponseZIO[-R, +E] = ZIO[R, E, Response]
type UMiddleware[+AIn, -BIn, -AOut, +BOut] = Middleware[Any, Nothing, AIn, BIn, AOut, BOut]

Expand Down

0 comments on commit 0d0f2c0

Please sign in to comment.