diff --git a/.jvmopts b/.jvmopts index cdfceeb7..80e9449d 100644 --- a/.jvmopts +++ b/.jvmopts @@ -1,4 +1,5 @@ -Xmx4G +-Xss2m -XX:ReservedCodeCacheSize=256m -XX:MaxMetaspaceSize=3G diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpRequestContext.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpRequestContext.scala deleted file mode 100644 index d9c55e10..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpRequestContext.scala +++ /dev/null @@ -1,6 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s - -import org.http4s.AuthedRequest - -// we can't make it a case class, see https://github.com/scala/bug/issues/11239 -class HttpRequestContext[F[_], Ctx](val request: AuthedRequest[F, Ctx], val context: Ctx) diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala index 7705c577..6ef35614 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/HttpServer.scala @@ -1,6 +1,5 @@ package izumi.idealingua.runtime.rpc.http4s -import _root_.io.circe.parser.* import cats.data.OptionT import cats.effect.Async import cats.effect.std.Queue @@ -11,31 +10,29 @@ import io.circe.{Json, Printer} import izumi.functional.bio.Exit.{Error, Interruption, Success, Termination} import izumi.functional.bio.{Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2} import izumi.fundamentals.platform.language.Quirks +import izumi.fundamentals.platform.language.Quirks.Discarder import izumi.fundamentals.platform.time.IzTime import izumi.idealingua.runtime.rpc.* import izumi.idealingua.runtime.rpc.http4s.HttpServer.{ServerWsRpcHandler, WsResponseMarker} +import izumi.idealingua.runtime.rpc.http4s.context.{HttpContextExtractor, WsContextExtractor} import izumi.idealingua.runtime.rpc.http4s.ws.* -import izumi.idealingua.runtime.rpc.http4s.ws.WsClientSession.WsClientSessionImpl -import izumi.idealingua.runtime.rpc.http4s.ws.WsContextProvider.WsAuthResult import logstage.LogIO2 import org.http4s.* import org.http4s.dsl.Http4sDsl -import org.http4s.server.AuthMiddleware import org.http4s.server.websocket.WebSocketBuilder2 import org.http4s.websocket.WebSocketFrame +import org.typelevel.ci.CIString import org.typelevel.vault.Key import java.time.ZonedDateTime import java.util.concurrent.RejectedExecutionException import scala.concurrent.duration.DurationInt -class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, RequestCtx, MethodCtx, ClientId]( - val muxer: IRTServerMultiplexor[F, RequestCtx], - val codec: IRTClientMultiplexor[F], - val contextProvider: AuthMiddleware[F[Throwable, _], RequestCtx], - val wsContextProvider: WsContextProvider[F, RequestCtx, ClientId], - val wsSessionStorage: WsSessionsStorage[F, RequestCtx, ClientId], - val listeners: Seq[WsSessionListener[F, ClientId]], +class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, AuthCtx]( + val contextServices: Set[IRTContextServices.AnyContext[F, AuthCtx]], + val httpContextExtractor: HttpContextExtractor[AuthCtx], + val wsContextExtractor: WsContextExtractor[AuthCtx], + val wsSessionsStorage: WsSessionsStorage[F, AuthCtx], dsl: Http4sDsl[F[Throwable, _]], logger: LogIO2[F], printer: Printer, @@ -43,74 +40,25 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, RequestCtx, ) { import dsl.* + protected val serverMuxer: IRTServerMultiplexor[F, AuthCtx] = IRTServerMultiplexor.combine(contextServices.map(_.authorizedMuxer)) + protected val wsContextsSessions: Set[WsContextSessions.AnyContext[F, AuthCtx]] = contextServices.map(_.authorizedWsSessions) + // WS Response attribute key, to differ from usual HTTP responses private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type]) - protected def loggingMiddle(service: HttpRoutes[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = { - cats.data.Kleisli { - (req: Request[F[Throwable, _]]) => - OptionT.apply { - (for { - _ <- logger.trace(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: initiated") - resp <- service(req).value - _ <- F.traverse(resp) { - case Status.Successful(resp) => - logger.debug(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: success, ${resp.status.code -> "code"} ${resp.status.reason -> "reason"}") - case resp if resp.attributes.contains(wsAttributeKey) => - logger.debug(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: websocket request") - case resp => - logger.info(s"${req.method.name -> "method"} ${req.pathInfo -> "uri"}: rejection, ${resp.status.code -> "code"} ${resp.status.reason -> "reason"}") - } - } yield resp).tapError { - cause => - logger.error(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: failure, $cause") - } - } - } - } - def service(ws: WebSocketBuilder2[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = { - val svc = AuthedRoutes.of(router(ws)) - val aservice: HttpRoutes[F[Throwable, _]] = contextProvider(svc) - loggingMiddle(aservice) - } - - protected def router(ws: WebSocketBuilder2[F[Throwable, _]]): PartialFunction[AuthedRequest[F[Throwable, _], RequestCtx], F[Throwable, Response[F[Throwable, _]]]] = { - case request @ GET -> Root / "ws" as ctx => - setupWs(request, ctx, ws) - - case request @ GET -> Root / service / method as ctx => - val methodId = IRTMethodId(IRTServiceId(service), IRTMethodName(method)) - processHttpRequest(new HttpRequestContext(request, ctx), body = "{}", methodId) - - case request @ POST -> Root / service / method as ctx => - val methodId = IRTMethodId(IRTServiceId(service), IRTMethodName(method)) - request.req.decode[String] { - body => - processHttpRequest(new HttpRequestContext(request, ctx), body, methodId) - } + val svc = HttpRoutes.of(router(ws)) + loggingMiddle(svc) } - protected def handleWsClose(context: WsClientSession[F, RequestCtx, ClientId]): F[Throwable, Unit] = { - logger.debug(s"WS Session: Websocket client disconnected ${context.id}.") *> - context.finish() - } - - protected def globalWsListener: WsSessionListener[F, ClientId] = new WsSessionListener[F, ClientId] { - def onSessionOpened(context: WsClientId[ClientId]): F[Throwable, Unit] = { - logger.debug(s"WS Session: opened ${context.id}.") - } - def onClientIdUpdate(context: WsClientId[ClientId], old: WsClientId[ClientId]): F[Throwable, Unit] = { - logger.debug(s"WS Session: Id updated to ${context.id}, was: ${old.id}.") - } - def onSessionClosed(context: WsClientId[ClientId]): F[Throwable, Unit] = { - logger.debug(s"WS Session: closed ${context.id}.") - } + protected def router(ws: WebSocketBuilder2[F[Throwable, _]]): PartialFunction[Request[F[Throwable, _]], F[Throwable, Response[F[Throwable, _]]]] = { + case request @ GET -> Root / "ws" => setupWs(request, ws) + case request @ GET -> Root / service / method => processHttpRequest(request, service, method)("{}") + case request @ POST -> Root / service / method => request.decode[String](processHttpRequest(request, service, method)) } protected def setupWs( - request: AuthedRequest[F[Throwable, _], RequestCtx], - initialContext: RequestCtx, + request: Request[F[Throwable, _]], ws: WebSocketBuilder2[F[Throwable, _]], ): F[Throwable, Response[F[Throwable, _]]] = { Quirks.discard(request) @@ -120,33 +68,39 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, RequestCtx, .evalMap(_ => logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping())) } for { - outQueue <- Queue.unbounded[F[Throwable, _], WebSocketFrame] - listenersWithGlobal = Seq(globalWsListener) ++ listeners - context = new WsClientSessionImpl(outQueue, initialContext, listenersWithGlobal, wsSessionStorage, printer, logger) - _ <- context.start() - outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream) + outQueue <- Queue.unbounded[F[Throwable, _], WebSocketFrame] + authContext <- F.syncThrowable(httpContextExtractor.extract(request)) + clientSession = new WsClientSession.Queued(outQueue, authContext, wsContextsSessions, wsSessionsStorage, wsContextExtractor, logger, printer) + _ <- clientSession.start(onWsConnected) + + outStream = Stream.fromQueueUnterminated(outQueue).merge(pingStream) inStream = { (inputStream: Stream[F[Throwable, _], WebSocketFrame]) => inputStream.evalMap { - processWsRequest(context, IzTime.utcNow)(_).flatMap { + processWsRequest(clientSession, IzTime.utcNow)(_).flatMap { case Some(v) => outQueue.offer(WebSocketFrame.Text(v)) case None => F.unit } } } - response <- ws.withOnClose(handleWsClose(context)).build(outStream, inStream) + wsSessionIdHeader = Header.Raw(HttpServer.`X-Ws-Session-Id`, clientSession.sessionId.sessionId.toString) + + response <- ws + .withOnClose(handleWsClose(clientSession)) + .withHeaders(Headers(wsSessionIdHeader)) + .build(outStream, inStream) } yield { response.withAttribute(wsAttributeKey, WsResponseMarker) } } protected def processWsRequest( - context: WsClientSession[F, RequestCtx, ClientId], + clientSession: WsClientSession[F, AuthCtx], requestTime: ZonedDateTime, )(frame: WebSocketFrame ): F[Throwable, Option[String]] = { (frame match { - case WebSocketFrame.Text(msg, _) => wsHandler(context).processRpcMessage(msg) + case WebSocketFrame.Text(msg, _) => wsHandler(clientSession).processRpcMessage(msg) case WebSocketFrame.Close(_) => F.pure(None) case _: WebSocketFrame.Pong => onWsHeartbeat(requestTime).as(None) case unknownMessage => @@ -157,101 +111,129 @@ class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, RequestCtx, }).map(_.map(p => printer.print(p.asJson))) } - protected def wsHandler(context: WsClientSession[F, RequestCtx, ClientId]): WsRpcHandler[F, RequestCtx] = { - new ServerWsRpcHandler(muxer, wsContextProvider, context, logger) + protected def wsHandler(clientSession: WsClientSession[F, AuthCtx]): WsRpcHandler[F, AuthCtx] = { + new ServerWsRpcHandler(clientSession, serverMuxer, wsContextExtractor, logger) + } + + protected def handleWsClose(session: WsClientSession[F, AuthCtx]): F[Throwable, Unit] = { + logger.debug(s"WS Session: Websocket client disconnected ${session.sessionId}.") *> + session.finish(onWsDisconnected) + } + + protected def onWsConnected(authContext: AuthCtx): F[Throwable, Unit] = { + authContext.discard() + F.unit } protected def onWsHeartbeat(requestTime: ZonedDateTime): F[Throwable, Unit] = { logger.debug(s"WS Session: pong frame at $requestTime") } + protected def onWsDisconnected(authContext: AuthCtx): F[Throwable, Unit] = { + authContext.discard() + F.unit + } + protected def processHttpRequest( - context: HttpRequestContext[F[Throwable, _], RequestCtx], - body: String, - method: IRTMethodId, + request: Request[F[Throwable, _]], + serviceName: String, + methodName: String, + )(body: String ): F[Throwable, Response[F[Throwable, _]]] = { - val ioR = for { - parsed <- F.fromEither(parse(body)) - maybeResult <- muxer.doInvoke(parsed, context.context, method) - } yield { - maybeResult - } - - ioR.sandboxExit.flatMap(handleHttpResult(context, method, _)) + val methodId = IRTMethodId(IRTServiceId(serviceName), IRTMethodName(methodName)) + (for { + authContext <- F.syncThrowable(httpContextExtractor.extract(request)) + parsedBody <- F.fromEither(io.circe.parser.parse(body)).leftMap(err => new IRTDecodingException(s"Can not parse JSON body '$body'.", Some(err))) + invokeRes <- serverMuxer.invokeMethod(methodId)(authContext, parsedBody) + } yield invokeRes).sandboxExit.flatMap(handleHttpResult(request, methodId)) } - private def handleHttpResult( - context: HttpRequestContext[F[Throwable, _], RequestCtx], + protected def handleHttpResult( + request: Request[F[Throwable, _]], method: IRTMethodId, - result: Exit[Throwable, Option[Json]], + )(result: Exit[Throwable, Json] ): F[Throwable, Response[F[Throwable, _]]] = { result match { - case Success(Some(value)) => - dsl.Ok(printer.print(value)) + case Success(res) => + Ok(printer.print(res)) - case Success(None) => - logger.warn(s"${context -> null}: No service handler for $method") *> - dsl.NotFound() + case Error(err: IRTMissingHandlerException, _) => + logger.warn(s"HTTP Request execution failed - no method handler for $method: $err") *> + NotFound() case Error(error: circe.Error, trace) => - logger.info(s"${context -> null}: Parsing failure while handling $method: $error $trace") *> - dsl.BadRequest() + logger.warn(s"HTTP Request execution failed - parsing failure while handling $method:\n${error.getMessage -> "error"}\n$trace") *> + BadRequest() case Error(error: IRTDecodingException, trace) => - logger.info(s"${context -> null}: Parsing failure while handling $method: $error $trace") *> - dsl.BadRequest() + logger.warn(s"HTTP Request execution failed - parsing failure while handling $method:\n$error\n$trace") *> + BadRequest() case Error(error: IRTLimitReachedException, trace) => - logger.debug(s"${context -> null}: Request failed because of request limit reached $method: $error $trace") *> - dsl.TooManyRequests() + logger.debug(s"HTTP Request failed - request limit reached $method:\n$error\n$trace") *> + TooManyRequests() case Error(error: IRTUnathorizedRequestContextException, trace) => - // workarount because implicits conflict - logger.debug(s"${context -> null}: Request failed because of unexpected request context reached $method: $error $trace") *> - dsl.Forbidden().map(_.copy(status = dsl.Unauthorized)) + logger.debug(s"HTTP Request failed - unauthorized $method call:\n$error\n$trace") *> + F.pure(Response(status = Status.Unauthorized)) case Error(error, trace) => - logger.info(s"${context -> null}: Unexpected failure while handling $method: $error $trace") *> - dsl.InternalServerError() + logger.warn(s"HTTP Request unexpectedly failed while handling $method:\n$error\n$trace") *> + InternalServerError() case Termination(_, (cause: IRTHttpFailureException) :: _, trace) => - logger.debug(s"${context -> null}: Request rejected, $method, ${context.request}, $cause, $trace") *> + logger.error(s"HTTP Request rejected - $method, $request:\n$cause\n$trace") *> F.pure(Response(status = cause.status)) case Termination(_, (cause: RejectedExecutionException) :: _, trace) => - logger.warn(s"${context -> null}: Not enough capacity to handle $method: $cause $trace") *> - dsl.TooManyRequests() + logger.warn(s"HTTP Request rejected - Not enough capacity to handle $method:\n$cause\n$trace") *> + TooManyRequests() case Termination(cause, _, trace) => - logger.error(s"${context -> null}: Execution failed, termination, $method, ${context.request}, $cause, $trace") *> - dsl.InternalServerError() + logger.error(s"HTTP Request execution failed, termination, $method, $request:\n$cause\n$trace") *> + InternalServerError() case Interruption(cause, _, trace) => - logger.info(s"${context -> null}: Unexpected interruption while handling $method: $cause $trace") *> - dsl.InternalServerError() + logger.error(s"HTTP Request unexpectedly interrupted while handling $method:\n$cause\n$trace") *> + InternalServerError() } } + protected def loggingMiddle(service: HttpRoutes[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = { + cats.data.Kleisli { + (req: Request[F[Throwable, _]]) => + OptionT.apply { + (for { + _ <- logger.trace(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: initiated") + resp <- service(req).value + _ <- F.traverse(resp) { + case Status.Successful(resp) => + logger.debug(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: success, ${resp.status.code -> "code"} ${resp.status.reason -> "reason"}") + case resp if resp.attributes.contains(wsAttributeKey) => + logger.debug(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: websocket request") + case resp => + logger.info(s"${req.method.name -> "method"} ${req.pathInfo -> "uri"}: rejection, ${resp.status.code -> "code"} ${resp.status.reason -> "reason"}") + } + } yield resp).tapError { + cause => + logger.error(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: failure, $cause") + } + } + } + } } object HttpServer { + val `X-Ws-Session-Id`: CIString = CIString("X-Ws-Session-Id") case object WsResponseMarker - class ServerWsRpcHandler[F[+_, +_]: IO2, RequestCtx, ClientId]( - muxer: IRTServerMultiplexor[F, RequestCtx], - wsContextProvider: WsContextProvider[F, RequestCtx, ClientId], - context: WsClientSession[F, RequestCtx, ClientId], + class ServerWsRpcHandler[F[+_, +_]: IO2, AuthCtx]( + clientSession: WsClientSession[F, AuthCtx], + muxer: IRTServerMultiplexor[F, AuthCtx], + wsContextExtractor: WsContextExtractor[AuthCtx], logger: LogIO2[F], - ) extends WsRpcHandler[F, RequestCtx](muxer, context, logger) { - override def handlePacket(packet: RpcPacket): F[Throwable, Unit] = { - wsContextProvider.toId(context.initialContext, context.id, packet).flatMap(context.updateId) - } - override def handleAuthRequest(packet: RpcPacket): F[Throwable, Option[RpcPacket]] = { - wsContextProvider.handleAuthorizationPacket(context.id, context.initialContext, packet).flatMap { - case WsAuthResult(id, packet) => context.updateId(id).as(Some(packet)) - } - } - override def extractContext(packet: RpcPacket): F[Throwable, RequestCtx] = { - wsContextProvider.toContext(context.id, context.initialContext, packet) + ) extends WsRpcHandler[F, AuthCtx](muxer, clientSession, logger) { + override protected def updateRequestCtx(packet: RpcPacket): F[Throwable, AuthCtx] = { + clientSession.updateRequestCtx(wsContextExtractor.extract(clientSession.sessionId, packet)) } } } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTAuthenticator.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTAuthenticator.scala new file mode 100644 index 00000000..4372ec46 --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTAuthenticator.scala @@ -0,0 +1,19 @@ +package izumi.idealingua.runtime.rpc.http4s + +import io.circe.Json +import izumi.functional.bio.{Applicative2, F} +import izumi.idealingua.runtime.rpc.IRTMethodId +import org.http4s.Headers + +import java.net.InetAddress + +abstract class IRTAuthenticator[F[_, _], AuthCtx, RequestCtx] { + def authenticate(authContext: AuthCtx, body: Option[Json], methodId: Option[IRTMethodId]): F[Nothing, Option[RequestCtx]] +} + +object IRTAuthenticator { + def unit[F[+_, +_]: Applicative2, C]: IRTAuthenticator[F, C, Unit] = new IRTAuthenticator[F, C, Unit] { + override def authenticate(authContext: C, body: Option[Json], methodId: Option[IRTMethodId]): F[Nothing, Option[Unit]] = F.pure(Some(())) + } + final case class AuthContext(headers: Headers, networkAddress: Option[InetAddress]) +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTContextServices.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTContextServices.scala new file mode 100644 index 00000000..91c0e7fa --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTContextServices.scala @@ -0,0 +1,68 @@ +package izumi.idealingua.runtime.rpc.http4s + +import izumi.functional.bio.{IO2, Monad2} +import izumi.idealingua.runtime.rpc.http4s.ws.WsContextSessions +import izumi.idealingua.runtime.rpc.{IRTServerMiddleware, IRTServerMultiplexor} +import izumi.reflect.Tag + +trait IRTContextServices[F[+_, +_], AuthCtx, RequestCtx, WsCtx] { + def name: String + def authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx] + def serverMuxer: IRTServerMultiplexor[F, RequestCtx] + def middlewares: Set[IRTServerMiddleware[F, RequestCtx]] + def wsSessions: WsContextSessions[F, RequestCtx, WsCtx] + + def authorizedMuxer(implicit io2: IO2[F]): IRTServerMultiplexor[F, AuthCtx] = { + val withMiddlewares: IRTServerMultiplexor[F, RequestCtx] = middlewares.toList.sortBy(_.priority).foldLeft(serverMuxer) { + case (muxer, middleware) => muxer.wrap(middleware) + } + val authorized: IRTServerMultiplexor[F, AuthCtx] = withMiddlewares.contramap { + case (authCtx, body, methodId) => authenticator.authenticate(authCtx, Some(body), Some(methodId)) + } + authorized + } + def authorizedWsSessions(implicit M: Monad2[F]): WsContextSessions[F, AuthCtx, WsCtx] = { + val authorized: WsContextSessions[F, AuthCtx, WsCtx] = wsSessions.contramap { + authCtx => + authenticator.authenticate(authCtx, None, None) + } + authorized + } +} + +object IRTContextServices { + type AnyContext[F[+_, +_], AuthCtx] = IRTContextServices[F, AuthCtx, ?, ?] + type AnyWsContext[F[+_, +_], AuthCtx, RequestCtx] = IRTContextServices[F, AuthCtx, RequestCtx, ?] + + def tagged[F[+_, +_], AuthCtx, RequestCtx: Tag, WsCtx: Tag]( + authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx], + serverMuxer: IRTServerMultiplexor[F, RequestCtx], + middlewares: Set[IRTServerMiddleware[F, RequestCtx]], + wsSessions: WsContextSessions[F, RequestCtx, WsCtx], + ): Tagged[F, AuthCtx, RequestCtx, WsCtx] = Tagged(authenticator, serverMuxer, middlewares, wsSessions) + + def named[F[+_, +_], AuthCtx, RequestCtx, WsCtx]( + name: String + )(authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx], + serverMuxer: IRTServerMultiplexor[F, RequestCtx], + middlewares: Set[IRTServerMiddleware[F, RequestCtx]], + wsSessions: WsContextSessions[F, RequestCtx, WsCtx], + ): Named[F, AuthCtx, RequestCtx, WsCtx] = Named(name, authenticator, serverMuxer, middlewares, wsSessions) + + final case class Named[F[+_, +_], AuthCtx, RequestCtx, WsCtx]( + name: String, + authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx], + serverMuxer: IRTServerMultiplexor[F, RequestCtx], + middlewares: Set[IRTServerMiddleware[F, RequestCtx]], + wsSessions: WsContextSessions[F, RequestCtx, WsCtx], + ) extends IRTContextServices[F, AuthCtx, RequestCtx, WsCtx] + + final case class Tagged[F[+_, +_], AuthCtx, RequestCtx: Tag, WsCtx: Tag]( + authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx], + serverMuxer: IRTServerMultiplexor[F, RequestCtx], + middlewares: Set[IRTServerMiddleware[F, RequestCtx]], + wsSessions: WsContextSessions[F, RequestCtx, WsCtx], + ) extends IRTContextServices[F, AuthCtx, RequestCtx, WsCtx] { + override def name: String = s"${Tag[RequestCtx].tag}:${Tag[WsCtx].tag}" + } +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTHttpFailureException.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTHttpFailureException.scala index f650abaa..b8cfd158 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTHttpFailureException.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/IRTHttpFailureException.scala @@ -7,8 +7,7 @@ abstract class IRTHttpFailureException( message: String, val status: Status, cause: Option[Throwable] = None, -) extends RuntimeException(message, cause.orNull) - with IRTTransportException +) extends IRTTransportException(message, cause) case class IRTUnexpectedHttpStatus(override val status: Status) extends IRTHttpFailureException(s"Unexpected http status: $status", status) case class IRTNoCredentialsException(override val status: Status) extends IRTHttpFailureException("No valid credentials", status) diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcher.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcher.scala index 56462146..223b0ccd 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcher.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcher.scala @@ -5,7 +5,7 @@ import izumi.functional.bio.{Async2, F, Temporal2} import izumi.idealingua.runtime.rpc.* import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcher.IRTDispatcherWs import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.WsRpcClientConnection -import izumi.idealingua.runtime.rpc.http4s.ws.RawResponse +import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsSessionId} import logstage.LogIO2 import java.util.concurrent.TimeoutException @@ -18,6 +18,10 @@ class WsRpcDispatcher[F[+_, +_]: Async2]( logger: LogIO2[F], ) extends IRTDispatcherWs[F] { + override def sessionId: Option[WsSessionId] = { + connection.sessionId + } + override def authorize(headers: Map[String, String]): F[Throwable, Unit] = { connection.authorize(headers, timeout) } @@ -62,6 +66,7 @@ class WsRpcDispatcher[F[+_, +_]: Async2]( object WsRpcDispatcher { trait IRTDispatcherWs[F[_, _]] extends IRTDispatcher[F] { + def sessionId: Option[WsSessionId] def authorize(headers: Map[String, String]): F[Throwable, Unit] } } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala index 170a3350..4111d1fd 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala @@ -5,10 +5,13 @@ import io.circe.{Json, Printer} import izumi.functional.bio.{Async2, Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2} import izumi.functional.lifecycle.Lifecycle import izumi.fundamentals.platform.language.Quirks.Discarder +import izumi.fundamentals.platform.uuid.UUIDGen import izumi.idealingua.runtime.rpc.* +import izumi.idealingua.runtime.rpc.http4s.HttpServer import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcher.IRTDispatcherWs -import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.{ClientWsRpcHandler, WsRpcClientConnection, WsRpcContextProvider, fromNettyFuture} -import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsRequestState, WsRpcHandler} +import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.{ClientWsRpcHandler, WsRpcClientConnection, fromNettyFuture} +import izumi.idealingua.runtime.rpc.http4s.context.WsContextExtractor +import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsRequestState, WsRpcHandler, WsSessionId} import izumi.logstage.api.IzLogger import logstage.LogIO2 import org.asynchttpclient.netty.ws.NettyWebSocket @@ -16,9 +19,11 @@ import org.asynchttpclient.ws.{WebSocket, WebSocketListener, WebSocketUpgradeHan import org.asynchttpclient.{DefaultAsyncHttpClient, DefaultAsyncHttpClientConfig} import org.http4s.Uri +import java.util.UUID import java.util.concurrent.atomic.AtomicReference import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.jdk.CollectionConverters.* +import scala.util.Try class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRun2]( codec: IRTClientMultiplexor[F], @@ -29,32 +34,49 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu def connect[ServerContext]( uri: Uri, - muxer: IRTServerMultiplexor[F, ServerContext], - contextProvider: WsRpcContextProvider[ServerContext], + serverMuxer: IRTServerMultiplexor[F, ServerContext], + wsContextExtractor: WsContextExtractor[ServerContext], + headers: Map[String, String] = Map.empty, ): Lifecycle[F[Throwable, _], WsRpcClientConnection[F]] = { for { - client <- WsRpcDispatcherFactory.asyncHttpClient[F] - requestState <- Lifecycle.liftF(F.syncThrowable(WsRequestState.create[F])) - listener <- Lifecycle.liftF(F.syncThrowable(createListener(muxer, contextProvider, requestState, dispatcherLogger(uri, logger)))) - handler <- Lifecycle.liftF(F.syncThrowable(new WebSocketUpgradeHandler(List(listener).asJava))) + client <- createAsyncHttpClient() + wsRequestState <- Lifecycle.liftF(F.syncThrowable(WsRequestState.create[F])) + listener <- Lifecycle.liftF(F.syncThrowable(createListener(serverMuxer, wsRequestState, wsContextExtractor, dispatcherLogger(uri, logger)))) + handler <- Lifecycle.liftF(F.syncThrowable(new WebSocketUpgradeHandler(List(listener).asJava))) nettyWebSocket <- Lifecycle.make( - F.fromFutureJava(client.prepareGet(uri.toString()).execute(handler).toCompletableFuture) + F.fromFutureJava { + client + .prepareGet(uri.toString()) + .setSingleHeaders(headers.asJava) + .execute(handler).toCompletableFuture + } )(nettyWebSocket => fromNettyFuture(nettyWebSocket.sendCloseFrame()).void) + sessionId = Option(nettyWebSocket.getUpgradeHeaders.get(HttpServer.`X-Ws-Session-Id`.toString)) + .flatMap(str => Try(WsSessionId(UUID.fromString(str))).toOption) // fill promises before closing WS connection, potentially giving a chance to send out an error response before closing - _ <- Lifecycle.make(F.unit)(_ => requestState.clear()) + _ <- Lifecycle.make(F.unit)(_ => wsRequestState.clear()) } yield { - new WsRpcClientConnection.Netty(nettyWebSocket, requestState, printer) + new WsRpcClientConnection.Netty(nettyWebSocket, wsRequestState, printer, sessionId) } } + def connectSimple( + uri: Uri, + serverMuxer: IRTServerMultiplexor[F, Unit], + headers: Map[String, String] = Map.empty, + ): Lifecycle[F[Throwable, _], WsRpcClientConnection[F]] = { + connect(uri, serverMuxer, WsContextExtractor.unit, headers) + } + def dispatcher[ServerContext]( uri: Uri, - muxer: IRTServerMultiplexor[F, ServerContext], - contextProvider: WsRpcContextProvider[ServerContext], + serverMuxer: IRTServerMultiplexor[F, ServerContext], + wsContextExtractor: WsContextExtractor[ServerContext], + headers: Map[String, String] = Map.empty, tweakRequest: RpcPacket => RpcPacket = identity, timeout: FiniteDuration = 30.seconds, ): Lifecycle[F[Throwable, _], IRTDispatcherWs[F]] = { - connect(uri, muxer, contextProvider).map { + connect(uri, serverMuxer, wsContextExtractor, headers).map { new WsRpcDispatcher(_, timeout, codec, dispatcherLogger(uri, logger)) { override protected def buildRequest(rpcPacketId: RpcPacketId, method: IRTMethodId, body: Json): RpcPacket = { tweakRequest(super.buildRequest(rpcPacketId, method, body)) @@ -63,22 +85,32 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu } } + def dispatcherSimple( + uri: Uri, + serverMuxer: IRTServerMultiplexor[F, Unit], + headers: Map[String, String] = Map.empty, + tweakRequest: RpcPacket => RpcPacket = identity, + timeout: FiniteDuration = 30.seconds, + ): Lifecycle[F[Throwable, _], IRTDispatcherWs[F]] = { + dispatcher(uri, serverMuxer, WsContextExtractor.unit, headers, tweakRequest, timeout) + } + protected def wsHandler[ServerContext]( + serverMuxer: IRTServerMultiplexor[F, ServerContext], + wsRequestState: WsRequestState[F], + wsContextExtractor: WsContextExtractor[ServerContext], logger: LogIO2[F], - muxer: IRTServerMultiplexor[F, ServerContext], - contextProvider: WsRpcContextProvider[ServerContext], - requestState: WsRequestState[F], ): WsRpcHandler[F, ServerContext] = { - new ClientWsRpcHandler(muxer, requestState, contextProvider, logger) + new ClientWsRpcHandler(serverMuxer, wsRequestState, wsContextExtractor, logger) } protected def createListener[ServerContext]( - muxer: IRTServerMultiplexor[F, ServerContext], - contextProvider: WsRpcContextProvider[ServerContext], - requestState: WsRequestState[F], + serverMuxer: IRTServerMultiplexor[F, ServerContext], + wsRequestState: WsRequestState[F], + wsContextExtractor: WsContextExtractor[ServerContext], logger: LogIO2[F], ): WebSocketListener = new WebSocketListener() { - private val handler = wsHandler(logger, muxer, contextProvider, requestState) + private val handler = wsHandler(serverMuxer, wsRequestState, wsContextExtractor, logger) private val socketRef = new AtomicReference[Option[WebSocket]](None) override def onOpen(websocket: WebSocket): Unit = { @@ -102,7 +134,7 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int): Unit = { UnsafeRun2[F].unsafeRunAsync(handler.processRpcMessage(payload)) { exit => - val maybeResponse = exit match { + val maybeResponse: Option[RpcPacket] = exit match { case Exit.Success(response) => response case Exit.Error(error, _) => handleWsError(List(error), "errored") case Exit.Termination(error, _, _) => handleWsError(List(error), "terminated") @@ -133,10 +165,8 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu Some(RpcPacket.rpcCritical(message, None)) } } -} -object WsRpcDispatcherFactory { - def asyncHttpClient[F[+_, +_]: IO2]: Lifecycle[F[Throwable, _], DefaultAsyncHttpClient] = { + protected def createAsyncHttpClient(): Lifecycle[F[Throwable, _], DefaultAsyncHttpClient] = { Lifecycle.fromAutoCloseable(F.syncThrowable { new DefaultAsyncHttpClient( new DefaultAsyncHttpClientConfig.Builder() @@ -153,26 +183,30 @@ object WsRpcDispatcherFactory { ) }) } +} + +object WsRpcDispatcherFactory { - class ClientWsRpcHandler[F[+_, +_]: IO2, ServerCtx]( - muxer: IRTServerMultiplexor[F, ServerCtx], + class ClientWsRpcHandler[F[+_, +_]: IO2, RequestCtx]( + muxer: IRTServerMultiplexor[F, RequestCtx], requestState: WsRequestState[F], - contextProvider: WsRpcContextProvider[ServerCtx], + wsContextExtractor: WsContextExtractor[RequestCtx], logger: LogIO2[F], - ) extends WsRpcHandler[F, ServerCtx](muxer, requestState, logger) { - override def handlePacket(packet: RpcPacket): F[Throwable, Unit] = { - F.unit - } - override def handleAuthRequest(packet: RpcPacket): F[Throwable, Option[RpcPacket]] = { - F.pure(None) - } - override def extractContext(packet: RpcPacket): F[Throwable, ServerCtx] = { - F.sync(contextProvider.toContext(packet)) + ) extends WsRpcHandler[F, RequestCtx](muxer, requestState, logger) { + private val wsSessionId: WsSessionId = WsSessionId(UUIDGen.getTimeUUID()) + private val requestCtxRef: AtomicReference[RequestCtx] = new AtomicReference() + override protected def updateRequestCtx(packet: RpcPacket): F[Throwable, RequestCtx] = F.sync { + val updated = wsContextExtractor.extract(wsSessionId, packet) + requestCtxRef.updateAndGet { + case null => updated + case previous => wsContextExtractor.merge(previous, updated) + } } } trait WsRpcClientConnection[F[_, _]] { private[clients] def requestAndAwait(id: RpcPacketId, packet: RpcPacket, method: Option[IRTMethodId], timeout: FiniteDuration): F[Throwable, Option[RawResponse]] + def sessionId: Option[WsSessionId] def authorize(headers: Map[String, String], timeout: FiniteDuration = 30.seconds): F[Throwable, Unit] } object WsRpcClientConnection { @@ -180,6 +214,7 @@ object WsRpcDispatcherFactory { nettyWebSocket: NettyWebSocket, requestState: WsRequestState[F], printer: Printer, + val sessionId: Option[WsSessionId], ) extends WsRpcClientConnection[F] { override def authorize(headers: Map[String, String], timeout: FiniteDuration): F[Throwable, Unit] = { @@ -205,13 +240,6 @@ object WsRpcDispatcherFactory { } } - trait WsRpcContextProvider[Ctx] { - def toContext(packet: RpcPacket): Ctx - } - object WsRpcContextProvider { - def unit: WsRpcContextProvider[Unit] = _ => () - } - private def fromNettyFuture[F[+_, +_]: Async2, A](mkNettyFuture: => io.netty.util.concurrent.Future[A]): F[Throwable, A] = { F.syncThrowable(mkNettyFuture).flatMap { nettyFuture => diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/HttpContextExtractor.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/HttpContextExtractor.scala new file mode 100644 index 00000000..bf3373f8 --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/HttpContextExtractor.scala @@ -0,0 +1,22 @@ +package izumi.idealingua.runtime.rpc.http4s.context + +import izumi.idealingua.runtime.rpc.http4s.IRTAuthenticator.AuthContext +import org.http4s.Request +import org.http4s.headers.`X-Forwarded-For` + +trait HttpContextExtractor[RequestCtx] { + def extract[F[_, _]](request: Request[F[Throwable, _]]): RequestCtx +} + +object HttpContextExtractor { + def authContext: HttpContextExtractor[AuthContext] = new HttpContextExtractor[AuthContext] { + override def extract[F[_, _]](request: Request[F[Throwable, _]]): AuthContext = { + val networkAddress = request.headers + .get[`X-Forwarded-For`] + .flatMap(_.values.head.map(_.toInetAddress)) + .orElse(request.remote.map(_.host.toInetAddress)) + val headers = request.headers + AuthContext(headers, networkAddress) + } + } +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/WsContextExtractor.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/WsContextExtractor.scala new file mode 100644 index 00000000..5e7f233d --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/WsContextExtractor.scala @@ -0,0 +1,30 @@ +package izumi.idealingua.runtime.rpc.http4s.context + +import izumi.idealingua.runtime.rpc.RpcPacket +import izumi.idealingua.runtime.rpc.http4s.IRTAuthenticator.AuthContext +import izumi.idealingua.runtime.rpc.http4s.ws.WsSessionId +import org.http4s.{Header, Headers} +import org.typelevel.ci.CIString + +trait WsContextExtractor[RequestCtx] { + def extract(wsSessionId: WsSessionId, packet: RpcPacket): RequestCtx + def merge(previous: RequestCtx, updated: RequestCtx): RequestCtx +} + +object WsContextExtractor { + def unit: WsContextExtractor[Unit] = new WsContextExtractor[Unit] { + override def extract(wsSessionId: WsSessionId, packet: RpcPacket): Unit = () + override def merge(previous: Unit, updated: Unit): Unit = () + } + def authContext: WsContextExtractor[AuthContext] = new WsContextExtractor[AuthContext] { + override def extract(wsSessionId: WsSessionId, packet: RpcPacket): AuthContext = { + val headersMap = packet.headers.getOrElse(Map.empty) + val headers = Headers.apply(headersMap.toSeq.map { case (k, v) => Header.Raw(CIString(k), v) }) + AuthContext(headers, None) + } + + override def merge(previous: AuthContext, updated: AuthContext): AuthContext = { + AuthContext(previous.headers ++ updated.headers, updated.networkAddress.orElse(previous.networkAddress)) + } + } +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/WsIdExtractor.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/WsIdExtractor.scala new file mode 100644 index 00000000..364d1feb --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/context/WsIdExtractor.scala @@ -0,0 +1,11 @@ +package izumi.idealingua.runtime.rpc.http4s.context + +trait WsIdExtractor[RequestCtx, WsCtx] { + def extract(ctx: RequestCtx, previous: Option[WsCtx]): Option[WsCtx] +} + +object WsIdExtractor { + def id[C]: WsIdExtractor[C, C] = (c, _) => Some(c) + def widen[C, C0 >: C]: WsIdExtractor[C, C0] = (c, _) => Some(c) + def unit[C]: WsIdExtractor[C, Unit] = (_, _) => Some(()) +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientId.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientId.scala deleted file mode 100644 index a4b1582a..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientId.scala +++ /dev/null @@ -1,9 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.ws - -import java.util.UUID - -case class WsSessionId(sessionId: UUID) extends AnyVal - -case class WsClientId[ClientId](sessionId: WsSessionId, id: Option[ClientId]) { - override def toString: String = s"${id.getOrElse("?")} / ${sessionId.sessionId}" -} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala index 636c41ef..addb1b33 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala @@ -1,13 +1,15 @@ package izumi.idealingua.runtime.rpc.http4s.ws import cats.effect.std.Queue -import io.circe.syntax.* +import io.circe.syntax.EncoderOps import io.circe.{Json, Printer} -import izumi.functional.bio.{F, IO2, Primitives2, Temporal2} +import izumi.functional.bio.{Applicative2, F, IO2, Primitives2, Temporal2} import izumi.fundamentals.platform.time.IzTime import izumi.fundamentals.platform.uuid.UUIDGen import izumi.idealingua.runtime.rpc.* -import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsClientResponder +import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.ClientWsRpcHandler +import izumi.idealingua.runtime.rpc.http4s.context.WsContextExtractor +import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsResponder import logstage.LogIO2 import org.http4s.websocket.WebSocketFrame import org.http4s.websocket.WebSocketFrame.Text @@ -17,43 +19,71 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference import scala.concurrent.duration.* -trait WsClientSession[F[+_, +_], RequestCtx, ClientId] extends WsClientResponder[F] { - def id: WsClientId[ClientId] - def initialContext: RequestCtx - - def updateId(maybeNewId: Option[ClientId]): F[Throwable, Unit] - def outQueue: Queue[F[Throwable, _], WebSocketFrame] +trait WsClientSession[F[+_, +_], SessionCtx] extends WsResponder[F] { + def sessionId: WsSessionId def requestAndAwaitResponse(method: IRTMethodId, data: Json, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] - def finish(): F[Throwable, Unit] + def updateRequestCtx(newContext: SessionCtx): F[Throwable, SessionCtx] + + def start(onStart: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit] + def finish(onFinish: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit] } object WsClientSession { - class WsClientSessionImpl[F[+_, +_]: IO2: Temporal2: Primitives2, RequestCtx, ClientId]( - val outQueue: Queue[F[Throwable, _], WebSocketFrame], - val initialContext: RequestCtx, - listeners: Seq[WsSessionListener[F, ClientId]], - wsSessionStorage: WsSessionsStorage[F, RequestCtx, ClientId], - printer: Printer, + + def empty[F[+_, +_]: Applicative2, Ctx](wsSessionId: WsSessionId): WsClientSession[F, Ctx] = new WsClientSession[F, Ctx] { + override def sessionId: WsSessionId = wsSessionId + override def requestAndAwaitResponse(method: IRTMethodId, data: Json, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = F.pure(None) + override def updateRequestCtx(newContext: Ctx): F[Throwable, Ctx] = F.pure(newContext) + override def start(onStart: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit + override def finish(onFinish: Ctx => F[Throwable, Unit]): F[Throwable, Unit] = F.unit + override def responseWith(id: RpcPacketId, response: RawResponse): F[Throwable, Unit] = F.unit + override def responseWithData(id: RpcPacketId, data: Json): F[Throwable, Unit] = F.unit + } + + abstract class Base[F[+_, +_]: IO2: Temporal2: Primitives2, SessionCtx]( + initialContext: SessionCtx, + wsSessionsContext: Set[WsContextSessions.AnyContext[F, SessionCtx]], + wsSessionStorage: WsSessionsStorage[F, SessionCtx], + wsContextExtractor: WsContextExtractor[SessionCtx], logger: LogIO2[F], - ) extends WsClientSession[F, RequestCtx, ClientId] { - private val openingTime: ZonedDateTime = IzTime.utcNow - private val sessionId: WsSessionId = WsSessionId(UUIDGen.getTimeUUID()) - private val clientId: AtomicReference[Option[ClientId]] = new AtomicReference[Option[ClientId]](None) - private val requestState: WsRequestState[F] = WsRequestState.create[F] + ) extends WsClientSession[F, SessionCtx] { + private val requestCtxRef = new AtomicReference[SessionCtx](initialContext) + private val openingTime: ZonedDateTime = IzTime.utcNow + + protected val requestState: WsRequestState[F] = WsRequestState.create[F] + protected def sendMessage(message: RpcPacket): F[Throwable, Unit] + protected def sendCloseMessage(): F[Throwable, Unit] - def id: WsClientId[ClientId] = WsClientId(sessionId, clientId.get()) + override val sessionId: WsSessionId = WsSessionId(UUIDGen.getTimeUUID()) + + override def updateRequestCtx(newContext: SessionCtx): F[Throwable, SessionCtx] = { + for { + contexts <- F.sync { + requestCtxRef.synchronized { + val oldContext = requestCtxRef.get() + val updatedContext = requestCtxRef.updateAndGet { + old => + wsContextExtractor.merge(old, newContext) + } + oldContext -> updatedContext + } + } + (oldContext, updatedContext) = contexts + _ <- F.when(oldContext != updatedContext) { + F.traverse_(wsSessionsContext)(_.updateSession(sessionId, Some(updatedContext))) + } + } yield updatedContext + } def requestAndAwaitResponse(method: IRTMethodId, data: Json, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = { val id = RpcPacketId.random() val request = RpcPacket.buzzerRequest(id, method, data) for { - _ <- logger.debug(s"WS Session: enqueue $request with $id to request state & send queue.") - response <- requestState.requestAndAwait(id, Some(method), timeout) { - outQueue.offer(Text(printer.print(request.asJson))) - } - _ <- logger.debug(s"WS Session: $method, ${id -> "id"}: cleaning request state.") + _ <- logger.debug(s"WS Session: enqueue $request with $id to request state & send queue.") + response <- requestState.requestAndAwait(id, Some(method), timeout)(sendMessage(request)) + _ <- logger.debug(s"WS Session: $method, ${id -> "id"}: cleaning request state.") } yield response } @@ -65,28 +95,23 @@ object WsClientSession { requestState.responseWithData(id, data) } - override def updateId(maybeNewId: Option[ClientId]): F[Throwable, Unit] = { - for { - old <- F.sync(id) - _ <- F.sync(clientId.set(maybeNewId)) - current <- F.sync(id) - _ <- F.when(old != current)(F.traverse_(listeners)(_.onClientIdUpdate(current, old))) - } yield () - } - - override def finish(): F[Throwable, Unit] = { - F.fromEither(WebSocketFrame.Close(1000)).flatMap(outQueue.offer(_)) *> - wsSessionStorage.deleteClient(sessionId) *> - F.traverse_(listeners)(_.onSessionClosed(id)) *> - requestState.clear() + override def finish(onFinish: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit] = { + val requestCtx = requestCtxRef.get() + sendCloseMessage() *> + requestState.clear() *> + wsSessionStorage.deleteSession(sessionId) *> + F.traverse_(wsSessionsContext)(_.updateSession(sessionId, None)) *> + onFinish(requestCtx) } - protected[http4s] def start(): F[Throwable, Unit] = { - wsSessionStorage.addClient(this) *> - F.traverse_(listeners)(_.onSessionOpened(id)) + override def start(onStart: SessionCtx => F[Throwable, Unit]): F[Throwable, Unit] = { + val requestCtx = requestCtxRef.get() + wsSessionStorage.addSession(this) *> + F.traverse_(wsSessionsContext)(_.updateSession(sessionId, Some(requestCtx))) *> + onStart(requestCtx) } - override def toString: String = s"[${id.toString}, ${duration().toSeconds}s]" + override def toString: String = s"[$sessionId, ${duration().toSeconds}s]" private[this] def duration(): FiniteDuration = { val now = IzTime.utcNow @@ -94,4 +119,42 @@ object WsClientSession { FiniteDuration(d.toNanos, TimeUnit.NANOSECONDS) } } + + final class Dummy[F[+_, +_]: IO2: Temporal2: Primitives2, SessionCtx]( + initialContext: SessionCtx, + muxer: IRTServerMultiplexor[F, Unit], + wsSessionsContext: Set[WsContextSessions.AnyContext[F, SessionCtx]], + wsSessionStorage: WsSessionsStorage[F, SessionCtx], + wsContextExtractor: WsContextExtractor[SessionCtx], + logger: LogIO2[F], + ) extends Base[F, SessionCtx](initialContext, wsSessionsContext, wsSessionStorage, wsContextExtractor, logger) { + private val clientHandler = new ClientWsRpcHandler(muxer, requestState, WsContextExtractor.unit, logger) + override protected def sendMessage(message: RpcPacket): F[Throwable, Unit] = { + clientHandler.processRpcPacket(message).flatMap { + case Some(RpcPacket(_, Some(json), None, Some(ref), _, _, _)) => + // discard any errors here (it's only possible to fail if the packet reference is missing) + requestState.responseWithData(ref, json).attempt.void + case _ => + F.unit + } + } + override protected def sendCloseMessage(): F[Throwable, Unit] = F.unit + } + + final class Queued[F[+_, +_]: IO2: Temporal2: Primitives2, SessionCtx]( + outQueue: Queue[F[Throwable, _], WebSocketFrame], + initialContext: SessionCtx, + wsSessionsContext: Set[WsContextSessions.AnyContext[F, SessionCtx]], + wsSessionStorage: WsSessionsStorage[F, SessionCtx], + wsContextExtractor: WsContextExtractor[SessionCtx], + logger: LogIO2[F], + printer: Printer, + ) extends Base[F, SessionCtx](initialContext, wsSessionsContext, wsSessionStorage, wsContextExtractor, logger) { + override protected def sendMessage(message: RpcPacket): F[Throwable, Unit] = { + outQueue.offer(Text(printer.print(message.asJson))) + } + override protected def sendCloseMessage(): F[Throwable, Unit] = { + F.fromEither(WebSocketFrame.Close(1000)).flatMap(outQueue.offer(_)) + } + } } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextProvider.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextProvider.scala deleted file mode 100644 index 92a1e2da..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextProvider.scala +++ /dev/null @@ -1,41 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.ws - -import izumi.functional.bio.{Applicative2, F} -import izumi.fundamentals.platform.language.Quirks -import izumi.idealingua.runtime.rpc.http4s.ws.WsContextProvider.WsAuthResult -import izumi.idealingua.runtime.rpc.{RPCPacketKind, RpcPacket} - -trait WsContextProvider[F[+_, +_], RequestCtx, ClientId] { - def toContext(id: WsClientId[ClientId], initial: RequestCtx, packet: RpcPacket): F[Throwable, RequestCtx] - - def toId(initial: RequestCtx, currentId: WsClientId[ClientId], packet: RpcPacket): F[Throwable, Option[ClientId]] - - // TODO: we use this to mangle with authorization but it's dirty - def handleAuthorizationPacket(id: WsClientId[ClientId], initial: RequestCtx, packet: RpcPacket): F[Throwable, WsAuthResult[ClientId]] -} - -object WsContextProvider { - - final case class WsAuthResult[ClientId](client: Option[ClientId], response: RpcPacket) - - class IdContextProvider[F[+_, +_]: Applicative2, RequestCtx, ClientId] extends WsContextProvider[F, RequestCtx, ClientId] { - override def handleAuthorizationPacket( - id: WsClientId[ClientId], - initial: RequestCtx, - packet: RpcPacket, - ): F[Throwable, WsAuthResult[ClientId]] = { - val res = RpcPacket(RPCPacketKind.RpcResponse, None, None, packet.id, None, None, None) - F.pure(WsAuthResult[ClientId](None, res)) - } - - override def toContext(id: WsClientId[ClientId], initial: RequestCtx, packet: RpcPacket): F[Throwable, RequestCtx] = { - Quirks.discard(packet, id) - F.pure(initial) - } - - override def toId(initial: RequestCtx, currentId: WsClientId[ClientId], packet: RpcPacket): F[Throwable, Option[ClientId]] = { - Quirks.discard(initial, packet) - F.pure(None) - } - } -} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextSessions.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextSessions.scala new file mode 100644 index 00000000..8d903e53 --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextSessions.scala @@ -0,0 +1,55 @@ +package izumi.idealingua.runtime.rpc.http4s.ws + +import izumi.functional.bio.{Applicative2, F, IO2, Monad2} +import izumi.idealingua.runtime.rpc.http4s.context.WsIdExtractor + +trait WsContextSessions[F[+_, +_], RequestCtx, WsCtx] { + self => + /** Updates session context if can be updated and sends callbacks to Set[WsSessionListener] if context were changed. */ + def updateSession(wsSessionId: WsSessionId, requestContext: Option[RequestCtx]): F[Throwable, Unit] + /** Contramap with F over context type. Useful for authentications and context extensions. */ + final def contramap[C](updateCtx: C => F[Throwable, Option[RequestCtx]])(implicit M: Monad2[F]): WsContextSessions[F, C, WsCtx] = new WsContextSessions[F, C, WsCtx] { + override def updateSession(wsSessionId: WsSessionId, requestContext: Option[C]): F[Throwable, Unit] = { + F.traverse(requestContext)(updateCtx).flatMap(mbCtx => self.updateSession(wsSessionId, mbCtx.flatten)) + } + } +} + +object WsContextSessions { + type AnyContext[F[+_, +_], RequestCtx] = WsContextSessions[F, RequestCtx, ?] + + def unit[F[+_, +_]: Applicative2, RequestCtx]: WsContextSessions[F, RequestCtx, Unit] = new Empty + + final class Empty[F[+_, +_]: Applicative2, RequestCtx, WsCtx] extends WsContextSessions[F, RequestCtx, WsCtx] { + override def updateSession(wsSessionId: WsSessionId, requestContext: Option[RequestCtx]): F[Throwable, Unit] = F.unit + } + + class WsContextSessionsImpl[F[+_, +_]: IO2, RequestCtx, WsCtx]( + wsContextStorage: WsContextStorage[F, WsCtx], + globalWsListeners: Set[WsSessionListener.Global[F]], + wsSessionListeners: Set[WsSessionListener[F, RequestCtx, WsCtx]], + wsIdExtractor: WsIdExtractor[RequestCtx, WsCtx], + ) extends WsContextSessions[F, RequestCtx, WsCtx] { + override def updateSession(wsSessionId: WsSessionId, requestContext: Option[RequestCtx]): F[Throwable, Unit] = { + for { + ctxUpdate <- wsContextStorage.updateContext(wsSessionId) { + mbPrevCtx => + requestContext.flatMap(wsIdExtractor.extract(_, mbPrevCtx)) + } + _ <- (requestContext, ctxUpdate.previous, ctxUpdate.updated) match { + case (Some(ctx), Some(previous), Some(updated)) if previous != updated => + F.traverse_(wsSessionListeners)(_.onSessionUpdated(wsSessionId, ctx, previous, updated)) *> + F.traverse_(globalWsListeners)(_.onSessionUpdated(wsSessionId, ctx, previous, updated)) + case (Some(ctx), None, Some(updated)) => + F.traverse_(wsSessionListeners)(_.onSessionOpened(wsSessionId, ctx, updated)) *> + F.traverse_(globalWsListeners)(_.onSessionOpened(wsSessionId, ctx, updated)) + case (_, Some(prev), None) => + F.traverse_(wsSessionListeners)(_.onSessionClosed(wsSessionId, prev)) *> + F.traverse_(globalWsListeners)(_.onSessionClosed(wsSessionId, prev)) + case _ => + F.unit + } + } yield () + } + } +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextStorage.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextStorage.scala new file mode 100644 index 00000000..ef99bacf --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsContextStorage.scala @@ -0,0 +1,112 @@ +package izumi.idealingua.runtime.rpc.http4s.ws + +import izumi.functional.bio.{F, IO2} +import izumi.idealingua.runtime.rpc.http4s.ws.WsContextStorage.{WsContextSessionId, WsCtxUpdate} +import izumi.idealingua.runtime.rpc.{IRTClientMultiplexor, IRTDispatcher} + +import java.util.concurrent.ConcurrentHashMap +import scala.concurrent.duration.{DurationInt, FiniteDuration} +import scala.jdk.CollectionConverters.* + +/** Sessions storage based on WS context. + * Supports [one session - one context] and [one context - many sessions mapping] + * It is possible to support [one sessions - many contexts] mapping (for generic context storages), + * but in such case we would able to choose one from many to update session context data. + */ +trait WsContextStorage[F[+_, +_], WsCtx] { + def getContext(wsSessionId: WsSessionId): F[Throwable, Option[WsCtx]] + def allSessions(): F[Throwable, Set[WsContextSessionId[WsCtx]]] + /** Updates session context using [updateCtx] function (maybeOldContext => maybeNewContext) */ + def updateContext(wsSessionId: WsSessionId)(updateCtx: Option[WsCtx] => Option[WsCtx]): F[Throwable, WsCtxUpdate[WsCtx]] + + def getSessions(ctx: WsCtx): F[Throwable, List[WsClientSession[F, ?]]] + def dispatchersFor(ctx: WsCtx, codec: IRTClientMultiplexor[F], timeout: FiniteDuration = 20.seconds): F[Throwable, List[IRTDispatcher[F]]] +} + +object WsContextStorage { + final case class WsCtxUpdate[WsCtx](previous: Option[WsCtx], updated: Option[WsCtx]) + final case class WsContextSessionId[WsCtx](sessionId: WsSessionId, ctx: WsCtx) + + class WsContextStorageImpl[F[+_, +_]: IO2, WsCtx]( + wsSessionsStorage: WsSessionsStorage[F, ?] + ) extends WsContextStorage[F, WsCtx] { + private[this] val sessionToId = new ConcurrentHashMap[WsSessionId, WsCtx]() + private[this] val idToSessions = new ConcurrentHashMap[WsCtx, Set[WsSessionId]]() + + override def allSessions(): F[Throwable, Set[WsContextSessionId[WsCtx]]] = F.sync { + sessionToId.asScala.map { case (s, c) => WsContextSessionId(s, c) }.toSet + } + + override def getContext(wsSessionId: WsSessionId): F[Throwable, Option[WsCtx]] = F.sync { + Option(sessionToId.get(wsSessionId)) + } + + override def updateContext(wsSessionId: WsSessionId)(updateCtx: Option[WsCtx] => Option[WsCtx]): F[Nothing, WsCtxUpdate[WsCtx]] = { + updateCtxImpl(wsSessionId)(updateCtx) + } + + override def getSessions(ctx: WsCtx): F[Throwable, List[WsClientSession[F, ?]]] = { + F.sync(synchronized(Option(idToSessions.get(ctx)).getOrElse(Set.empty).toList)).flatMap { + sessions => + F.traverse[Throwable, WsSessionId, Option[WsClientSession[F, ?]]](sessions) { + wsSessionId => wsSessionsStorage.getSession(wsSessionId) + }.map(_.flatten) + } + } + + override def dispatchersFor(ctx: WsCtx, codec: IRTClientMultiplexor[F], timeout: FiniteDuration): F[Throwable, List[IRTDispatcher[F]]] = { + F.sync(synchronized(Option(idToSessions.get(ctx)).getOrElse(Set.empty)).toList).flatMap { + sessions => + F.traverse[Throwable, WsSessionId, Option[IRTDispatcher[F]]](sessions) { + wsSessionId => wsSessionsStorage.dispatcherForSession(wsSessionId, codec, timeout) + }.map(_.flatten) + } + } + + @inline private final def updateCtxImpl( + wsSessionId: WsSessionId + )(updateCtx: Option[WsCtx] => Option[WsCtx] + ): F[Nothing, WsCtxUpdate[WsCtx]] = F.sync { + synchronized { + val mbPrevCtx = Option(sessionToId.get(wsSessionId)) + val mbNewCtx = updateCtx(mbPrevCtx) + (mbNewCtx, mbPrevCtx) match { + case (Some(updCtx), mbPrevCtx) => + mbPrevCtx.foreach(removeSessionFromCtx(_, wsSessionId)) + sessionToId.put(wsSessionId, updCtx) + addSessionToCtx(updCtx, wsSessionId) + () + case (None, Some(prevCtx)) => + sessionToId.remove(wsSessionId) + removeSessionFromCtx(prevCtx, wsSessionId) + () + case _ => + () + } + WsCtxUpdate(mbPrevCtx, mbNewCtx) + } + } + + @inline private final def addSessionToCtx(wsCtx: WsCtx, wsSessionId: WsSessionId): Unit = { + idToSessions.compute( + wsCtx, + { + case (_, null) => Set(wsSessionId) + case (_, s) => s + wsSessionId + }, + ) + () + } + + @inline private final def removeSessionFromCtx(wsCtx: WsCtx, wsSessionId: WsSessionId): Unit = { + idToSessions.compute( + wsCtx, + { + case (_, null) => null + case (_, s) => Option(s - wsSessionId).filter(_.nonEmpty).orNull + }, + ) + () + } + } +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala index a094cf4e..39142463 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala @@ -5,7 +5,7 @@ import izumi.functional.bio.{Clock1, Clock2, F, IO2, Primitives2, Promise2, Temp import izumi.fundamentals.platform.language.Quirks.* import izumi.idealingua.runtime.rpc.* import izumi.idealingua.runtime.rpc.http4s.ws.RawResponse.BadRawResponse -import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsClientResponder +import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsResponder import java.time.OffsetDateTime import java.time.temporal.ChronoUnit @@ -14,7 +14,7 @@ import scala.collection.mutable import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters.* -trait WsRequestState[F[_, _]] extends WsClientResponder[F] { +trait WsRequestState[F[_, _]] extends WsResponder[F] { def requestAndAwait[A]( id: RpcPacketId, methodId: Option[IRTMethodId], diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRpcHandler.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRpcHandler.scala index 90c1c8d2..100d64f5 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRpcHandler.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRpcHandler.scala @@ -5,39 +5,43 @@ import izumi.functional.bio.Exit.Success import izumi.functional.bio.{Exit, F, IO2} import izumi.fundamentals.platform.language.Quirks.Discarder import izumi.idealingua.runtime.rpc.* -import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsClientResponder +import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsResponder import logstage.LogIO2 abstract class WsRpcHandler[F[+_, +_]: IO2, RequestCtx]( muxer: IRTServerMultiplexor[F, RequestCtx], - responder: WsClientResponder[F], + responder: WsResponder[F], logger: LogIO2[F], ) { - protected def handlePacket(packet: RpcPacket): F[Throwable, Unit] - protected def handleAuthRequest(packet: RpcPacket): F[Throwable, Option[RpcPacket]] - protected def extractContext(packet: RpcPacket): F[Throwable, RequestCtx] + /** Update context based on RpcPacket (or extract). + * Called on each RpcPacket messages before packet handling + */ + protected def updateRequestCtx(packet: RpcPacket): F[Throwable, RequestCtx] - protected def handleAuthResponse(ref: RpcPacketId, packet: RpcPacket): F[Throwable, Option[RpcPacket]] = { - packet.discard() - responder.responseWith(ref, RawResponse.EmptyRawResponse()).as(None) + def processRpcMessage(message: String): F[Throwable, Option[RpcPacket]] = { + for { + packet <- F + .fromEither(io.circe.parser.decode[RpcPacket](message)) + .leftMap(err => new IRTDecodingException(s"Can not decode Rpc Packet '$message'.\nError: $err.")) + response <- processRpcPacket(packet) + } yield response } - def processRpcMessage(message: String): F[Throwable, Option[RpcPacket]] = { + def processRpcPacket(packet: RpcPacket): F[Throwable, Option[RpcPacket]] = { for { - packet <- F.fromEither(io.circe.parser.decode[RpcPacket](message)) - _ <- handlePacket(packet) + requestCtx <- updateRequestCtx(packet) response <- packet match { // auth case RpcPacket(RPCPacketKind.RpcRequest, None, _, _, _, _, _) => - handleAuthRequest(packet) + handleAuthRequest(requestCtx, packet) case RpcPacket(RPCPacketKind.RpcResponse, None, _, Some(ref), _, _, _) => handleAuthResponse(ref, packet) // rpc case RpcPacket(RPCPacketKind.RpcRequest, Some(data), Some(id), _, Some(service), Some(method), _) => - handleWsRequest(packet, data, IRTMethodId(IRTServiceId(service), IRTMethodName(method)))( + handleWsRequest(IRTMethodId(IRTServiceId(service), IRTMethodName(method)), requestCtx, data)( onSuccess = RpcPacket.rpcResponse(id, _), onFail = RpcPacket.rpcFail(Some(id), _), ) @@ -50,7 +54,7 @@ abstract class WsRpcHandler[F[+_, +_]: IO2, RequestCtx]( // buzzer case RpcPacket(RPCPacketKind.BuzzRequest, Some(data), Some(id), _, Some(service), Some(method), _) => - handleWsRequest(packet, data, IRTMethodId(IRTServiceId(service), IRTMethodName(method)))( + handleWsRequest(IRTMethodId(IRTServiceId(service), IRTMethodName(method)), requestCtx, data)( onSuccess = RpcPacket.buzzerResponse(id, _), onFail = RpcPacket.buzzerFail(Some(id), _), ) @@ -78,36 +82,61 @@ abstract class WsRpcHandler[F[+_, +_]: IO2, RequestCtx]( } protected def handleWsRequest( - input: RpcPacket, - data: Json, methodId: IRTMethodId, + requestCtx: RequestCtx, + data: Json, )(onSuccess: Json => RpcPacket, onFail: String => RpcPacket, ): F[Throwable, Option[RpcPacket]] = { - for { - userCtx <- extractContext(input) - res <- muxer.doInvoke(data, userCtx, methodId).sandboxExit.flatMap { - case Success(Some(res)) => - F.pure(Some(onSuccess(res))) - - case Success(None) => - logger.error(s"WS request errored: No rpc handler for $methodId").as(Some(onFail("No rpc handler."))) - - case Exit.Termination(exception, allExceptions, trace) => - logger.error(s"WS request terminated, $exception, $allExceptions, $trace").as(Some(onFail(exception.getMessage))) + muxer.invokeMethod(methodId)(requestCtx, data).sandboxExit.flatMap { + case Success(res) => + F.pure(Some(onSuccess(res))) + + case Exit.Error(error: IRTMissingHandlerException, trace) => + logger + .error(s"WS Request failed - no method handler for $methodId:\n$error\n$trace") + .as(Some(onFail("Not Found."))) + + case Exit.Error(error: IRTUnathorizedRequestContextException, trace) => + logger + .warn(s"WS Request failed - unauthorized $methodId call:\n$error\n$trace") + .as(Some(onFail("Unauthorized."))) + + case Exit.Error(error: IRTDecodingException, trace) => + logger + .warn(s"WS Request failed - decoding failed:\n$error\n$trace") + .as(Some(onFail("BadRequest."))) + + case Exit.Termination(exception, allExceptions, trace) => + logger + .error(s"WS Request terminated:\n$exception\n$allExceptions\n$trace") + .as(Some(onFail(exception.getMessage))) + + case Exit.Error(exception, trace) => + logger + .error(s"WS Request unexpectedly failed:\n$exception\n$trace") + .as(Some(onFail(exception.getMessage))) + + case Exit.Interruption(exception, allExceptions, trace) => + logger + .error(s"WS Request unexpectedly interrupted:\n$exception\n$allExceptions\n$trace") + .as(Some(onFail(exception.getMessage))) + } + } - case Exit.Error(exception, trace) => - logger.error(s"WS request failed, $exception $trace").as(Some(onFail(exception.getMessage))) + protected def handleAuthRequest(requestCtx: RequestCtx, packet: RpcPacket): F[Throwable, Option[RpcPacket]] = { + requestCtx.discard() + F.pure(Some(RpcPacket(RPCPacketKind.RpcResponse, None, None, packet.id, None, None, None))) + } - case Exit.Interruption(exception, allExceptions, trace) => - logger.error(s"WS request interrupted, $exception $allExceptions $trace").as(Some(onFail(exception.getMessage))) - } - } yield res + protected def handleAuthResponse(ref: RpcPacketId, packet: RpcPacket): F[Throwable, Option[RpcPacket]] = { + packet.discard() + responder.responseWith(ref, RawResponse.EmptyRawResponse()).as(None) } } object WsRpcHandler { - trait WsClientResponder[F[_, _]] { + trait WsResponder[F[_, _]] { def responseWith(id: RpcPacketId, response: RawResponse): F[Throwable, Unit] def responseWithData(id: RpcPacketId, data: Json): F[Throwable, Unit] } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionId.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionId.scala new file mode 100644 index 00000000..6ad56c61 --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionId.scala @@ -0,0 +1,5 @@ +package izumi.idealingua.runtime.rpc.http4s.ws + +import java.util.UUID + +final case class WsSessionId(sessionId: UUID) extends AnyVal diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionListener.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionListener.scala index 2e20d04c..ae227e68 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionListener.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionListener.scala @@ -1,18 +1,19 @@ package izumi.idealingua.runtime.rpc.http4s.ws -import izumi.functional.bio.Applicative2 +import izumi.functional.bio.{Applicative2, F} -trait WsSessionListener[F[+_, +_], ClientId] { - def onSessionOpened(context: WsClientId[ClientId]): F[Throwable, Unit] - def onClientIdUpdate(context: WsClientId[ClientId], old: WsClientId[ClientId]): F[Throwable, Unit] - def onSessionClosed(context: WsClientId[ClientId]): F[Throwable, Unit] +trait WsSessionListener[F[_, _], -RequestCtx, -WsCtx] { + def onSessionOpened(sessionId: WsSessionId, reqCtx: RequestCtx, wsCtx: WsCtx): F[Throwable, Unit] + def onSessionUpdated(sessionId: WsSessionId, reqCtx: RequestCtx, prevStx: WsCtx, newCtx: WsCtx): F[Throwable, Unit] + def onSessionClosed(sessionId: WsSessionId, wsCtx: WsCtx): F[Throwable, Unit] } object WsSessionListener { - def empty[F[+_, +_]: Applicative2, ClientId]: WsSessionListener[F, ClientId] = new WsSessionListener[F, ClientId] { - import izumi.functional.bio.F - override def onSessionOpened(context: WsClientId[ClientId]): F[Throwable, Unit] = F.unit - override def onClientIdUpdate(context: WsClientId[ClientId], old: WsClientId[ClientId]): F[Throwable, Unit] = F.unit - override def onSessionClosed(context: WsClientId[ClientId]): F[Throwable, Unit] = F.unit + type Global[F[_, _]] = WsSessionListener[F, Any, Any] + + def empty[F[+_, +_]: Applicative2, RequestCtx, WsCtx]: WsSessionListener[F, RequestCtx, WsCtx] = new WsSessionListener[F, RequestCtx, WsCtx] { + override def onSessionOpened(sessionId: WsSessionId, reqCtx: RequestCtx, wsCtx: WsCtx): F[Throwable, Unit] = F.unit + override def onSessionUpdated(sessionId: WsSessionId, reqCtx: RequestCtx, prevStx: WsCtx, newCtx: WsCtx): F[Throwable, Unit] = F.unit + override def onSessionClosed(sessionId: WsSessionId, wsCtx: WsCtx): F[Throwable, Unit] = F.unit } } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionsStorage.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionsStorage.scala index fcd030c4..b82c732f 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionsStorage.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsSessionsStorage.scala @@ -8,58 +8,57 @@ import java.util.concurrent.{ConcurrentHashMap, TimeoutException} import scala.concurrent.duration.* import scala.jdk.CollectionConverters.* -trait WsSessionsStorage[F[+_, +_], RequestCtx, ClientId] { - def addClient(ctx: WsClientSession[F, RequestCtx, ClientId]): F[Throwable, WsClientSession[F, RequestCtx, ClientId]] - def deleteClient(id: WsSessionId): F[Throwable, Option[WsClientSession[F, RequestCtx, ClientId]]] - def allClients(): F[Throwable, Seq[WsClientSession[F, RequestCtx, ClientId]]] - - def dispatcherForSession(id: WsSessionId, timeout: FiniteDuration = 20.seconds): F[Throwable, Option[IRTDispatcher[F]]] - def dispatcherForClient(id: ClientId, timeout: FiniteDuration = 20.seconds): F[Throwable, Option[IRTDispatcher[F]]] +trait WsSessionsStorage[F[+_, +_], SessionCtx] { + def addSession(session: WsClientSession[F, SessionCtx]): F[Throwable, WsClientSession[F, SessionCtx]] + def deleteSession(sessionId: WsSessionId): F[Throwable, Option[WsClientSession[F, SessionCtx]]] + def allSessions(): F[Throwable, Seq[WsClientSession[F, SessionCtx]]] + def getSession(sessionId: WsSessionId): F[Throwable, Option[WsClientSession[F, SessionCtx]]] + + def dispatcherForSession( + sessionId: WsSessionId, + codec: IRTClientMultiplexor[F], + timeout: FiniteDuration = 20.seconds, + ): F[Throwable, Option[IRTDispatcher[F]]] } object WsSessionsStorage { - class WsSessionsStorageImpl[F[+_, +_]: IO2, RequestContext, ClientId]( - logger: LogIO2[F], - codec: IRTClientMultiplexor[F], - ) extends WsSessionsStorage[F, RequestContext, ClientId] { - - protected val sessions = new ConcurrentHashMap[WsSessionId, WsClientSession[F, RequestContext, ClientId]]() + class WsSessionsStorageImpl[F[+_, +_]: IO2, SessionCtx](logger: LogIO2[F]) extends WsSessionsStorage[F, SessionCtx] { + protected val sessions = new ConcurrentHashMap[WsSessionId, WsClientSession[F, SessionCtx]]() - override def addClient(ctx: WsClientSession[F, RequestContext, ClientId]): F[Throwable, WsClientSession[F, RequestContext, ClientId]] = { + override def addSession(session: WsClientSession[F, SessionCtx]): F[Throwable, WsClientSession[F, SessionCtx]] = { for { - _ <- logger.debug(s"Adding a client with session - ${ctx.id}") - _ <- F.sync(sessions.put(ctx.id.sessionId, ctx)) - } yield ctx + _ <- logger.debug(s"Adding a client with session - ${session.sessionId}") + _ <- F.sync(sessions.put(session.sessionId, session)) + } yield session } - override def deleteClient(id: WsSessionId): F[Throwable, Option[WsClientSession[F, RequestContext, ClientId]]] = { + override def deleteSession(sessionId: WsSessionId): F[Throwable, Option[WsClientSession[F, SessionCtx]]] = { for { - _ <- logger.debug(s"Deleting a client with session - $id") - res <- F.sync(Option(sessions.remove(id))) + _ <- logger.debug(s"Deleting a client with session - $sessionId") + res <- F.sync(Option(sessions.remove(sessionId))) } yield res } - override def allClients(): F[Throwable, Seq[WsClientSession[F, RequestContext, ClientId]]] = F.sync { - sessions.values().asScala.toSeq + override def getSession(sessionId: WsSessionId): F[Throwable, Option[WsClientSession[F, SessionCtx]]] = { + F.sync(Option(sessions.get(sessionId))) } - override def dispatcherForClient(clientId: ClientId, timeout: FiniteDuration): F[Throwable, Option[WsClientDispatcher[F, RequestContext, ClientId]]] = { - F.sync(sessions.values().asScala.find(_.id.id.contains(clientId))).flatMap { - F.traverse(_) { - session => - dispatcherForSession(session.id.sessionId, timeout) - }.map(_.flatten) - } + override def allSessions(): F[Throwable, Seq[WsClientSession[F, SessionCtx]]] = F.sync { + sessions.values().asScala.toSeq } - override def dispatcherForSession(id: WsSessionId, timeout: FiniteDuration): F[Throwable, Option[WsClientDispatcher[F, RequestContext, ClientId]]] = F.sync { - Option(sessions.get(id)).map(new WsClientDispatcher(_, codec, logger, timeout)) + override def dispatcherForSession( + sessionId: WsSessionId, + codec: IRTClientMultiplexor[F], + timeout: FiniteDuration, + ): F[Throwable, Option[WsClientDispatcher[F, SessionCtx]]] = F.sync { + Option(sessions.get(sessionId)).map(new WsClientDispatcher(_, codec, logger, timeout)) } } - class WsClientDispatcher[F[+_, +_]: IO2, RequestContext, ClientId]( - session: WsClientSession[F, RequestContext, ClientId], + class WsClientDispatcher[F[+_, +_]: IO2, RequestCtx]( + session: WsClientSession[F, RequestCtx], codec: IRTClientMultiplexor[F], logger: LogIO2[F], timeout: FiniteDuration, @@ -70,7 +69,7 @@ object WsSessionsStorage { response <- session.requestAndAwaitResponse(request.method, json, timeout) res <- response match { case Some(value: RawResponse.EmptyRawResponse) => - F.fail(new IRTGenericFailure(s"${request.method -> "method"}: empty response: $value")) + F.fail(new IRTGenericFailure(s"${request.method}: empty response: $value")) case Some(value: RawResponse.GoodRawResponse) => logger.debug(s"WS Session: ${request.method -> "method"}: Have response: $value.") *> @@ -78,15 +77,14 @@ object WsSessionsStorage { case Some(value: RawResponse.BadRawResponse) => logger.debug(s"WS Session: ${request.method -> "method"}: Generic failure response: ${value.error}.") *> - F.fail(new IRTGenericFailure(s"${request.method -> "method"}: generic failure: ${value.error}")) + F.fail(new IRTGenericFailure(s"${request.method}: generic failure: ${value.error}")) case None => logger.warn(s"WS Session: ${request.method -> "method"}: Timeout exception $timeout.") *> - F.fail(new TimeoutException(s"${request.method -> "method"}: No response in $timeout")) + F.fail(new TimeoutException(s"${request.method}: No response in $timeout")) } } yield res } } - } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala index c0d1f0ca..abbc3a4c 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala @@ -1,92 +1,363 @@ package izumi.idealingua.runtime.rpc.http4s -import io.circe.Json -import izumi.functional.bio.Exit.{Error, Interruption, Success, Termination} -import izumi.functional.bio.{Exit, F} -import izumi.fundamentals.platform.language.Quirks.* +import cats.effect.Async +import io.circe.{Json, Printer} +import izumi.functional.bio.Exit.{Error, Success, Termination} +import izumi.functional.bio.UnsafeRun2.FailureHandler +import izumi.functional.bio.impl.{AsyncZio, PrimitivesZio} +import izumi.functional.bio.{Async2, Exit, F, Primitives2, Temporal2, UnsafeRun2} +import izumi.functional.lifecycle.Lifecycle +import izumi.fundamentals.platform.network.IzSockets import izumi.idealingua.runtime.rpc.* +import izumi.idealingua.runtime.rpc.http4s.Http4sTransportTest.{Ctx, IO2R} +import izumi.idealingua.runtime.rpc.http4s.IRTAuthenticator.AuthContext import izumi.idealingua.runtime.rpc.http4s.clients.HttpRpcDispatcher.IRTDispatcherRaw -import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsRequestState} -import izumi.r2.idealingua.test.generated.{GreeterServiceClientWrapped, GreeterServiceMethods} +import izumi.idealingua.runtime.rpc.http4s.clients.{HttpRpcDispatcher, HttpRpcDispatcherFactory, WsRpcDispatcher, WsRpcDispatcherFactory} +import izumi.idealingua.runtime.rpc.http4s.context.{HttpContextExtractor, WsContextExtractor} +import izumi.idealingua.runtime.rpc.http4s.fixtures.TestServices +import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsClientSession, WsRequestState} +import izumi.logstage.api.routing.{ConfigurableLogRouter, StaticLogRouter} +import izumi.logstage.api.{IzLogger, Log} +import izumi.r2.idealingua.test.generated.* +import logstage.LogIO2 import org.http4s.* import org.http4s.blaze.server.* +import org.http4s.dsl.Http4sDsl import org.http4s.headers.Authorization import org.http4s.server.Router import org.scalatest.wordspec.AnyWordSpec -import zio.interop.catz.asyncInstance -import zio.{IO, ZIO} -import java.util.Base64 +import java.util.concurrent.Executors +import scala.concurrent.ExecutionContext.global import scala.concurrent.duration.DurationInt -class Http4sTransportTest extends AnyWordSpec { +final class Http4sTransportTest + extends Http4sTransportTestBase[zio.IO]()( + async2 = AsyncZio, + primitives2 = PrimitivesZio, + temporal2 = AsyncZio, + unsafeRun2 = IO2R, + asyncThrowable = zio.interop.catz.asyncInstance, + ) +object Http4sTransportTest { + final val izLogger: IzLogger = makeLogger() + final val handler: FailureHandler.Custom = UnsafeRun2.FailureHandler.Custom(message => izLogger.trace(s"Fiber failed: $message")) + final val IO2R: UnsafeRun2[zio.IO] = UnsafeRun2.createZIO( + handler = handler, + customCpuPool = Some( + zio.Executor.fromJavaExecutor( + Executors.newFixedThreadPool(2) + ) + ), + ) + + final class Ctx[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRun2](implicit asyncThrowable: Async[F[Throwable, _]]) { + private val logger: LogIO2[F] = LogIO2.fromLogger(izLogger) + private val printer: Printer = Printer.noSpaces.copy(dropNullValues = true) + private val dsl: Http4sDsl[F[Throwable, _]] = Http4sDsl.apply[F[Throwable, _]] + + def badAuth(): Header.ToRaw = Authorization(Credentials.Token(AuthScheme.Bearer, "token")) + def publicAuth(user: String): Header.ToRaw = Authorization(BasicCredentials(user, "public")) + def protectedAuth(user: String): Header.ToRaw = Authorization(BasicCredentials(user, "protected")) + def privateAuth(user: String): Header.ToRaw = Authorization(BasicCredentials(user, "private")) + + def withServer(f: HttpServerContext[F] => F[Throwable, Unit]): Unit = { + executeF { + (for { + testServices <- Lifecycle.liftF(F.syncThrowable(new TestServices[F](logger))) + ioService <- Lifecycle.liftF { + F.syncThrowable { + new HttpServer[F, AuthContext]( + contextServices = testServices.Server.contextServices, + httpContextExtractor = HttpContextExtractor.authContext, + wsContextExtractor = WsContextExtractor.authContext, + wsSessionsStorage = testServices.Server.wsStorage, + dsl = dsl, + logger = logger, + printer = printer, + ) + } + } + addr <- Lifecycle.liftF(F.sync(IzSockets.temporaryServerAddress())) + port = addr.getPort + host = addr.getHostName + _ <- Lifecycle.fromCats { + BlazeServerBuilder[F[Throwable, _]] + .bindHttp(port, host) + .withHttpWebSocketApp(ws => Router("/" -> ioService.service(ws)).orNotFound) + .resource + } + execCtx = HttpExecutionContext(global) + baseUri = Uri(Some(Uri.Scheme.http), Some(Uri.Authority(host = Uri.RegName(host), port = Some(port)))) + wsUri = Uri.unsafeFromString(s"ws://$host:$port/ws") + } yield HttpServerContext(baseUri, wsUri, testServices, execCtx, printer, logger)).use(f) + } + } + + def executeF(io: F[Throwable, Unit]): Unit = { + UnsafeRun2[F].unsafeRunSync(io) match { + case Success(()) => () + case failure: Exit.Failure[?] => throw failure.trace.toThrowable + } + } + } + + final case class HttpServerContext[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRun2]( + baseUri: Uri, + wsUri: Uri, + testServices: TestServices[F], + execCtx: HttpExecutionContext, + printer: Printer, + logger: LogIO2[F], + )(implicit asyncThrowable: Async[F[Throwable, _]] + ) { + val httpClientFactory: HttpRpcDispatcherFactory[F] = { + new HttpRpcDispatcherFactory[F](testServices.Client.codec, execCtx, printer, logger) + } + def httpRpcClientDispatcher(headers: Headers): HttpRpcDispatcher.IRTDispatcherRaw[F] = { + httpClientFactory.dispatcher(baseUri, headers) + } + + val wsClientFactory: WsRpcDispatcherFactory[F] = { + new WsRpcDispatcherFactory[F](testServices.Client.codec, printer, logger, izLogger) + } + def wsRpcClientDispatcher(headers: Map[String, String] = Map.empty): Lifecycle[F[Throwable, _], WsRpcDispatcher.IRTDispatcherWs[F]] = { + wsClientFactory.dispatcherSimple(wsUri, testServices.Client.buzzerMultiplexor, headers) + } + } + + private def makeLogger(): IzLogger = { + val router = ConfigurableLogRouter( + Log.Level.Debug, + levels = Map( + "io.netty" -> Log.Level.Error, + "org.http4s.blaze.channel.nio1" -> Log.Level.Error, + "org.http4s" -> Log.Level.Error, + "org.asynchttpclient" -> Log.Level.Error, + ), + ) + + val out = IzLogger(router) + StaticLogRouter.instance.setup(router) + out + } +} + +abstract class Http4sTransportTestBase[F[+_, +_]]( + implicit + async2: Async2[F], + primitives2: Primitives2[F], + temporal2: Temporal2[F], + unsafeRun2: UnsafeRun2[F], + asyncThrowable: Async[F[Throwable, _]], +) extends AnyWordSpec { + private val ctx = new Ctx[F] + + import ctx.* import fixtures.* - import Http4sTestContext.* - import RT.* "Http4s transport" should { "support http" in { withServer { - for { - // with credentials - httpClient1 <- F.sync(httpRpcClientDispatcher(Headers(Authorization(BasicCredentials("user", "pass"))))) - greeterClient1 = new GreeterServiceClientWrapped(httpClient1) - _ <- greeterClient1.greet("John", "Smith").map(res => assert(res == "Hi, John Smith!")) - _ <- greeterClient1.alternative().either.map(res => assert(res == Right("value"))) - _ <- checkBadBody("{}", httpClient1) - _ <- checkBadBody("{unparseable", httpClient1) - - // without credentials - greeterClient2 <- F.sync(httpRpcClientDispatcher(Headers())).map(new GreeterServiceClientWrapped(_)) - _ <- F.sandboxExit(greeterClient2.alternative()).map { - case Termination(exception: IRTUnexpectedHttpStatus, _, _) => assert(exception.status == Status.Forbidden) - case o => fail(s"Expected IRTGenericFailure but got $o") - } + ctx => + for { + // with credentials + privateClient <- F.sync(ctx.httpRpcClientDispatcher(Headers(privateAuth("user1")))) + protectedClient <- F.sync(ctx.httpRpcClientDispatcher(Headers(protectedAuth("user2")))) + publicClient <- F.sync(ctx.httpRpcClientDispatcher(Headers(publicAuth("user3")))) + publicOrcClient <- F.sync(ctx.httpRpcClientDispatcher(Headers(publicAuth("orc")))) - // with bad credentials - greeterClient2 <- F.sync(httpRpcClientDispatcher(Headers(Authorization(BasicCredentials("user", "badpass"))))).map(new GreeterServiceClientWrapped(_)) - _ <- F.sandboxExit(greeterClient2.alternative()).map { - case Termination(exception: IRTUnexpectedHttpStatus, _, _) => assert(exception.status == Status.Unauthorized) - case o => fail(s"Expected IRTGenericFailure but got $o") - } - } yield () + // Private API test + _ <- new PrivateTestServiceWrappedClient(privateClient) + .test("test").map(res => assert(res.startsWith("Private"))) + _ <- checkUnauthorizedHttpCall(new PrivateTestServiceWrappedClient(protectedClient).test("test")) + _ <- checkUnauthorizedHttpCall(new ProtectedTestServiceWrappedClient(publicClient).test("test")) + + // Protected API test + _ <- new ProtectedTestServiceWrappedClient(protectedClient) + .test("test").map(res => assert(res.startsWith("Protected"))) + _ <- checkUnauthorizedHttpCall(new ProtectedTestServiceWrappedClient(privateClient).test("test")) + _ <- checkUnauthorizedHttpCall(new ProtectedTestServiceWrappedClient(publicClient).test("test")) + + // Public API test + _ <- new GreeterServiceClientWrapped(protectedClient) + .greet("Protected", "Client").map(res => assert(res == "Hi, Protected Client!")) + _ <- new GreeterServiceClientWrapped(privateClient) + .greet("Protected", "Client").map(res => assert(res == "Hi, Protected Client!")) + greaterClient = new GreeterServiceClientWrapped(publicClient) + _ <- greaterClient.greet("John", "Smith").map(res => assert(res == "Hi, John Smith!")) + _ <- greaterClient.alternative().attempt.map(res => assert(res == Right("value"))) + + // middleware test + _ <- checkUnauthorizedHttpCall(new GreeterServiceClientWrapped(publicOrcClient).greet("Orc", "Smith")) + + // bad body test + _ <- checkBadBody("{}", publicClient) + _ <- checkBadBody("{unparseable", publicClient) + } yield () } } "support websockets" in { withServer { - wsRpcClientDispatcher().use { - dispatcher => - for { - id1 <- ZIO.succeed(s"Basic ${Base64.getEncoder.encodeToString("user:pass".getBytes)}") - _ <- dispatcher.authorize(Map("Authorization" -> id1)) - greeterClient = new GreeterServiceClientWrapped(dispatcher) - _ <- greeterClient.greet("John", "Smith").map(res => assert(res == "Hi, John Smith!")) - _ <- greeterClient.alternative().either.map(res => assert(res == Right("value"))) - buzzers <- ioService.wsSessionStorage.dispatcherForClient(id1) - _ = assert(buzzers.nonEmpty) - _ <- ZIO.foreach(buzzers) { - buzzer => - val client = new GreeterServiceClientWrapped(buzzer) - client.greet("John", "Buzzer").map(res => assert(res == "Hi, John Buzzer!")) - } - _ <- dispatcher.authorize(Map("Authorization" -> s"Basic ${Base64.getEncoder.encodeToString("user:badpass".getBytes)}")) - _ <- F.sandboxExit(greeterClient.alternative()).map { - case Termination(_: IRTGenericFailure, _, _) => - case o => F.fail(s"Expected IRTGenericFailure but got $o") - } - } yield () - } + ctx => + import ctx.testServices.{Client, Server} + ctx.wsRpcClientDispatcher().use { + dispatcher => + for { + publicHeaders <- F.pure(Map("Authorization" -> publicAuth("user").values.head.value)) + privateHeaders <- F.pure(Map("Authorization" -> privateAuth("user").values.head.value)) + protectedHeaders <- F.pure(Map("Authorization" -> protectedAuth("user").values.head.value)) + protectedHeaders2 <- F.pure(Map("Authorization" -> protectedAuth("John").values.head.value)) + badHeaders <- F.pure(Map("Authorization" -> badAuth().values.head.value)) + + publicClient = new GreeterServiceClientWrapped[F](dispatcher) + privateClient = new PrivateTestServiceWrappedClient[F](dispatcher) + protectedClient = new ProtectedTestServiceWrappedClient[F](dispatcher) + + // session id is set + sessionId <- F.fromOption(new RuntimeException("Missing Ws Session Id."))(dispatcher.sessionId) + _ <- Server.wsStorage.getSession(sessionId).fromOption(new RuntimeException("Missing Ws Session.")) + + // no dispatchers yet + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.privateWsStorage.dispatchersFor(PrivateContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.publicWsStorage.dispatchersFor(PublicContext("user"), Client.codec).map(b => assert(b.isEmpty)) + // all listeners are empty + _ = assert(Server.protectedWsListener.connectedContexts.isEmpty) + _ = assert(Server.privateWsListener.connectedContexts.isEmpty) + _ = assert(Server.publicWsListener.connectedContexts.isEmpty) + + // public authorization + _ <- dispatcher.authorize(publicHeaders) + // protected and private listeners are empty + _ = assert(Server.protectedWsListener.connectedContexts.isEmpty) + _ = assert(Server.privateWsListener.connectedContexts.isEmpty) + _ = assert(Server.publicWsListener.connectedContexts.contains(PublicContext("user"))) + // protected and private sessions are empty + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.privateWsStorage.dispatchersFor(PrivateContext("user"), Client.codec).map(b => assert(b.isEmpty)) + // public dispatcher works as expected + publicContextBuzzer <- Server.publicWsStorage + .dispatchersFor(PublicContext("user"), Client.codec).map(_.headOption) + .fromOption(new RuntimeException("Missing Buzzer")) + _ <- new GreeterServiceClientWrapped(publicContextBuzzer).greet("John", "Buzzer").map(res => assert(res == "Hi, John Buzzer!")) + _ <- publicClient.greet("John", "Smith").map(res => assert(res == "Hi, John Smith!")) + _ <- publicClient.alternative().attempt.map(res => assert(res == Right("value"))) + _ <- checkUnauthorizedWsCall(privateClient.test("")) + _ <- checkUnauthorizedWsCall(protectedClient.test("")) + + // re-authorize with private + _ <- dispatcher.authorize(privateHeaders) + // protected listener is empty + _ = assert(Server.protectedWsListener.connectedContexts.isEmpty) + _ = assert(Server.privateWsListener.connectedContexts.contains(PrivateContext("user"))) + _ = assert(Server.publicWsListener.connectedContexts.contains(PublicContext("user"))) + // protected sessions is empty + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.privateWsStorage.dispatchersFor(PrivateContext("user"), Client.codec).map(b => assert(b.nonEmpty)) + _ <- Server.publicWsStorage.dispatchersFor(PublicContext("user"), Client.codec).map(b => assert(b.nonEmpty)) + _ <- privateClient.test("test").map(res => assert(res.startsWith("Private"))) + _ <- publicClient.greet("John", "Smith").map(res => assert(res == "Hi, John Smith!")) + _ <- checkUnauthorizedWsCall(protectedClient.test("")) + + // re-authorize with protected + _ <- dispatcher.authorize(protectedHeaders) + // private listener is empty + _ = assert(Server.protectedWsListener.connectedContexts.contains(ProtectedContext("user"))) + _ = assert(Server.privateWsListener.connectedContexts.isEmpty) + _ = assert(Server.publicWsListener.connectedContexts.contains(PublicContext("user"))) + // private sessions is empty + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("user"), Client.codec).map(b => assert(b.nonEmpty)) + _ <- Server.privateWsStorage.dispatchersFor(PrivateContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.publicWsStorage.dispatchersFor(PublicContext("user"), Client.codec).map(b => assert(b.nonEmpty)) + _ <- protectedClient.test("test").map(res => assert(res.startsWith("Protected"))) + _ <- publicClient.greet("John", "Smith").map(res => assert(res == "Hi, John Smith!")) + _ <- checkUnauthorizedWsCall(privateClient.test("")) + + // auth session context update + _ <- dispatcher.authorize(protectedHeaders2) + // session and listeners notified + _ = assert(Server.protectedWsListener.connectedContexts.contains(ProtectedContext("John"))) + _ = assert(Server.protectedWsListener.connectedContexts.size == 1) + _ = assert(Server.publicWsListener.connectedContexts.contains(PublicContext("John"))) + _ = assert(Server.publicWsListener.connectedContexts.size == 1) + _ = assert(Server.privateWsListener.connectedContexts.isEmpty) + _ <- Server.privateWsStorage.dispatchersFor(PrivateContext("John"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.publicWsStorage.dispatchersFor(PublicContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.publicWsStorage.dispatchersFor(PublicContext("John"), Client.codec).map(b => assert(b.nonEmpty)) + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("John"), Client.codec).map(b => assert(b.nonEmpty)) + + // bad authorization + _ <- dispatcher.authorize(badHeaders) + _ <- checkUnauthorizedWsCall(publicClient.alternative()) + } yield () + } + } + } + + "support websockets request auth" in { + withServer { + ctx => + import ctx.testServices.{Client, Server} + for { + privateHeaders <- F.pure(Map("Authorization" -> privateAuth("user").values.head.value)) + _ <- ctx.wsRpcClientDispatcher(privateHeaders).use { + dispatcher => + val publicClient = new GreeterServiceClientWrapped[F](dispatcher) + val privateClient = new PrivateTestServiceWrappedClient[F](dispatcher) + val protectedClient = new ProtectedTestServiceWrappedClient[F](dispatcher) + for { + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.privateWsStorage.dispatchersFor(PrivateContext("user"), Client.codec).map(b => assert(b.nonEmpty)) + _ <- Server.publicWsStorage.dispatchersFor(PublicContext("user"), Client.codec).map(b => assert(b.nonEmpty)) + _ = assert(Server.protectedWsListener.connectedContexts.isEmpty) + _ = assert(Server.privateWsListener.connectedContexts.size == 1) + _ = assert(Server.publicWsListener.connectedContexts.size == 1) + + _ <- privateClient.test("test").map(res => assert(res.startsWith("Private"))) + _ <- publicClient.greet("John", "Smith").map(res => assert(res == "Hi, John Smith!")) + _ <- checkUnauthorizedWsCall(protectedClient.test("")) + } yield () + } + } yield () + } + } + + "support websockets multiple sessions on same context" in { + withServer { + ctx => + import ctx.testServices.{Client, Server} + for { + privateHeaders <- F.pure(Map("Authorization" -> privateAuth("user").values.head.value)) + _ <- { + for { + c1 <- ctx.wsRpcClientDispatcher(privateHeaders) + c2 <- ctx.wsRpcClientDispatcher(privateHeaders) + } yield (c1, c2) + }.use { + case (_, _) => + for { + _ <- Server.protectedWsStorage.dispatchersFor(ProtectedContext("user"), Client.codec).map(b => assert(b.isEmpty)) + _ <- Server.privateWsStorage.dispatchersFor(PrivateContext("user"), Client.codec).map(b => assert(b.size == 2)) + _ <- Server.publicWsStorage.dispatchersFor(PublicContext("user"), Client.codec).map(b => assert(b.size == 2)) + _ = assert(Server.protectedWsListener.connected.isEmpty) + _ = assert(Server.privateWsListener.connected.size == 2) + _ = assert(Server.publicWsListener.connected.size == 2) + } yield () + } + } yield () } } "support request state clean" in { - executeIO { - val rs = new WsRequestState.Default[IO]() + executeF { + val rs = new WsRequestState.Default[F]() for { - id1 <- ZIO.succeed(RpcPacketId.random()) - id2 <- ZIO.succeed(RpcPacketId.random()) + id1 <- F.pure(RpcPacketId.random()) + id2 <- F.pure(RpcPacketId.random()) _ <- rs.registerRequest(id1, None, 0.minutes) _ <- rs.registerRequest(id2, None, 5.minutes) _ <- F.attempt(rs.awaitResponse(id1, 5.seconds)).map { @@ -99,38 +370,60 @@ class Http4sTransportTest extends AnyWordSpec { } yield () } } - } - def withServer(f: IO[Throwable, Any]): Unit = { - executeIO { - BlazeServerBuilder[IO[Throwable, _]] - .bindHttp(port, host) - .withHttpWebSocketApp(ws => Router("/" -> ioService.service(ws)).orNotFound) - .resource - .use(_ => f) - .unit + "support dummy ws client" in { + // server not used here + // but we need to construct test contexts + withServer { + ctx => + import ctx.testServices.{Client, Server} + val client = new WsClientSession.Dummy[F, AuthContext]( + AuthContext(Headers(publicAuth("user")), None), + Client.buzzerMultiplexor, + Server.contextServices.map(_.authorizedWsSessions), + Server.wsStorage, + WsContextExtractor.authContext, + ctx.logger, + ) + for { + _ <- client.start(_ => F.unit) + _ = assert(Server.protectedWsListener.connected.isEmpty) + _ = assert(Server.privateWsListener.connected.isEmpty) + _ = assert(Server.publicWsListener.connected.size == 1) + dispatcher <- Server.publicWsStorage + .dispatchersFor(PublicContext("user"), Client.codec).map(_.headOption) + .fromOption(new RuntimeException("Missing dispatcher")) + _ <- new GreeterServiceClientWrapped(dispatcher) + .greet("John", "Buzzer") + .map(res => assert(res == "Hi, John Buzzer!")) + _ <- client.finish(_ => F.unit) + _ = assert(Server.protectedWsListener.connected.isEmpty) + _ = assert(Server.privateWsListener.connected.isEmpty) + _ = assert(Server.publicWsListener.connected.isEmpty) + } yield () + } } } - def executeIO(io: IO[Throwable, Any]): Unit = { - IO2R.unsafeRunSync(io.unit) match { - case Success(()) => () - case failure: Exit.Failure[?] => throw failure.trace.toThrowable - } + def checkUnauthorizedHttpCall[E, A](call: F[E, A]): F[Throwable, Unit] = { + call.sandboxExit.map { + case Termination(exception: IRTUnexpectedHttpStatus, _, _) => assert(exception.status == Status.Unauthorized) + case o => fail(s"Expected Unauthorized status but got $o") + }.void } - def checkBadBody(body: String, disp: IRTDispatcherRaw[IO]): ZIO[Any, Nothing, Unit] = { - F.sandboxExit(disp.dispatchRaw(GreeterServiceMethods.greet.id, body)).map { - case Error(value: IRTUnexpectedHttpStatus, _) => - assert(value.status == Status.BadRequest).discard() - case Error(value, _) => - fail(s"Unexpected error: $value") - case Success(value) => - fail(s"Unexpected success: $value") - case Termination(exception, _, _) => - fail("Unexpected failure", exception) - case Interruption(value, _, _) => - fail(s"Interrupted: $value") - } + def checkUnauthorizedWsCall[E, A](call: F[E, A]): F[Throwable, Unit] = { + call.sandboxExit.map { + case Termination(f: IRTGenericFailure, _, _) => assert(f.getMessage.contains("""{"cause":"Unauthorized."}""")) + case o => fail(s"Expected IRTGenericFailure with Unauthorized message but got $o") + }.void + } + + def checkBadBody(body: String, disp: IRTDispatcherRaw[F]): F[Nothing, Unit] = { + disp + .dispatchRaw(GreeterServiceMethods.greet.id, body).sandboxExit.map { + case Error(value: IRTUnexpectedHttpStatus, _) => assert(value.status == Status.BadRequest) + case o => fail(s"Expected IRTUnexpectedHttpStatus with BadRequest but got $o") + }.void } } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyAuthorizingDispatcher.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyAuthorizingDispatcher.scala deleted file mode 100644 index 91a6f5c1..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyAuthorizingDispatcher.scala +++ /dev/null @@ -1,34 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.fixtures - -import izumi.functional.bio.IO2 -import izumi.idealingua.runtime.rpc._ -import izumi.idealingua.runtime.rpc.http4s.{IRTBadCredentialsException, IRTNoCredentialsException} -import org.http4s.{BasicCredentials, Status} - -final class DummyAuthorizingDispatcher[F[+_, +_]: IO2, Ctx](proxied: IRTWrappedService[F, Ctx]) extends IRTWrappedService[F, Ctx] { - override def serviceId: IRTServiceId = proxied.serviceId - - override def allMethods: Map[IRTMethodId, IRTMethodWrapper[F, Ctx]] = proxied.allMethods.mapValues { - method => - new IRTMethodWrapper[F, Ctx] { - val R: IO2[F] = implicitly - - override val signature: IRTMethodSignature = method.signature - override val marshaller: IRTCirceMarshaller = method.marshaller - - override def invoke(ctx: Ctx, input: signature.Input): F[Nothing, signature.Output] = { - ctx match { - case DummyRequestContext(_, Some(BasicCredentials(user, pass))) => - if (user == "user" && pass == "pass") { - method.invoke(ctx, input.asInstanceOf[method.signature.Input]).map(_.asInstanceOf[signature.Output]) - } else { - R.terminate(IRTBadCredentialsException(Status.Unauthorized)) - } - - case _ => - R.terminate(IRTNoCredentialsException(Status.Forbidden)) - } - } - } - }.toMap -} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyRequestContext.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyRequestContext.scala deleted file mode 100644 index cf6a75b8..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyRequestContext.scala +++ /dev/null @@ -1,6 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.fixtures - -import com.comcast.ip4s.IpAddress -import org.http4s.Credentials - -final case class DummyRequestContext(ip: IpAddress, credentials: Option[Credentials]) diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyServices.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyServices.scala deleted file mode 100644 index 9843ae77..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/DummyServices.scala +++ /dev/null @@ -1,29 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.fixtures - -import izumi.functional.bio.IO2 -import izumi.idealingua.runtime.rpc.* -import izumi.r2.idealingua.test.generated.{GreeterServiceClientWrapped, GreeterServiceServerWrapped} -import izumi.r2.idealingua.test.impls.AbstractGreeterServer - -class DummyServices[F[+_, +_]: IO2, Ctx] { - - object Server { - private val greeterService = new AbstractGreeterServer.Impl[F, Ctx] - private val greeterDispatcher = new GreeterServiceServerWrapped(greeterService) - private val dispatchers: Set[IRTWrappedService[F, Ctx]] = Set(greeterDispatcher).map(d => new DummyAuthorizingDispatcher(d)) - val multiplexor = new IRTServerMultiplexorImpl[F, Ctx, Ctx](dispatchers, ContextExtender.id) - - private val clients: Set[IRTWrappedClient] = Set(GreeterServiceClientWrapped) - val codec = new IRTClientMultiplexorImpl[F](clients) - } - - object Client { - private val greeterService = new AbstractGreeterServer.Impl[F, Unit] - private val greeterDispatcher = new GreeterServiceServerWrapped(greeterService) - private val dispatchers: Set[IRTWrappedService[F, Unit]] = Set(greeterDispatcher) - - private val clients: Set[IRTWrappedClient] = Set(GreeterServiceClientWrapped) - val codec = new IRTClientMultiplexorImpl[F](clients) - val buzzerMultiplexor = new IRTServerMultiplexorImpl[F, Unit, Unit](dispatchers, ContextExtender.id) - } -} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/Http4sTestContext.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/Http4sTestContext.scala deleted file mode 100644 index c988b71d..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/Http4sTestContext.scala +++ /dev/null @@ -1,118 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.fixtures - -import cats.data.{Kleisli, OptionT} -import com.comcast.ip4s.* -import izumi.functional.lifecycle.Lifecycle -import izumi.fundamentals.platform.language.Quirks -import izumi.fundamentals.platform.network.IzSockets -import izumi.idealingua.runtime.rpc.http4s.* -import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.WsRpcContextProvider -import izumi.idealingua.runtime.rpc.http4s.clients.{HttpRpcDispatcher, HttpRpcDispatcherFactory, WsRpcDispatcher, WsRpcDispatcherFactory} -import izumi.idealingua.runtime.rpc.http4s.ws.WsContextProvider.WsAuthResult -import izumi.idealingua.runtime.rpc.http4s.ws.WsSessionsStorage.WsSessionsStorageImpl -import izumi.idealingua.runtime.rpc.http4s.ws.{WsClientId, WsContextProvider, WsSessionListener} -import izumi.idealingua.runtime.rpc.{RPCPacketKind, RpcPacket} -import org.http4s.* -import org.http4s.headers.Authorization -import org.http4s.server.AuthMiddleware -import zio.interop.catz.* -import zio.{IO, ZIO} - -object Http4sTestContext { - import RT.IO2R - - // - final val addr = IzSockets.temporaryServerAddress() - final val port = addr.getPort - final val host = addr.getHostName - final val baseUri = Uri(Some(Uri.Scheme.http), Some(Uri.Authority(host = Uri.RegName(host), port = Some(port)))) - final val wsUri = Uri.unsafeFromString(s"ws://$host:$port/ws") - - // -// -// import RT.rt -// import rt.* - - final val demo = new DummyServices[IO, DummyRequestContext]() - - // - final val authUser: Kleisli[OptionT[IO[Throwable, _], _], Request[IO[Throwable, _]], DummyRequestContext] = - Kleisli { - (request: Request[IO[Throwable, _]]) => - val context = DummyRequestContext(request.remoteAddr.getOrElse(ipv4"0.0.0.0"), request.headers.get[Authorization].map(_.credentials)) - OptionT.liftF(ZIO.attempt(context)) - } - - final val wsContextProvider: WsContextProvider[IO, DummyRequestContext, String] = new WsContextProvider[IO, DummyRequestContext, String] { - override def toContext(id: WsClientId[String], initial: DummyRequestContext, packet: RpcPacket): zio.IO[Throwable, DummyRequestContext] = { - ZIO.succeed { - val fromState = id.id.map(header => Map("Authorization" -> header)).getOrElse(Map.empty) - val allHeaders = fromState ++ packet.headers.getOrElse(Map.empty) - val creds = allHeaders.get("Authorization").flatMap(Authorization.parse(_).toOption).map(_.credentials) - DummyRequestContext(initial.ip, creds.orElse(initial.credentials)) - } - } - - override def toId(initial: DummyRequestContext, currentId: WsClientId[String], packet: RpcPacket): zio.IO[Throwable, Option[String]] = { - ZIO.attempt { - val fromState = currentId.id.map(header => Map("Authorization" -> header)).getOrElse(Map.empty) - val allHeaders = fromState ++ packet.headers.getOrElse(Map.empty) - allHeaders.get("Authorization") - } - } - - override def handleAuthorizationPacket( - id: WsClientId[String], - initial: DummyRequestContext, - packet: RpcPacket, - ): IO[Throwable, WsAuthResult[String]] = { - Quirks.discard(id, initial) - - packet.headers.flatMap(_.get("Authorization")) match { - case Some(value) if value.isEmpty => - // here we may clear internal state - ZIO.succeed(WsAuthResult(None, RpcPacket(RPCPacketKind.RpcResponse, None, None, packet.id, None, None, None))) - - case Some(_) => - toId(initial, id, packet).flatMap { - case Some(header) => - // here we may set internal state - ZIO.succeed(WsAuthResult(Some(header), RpcPacket(RPCPacketKind.RpcResponse, None, None, packet.id, None, None, None))) - - case None => - ZIO.succeed(WsAuthResult(None, RpcPacket.rpcFail(packet.id, "Authorization failed"))) - } - - case None => - ZIO.succeed(WsAuthResult(None, RpcPacket(RPCPacketKind.RpcResponse, None, None, packet.id, None, None, None))) - } - } - } - - final val storage = new WsSessionsStorageImpl[IO, DummyRequestContext, String](RT.logger, demo.Server.codec) - final val ioService = new HttpServer[IO, DummyRequestContext, DummyRequestContext, String]( - demo.Server.multiplexor, - demo.Server.codec, - AuthMiddleware(authUser), - wsContextProvider, - storage, - Seq(WsSessionListener.empty[IO, String]), - RT.dsl, - RT.logger, - RT.printer, - ) - - final val httpClientFactory: HttpRpcDispatcherFactory[IO] = { - new HttpRpcDispatcherFactory[IO](demo.Client.codec, RT.execCtx, RT.printer, RT.logger) - } - final def httpRpcClientDispatcher(headers: Headers): HttpRpcDispatcher.IRTDispatcherRaw[IO] = { - httpClientFactory.dispatcher(baseUri, headers) - } - - final val wsClientFactory: WsRpcDispatcherFactory[IO] = { - new WsRpcDispatcherFactory[IO](demo.Client.codec, RT.printer, RT.logger, RT.izLogger) - } - final def wsRpcClientDispatcher(): Lifecycle[IO[Throwable, _], WsRpcDispatcher.IRTDispatcherWs[IO]] = { - wsClientFactory.dispatcher(wsUri, demo.Client.buzzerMultiplexor, WsRpcContextProvider.unit) - } -} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/LoggingWsListener.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/LoggingWsListener.scala new file mode 100644 index 00000000..81107241 --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/LoggingWsListener.scala @@ -0,0 +1,25 @@ +package izumi.idealingua.runtime.rpc.http4s.fixtures + +import izumi.functional.bio.{F, IO2} +import izumi.idealingua.runtime.rpc.http4s.ws.{WsSessionId, WsSessionListener} + +import scala.collection.mutable + +final class LoggingWsListener[F[+_, +_]: IO2, RequestCtx, WsCtx] extends WsSessionListener[F, RequestCtx, WsCtx] { + private val connections = mutable.Set.empty[(WsSessionId, WsCtx)] + def connected: Set[(WsSessionId, WsCtx)] = connections.toSet + def connectedContexts: Set[WsCtx] = connections.map(_._2).toSet + + override def onSessionOpened(sessionId: WsSessionId, reqCtx: RequestCtx, wsCtx: WsCtx): F[Throwable, Unit] = F.sync { + connections.add(sessionId -> wsCtx) + }.void + + override def onSessionUpdated(sessionId: WsSessionId, reqCtx: RequestCtx, prevStx: WsCtx, newCtx: WsCtx): F[Throwable, Unit] = F.sync { + connections.remove(sessionId -> prevStx) + connections.add(sessionId -> newCtx) + }.void + + override def onSessionClosed(sessionId: WsSessionId, wsCtx: WsCtx): F[Throwable, Unit] = F.sync { + connections.remove(sessionId -> wsCtx) + }.void +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/RT.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/RT.scala deleted file mode 100644 index e80b3b7e..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/RT.scala +++ /dev/null @@ -1,48 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.fixtures - -import io.circe.Printer -import izumi.functional.bio.UnsafeRun2 -import izumi.idealingua.runtime.rpc.http4s.HttpExecutionContext -import izumi.logstage.api.routing.{ConfigurableLogRouter, StaticLogRouter} -import izumi.logstage.api.{IzLogger, Log} -import logstage.LogIO -import org.http4s.dsl.Http4sDsl -import zio.IO - -import java.util.concurrent.{Executor, Executors} -import scala.concurrent.ExecutionContext.global - -object RT { - final val izLogger = makeLogger() - final val logger = LogIO.fromLogger[IO[Nothing, _]](makeLogger()) - final val printer: Printer = Printer.noSpaces.copy(dropNullValues = true) - - final val handler = UnsafeRun2.FailureHandler.Custom(message => izLogger.warn(s"Fiber failed: $message")) - implicit val IO2R: UnsafeRun2[zio.IO] = UnsafeRun2.createZIO( - handler = handler, - customCpuPool = Some( - zio.Executor.fromJavaExecutor( - Executors.newFixedThreadPool(2) - ) - ), - ) - final val dsl = Http4sDsl.apply[zio.IO[Throwable, _]] - final val execCtx = HttpExecutionContext(global) - - private def makeLogger(): IzLogger = { - val router = ConfigurableLogRouter( - Log.Level.Debug, - levels = Map( - "org.http4s" -> Log.Level.Warn, - "org.http4s.server.blaze" -> Log.Level.Error, - "org.http4s.blaze.channel.nio1" -> Log.Level.Crit, - "izumi.idealingua.runtime.rpc.http4s" -> Log.Level.Crit, - ), - ) - - val out = IzLogger(router) - StaticLogRouter.instance.setup(router) - out - } - -} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestContext.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestContext.scala new file mode 100644 index 00000000..84317709 --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestContext.scala @@ -0,0 +1,17 @@ +package izumi.idealingua.runtime.rpc.http4s.fixtures + +sealed trait TestContext { + def user: String +} + +final case class PrivateContext(user: String) extends TestContext { + override def toString: String = s"private: $user" +} + +final case class ProtectedContext(user: String) extends TestContext { + override def toString: String = s"protected: $user" +} + +final case class PublicContext(user: String) extends TestContext { + override def toString: String = s"public: $user" +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestDispatcher.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestDispatcher.scala deleted file mode 100644 index 36d6e010..00000000 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestDispatcher.scala +++ /dev/null @@ -1,18 +0,0 @@ -package izumi.idealingua.runtime.rpc.http4s.fixtures - -import java.util.concurrent.atomic.AtomicReference - -import org.http4s.headers.Authorization -import org.http4s.{BasicCredentials, Header} - -trait TestDispatcher { - val creds = new AtomicReference[Seq[Header.ToRaw]](Seq.empty) - - def setupCredentials(login: String, password: String): Unit = { - creds.set(Seq(Authorization(BasicCredentials(login, password)))) - } - - def cancelCredentials(): Unit = { - creds.set(Seq.empty) - } -} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestServices.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestServices.scala new file mode 100644 index 00000000..5b99521b --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/fixtures/TestServices.scala @@ -0,0 +1,177 @@ +package izumi.idealingua.runtime.rpc.http4s.fixtures + +import io.circe.Json +import izumi.functional.bio.{F, IO2} +import izumi.idealingua.runtime.rpc.* +import izumi.idealingua.runtime.rpc.http4s.IRTAuthenticator.AuthContext +import izumi.idealingua.runtime.rpc.http4s.context.WsIdExtractor +import izumi.idealingua.runtime.rpc.http4s.ws.* +import izumi.idealingua.runtime.rpc.http4s.ws.WsContextSessions.WsContextSessionsImpl +import izumi.idealingua.runtime.rpc.http4s.ws.WsContextStorage.WsContextStorageImpl +import izumi.idealingua.runtime.rpc.http4s.ws.WsSessionsStorage.WsSessionsStorageImpl +import izumi.idealingua.runtime.rpc.http4s.{IRTAuthenticator, IRTContextServices} +import izumi.r2.idealingua.test.generated.* +import izumi.r2.idealingua.test.impls.AbstractGreeterServer +import logstage.LogIO2 +import org.http4s.BasicCredentials +import org.http4s.headers.Authorization + +class TestServices[F[+_, +_]: IO2]( + logger: LogIO2[F] +) { + + object Server { + def userBlacklistMiddleware[C <: TestContext]( + rejectedNames: Set[String] + ): IRTServerMiddleware[F, C] = new IRTServerMiddleware[F, C] { + override def priority: Int = 0 + override def apply(methodId: IRTMethodId)(context: C, parsedBody: Json)(next: => F[Throwable, Json]): F[Throwable, Json] = { + F.ifThenElse(rejectedNames.contains(context.user))( + F.fail(new IRTUnathorizedRequestContextException(s"Rejected for users: $rejectedNames.")), + next, + ) + } + } + final val wsStorage: WsSessionsStorage[F, AuthContext] = new WsSessionsStorageImpl[F, AuthContext](logger) + final val globalWsListeners = Set( + new WsSessionListener[F, Any, Any] { + override def onSessionOpened(sessionId: WsSessionId, reqCtx: Any, wsCtx: Any): F[Throwable, Unit] = { + logger.debug(s"WS Session: $sessionId opened $wsCtx on $reqCtx.") + } + override def onSessionUpdated(sessionId: WsSessionId, reqCtx: Any, prevStx: Any, newCtx: Any): F[Throwable, Unit] = { + logger.debug(s"WS Session: $sessionId updated $newCtx from $prevStx on $reqCtx.") + } + override def onSessionClosed(sessionId: WsSessionId, wsCtx: Any): F[Throwable, Unit] = { + logger.debug(s"WS Session: $sessionId closed $wsCtx .") + } + } + ) + // PRIVATE + final val privateAuth = new IRTAuthenticator[F, AuthContext, PrivateContext] { + override def authenticate(authContext: AuthContext, body: Option[Json], method: Option[IRTMethodId]): F[Nothing, Option[PrivateContext]] = F.sync { + authContext.headers.get[Authorization].map(_.credentials).collect { + case BasicCredentials(user, "private") => PrivateContext(user) + } + } + } + final val privateWsListener: LoggingWsListener[F, PrivateContext, TestContext] = { + new LoggingWsListener[F, PrivateContext, TestContext] + } + final val privateWsStorage: WsContextStorage[F, PrivateContext] = new WsContextStorageImpl(wsStorage) + final val privateWsSession: WsContextSessions[F, PrivateContext, PrivateContext] = { + new WsContextSessionsImpl( + wsContextStorage = privateWsStorage, + globalWsListeners = globalWsListeners, + wsSessionListeners = Set(privateWsListener), + wsIdExtractor = WsIdExtractor.id, + ) + } + final val privateService: IRTWrappedService[F, PrivateContext] = { + new PrivateTestServiceWrappedServer[F, PrivateContext]( + new PrivateTestServiceServer[F, PrivateContext] { + def test(ctx: PrivateContext, str: String): Just[String] = F.pure(s"Private: $str") + } + ) + } + final val privateServices: IRTContextServices[F, AuthContext, PrivateContext, PrivateContext] = { + IRTContextServices.tagged[F, AuthContext, PrivateContext, PrivateContext]( + authenticator = privateAuth, + serverMuxer = new IRTServerMultiplexor.FromServices(Set(privateService)), + middlewares = Set.empty, + wsSessions = privateWsSession, + ) + } + + // PROTECTED + final val protectedAuth = new IRTAuthenticator[F, AuthContext, ProtectedContext] { + override def authenticate(authContext: AuthContext, body: Option[Json], method: Option[IRTMethodId]): F[Nothing, Option[ProtectedContext]] = F.sync { + authContext.headers.get[Authorization].map(_.credentials).collect { + case BasicCredentials(user, "protected") => ProtectedContext(user) + } + } + } + final val protectedWsListener: LoggingWsListener[F, ProtectedContext, TestContext] = { + new LoggingWsListener[F, ProtectedContext, TestContext] + } + final val protectedWsStorage: WsContextStorage[F, ProtectedContext] = new WsContextStorageImpl(wsStorage) + final val protectedWsSession: WsContextSessions[F, ProtectedContext, ProtectedContext] = { + new WsContextSessionsImpl[F, ProtectedContext, ProtectedContext]( + wsContextStorage = protectedWsStorage, + globalWsListeners = globalWsListeners, + wsSessionListeners = Set(protectedWsListener), + wsIdExtractor = WsIdExtractor.id, + ) + } + final val protectedService: IRTWrappedService[F, ProtectedContext] = { + new ProtectedTestServiceWrappedServer[F, ProtectedContext]( + new ProtectedTestServiceServer[F, ProtectedContext] { + def test(ctx: ProtectedContext, str: String): Just[String] = F.pure(s"Protected: $str") + } + ) + } + final val protectedServices: IRTContextServices[F, AuthContext, ProtectedContext, ProtectedContext] = { + IRTContextServices.tagged[F, AuthContext, ProtectedContext, ProtectedContext]( + authenticator = protectedAuth, + serverMuxer = new IRTServerMultiplexor.FromServices(Set(protectedService)), + middlewares = Set.empty, + wsSessions = protectedWsSession, + ) + } + + // PUBLIC + final val publicAuth = new IRTAuthenticator[F, AuthContext, PublicContext] { + override def authenticate(authContext: AuthContext, body: Option[Json], method: Option[IRTMethodId]): F[Nothing, Option[PublicContext]] = F.sync { + authContext.headers.get[Authorization].map(_.credentials).collect { + case BasicCredentials(user, _) => PublicContext(user) + } + } + } + final val publicWsListener: LoggingWsListener[F, PublicContext, TestContext] = { + new LoggingWsListener[F, PublicContext, TestContext] + } + final val publicWsStorage: WsContextStorage[F, PublicContext] = new WsContextStorageImpl(wsStorage) + final val publicWsSession: WsContextSessions[F, PublicContext, PublicContext] = { + new WsContextSessionsImpl( + wsContextStorage = publicWsStorage, + globalWsListeners = globalWsListeners, + wsSessionListeners = Set(publicWsListener), + wsIdExtractor = WsIdExtractor.id, + ) + } + final val publicService: IRTWrappedService[F, PublicContext] = { + new GreeterServiceServerWrapped[F, PublicContext]( + new AbstractGreeterServer.Impl[F, PublicContext] + ) + } + final val publicServices: IRTContextServices[F, AuthContext, PublicContext, PublicContext] = { + IRTContextServices.tagged[F, AuthContext, PublicContext, PublicContext]( + authenticator = publicAuth, + serverMuxer = new IRTServerMultiplexor.FromServices(Set(publicService)), + middlewares = Set(userBlacklistMiddleware(Set("orc"))), + wsSessions = publicWsSession, + ) + } + + final val contextServices: Set[IRTContextServices.AnyContext[F, AuthContext]] = { + Set[IRTContextServices.AnyContext[F, AuthContext]]( + privateServices, + protectedServices, + publicServices, + ) + } + } + + object Client { + private val greeterService: AbstractGreeterServer[F, Unit] = new AbstractGreeterServer.Impl[F, Unit] + private val greeterDispatcher: GreeterServiceServerWrapped[F, Unit] = new GreeterServiceServerWrapped[F, Unit](greeterService) + private val dispatchers: Set[IRTWrappedService[F, Unit]] = Set[IRTWrappedService[F, Unit]](greeterDispatcher) + + private val clients: Set[IRTWrappedClient] = Set[IRTWrappedClient]( + GreeterServiceClientWrapped, + ProtectedTestServiceWrappedClient, + PrivateTestServiceWrappedClient, + ) + val codec: IRTClientMultiplexorImpl[F] = new IRTClientMultiplexorImpl[F](clients) + val buzzerMultiplexor: IRTServerMultiplexor[F, Unit] = new IRTServerMultiplexor.FromServices[F, Unit](dispatchers) + } +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMethod.scala b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMethod.scala new file mode 100644 index 00000000..7fc6b4da --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMethod.scala @@ -0,0 +1,51 @@ +package izumi.idealingua.runtime.rpc + +import io.circe.Json +import izumi.functional.bio.{Error2, Exit, F, IO2} + +trait IRTServerMethod[F[+_, +_], C] { + self => + def methodId: IRTMethodId + def invoke(context: C, parsedBody: Json): F[Throwable, Json] + + /** Contramap eval on context C2 -> C. If context is missing IRTUnathorizedRequestContextException will raise. */ + final def contramap[C2](updateContext: (C2, Json, IRTMethodId) => F[Throwable, Option[C]])(implicit E: Error2[F]): IRTServerMethod[F, C2] = new IRTServerMethod[F, C2] { + override def methodId: IRTMethodId = self.methodId + override def invoke(context: C2, parsedBody: Json): F[Throwable, Json] = { + updateContext(context, parsedBody, methodId) + .fromOption(new IRTUnathorizedRequestContextException(s"Unauthorized $methodId call. Context: $context.")) + .flatMap(self.invoke(_, parsedBody)) + } + } + + /** Wrap invocation with function '(Context, Body)(Method.Invoke) => Result' . */ + final def wrap(middleware: (C, Json) => F[Throwable, Json] => F[Throwable, Json]): IRTServerMethod[F, C] = new IRTServerMethod[F, C] { + override def methodId: IRTMethodId = self.methodId + override def invoke(context: C, parsedBody: Json): F[Throwable, Json] = { + middleware(context, parsedBody)(self.invoke(context, parsedBody)) + } + } +} + +object IRTServerMethod { + def apply[F[+_, +_]: IO2, C](method: IRTMethodWrapper[F, C]): IRTServerMethod[F, C] = FromWrapper.apply(method) + + final case class FromWrapper[F[+_, +_]: IO2, C](method: IRTMethodWrapper[F, C]) extends IRTServerMethod[F, C] { + override def methodId: IRTMethodId = method.signature.id + @inline override def invoke(context: C, parsedBody: Json): F[Throwable, Json] = { + val methodId = method.signature.id + for { + requestBody <- F.syncThrowable(method.marshaller.decodeRequest[F].apply(IRTJsonBody(methodId, parsedBody))).flatten.sandbox.catchAll { + case Exit.Interruption(decodingFailure, _, trace) => + F.fail(new IRTDecodingException(s"$methodId: Failed to decode JSON '${parsedBody.noSpaces}'.\nTrace: $trace", Some(decodingFailure))) + case Exit.Termination(_, exceptions, trace) => + F.fail(new IRTDecodingException(s"$methodId: Failed to decode JSON '${parsedBody.noSpaces}'.\nTrace: $trace", exceptions.headOption)) + case Exit.Error(decodingFailure, trace) => + F.fail(new IRTDecodingException(s"$methodId: Failed to decode JSON '${parsedBody.noSpaces}'.\nTrace: $trace", Some(decodingFailure))) + } + result <- F.syncThrowable(method.invoke(context, requestBody.value.asInstanceOf[method.signature.Input])).flatten + encoded <- F.syncThrowable(method.marshaller.encodeResponse.apply(IRTResBody(result))) + } yield encoded + } + } +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMiddleware.scala b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMiddleware.scala new file mode 100644 index 00000000..b22bd29e --- /dev/null +++ b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMiddleware.scala @@ -0,0 +1,13 @@ +package izumi.idealingua.runtime.rpc + +import io.circe.Json + +trait IRTServerMiddleware[F[_, _], C] { + def priority: Int + def apply( + methodId: IRTMethodId + )(context: C, + parsedBody: Json, + )(next: => F[Throwable, Json] + ): F[Throwable, Json] +} diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMultiplexor.scala b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMultiplexor.scala index d0d07242..52a20851 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMultiplexor.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTServerMultiplexor.scala @@ -1,53 +1,46 @@ package izumi.idealingua.runtime.rpc import io.circe.Json -import izumi.functional.bio.{Exit, F, IO2} +import izumi.functional.bio.{Error2, F, IO2} -trait ContextExtender[-Ctx, +Ctx2] { - def extend(context: Ctx, body: Json, irtMethodId: IRTMethodId): Ctx2 -} - -object ContextExtender { - def id[Ctx]: ContextExtender[Ctx, Ctx] = (context, _, _) => context -} - -trait IRTServerMultiplexor[F[+_, +_], -C] { - def doInvoke(parsedBody: Json, context: C, toInvoke: IRTMethodId): F[Throwable, Option[Json]] -} +trait IRTServerMultiplexor[F[+_, +_], C] { + self => + def methods: Map[IRTMethodId, IRTServerMethod[F, C]] -class IRTServerMultiplexorImpl[F[+_, +_]: IO2, -C, -C2]( - list: Set[IRTWrappedService[F, C2]], - extender: ContextExtender[C, C2], -) extends IRTServerMultiplexor[F, C] { - val services: Map[IRTServiceId, IRTWrappedService[F, C2]] = list.map(s => s.serviceId -> s).toMap + def invokeMethod(method: IRTMethodId)(context: C, parsedBody: Json)(implicit E: Error2[F]): F[Throwable, Json] = { + F.fromOption(new IRTMissingHandlerException(s"Method $method not found.", parsedBody))(methods.get(method)) + .flatMap(_.invoke(context, parsedBody)) + } - def doInvoke(parsedBody: Json, context: C, toInvoke: IRTMethodId): F[Throwable, Option[Json]] = { - (for { - service <- services.get(toInvoke.service) - method <- service.allMethods.get(toInvoke) - } yield method) match { - case Some(value) => - invoke(extender.extend(context, parsedBody, toInvoke), toInvoke, value, parsedBody).map(Some.apply) - case None => - F.pure(None) + /** Contramap eval on context C2 -> C. If context is missing IRTUnathorizedRequestContextException will raise. */ + final def contramap[C2]( + updateContext: (C2, Json, IRTMethodId) => F[Throwable, Option[C]] + )(implicit io2: IO2[F] + ): IRTServerMultiplexor[F, C2] = { + val mappedMethods = self.methods.map { case (k, v) => k -> v.contramap(updateContext) } + new IRTServerMultiplexor.FromMethods(mappedMethods) + } + /** Wrap invocation with function '(Context, Body)(Method.Invoke) => Result' . */ + final def wrap(middleware: IRTServerMiddleware[F, C]): IRTServerMultiplexor[F, C] = { + val wrappedMethods = self.methods.map { + case (methodId, method) => + val wrappedMethod: IRTServerMethod[F, C] = method.wrap { + case (ctx, body) => + next => middleware(method.methodId)(ctx, body)(next) + } + methodId -> wrappedMethod } + new IRTServerMultiplexor.FromMethods(wrappedMethods) } +} - @inline private[this] def invoke(context: C2, toInvoke: IRTMethodId, method: IRTMethodWrapper[F, C2], parsedBody: Json): F[Throwable, Json] = { - for { - decodeAction <- F.syncThrowable(method.marshaller.decodeRequest[F].apply(IRTJsonBody(toInvoke, parsedBody))) - safeDecoded <- decodeAction.sandbox.catchAll { - case Exit.Interruption(decodingFailure, _, trace) => - F.fail(new IRTDecodingException(s"$toInvoke: Failed to decode JSON ${parsedBody.toString()} $trace", Some(decodingFailure))) - case Exit.Termination(_, exceptions, trace) => - F.fail(new IRTDecodingException(s"$toInvoke: Failed to decode JSON ${parsedBody.toString()} $trace", exceptions.headOption)) - case Exit.Error(decodingFailure, trace) => - F.fail(new IRTDecodingException(s"$toInvoke: Failed to decode JSON ${parsedBody.toString()} $trace", Some(decodingFailure))) - } - casted = safeDecoded.value.asInstanceOf[method.signature.Input] - resultAction <- F.syncThrowable(method.invoke(context, casted)) - safeResult <- resultAction - encoded <- F.syncThrowable(method.marshaller.encodeResponse.apply(IRTResBody(safeResult))) - } yield encoded +object IRTServerMultiplexor { + def combine[F[+_, +_], C](multiplexors: Iterable[IRTServerMultiplexor[F, C]]): IRTServerMultiplexor[F, C] = { + new FromMethods(multiplexors.flatMap(_.methods).toMap) } + + class FromMethods[F[+_, +_], C](val methods: Map[IRTMethodId, IRTServerMethod[F, C]]) extends IRTServerMultiplexor[F, C] + + class FromServices[F[+_, +_]: IO2, C](val services: Set[IRTWrappedService[F, C]]) + extends FromMethods[F, C](services.flatMap(_.allMethods.map { case (k, v) => k -> IRTServerMethod(v) }).toMap) } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTTransportException.scala b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTTransportException.scala index a1725e49..5596a0be 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTTransportException.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/IRTTransportException.scala @@ -1,17 +1,17 @@ package izumi.idealingua.runtime.rpc -trait IRTTransportException +abstract class IRTTransportException(message: String, cause: Option[Throwable]) extends RuntimeException(message, cause.orNull) -class IRTUnparseableDataException(message: String, cause: Option[Throwable] = None) extends RuntimeException(message, cause.orNull) with IRTTransportException +class IRTUnparseableDataException(message: String, cause: Option[Throwable] = None) extends IRTTransportException(message, cause) -class IRTDecodingException(message: String, cause: Option[Throwable] = None) extends RuntimeException(message, cause.orNull) with IRTTransportException +class IRTDecodingException(message: String, cause: Option[Throwable] = None) extends IRTTransportException(message, cause) -class IRTTypeMismatchException(message: String, val v: Any, cause: Option[Throwable] = None) extends RuntimeException(message, cause.orNull) with IRTTransportException +class IRTTypeMismatchException(message: String, val v: Any, cause: Option[Throwable] = None) extends IRTTransportException(message, cause) -class IRTMissingHandlerException(message: String, val v: Any, cause: Option[Throwable] = None) extends RuntimeException(message, cause.orNull) with IRTTransportException +class IRTMissingHandlerException(message: String, val v: Any, cause: Option[Throwable] = None) extends IRTTransportException(message, cause) -class IRTLimitReachedException(message: String, cause: Option[Throwable] = None) extends RuntimeException(message, cause.orNull) with IRTTransportException +class IRTLimitReachedException(message: String, cause: Option[Throwable] = None) extends IRTTransportException(message, cause) -class IRTUnathorizedRequestContextException(message: String, cause: Option[Throwable] = None) extends RuntimeException(message, cause.orNull) with IRTTransportException +class IRTUnathorizedRequestContextException(message: String, cause: Option[Throwable] = None) extends IRTTransportException(message, cause) -class IRTGenericFailure(message: String, cause: Option[Throwable] = None) extends RuntimeException(message, cause.orNull) with IRTTransportException +class IRTGenericFailure(message: String, cause: Option[Throwable] = None) extends IRTTransportException(message, cause) diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/packets.scala b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/packets.scala index 386af425..4cfde254 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/packets.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-scala/src/main/scala/izumi/idealingua/runtime/rpc/packets.scala @@ -78,7 +78,6 @@ object RpcPacketId { def random(): RpcPacketId = RpcPacketId(UUIDGen.getTimeUUID().toString) implicit def dec0: Decoder[RpcPacketId] = Decoder.decodeString.map(RpcPacketId.apply) - implicit def enc0: Encoder[RpcPacketId] = Encoder.encodeString.contramap(_.v) } @@ -94,6 +93,14 @@ case class RpcPacket( def withHeaders(h: Map[String, String]): RpcPacket = { copy(headers = Option(h).filter(_.nonEmpty)) } + def methodId: Option[IRTMethodId] = { + for { + m <- method + s <- service + } yield { + IRTMethodId(IRTServiceId(s), IRTMethodName(m)) + } + } } object RpcPacket { @@ -118,7 +125,7 @@ object RpcPacket { } def rpcFail(ref: Option[RpcPacketId], cause: String): RpcPacket = { - RpcPacket(RPCPacketKind.RpcFail, Some(Map("cause" -> cause).asJson), None, ref, None, None, None) + RpcPacket(RPCPacketKind.RpcFail, Some(Json.obj("cause" -> Json.fromString(cause))), None, ref, None, None, None) } def buzzerRequest(id: RpcPacketId, method: IRTMethodId, data: Json): RpcPacket = { @@ -130,6 +137,6 @@ object RpcPacket { } def buzzerFail(ref: Option[RpcPacketId], cause: String): RpcPacket = { - RpcPacket(RPCPacketKind.BuzzFailure, Some(Map("cause" -> cause).asJson), None, ref, None, None, None) + RpcPacket(RPCPacketKind.BuzzFailure, Some(Json.obj("cause" -> Json.fromString(cause))), None, ref, None, None, None) } } diff --git a/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/GreeterRunnerExample.scala b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/GreeterRunnerExample.scala index 9fd6a2b6..1218933b 100644 --- a/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/GreeterRunnerExample.scala +++ b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/GreeterRunnerExample.scala @@ -1,14 +1,14 @@ package izumi.r2.idealingua.test import _root_.io.circe.syntax.* -import izumi.idealingua.runtime.rpc.{ContextExtender, IRTServerMultiplexorImpl} +import izumi.idealingua.runtime.rpc.IRTServerMultiplexor import izumi.r2.idealingua.test.generated.GreeterServiceServerWrapped import zio.* object GreeterRunnerExample { def main(args: Array[String]): Unit = { val greeter = new GreeterServiceServerWrapped[IO, Unit](new impls.AbstractGreeterServer.Impl[IO, Unit]()) - val multiplexor = new IRTServerMultiplexorImpl[IO, Unit, Unit](Set(greeter), ContextExtender.id) + val multiplexor = new IRTServerMultiplexor.FromServices[IO, Unit](Set(greeter)) val req1 = new greeter.greet.signature.Input("John", "Doe") val json1 = req1.asJson @@ -18,8 +18,8 @@ object GreeterRunnerExample { val json2 = req2.asJson println(json2) - val invoked1 = multiplexor.doInvoke(json1, (), greeter.greet.signature.id) - val invoked2 = multiplexor.doInvoke(json1, (), greeter.alternative.signature.id) + val invoked1 = multiplexor.invokeMethod(greeter.greet.signature.id)((), json1) + val invoked2 = multiplexor.invokeMethod(greeter.alternative.signature.id)((), json1) implicit val unsafe: Unsafe = Unsafe.unsafe(identity) println(zio.Runtime.default.unsafe.run(invoked1).getOrThrowFiberFailure()) diff --git a/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/generated/PrivateTestServiceServer.scala b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/generated/PrivateTestServiceServer.scala new file mode 100644 index 00000000..7c6cec8c --- /dev/null +++ b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/generated/PrivateTestServiceServer.scala @@ -0,0 +1,107 @@ +package izumi.r2.idealingua.test.generated + +import io.circe.* +import io.circe.generic.semiauto.* +import io.circe.syntax.* +import izumi.functional.bio.IO2 as IRTIO2 +import izumi.idealingua.runtime.rpc.* + +trait PrivateTestServiceServer[Or[+_, +_], C] { + type Just[+T] = Or[Nothing, T] + def test(ctx: C, str: String): Just[String] +} + +trait PrivateTestServiceClient[Or[+_, +_]] { + type Just[+T] = Or[Nothing, T] + def test(str: String): Just[String] +} + +class PrivateTestServiceWrappedClient[Or[+_, +_]: IRTIO2](_dispatcher: IRTDispatcher[Or]) extends PrivateTestServiceClient[Or] { + final val _F: IRTIO2[Or] = implicitly + import izumi.r2.idealingua.test.generated.PrivateTestService as _M + def test(str: String): Just[String] = { + _F.redeem(_dispatcher.dispatch(IRTMuxRequest(IRTReqBody(new _M.test.Input(str)), _M.test.id)))( + { + err => _F.terminate(err) + }, + { + case IRTMuxResponse(IRTResBody(v: _M.test.Output), method) if method == _M.test.id => + _F.pure(v.value) + case v => + val id = "PrivateTestService.PrivateTestServiceWrappedClient.test" + val expected = classOf[_M.test.Input].toString + _F.terminate(new IRTTypeMismatchException(s"Unexpected type in $id: $v, expected $expected got ${v.getClass}", v, None)) + }, + ) + } +} + +object PrivateTestServiceWrappedClient extends IRTWrappedClient { + val allCodecs: Map[IRTMethodId, IRTCirceMarshaller] = { + Map(PrivateTestService.test.id -> PrivateTestServiceCodecs.test) + } +} + +class PrivateTestServiceWrappedServer[Or[+_, +_]: IRTIO2, C](_service: PrivateTestServiceServer[Or, C]) extends IRTWrappedService[Or, C] { + final val _F: IRTIO2[Or] = implicitly + final val serviceId: IRTServiceId = PrivateTestService.serviceId + val allMethods: Map[IRTMethodId, IRTMethodWrapper[Or, C]] = { + Seq[IRTMethodWrapper[Or, C]](test).map(m => m.signature.id -> m).toMap + } + object test extends IRTMethodWrapper[Or, C] { + import PrivateTestService.test.* + val signature: PrivateTestService.test.type = PrivateTestService.test + val marshaller: PrivateTestServiceCodecs.test.type = PrivateTestServiceCodecs.test + def invoke(ctx: C, input: Input): Just[Output] = { + assert(ctx.asInstanceOf[_root_.scala.AnyRef] != null && input.asInstanceOf[_root_.scala.AnyRef] != null) + _F.map(_service.test(ctx, input.str))(v => new Output(v)) + } + } +} + +object PrivateTestServiceWrappedServer + +object PrivateTestService { + final val serviceId: IRTServiceId = IRTServiceId("PrivateTestService") + object test extends IRTMethodSignature { + final val id: IRTMethodId = IRTMethodId(serviceId, IRTMethodName("test")) + type Input = TestInput + type Output = TestOutput + } + final case class TestInput(str: String) + object TestInput { + implicit val encodeTestInput: Encoder.AsObject[TestInput] = deriveEncoder[TestInput] + implicit val decodeTestInput: Decoder[TestInput] = deriveDecoder[TestInput] + } + final case class TestOutput(value: String) + object TestOutput { + implicit val encodeUnwrappedTestOutput: Encoder[TestOutput] = Encoder.instance { + v => v.value.asJson + } + implicit val decodeUnwrappedTestOutput: Decoder[TestOutput] = Decoder.instance { + v => v.as[String].map(d => TestOutput(d)) + } + } +} + +object PrivateTestServiceCodecs { + object test extends IRTCirceMarshaller { + import PrivateTestService.test.* + def encodeRequest: PartialFunction[IRTReqBody, Json] = { + case IRTReqBody(value: Input) => + value.asJson + } + def decodeRequest[Or[+_, +_]: IRTIO2]: PartialFunction[IRTJsonBody, Or[DecodingFailure, IRTReqBody]] = { + case IRTJsonBody(m, packet) if m == id => + this.decoded[Or, IRTReqBody](packet.as[Input].map(v => IRTReqBody(v))) + } + def encodeResponse: PartialFunction[IRTResBody, Json] = { + case IRTResBody(value: Output) => + value.asJson + } + def decodeResponse[Or[+_, +_]: IRTIO2]: PartialFunction[IRTJsonBody, Or[DecodingFailure, IRTResBody]] = { + case IRTJsonBody(m, packet) if m == id => + decoded[Or, IRTResBody](packet.as[Output].map(v => IRTResBody(v))) + } + } +} diff --git a/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/generated/ProtectedTestServiceServer.scala b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/generated/ProtectedTestServiceServer.scala new file mode 100644 index 00000000..fed8f8e5 --- /dev/null +++ b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/generated/ProtectedTestServiceServer.scala @@ -0,0 +1,107 @@ +package izumi.r2.idealingua.test.generated + +import io.circe.* +import io.circe.generic.semiauto.* +import io.circe.syntax.* +import izumi.functional.bio.IO2 as IRTIO2 +import izumi.idealingua.runtime.rpc.* + +trait ProtectedTestServiceServer[Or[+_, +_], C] { + type Just[+T] = Or[Nothing, T] + def test(ctx: C, str: String): Just[String] +} + +trait ProtectedTestServiceClient[Or[+_, +_]] { + type Just[+T] = Or[Nothing, T] + def test(str: String): Just[String] +} + +class ProtectedTestServiceWrappedClient[Or[+_, +_]: IRTIO2](_dispatcher: IRTDispatcher[Or]) extends ProtectedTestServiceClient[Or] { + final val _F: IRTIO2[Or] = implicitly + import izumi.r2.idealingua.test.generated.ProtectedTestService as _M + def test(str: String): Just[String] = { + _F.redeem(_dispatcher.dispatch(IRTMuxRequest(IRTReqBody(new _M.test.Input(str)), _M.test.id)))( + { + err => _F.terminate(err) + }, + { + case IRTMuxResponse(IRTResBody(v: _M.test.Output), method) if method == _M.test.id => + _F.pure(v.value) + case v => + val id = "ProtectedTestService.ProtectedTestServiceWrappedClient.test" + val expected = classOf[_M.test.Input].toString + _F.terminate(new IRTTypeMismatchException(s"Unexpected type in $id: $v, expected $expected got ${v.getClass}", v, None)) + }, + ) + } +} + +object ProtectedTestServiceWrappedClient extends IRTWrappedClient { + val allCodecs: Map[IRTMethodId, IRTCirceMarshaller] = { + Map(ProtectedTestService.test.id -> ProtectedTestServiceCodecs.test) + } +} + +class ProtectedTestServiceWrappedServer[Or[+_, +_]: IRTIO2, C](_service: ProtectedTestServiceServer[Or, C]) extends IRTWrappedService[Or, C] { + final val _F: IRTIO2[Or] = implicitly + final val serviceId: IRTServiceId = ProtectedTestService.serviceId + val allMethods: Map[IRTMethodId, IRTMethodWrapper[Or, C]] = { + Seq[IRTMethodWrapper[Or, C]](test).map(m => m.signature.id -> m).toMap + } + object test extends IRTMethodWrapper[Or, C] { + import ProtectedTestService.test.* + val signature: ProtectedTestService.test.type = ProtectedTestService.test + val marshaller: ProtectedTestServiceCodecs.test.type = ProtectedTestServiceCodecs.test + def invoke(ctx: C, input: Input): Just[Output] = { + assert(ctx.asInstanceOf[_root_.scala.AnyRef] != null && input.asInstanceOf[_root_.scala.AnyRef] != null) + _F.map(_service.test(ctx, input.str))(v => new Output(v)) + } + } +} + +object ProtectedTestServiceWrappedServer + +object ProtectedTestService { + final val serviceId: IRTServiceId = IRTServiceId("ProtectedTestService") + object test extends IRTMethodSignature { + final val id: IRTMethodId = IRTMethodId(serviceId, IRTMethodName("test")) + type Input = TestInput + type Output = TestOutput + } + final case class TestInput(str: String) + object TestInput { + implicit val encodeTestInput: Encoder.AsObject[TestInput] = deriveEncoder[TestInput] + implicit val decodeTestInput: Decoder[TestInput] = deriveDecoder[TestInput] + } + final case class TestOutput(value: String) + object TestOutput { + implicit val encodeUnwrappedTestOutput: Encoder[TestOutput] = Encoder.instance { + v => v.value.asJson + } + implicit val decodeUnwrappedTestOutput: Decoder[TestOutput] = Decoder.instance { + v => v.as[String].map(d => TestOutput(d)) + } + } +} + +object ProtectedTestServiceCodecs { + object test extends IRTCirceMarshaller { + import ProtectedTestService.test.* + def encodeRequest: PartialFunction[IRTReqBody, Json] = { + case IRTReqBody(value: Input) => + value.asJson + } + def decodeRequest[Or[+_, +_]: IRTIO2]: PartialFunction[IRTJsonBody, Or[DecodingFailure, IRTReqBody]] = { + case IRTJsonBody(m, packet) if m == id => + this.decoded[Or, IRTReqBody](packet.as[Input].map(v => IRTReqBody(v))) + } + def encodeResponse: PartialFunction[IRTResBody, Json] = { + case IRTResBody(value: Output) => + value.asJson + } + def decodeResponse[Or[+_, +_]: IRTIO2]: PartialFunction[IRTJsonBody, Or[DecodingFailure, IRTResBody]] = { + case IRTJsonBody(m, packet) if m == id => + decoded[Or, IRTResBody](packet.as[Output].map(v => IRTResBody(v))) + } + } +} diff --git a/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/impls/AbstractGreeterServer.scala b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/impls/AbstractGreeterServer.scala index 445a818e..c88ddd98 100644 --- a/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/impls/AbstractGreeterServer.scala +++ b/idealingua-v1/idealingua-v1-test-defs/src/main/scala/izumi/r2/idealingua/test/impls/AbstractGreeterServer.scala @@ -1,27 +1,13 @@ package izumi.r2.idealingua.test.impls -import izumi.functional.bio.IO2 -import izumi.r2.idealingua.test.generated._ +import izumi.functional.bio.{F, IO2} +import izumi.r2.idealingua.test.generated.* abstract class AbstractGreeterServer[F[+_, +_]: IO2, C] extends GreeterServiceServer[F, C] { - - val R: IO2[F] = implicitly - - override def greet(ctx: C, name: String, surname: String): Just[String] = R.pure { - s"Hi, $name $surname!" - } - - override def sayhi(ctx: C): Just[String] = R.pure { - "Hi!" - } - - override def alternative(ctx: C): F[Long, String] = R.fromEither { - Right("value") - } - - override def nothing(ctx: C): F[Nothing, String] = R.pure { - "" - } + override def greet(ctx: C, name: String, surname: String): Just[String] = F.pure(s"Hi, $name $surname!") + override def sayhi(ctx: C): Just[String] = F.pure(s"Hi! With $ctx.") + override def alternative(ctx: C): F[Long, String] = F.fromEither(Right("value")) + override def nothing(ctx: C): F[Nothing, String] = F.pure("") } object AbstractGreeterServer {