Skip to content

Commit

Permalink
feat: added new websocket connection limit from ngnix headers
Browse files Browse the repository at this point in the history
  • Loading branch information
ranlavanet committed Oct 23, 2024
1 parent 58998e9 commit 89dcd69
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 111 deletions.
3 changes: 2 additions & 1 deletion protocol/chainlib/consumer_websocket_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ var (
)

const (
WebSocketRateLimitHeader = "x-lava-rate-limit"
WebSocketRateLimitHeader = "x-lava-websocket-rate-limit"
WebSocketOpenConnectionsLimitHeader = "x-lava-websocket-open-connections-limit"
)

type ConsumerWebsocketManager struct {
Expand Down
84 changes: 3 additions & 81 deletions protocol/chainlib/jsonRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ import (
"context"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/goccy/go-json"
Expand Down Expand Up @@ -307,56 +304,6 @@ func (apip *JsonRPCChainParser) ChainBlockStats() (allowedBlockLagForQosSync int
return apip.spec.AllowedBlockLagForQosSync, averageBlockTime, apip.spec.BlockDistanceForFinalizedData, apip.spec.BlocksInFinalizationProof
}

// Will limit a certain amount of connections per IP
type WebsocketConnectionLimiter struct {
ipToNumberOfActiveConnections map[string]int64
lock sync.RWMutex
}

func (wcl *WebsocketConnectionLimiter) addIpConnectionAndGetCurrentAmount(ip string) int64 {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] += 1
return wcl.ipToNumberOfActiveConnections[ip]
}

func (wcl *WebsocketConnectionLimiter) decreaseIpConnectionAndGetCurrentAmount(ip string) {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] -= 1
if wcl.ipToNumberOfActiveConnections[ip] == 0 {
delete(wcl.ipToNumberOfActiveConnections, ip)
}
}

func (wcl *WebsocketConnectionLimiter) getKey(ip string, forwardedIp string) string {
returnedKey := ""
ipOriginal := net.ParseIP(ip)
if ipOriginal != nil {
returnedKey = ipOriginal.String()
} else {
ipPart, _, err := net.SplitHostPort(ip)
if err == nil {
returnedKey = ipPart
}
}
ips := strings.Split(forwardedIp, ",")
for _, ipStr := range ips {
ipParsed := net.ParseIP(strings.TrimSpace(ipStr))
if ipParsed != nil {
returnedKey += SEP + ipParsed.String()
} else {
ipPart, _, err := net.SplitHostPort(ipStr)
if err == nil {
returnedKey += SEP + ipPart
}
}
}
return returnedKey
}

type JsonRPCChainListener struct {
endpoint *lavasession.RPCEndpoint
relaySender RelaySender
Expand Down Expand Up @@ -400,20 +347,7 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
app := createAndSetupBaseAppListener(cmdFlags, apil.endpoint.HealthCheckPath, apil.healthReporter)

app.Use("/ws", func(c *fiber.Ctx) error {
forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME)
if forwardedFor == "" {
// If not present, fallback to c.IP() which retrieves the real IP
forwardedFor = c.IP()
}
// Store the X-Forwarded-For or real IP in the context
c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor)

rateLimitString := c.Get(WebSocketRateLimitHeader)
rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64)
if err != nil {
rateLimit = 0
}
c.Locals(WebSocketRateLimitHeader, rateLimit)
apil.websocketConnectionLimiter.handleFiberRateLimitFlags(c)

// IsWebSocketUpgrade returns true if the client
// requested upgrade to the WebSocket protocol.
Expand All @@ -428,20 +362,8 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
apiInterface := apil.endpoint.ApiInterface

webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) {
if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 { // 0 is disabled.
ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME)
ipForwarded, assertionSuccessful := ipForwardedInterface.(string)
if !assertionSuccessful {
ipForwarded = ""
}
ip := websocketConn.RemoteAddr().String()
key := apil.websocketConnectionLimiter.getKey(ip, ipForwarded)
numberOfActiveConnections := apil.websocketConnectionLimiter.addIpConnectionAndGetCurrentAmount(key)
defer apil.websocketConnectionLimiter.decreaseIpConnectionAndGetCurrentAmount(key)
if numberOfActiveConnections > MaximumNumberOfParallelWebsocketConnectionsPerIp {
websocketConn.WriteMessage(1, []byte(fmt.Sprintf("Too Many Open Connections, limited to %d", MaximumNumberOfParallelWebsocketConnectionsPerIp)))
return
}
if !apil.websocketConnectionLimiter.canOpenConnection(websocketConn) {
return
}
rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader)
rateLimit, assertionSuccessful := rateLimitInf.(int64)
Expand Down
32 changes: 3 additions & 29 deletions protocol/chainlib/tendermintRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm
apiInterface := apil.endpoint.ApiInterface

app.Use("/ws", func(c *fiber.Ctx) error {
forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME)
if forwardedFor == "" {
// If not present, fallback to c.IP() which retrieves the real IP
forwardedFor = c.IP()
}
// Store the X-Forwarded-For or real IP in the context
c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor)

rateLimitString := c.Get(WebSocketRateLimitHeader)
rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64)
if err != nil {
rateLimit = 0
}
c.Locals(WebSocketRateLimitHeader, rateLimit)

apil.websocketConnectionLimiter.handleFiberRateLimitFlags(c)
// IsWebSocketUpgrade returns true if the client
// requested upgrade to the WebSocket protocol.
if websocket.IsWebSocketUpgrade(c) {
Expand All @@ -395,20 +381,8 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm
return fiber.ErrUpgradeRequired
})
webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) {
if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 { // 0 is disabled.
ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME)
ipForwarded, found := ipForwardedInterface.(string)
if !found {
ipForwarded = ""
}
ip := websocketConn.RemoteAddr().String()
key := apil.websocketConnectionLimiter.getKey(ip, ipForwarded)
numberOfActiveConnections := apil.websocketConnectionLimiter.addIpConnectionAndGetCurrentAmount(key)
defer apil.websocketConnectionLimiter.decreaseIpConnectionAndGetCurrentAmount(key)
if numberOfActiveConnections > MaximumNumberOfParallelWebsocketConnectionsPerIp {
websocketConn.WriteMessage(1, []byte(fmt.Sprintf("Too Many Open Connections, limited to %d", MaximumNumberOfParallelWebsocketConnectionsPerIp)))
return
}
if !apil.websocketConnectionLimiter.canOpenConnection(websocketConn) {
return
}

rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader)
Expand Down
122 changes: 122 additions & 0 deletions protocol/chainlib/websocket_connection_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package chainlib

import (
"fmt"
"net"
"strconv"
"strings"
"sync"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/lavanet/lava/v3/protocol/common"

Check failure on line 12 in protocol/chainlib/websocket_connection_limiter.go

View workflow job for this annotation

GitHub Actions / test-protocol

no required module provides package github.com/lavanet/lava/v3/protocol/common; to add it:

Check failure on line 12 in protocol/chainlib/websocket_connection_limiter.go

View workflow job for this annotation

GitHub Actions / test-protocol

no required module provides package github.com/lavanet/lava/v3/protocol/common; to add it:

Check failure on line 12 in protocol/chainlib/websocket_connection_limiter.go

View workflow job for this annotation

GitHub Actions / test-protocol-e2e

cannot find module providing package github.com/lavanet/lava/v3/protocol/common: import lookup disabled by -mod=readonly
"github.com/lavanet/lava/v3/utils"

Check failure on line 13 in protocol/chainlib/websocket_connection_limiter.go

View workflow job for this annotation

GitHub Actions / test-protocol

no required module provides package github.com/lavanet/lava/v3/utils; to add it:

Check failure on line 13 in protocol/chainlib/websocket_connection_limiter.go

View workflow job for this annotation

GitHub Actions / test-protocol-e2e

cannot find module providing package github.com/lavanet/lava/v3/utils: import lookup disabled by -mod=readonly
)

// Will limit a certain amount of connections per IP
type WebsocketConnectionLimiter struct {
ipToNumberOfActiveConnections map[string]int64
lock sync.RWMutex
}

func (wcl *WebsocketConnectionLimiter) handleFiberRateLimitFlags(c *fiber.Ctx) {
forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME)
if forwardedFor == "" {
// If not present, fallback to c.IP() which retrieves the real IP
forwardedFor = c.IP()
}
// Store the X-Forwarded-For or real IP in the context
c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor)

rateLimitString := c.Get(WebSocketRateLimitHeader)
rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64)
if err != nil {
rateLimit = 0
}
c.Locals(WebSocketRateLimitHeader, rateLimit)

connectionLimitString := c.Get(WebSocketOpenConnectionsLimitHeader)
connectionLimit, err := strconv.ParseInt(connectionLimitString, 10, 64)
if err != nil {
connectionLimit = 0
}
c.Locals(WebSocketOpenConnectionsLimitHeader, connectionLimit)
}

func (wcl *WebsocketConnectionLimiter) getConnectionLimit(websocketConn *websocket.Conn) int64 {
connectionLimitHeaderValue, ok := websocketConn.Locals(WebSocketOpenConnectionsLimitHeader).(int64)
if !ok || connectionLimitHeaderValue < 0 {
connectionLimitHeaderValue = 0
}
// Do not allow header to overwrite flag value if its set.
if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 && connectionLimitHeaderValue > MaximumNumberOfParallelWebsocketConnectionsPerIp {
return MaximumNumberOfParallelWebsocketConnectionsPerIp
}
// Return the larger of the global limit (if set) or the header value
return utils.Max(MaximumNumberOfParallelWebsocketConnectionsPerIp, connectionLimitHeaderValue)
}

func (wcl *WebsocketConnectionLimiter) canOpenConnection(websocketConn *websocket.Conn) bool {
// Check which connection limit is higher and use that.
connectionLimit := wcl.getConnectionLimit(websocketConn)
if connectionLimit > 0 { // 0 is disabled.
ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME)
ipForwarded, assertionSuccessful := ipForwardedInterface.(string)
if !assertionSuccessful {
ipForwarded = ""
}
ip := websocketConn.RemoteAddr().String()
key := wcl.getKey(ip, ipForwarded)
numberOfActiveConnections := wcl.addIpConnectionAndGetCurrentAmount(key)
defer wcl.decreaseIpConnectionAndGetCurrentAmount(key)
if numberOfActiveConnections > connectionLimit {
websocketConn.WriteMessage(1, []byte(fmt.Sprintf("Too Many Open Connections, limited to %d", connectionLimit)))
return false
}
}
return true
}

func (wcl *WebsocketConnectionLimiter) addIpConnectionAndGetCurrentAmount(ip string) int64 {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] += 1
return wcl.ipToNumberOfActiveConnections[ip]
}

func (wcl *WebsocketConnectionLimiter) decreaseIpConnectionAndGetCurrentAmount(ip string) {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] -= 1
if wcl.ipToNumberOfActiveConnections[ip] == 0 {
delete(wcl.ipToNumberOfActiveConnections, ip)
}
}

func (wcl *WebsocketConnectionLimiter) getKey(ip string, forwardedIp string) string {
returnedKey := ""
ipOriginal := net.ParseIP(ip)
if ipOriginal != nil {
returnedKey = ipOriginal.String()
} else {
ipPart, _, err := net.SplitHostPort(ip)
if err == nil {
returnedKey = ipPart
}
}
ips := strings.Split(forwardedIp, ",")
for _, ipStr := range ips {
ipParsed := net.ParseIP(strings.TrimSpace(ipStr))
if ipParsed != nil {
returnedKey += SEP + ipParsed.String()
} else {
ipPart, _, err := net.SplitHostPort(ipStr)
if err == nil {
returnedKey += SEP + ipPart
}
}
}
return returnedKey
}

0 comments on commit 89dcd69

Please sign in to comment.