diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index 8b012b080a..dba64cb4dd 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -23,7 +23,8 @@ var ( ) const ( - WebSocketRateLimitHeader = "x-lava-rate-limit" + WebSocketRateLimitHeader = "x-lava-websocket-rate-limit" + WebSocketOpenConnectionsLimitHeader = "x-lava-websocket-open-connections-limit" ) type ConsumerWebsocketManager struct { diff --git a/protocol/chainlib/jsonRPC.go b/protocol/chainlib/jsonRPC.go index 85d4002632..0ac7662305 100644 --- a/protocol/chainlib/jsonRPC.go +++ b/protocol/chainlib/jsonRPC.go @@ -4,11 +4,8 @@ import ( "context" "errors" "fmt" - "net" "net/http" "strconv" - "strings" - "sync" "time" "github.com/goccy/go-json" @@ -307,56 +304,6 @@ func (apip *JsonRPCChainParser) ChainBlockStats() (allowedBlockLagForQosSync int return apip.spec.AllowedBlockLagForQosSync, averageBlockTime, apip.spec.BlockDistanceForFinalizedData, apip.spec.BlocksInFinalizationProof } -// Will limit a certain amount of connections per IP -type WebsocketConnectionLimiter struct { - ipToNumberOfActiveConnections map[string]int64 - lock sync.RWMutex -} - -func (wcl *WebsocketConnectionLimiter) addIpConnectionAndGetCurrentAmount(ip string) int64 { - wcl.lock.Lock() - defer wcl.lock.Unlock() - // wether it exists or not we add 1. - wcl.ipToNumberOfActiveConnections[ip] += 1 - return wcl.ipToNumberOfActiveConnections[ip] -} - -func (wcl *WebsocketConnectionLimiter) decreaseIpConnectionAndGetCurrentAmount(ip string) { - wcl.lock.Lock() - defer wcl.lock.Unlock() - // wether it exists or not we add 1. - wcl.ipToNumberOfActiveConnections[ip] -= 1 - if wcl.ipToNumberOfActiveConnections[ip] == 0 { - delete(wcl.ipToNumberOfActiveConnections, ip) - } -} - -func (wcl *WebsocketConnectionLimiter) getKey(ip string, forwardedIp string) string { - returnedKey := "" - ipOriginal := net.ParseIP(ip) - if ipOriginal != nil { - returnedKey = ipOriginal.String() - } else { - ipPart, _, err := net.SplitHostPort(ip) - if err == nil { - returnedKey = ipPart - } - } - ips := strings.Split(forwardedIp, ",") - for _, ipStr := range ips { - ipParsed := net.ParseIP(strings.TrimSpace(ipStr)) - if ipParsed != nil { - returnedKey += SEP + ipParsed.String() - } else { - ipPart, _, err := net.SplitHostPort(ipStr) - if err == nil { - returnedKey += SEP + ipPart - } - } - } - return returnedKey -} - type JsonRPCChainListener struct { endpoint *lavasession.RPCEndpoint relaySender RelaySender @@ -400,20 +347,7 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con app := createAndSetupBaseAppListener(cmdFlags, apil.endpoint.HealthCheckPath, apil.healthReporter) app.Use("/ws", func(c *fiber.Ctx) error { - forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME) - if forwardedFor == "" { - // If not present, fallback to c.IP() which retrieves the real IP - forwardedFor = c.IP() - } - // Store the X-Forwarded-For or real IP in the context - c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor) - - rateLimitString := c.Get(WebSocketRateLimitHeader) - rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64) - if err != nil { - rateLimit = 0 - } - c.Locals(WebSocketRateLimitHeader, rateLimit) + apil.websocketConnectionLimiter.handleFiberRateLimitFlags(c) // IsWebSocketUpgrade returns true if the client // requested upgrade to the WebSocket protocol. @@ -428,20 +362,8 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con apiInterface := apil.endpoint.ApiInterface webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) { - if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 { // 0 is disabled. - ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME) - ipForwarded, assertionSuccessful := ipForwardedInterface.(string) - if !assertionSuccessful { - ipForwarded = "" - } - ip := websocketConn.RemoteAddr().String() - key := apil.websocketConnectionLimiter.getKey(ip, ipForwarded) - numberOfActiveConnections := apil.websocketConnectionLimiter.addIpConnectionAndGetCurrentAmount(key) - defer apil.websocketConnectionLimiter.decreaseIpConnectionAndGetCurrentAmount(key) - if numberOfActiveConnections > MaximumNumberOfParallelWebsocketConnectionsPerIp { - websocketConn.WriteMessage(1, []byte(fmt.Sprintf("Too Many Open Connections, limited to %d", MaximumNumberOfParallelWebsocketConnectionsPerIp))) - return - } + if !apil.websocketConnectionLimiter.canOpenConnection(websocketConn) { + return } rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader) rateLimit, assertionSuccessful := rateLimitInf.(int64) diff --git a/protocol/chainlib/tendermintRPC.go b/protocol/chainlib/tendermintRPC.go index 971c0ad751..68fac97cbd 100644 --- a/protocol/chainlib/tendermintRPC.go +++ b/protocol/chainlib/tendermintRPC.go @@ -371,21 +371,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm apiInterface := apil.endpoint.ApiInterface app.Use("/ws", func(c *fiber.Ctx) error { - forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME) - if forwardedFor == "" { - // If not present, fallback to c.IP() which retrieves the real IP - forwardedFor = c.IP() - } - // Store the X-Forwarded-For or real IP in the context - c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor) - - rateLimitString := c.Get(WebSocketRateLimitHeader) - rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64) - if err != nil { - rateLimit = 0 - } - c.Locals(WebSocketRateLimitHeader, rateLimit) - + apil.websocketConnectionLimiter.handleFiberRateLimitFlags(c) // IsWebSocketUpgrade returns true if the client // requested upgrade to the WebSocket protocol. if websocket.IsWebSocketUpgrade(c) { @@ -395,20 +381,8 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm return fiber.ErrUpgradeRequired }) webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) { - if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 { // 0 is disabled. - ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME) - ipForwarded, found := ipForwardedInterface.(string) - if !found { - ipForwarded = "" - } - ip := websocketConn.RemoteAddr().String() - key := apil.websocketConnectionLimiter.getKey(ip, ipForwarded) - numberOfActiveConnections := apil.websocketConnectionLimiter.addIpConnectionAndGetCurrentAmount(key) - defer apil.websocketConnectionLimiter.decreaseIpConnectionAndGetCurrentAmount(key) - if numberOfActiveConnections > MaximumNumberOfParallelWebsocketConnectionsPerIp { - websocketConn.WriteMessage(1, []byte(fmt.Sprintf("Too Many Open Connections, limited to %d", MaximumNumberOfParallelWebsocketConnectionsPerIp))) - return - } + if !apil.websocketConnectionLimiter.canOpenConnection(websocketConn) { + return } rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader) diff --git a/protocol/chainlib/websocket_connection_limiter.go b/protocol/chainlib/websocket_connection_limiter.go new file mode 100644 index 0000000000..d3d1a7cc22 --- /dev/null +++ b/protocol/chainlib/websocket_connection_limiter.go @@ -0,0 +1,122 @@ +package chainlib + +import ( + "fmt" + "net" + "strconv" + "strings" + "sync" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/websocket/v2" + "github.com/lavanet/lava/v3/protocol/common" + "github.com/lavanet/lava/v3/utils" +) + +// Will limit a certain amount of connections per IP +type WebsocketConnectionLimiter struct { + ipToNumberOfActiveConnections map[string]int64 + lock sync.RWMutex +} + +func (wcl *WebsocketConnectionLimiter) handleFiberRateLimitFlags(c *fiber.Ctx) { + forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME) + if forwardedFor == "" { + // If not present, fallback to c.IP() which retrieves the real IP + forwardedFor = c.IP() + } + // Store the X-Forwarded-For or real IP in the context + c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor) + + rateLimitString := c.Get(WebSocketRateLimitHeader) + rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64) + if err != nil { + rateLimit = 0 + } + c.Locals(WebSocketRateLimitHeader, rateLimit) + + connectionLimitString := c.Get(WebSocketOpenConnectionsLimitHeader) + connectionLimit, err := strconv.ParseInt(connectionLimitString, 10, 64) + if err != nil { + connectionLimit = 0 + } + c.Locals(WebSocketOpenConnectionsLimitHeader, connectionLimit) +} + +func (wcl *WebsocketConnectionLimiter) getConnectionLimit(websocketConn *websocket.Conn) int64 { + connectionLimitHeaderValue, ok := websocketConn.Locals(WebSocketOpenConnectionsLimitHeader).(int64) + if !ok || connectionLimitHeaderValue < 0 { + connectionLimitHeaderValue = 0 + } + // Do not allow header to overwrite flag value if its set. + if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 && connectionLimitHeaderValue > MaximumNumberOfParallelWebsocketConnectionsPerIp { + return MaximumNumberOfParallelWebsocketConnectionsPerIp + } + // Return the larger of the global limit (if set) or the header value + return utils.Max(MaximumNumberOfParallelWebsocketConnectionsPerIp, connectionLimitHeaderValue) +} + +func (wcl *WebsocketConnectionLimiter) canOpenConnection(websocketConn *websocket.Conn) bool { + // Check which connection limit is higher and use that. + connectionLimit := wcl.getConnectionLimit(websocketConn) + if connectionLimit > 0 { // 0 is disabled. + ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME) + ipForwarded, assertionSuccessful := ipForwardedInterface.(string) + if !assertionSuccessful { + ipForwarded = "" + } + ip := websocketConn.RemoteAddr().String() + key := wcl.getKey(ip, ipForwarded) + numberOfActiveConnections := wcl.addIpConnectionAndGetCurrentAmount(key) + defer wcl.decreaseIpConnectionAndGetCurrentAmount(key) + if numberOfActiveConnections > connectionLimit { + websocketConn.WriteMessage(1, []byte(fmt.Sprintf("Too Many Open Connections, limited to %d", connectionLimit))) + return false + } + } + return true +} + +func (wcl *WebsocketConnectionLimiter) addIpConnectionAndGetCurrentAmount(ip string) int64 { + wcl.lock.Lock() + defer wcl.lock.Unlock() + // wether it exists or not we add 1. + wcl.ipToNumberOfActiveConnections[ip] += 1 + return wcl.ipToNumberOfActiveConnections[ip] +} + +func (wcl *WebsocketConnectionLimiter) decreaseIpConnectionAndGetCurrentAmount(ip string) { + wcl.lock.Lock() + defer wcl.lock.Unlock() + // wether it exists or not we add 1. + wcl.ipToNumberOfActiveConnections[ip] -= 1 + if wcl.ipToNumberOfActiveConnections[ip] == 0 { + delete(wcl.ipToNumberOfActiveConnections, ip) + } +} + +func (wcl *WebsocketConnectionLimiter) getKey(ip string, forwardedIp string) string { + returnedKey := "" + ipOriginal := net.ParseIP(ip) + if ipOriginal != nil { + returnedKey = ipOriginal.String() + } else { + ipPart, _, err := net.SplitHostPort(ip) + if err == nil { + returnedKey = ipPart + } + } + ips := strings.Split(forwardedIp, ",") + for _, ipStr := range ips { + ipParsed := net.ParseIP(strings.TrimSpace(ipStr)) + if ipParsed != nil { + returnedKey += SEP + ipParsed.String() + } else { + ipPart, _, err := net.SplitHostPort(ipStr) + if err == nil { + returnedKey += SEP + ipPart + } + } + } + return returnedKey +}