diff --git a/zio-http-testkit/src/main/scala/zio/http/TestServer.scala b/zio-http-testkit/src/main/scala/zio/http/TestServer.scala index d551f693dd..ebac563114 100644 --- a/zio-http-testkit/src/main/scala/zio/http/TestServer.scala +++ b/zio-http-testkit/src/main/scala/zio/http/TestServer.scala @@ -111,6 +111,19 @@ final case class TestServer(driver: Driver, bindPort: Int) extends Server { ), ) + override def installScoped[R](routes: Routes[R with Scope, Response])(implicit + trace: zio.Trace, + tag: EnvironmentTag[R], + ): URIO[R, Unit] = + ZIO + .environment[R] + .flatMap( + driver.addAppScoped( + routes, + _, + ), + ) + override def port: UIO[Int] = ZIO.succeed(bindPort) } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/NettyDriver.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/NettyDriver.scala index 334b186711..5ecac76685 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/NettyDriver.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/NettyDriver.scala @@ -62,12 +62,38 @@ private[zio] final case class NettyDriver( } yield StartResult(port, serverInboundHandler.inFlightRequests) def addApp[R](newApp: Routes[R, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit] = + addAppImpl(asScoped = false, newApp, env) + + def addAppScoped[R](newApp: Routes[R with Scope, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit] = + addAppImpl(asScoped = true, newApp, env) + + override def createClientDriver()(implicit trace: Trace): ZIO[Scope, Throwable, ClientDriver] = + for { + channelFactory <- ChannelFactories.Client.live.build + .provideSomeEnvironment[Scope](_ ++ ZEnvironment[ChannelType.Config](nettyConfig)) + nettyRuntime <- NettyRuntime.live.build + } yield NettyClientDriver(channelFactory.get, eventLoopGroups.worker, nettyRuntime.get) + + override def toString: String = s"NettyDriver($serverConfig)" + + private def addAppImpl[E, R <: E](asScoped: Boolean, newApp: Routes[R, Response], env: ZEnvironment[E])(implicit + trace: Trace, + ): UIO[Unit] = ZIO.fiberId.map { fiberId => var loop = true while (loop) { val oldAppAndRt = appRef.get() val (oldApp, oldRt) = oldAppAndRt - val updatedApp = (oldApp ++ newApp).asInstanceOf[Routes[Any, Response]] + val updatedApp = oldApp.fold( + oldUnscoped => { + if (asScoped) { + Right((oldUnscoped ++ newApp).asInstanceOf[Routes[Scope, Response]]) + } else { + Left((oldUnscoped ++ newApp).asInstanceOf[Routes[Any, Response]]) + } + }, + oldScoped => Right((oldScoped ++ newApp).asInstanceOf[Routes[Scope, Response]]), + ) val updatedEnv = oldRt.environment.unionAll(env) // Update the fiberRefs with the new environment to avoid doing this every time we run / fork a fiber val updatedFibRefs = oldRt.fiberRefs.updatedAs(fiberId)(FiberRef.currentEnvironment, updatedEnv) @@ -78,15 +104,6 @@ private[zio] final case class NettyDriver( } serverInboundHandler.refreshApp() } - - override def createClientDriver()(implicit trace: Trace): ZIO[Scope, Throwable, ClientDriver] = - for { - channelFactory <- ChannelFactories.Client.live.build - .provideSomeEnvironment[Scope](_ ++ ZEnvironment[ChannelType.Config](nettyConfig)) - nettyRuntime <- NettyRuntime.live.build - } yield NettyClientDriver(channelFactory.get, eventLoopGroups.worker, nettyRuntime.get) - - override def toString: String = s"NettyDriver($serverConfig)" } object NettyDriver { diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 74340d825e..fcdd646b9f 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -48,8 +48,8 @@ private[zio] final case class ServerInboundHandler( implicit private val unsafe: Unsafe = Unsafe.unsafe - private var handler: Handler[Any, Nothing, Request, Response] = _ - private var runtime: NettyRuntime = _ + private var handle: Request => ZIO[Any, Nothing, Response] = _ + private var runtime: NettyRuntime = _ val inFlightRequests: LongAdder = new LongAdder() private val readClientCert = config.sslConfig.exists(_.includeClientCert) @@ -58,7 +58,15 @@ private[zio] final case class ServerInboundHandler( def refreshApp(): Unit = { val pair = appRef.get() - this.handler = pair._1.toHandler + this.handle = pair._1 match { + case Left(unscopedHandler) => + val handler = unscopedHandler.toHandler + handler.apply + case Right(scopedHandler) => + val handler = scopedHandler.toHandler + (req: Request) => ZIO.scoped(handler(req)) + } + this.runtime = new NettyRuntime(pair._2) } @@ -88,7 +96,7 @@ private[zio] final case class ServerInboundHandler( releaseRequest() } else { val req = makeZioRequest(ctx, jReq) - val exit = handler(req) + val exit = handle(req) if (attemptImmediateWrite(ctx, req.method, exit)) { releaseRequest() } else { diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/package.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/package.scala index 11da4dc374..ee7c995463 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/package.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/package.scala @@ -24,13 +24,14 @@ import java.util.concurrent.atomic.AtomicReference // scalafix:ok; import zio.stacktracer.TracingImplicits.disableAutoTrace package object server { - private[server] type RoutesRef = AtomicReference[(Routes[Any, Response], Runtime[Any])] + private[server] type RoutesRef = + AtomicReference[(Either[Routes[Any, Response], Routes[Scope, Response]], Runtime[Any])] private[server] object AppRef { val empty: UIO[RoutesRef] = { implicit val trace: Trace = Trace.empty // Environment will be populated when we `install` the app - ZIO.runtime[Any].map(rt => new AtomicReference((Routes.empty, rt.mapEnvironment(_ => ZEnvironment.empty)))) + ZIO.runtime[Any].map(rt => new AtomicReference((Left(Routes.empty), rt.mapEnvironment(_ => ZEnvironment.empty)))) } } diff --git a/zio-http/shared/src/main/scala/zio/http/Driver.scala b/zio-http/shared/src/main/scala/zio/http/Driver.scala index 0787c90658..222476ab1d 100644 --- a/zio-http/shared/src/main/scala/zio/http/Driver.scala +++ b/zio-http/shared/src/main/scala/zio/http/Driver.scala @@ -27,6 +27,7 @@ trait Driver { def start(implicit trace: Trace): RIO[Scope, StartResult] def addApp[R](newRoutes: Routes[R, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit] + def addAppScoped[R](newRoutes: Routes[R with Scope, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit] def createClientDriver()(implicit trace: Trace): ZIO[Scope, Throwable, ClientDriver] } diff --git a/zio-http/shared/src/main/scala/zio/http/Server.scala b/zio-http/shared/src/main/scala/zio/http/Server.scala index b1522afc29..b08e06fcb0 100644 --- a/zio-http/shared/src/main/scala/zio/http/Server.scala +++ b/zio-http/shared/src/main/scala/zio/http/Server.scala @@ -34,6 +34,14 @@ trait Server { */ def install[R](routes: Routes[R, Response])(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R, Unit] + /** + * Installs the given HTTP application into the server, providing a Scope for + * each request. + */ + def installScoped[R]( + routes: Routes[R with Scope, Response], + )(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R, Unit] + /** * The port on which the server is listening. * @@ -443,6 +451,15 @@ object Server extends ServerPlatformSpecific { ZIO.never } + def serveScoped[R]( + routes: Routes[R with Scope, Response], + )(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Nothing] = { + ZIO.logInfo("Starting the server...") *> + ZIO.serviceWithZIO[Server](_.installScoped[R](routes)) *> + ZIO.logInfo("Server started") *> + ZIO.never + } + def serve[R]( route: Route[R, Response], routes: Route[R, Response]*, @@ -450,12 +467,25 @@ object Server extends ServerPlatformSpecific { serve(Routes(route, routes: _*)) } + def serveScoped[R]( + route: Route[R with Scope, Response], + routes: Route[R with Scope, Response]*, + )(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Nothing] = { + serveScoped[R](Routes(route, routes: _*)) + } + def install[R]( routes: Routes[R, Response], )(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Int] = { ZIO.serviceWithZIO[Server](_.install[R](routes)) *> ZIO.serviceWithZIO[Server](_.port) } + def installScoped[R]( + routes: Routes[R with Scope, Response], + )(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Int] = { + ZIO.serviceWithZIO[Server](_.installScoped[R](routes)) *> ZIO.serviceWithZIO[Server](_.port) + } + private[http] val base: ZLayer[Driver & Config, Throwable, Server] = { implicit val trace: Trace = Trace.empty ZLayer.scoped { @@ -533,6 +563,16 @@ object Server extends ServerPlatformSpecific { _ <- ZIO.environment[R].flatMap(env => driver.addApp(routes, env.prune[R])) } yield () + override def installScoped[R](routes: Routes[R with Scope, Response])(implicit + trace: Trace, + tag: EnvironmentTag[R], + ): URIO[R, Unit] = + for { + _ <- initialInstall.succeed(()) + _ <- serverStarted.await.orDie + _ <- ZIO.environment[R].flatMap(env => driver.addAppScoped(routes, env.prune[R])) + } yield () + override def port: UIO[Int] = serverStarted.await.orDie }