Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: better http contexts #480

Merged
merged 24 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .jvmopts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
-Xmx4G
-Xss2m
-XX:ReservedCodeCacheSize=256m
-XX:MaxMetaspaceSize=3G

Expand Down

This file was deleted.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package izumi.idealingua.runtime.rpc.http4s

import io.circe.Json
import izumi.functional.bio.{Applicative2, F}
import izumi.idealingua.runtime.rpc.IRTMethodId
import org.http4s.Headers

import java.net.InetAddress

abstract class IRTAuthenticator[F[_, _], AuthCtx, RequestCtx] {
def authenticate(authContext: AuthCtx, body: Option[Json], methodId: Option[IRTMethodId]): F[Nothing, Option[RequestCtx]]
}

object IRTAuthenticator {
def unit[F[+_, +_]: Applicative2, C]: IRTAuthenticator[F, C, Unit] = new IRTAuthenticator[F, C, Unit] {
override def authenticate(authContext: C, body: Option[Json], methodId: Option[IRTMethodId]): F[Nothing, Option[Unit]] = F.pure(Some(()))
}
final case class AuthContext(headers: Headers, networkAddress: Option[InetAddress])
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package izumi.idealingua.runtime.rpc.http4s

import izumi.functional.bio.{IO2, Monad2}
import izumi.idealingua.runtime.rpc.http4s.ws.WsContextSessions
import izumi.idealingua.runtime.rpc.{IRTServerMiddleware, IRTServerMultiplexor}
import izumi.reflect.Tag

trait IRTContextServices[F[+_, +_], AuthCtx, RequestCtx, WsCtx] {
def name: String
def authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx]
def serverMuxer: IRTServerMultiplexor[F, RequestCtx]
def middlewares: Set[IRTServerMiddleware[F, RequestCtx]]
def wsSessions: WsContextSessions[F, RequestCtx, WsCtx]

def authorizedMuxer(implicit io2: IO2[F]): IRTServerMultiplexor[F, AuthCtx] = {
val withMiddlewares: IRTServerMultiplexor[F, RequestCtx] = middlewares.toList.sortBy(_.priority).foldLeft(serverMuxer) {
case (muxer, middleware) => muxer.wrap(middleware)
}
val authorized: IRTServerMultiplexor[F, AuthCtx] = withMiddlewares.contramap {
case (authCtx, body, methodId) => authenticator.authenticate(authCtx, Some(body), Some(methodId))
}
authorized
}
def authorizedWsSessions(implicit M: Monad2[F]): WsContextSessions[F, AuthCtx, WsCtx] = {
val authorized: WsContextSessions[F, AuthCtx, WsCtx] = wsSessions.contramap {
authCtx =>
authenticator.authenticate(authCtx, None, None)
}
authorized
}
}

object IRTContextServices {
type AnyContext[F[+_, +_], AuthCtx] = IRTContextServices[F, AuthCtx, ?, ?]
type AnyWsContext[F[+_, +_], AuthCtx, RequestCtx] = IRTContextServices[F, AuthCtx, RequestCtx, ?]

def tagged[F[+_, +_], AuthCtx, RequestCtx: Tag, WsCtx: Tag](
authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx],
serverMuxer: IRTServerMultiplexor[F, RequestCtx],
middlewares: Set[IRTServerMiddleware[F, RequestCtx]],
wsSessions: WsContextSessions[F, RequestCtx, WsCtx],
): Tagged[F, AuthCtx, RequestCtx, WsCtx] = Tagged(authenticator, serverMuxer, middlewares, wsSessions)

def named[F[+_, +_], AuthCtx, RequestCtx, WsCtx](
name: String
)(authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx],
serverMuxer: IRTServerMultiplexor[F, RequestCtx],
middlewares: Set[IRTServerMiddleware[F, RequestCtx]],
wsSessions: WsContextSessions[F, RequestCtx, WsCtx],
): Named[F, AuthCtx, RequestCtx, WsCtx] = Named(name, authenticator, serverMuxer, middlewares, wsSessions)

final case class Named[F[+_, +_], AuthCtx, RequestCtx, WsCtx](
name: String,
authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx],
serverMuxer: IRTServerMultiplexor[F, RequestCtx],
middlewares: Set[IRTServerMiddleware[F, RequestCtx]],
wsSessions: WsContextSessions[F, RequestCtx, WsCtx],
) extends IRTContextServices[F, AuthCtx, RequestCtx, WsCtx]

final case class Tagged[F[+_, +_], AuthCtx, RequestCtx: Tag, WsCtx: Tag](
authenticator: IRTAuthenticator[F, AuthCtx, RequestCtx],
serverMuxer: IRTServerMultiplexor[F, RequestCtx],
middlewares: Set[IRTServerMiddleware[F, RequestCtx]],
wsSessions: WsContextSessions[F, RequestCtx, WsCtx],
) extends IRTContextServices[F, AuthCtx, RequestCtx, WsCtx] {
override def name: String = s"${Tag[RequestCtx].tag}:${Tag[WsCtx].tag}"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ abstract class IRTHttpFailureException(
message: String,
val status: Status,
cause: Option[Throwable] = None,
) extends RuntimeException(message, cause.orNull)
with IRTTransportException
) extends IRTTransportException(message, cause)

case class IRTUnexpectedHttpStatus(override val status: Status) extends IRTHttpFailureException(s"Unexpected http status: $status", status)
case class IRTNoCredentialsException(override val status: Status) extends IRTHttpFailureException("No valid credentials", status)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import izumi.functional.bio.{Async2, F, Temporal2}
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcher.IRTDispatcherWs
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.WsRpcClientConnection
import izumi.idealingua.runtime.rpc.http4s.ws.RawResponse
import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsSessionId}
import logstage.LogIO2

import java.util.concurrent.TimeoutException
Expand All @@ -18,6 +18,10 @@ class WsRpcDispatcher[F[+_, +_]: Async2](
logger: LogIO2[F],
) extends IRTDispatcherWs[F] {

override def sessionId: Option[WsSessionId] = {
connection.sessionId
}

override def authorize(headers: Map[String, String]): F[Throwable, Unit] = {
connection.authorize(headers, timeout)
}
Expand Down Expand Up @@ -62,6 +66,7 @@ class WsRpcDispatcher[F[+_, +_]: Async2](

object WsRpcDispatcher {
trait IRTDispatcherWs[F[_, _]] extends IRTDispatcher[F] {
def sessionId: Option[WsSessionId]
def authorize(headers: Map[String, String]): F[Throwable, Unit]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@ import io.circe.{Json, Printer}
import izumi.functional.bio.{Async2, Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2}
import izumi.functional.lifecycle.Lifecycle
import izumi.fundamentals.platform.language.Quirks.Discarder
import izumi.fundamentals.platform.uuid.UUIDGen
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.HttpServer
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcher.IRTDispatcherWs
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.{ClientWsRpcHandler, WsRpcClientConnection, WsRpcContextProvider, fromNettyFuture}
import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsRequestState, WsRpcHandler}
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.{ClientWsRpcHandler, WsRpcClientConnection, fromNettyFuture}
import izumi.idealingua.runtime.rpc.http4s.context.WsContextExtractor
import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsRequestState, WsRpcHandler, WsSessionId}
import izumi.logstage.api.IzLogger
import logstage.LogIO2
import org.asynchttpclient.netty.ws.NettyWebSocket
import org.asynchttpclient.ws.{WebSocket, WebSocketListener, WebSocketUpgradeHandler}
import org.asynchttpclient.{DefaultAsyncHttpClient, DefaultAsyncHttpClientConfig}
import org.http4s.Uri

import java.util.UUID
import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.jdk.CollectionConverters.*
import scala.util.Try

class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRun2](
codec: IRTClientMultiplexor[F],
Expand All @@ -29,32 +34,49 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu

def connect[ServerContext](
uri: Uri,
muxer: IRTServerMultiplexor[F, ServerContext],
contextProvider: WsRpcContextProvider[ServerContext],
serverMuxer: IRTServerMultiplexor[F, ServerContext],
wsContextExtractor: WsContextExtractor[ServerContext],
headers: Map[String, String] = Map.empty,
): Lifecycle[F[Throwable, _], WsRpcClientConnection[F]] = {
for {
client <- WsRpcDispatcherFactory.asyncHttpClient[F]
requestState <- Lifecycle.liftF(F.syncThrowable(WsRequestState.create[F]))
listener <- Lifecycle.liftF(F.syncThrowable(createListener(muxer, contextProvider, requestState, dispatcherLogger(uri, logger))))
handler <- Lifecycle.liftF(F.syncThrowable(new WebSocketUpgradeHandler(List(listener).asJava)))
client <- createAsyncHttpClient()
wsRequestState <- Lifecycle.liftF(F.syncThrowable(WsRequestState.create[F]))
listener <- Lifecycle.liftF(F.syncThrowable(createListener(serverMuxer, wsRequestState, wsContextExtractor, dispatcherLogger(uri, logger))))
handler <- Lifecycle.liftF(F.syncThrowable(new WebSocketUpgradeHandler(List(listener).asJava)))
nettyWebSocket <- Lifecycle.make(
F.fromFutureJava(client.prepareGet(uri.toString()).execute(handler).toCompletableFuture)
F.fromFutureJava {
client
.prepareGet(uri.toString())
.setSingleHeaders(headers.asJava)
.execute(handler).toCompletableFuture
}
)(nettyWebSocket => fromNettyFuture(nettyWebSocket.sendCloseFrame()).void)
sessionId = Option(nettyWebSocket.getUpgradeHeaders.get(HttpServer.`X-Ws-Session-Id`.toString))
.flatMap(str => Try(WsSessionId(UUID.fromString(str))).toOption)
// fill promises before closing WS connection, potentially giving a chance to send out an error response before closing
_ <- Lifecycle.make(F.unit)(_ => requestState.clear())
_ <- Lifecycle.make(F.unit)(_ => wsRequestState.clear())
} yield {
new WsRpcClientConnection.Netty(nettyWebSocket, requestState, printer)
new WsRpcClientConnection.Netty(nettyWebSocket, wsRequestState, printer, sessionId)
}
}

def connectSimple(
uri: Uri,
serverMuxer: IRTServerMultiplexor[F, Unit],
headers: Map[String, String] = Map.empty,
): Lifecycle[F[Throwable, _], WsRpcClientConnection[F]] = {
connect(uri, serverMuxer, WsContextExtractor.unit, headers)
}

def dispatcher[ServerContext](
uri: Uri,
muxer: IRTServerMultiplexor[F, ServerContext],
contextProvider: WsRpcContextProvider[ServerContext],
serverMuxer: IRTServerMultiplexor[F, ServerContext],
wsContextExtractor: WsContextExtractor[ServerContext],
headers: Map[String, String] = Map.empty,
tweakRequest: RpcPacket => RpcPacket = identity,
timeout: FiniteDuration = 30.seconds,
): Lifecycle[F[Throwable, _], IRTDispatcherWs[F]] = {
connect(uri, muxer, contextProvider).map {
connect(uri, serverMuxer, wsContextExtractor, headers).map {
new WsRpcDispatcher(_, timeout, codec, dispatcherLogger(uri, logger)) {
override protected def buildRequest(rpcPacketId: RpcPacketId, method: IRTMethodId, body: Json): RpcPacket = {
tweakRequest(super.buildRequest(rpcPacketId, method, body))
Expand All @@ -63,22 +85,32 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu
}
}

def dispatcherSimple(
uri: Uri,
serverMuxer: IRTServerMultiplexor[F, Unit],
headers: Map[String, String] = Map.empty,
tweakRequest: RpcPacket => RpcPacket = identity,
timeout: FiniteDuration = 30.seconds,
): Lifecycle[F[Throwable, _], IRTDispatcherWs[F]] = {
dispatcher(uri, serverMuxer, WsContextExtractor.unit, headers, tweakRequest, timeout)
}

protected def wsHandler[ServerContext](
serverMuxer: IRTServerMultiplexor[F, ServerContext],
wsRequestState: WsRequestState[F],
wsContextExtractor: WsContextExtractor[ServerContext],
logger: LogIO2[F],
muxer: IRTServerMultiplexor[F, ServerContext],
contextProvider: WsRpcContextProvider[ServerContext],
requestState: WsRequestState[F],
): WsRpcHandler[F, ServerContext] = {
new ClientWsRpcHandler(muxer, requestState, contextProvider, logger)
new ClientWsRpcHandler(serverMuxer, wsRequestState, wsContextExtractor, logger)
}

protected def createListener[ServerContext](
muxer: IRTServerMultiplexor[F, ServerContext],
contextProvider: WsRpcContextProvider[ServerContext],
requestState: WsRequestState[F],
serverMuxer: IRTServerMultiplexor[F, ServerContext],
wsRequestState: WsRequestState[F],
wsContextExtractor: WsContextExtractor[ServerContext],
logger: LogIO2[F],
): WebSocketListener = new WebSocketListener() {
private val handler = wsHandler(logger, muxer, contextProvider, requestState)
private val handler = wsHandler(serverMuxer, wsRequestState, wsContextExtractor, logger)
private val socketRef = new AtomicReference[Option[WebSocket]](None)

override def onOpen(websocket: WebSocket): Unit = {
Expand All @@ -102,7 +134,7 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu
override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int): Unit = {
UnsafeRun2[F].unsafeRunAsync(handler.processRpcMessage(payload)) {
exit =>
val maybeResponse = exit match {
val maybeResponse: Option[RpcPacket] = exit match {
case Exit.Success(response) => response
case Exit.Error(error, _) => handleWsError(List(error), "errored")
case Exit.Termination(error, _, _) => handleWsError(List(error), "terminated")
Expand Down Expand Up @@ -133,10 +165,8 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu
Some(RpcPacket.rpcCritical(message, None))
}
}
}

object WsRpcDispatcherFactory {
def asyncHttpClient[F[+_, +_]: IO2]: Lifecycle[F[Throwable, _], DefaultAsyncHttpClient] = {
protected def createAsyncHttpClient(): Lifecycle[F[Throwable, _], DefaultAsyncHttpClient] = {
Lifecycle.fromAutoCloseable(F.syncThrowable {
new DefaultAsyncHttpClient(
new DefaultAsyncHttpClientConfig.Builder()
Expand All @@ -153,33 +183,38 @@ object WsRpcDispatcherFactory {
)
})
}
}

object WsRpcDispatcherFactory {

class ClientWsRpcHandler[F[+_, +_]: IO2, ServerCtx](
muxer: IRTServerMultiplexor[F, ServerCtx],
class ClientWsRpcHandler[F[+_, +_]: IO2, RequestCtx](
muxer: IRTServerMultiplexor[F, RequestCtx],
requestState: WsRequestState[F],
contextProvider: WsRpcContextProvider[ServerCtx],
wsContextExtractor: WsContextExtractor[RequestCtx],
logger: LogIO2[F],
) extends WsRpcHandler[F, ServerCtx](muxer, requestState, logger) {
override def handlePacket(packet: RpcPacket): F[Throwable, Unit] = {
F.unit
}
override def handleAuthRequest(packet: RpcPacket): F[Throwable, Option[RpcPacket]] = {
F.pure(None)
}
override def extractContext(packet: RpcPacket): F[Throwable, ServerCtx] = {
F.sync(contextProvider.toContext(packet))
) extends WsRpcHandler[F, RequestCtx](muxer, requestState, logger) {
private val wsSessionId: WsSessionId = WsSessionId(UUIDGen.getTimeUUID())
private val requestCtxRef: AtomicReference[RequestCtx] = new AtomicReference()
override protected def updateRequestCtx(packet: RpcPacket): F[Throwable, RequestCtx] = F.sync {
val updated = wsContextExtractor.extract(wsSessionId, packet)
requestCtxRef.updateAndGet {
case null => updated
case previous => wsContextExtractor.merge(previous, updated)
}
}
}

trait WsRpcClientConnection[F[_, _]] {
private[clients] def requestAndAwait(id: RpcPacketId, packet: RpcPacket, method: Option[IRTMethodId], timeout: FiniteDuration): F[Throwable, Option[RawResponse]]
def sessionId: Option[WsSessionId]
def authorize(headers: Map[String, String], timeout: FiniteDuration = 30.seconds): F[Throwable, Unit]
}
object WsRpcClientConnection {
class Netty[F[+_, +_]: Async2](
nettyWebSocket: NettyWebSocket,
requestState: WsRequestState[F],
printer: Printer,
val sessionId: Option[WsSessionId],
) extends WsRpcClientConnection[F] {

override def authorize(headers: Map[String, String], timeout: FiniteDuration): F[Throwable, Unit] = {
Expand All @@ -205,13 +240,6 @@ object WsRpcDispatcherFactory {
}
}

trait WsRpcContextProvider[Ctx] {
def toContext(packet: RpcPacket): Ctx
}
object WsRpcContextProvider {
def unit: WsRpcContextProvider[Unit] = _ => ()
}

private def fromNettyFuture[F[+_, +_]: Async2, A](mkNettyFuture: => io.netty.util.concurrent.Future[A]): F[Throwable, A] = {
F.syncThrowable(mkNettyFuture).flatMap {
nettyFuture =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package izumi.idealingua.runtime.rpc.http4s.context

import izumi.idealingua.runtime.rpc.http4s.IRTAuthenticator.AuthContext
import org.http4s.Request
import org.http4s.headers.`X-Forwarded-For`

trait HttpContextExtractor[RequestCtx] {
def extract[F[_, _]](request: Request[F[Throwable, _]]): RequestCtx
}

object HttpContextExtractor {
def authContext: HttpContextExtractor[AuthContext] = new HttpContextExtractor[AuthContext] {
override def extract[F[_, _]](request: Request[F[Throwable, _]]): AuthContext = {
val networkAddress = request.headers
.get[`X-Forwarded-For`]
.flatMap(_.values.head.map(_.toInetAddress))
.orElse(request.remote.map(_.host.toInetAddress))
val headers = request.headers
AuthContext(headers, networkAddress)
}
}
}
Loading