Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PRT - Websocket limited per ip #1738

Merged
merged 22 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions protocol/chainlib/consumer_websocket_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ var (
WebSocketBanDuration = time.Duration(0) // once rate limit is reached, will not allow new incoming message for a duration
)

const (
WebSocketRateLimitHeader = "x-lava-websocket-rate-limit"
WebSocketOpenConnectionsLimitHeader = "x-lava-websocket-open-connections-limit"
)

type ConsumerWebsocketManager struct {
websocketConn *websocket.Conn
rpcConsumerLogs *metrics.RPCConsumerLogs
Expand All @@ -35,6 +40,7 @@ type ConsumerWebsocketManager struct {
relaySender RelaySender
consumerWsSubscriptionManager *ConsumerWSSubscriptionManager
WebsocketConnectionUID string
headerRateLimit uint64
}

type ConsumerWebsocketManagerOptions struct {
Expand All @@ -50,6 +56,7 @@ type ConsumerWebsocketManagerOptions struct {
RelaySender RelaySender
ConsumerWsSubscriptionManager *ConsumerWSSubscriptionManager
WebsocketConnectionUID string
headerRateLimit uint64
}

func NewConsumerWebsocketManager(options ConsumerWebsocketManagerOptions) *ConsumerWebsocketManager {
Expand All @@ -66,6 +73,7 @@ func NewConsumerWebsocketManager(options ConsumerWebsocketManagerOptions) *Consu
refererData: options.RefererData,
consumerWsSubscriptionManager: options.ConsumerWsSubscriptionManager,
WebsocketConnectionUID: options.WebsocketConnectionUID,
headerRateLimit: options.headerRateLimit,
}
return cwm
}
Expand Down Expand Up @@ -145,7 +153,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
// rate limit routine
requestsPerSecond := &atomic.Uint64{}
go func() {
if WebSocketRateLimit <= 0 {
if WebSocketRateLimit <= 0 && cwm.headerRateLimit <= 0 {
return
}
ticker := time.NewTicker(time.Second) // rate limit per second.
Expand All @@ -156,7 +164,8 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
return
case <-ticker.C:
// check if rate limit reached, and ban is required
if WebSocketBanDuration > 0 && requestsPerSecond.Load() > uint64(WebSocketRateLimit) {
currentRequestsPerSecondLoad := requestsPerSecond.Load()
if WebSocketBanDuration > 0 && (cwm.headerRateLimit > currentRequestsPerSecondLoad || currentRequestsPerSecondLoad > uint64(WebSocketRateLimit)) {
ranlavanet marked this conversation as resolved.
Show resolved Hide resolved
// wait the ban duration before resetting the store.
select {
case <-webSocketCtx.Done():
Expand Down Expand Up @@ -185,7 +194,9 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
}

// Check rate limit is met
if WebSocketRateLimit > 0 && requestsPerSecond.Add(1) > uint64(WebSocketRateLimit) {
currentRequestsPerSecond := requestsPerSecond.Add(1)
if (cwm.headerRateLimit > 0 && currentRequestsPerSecond > cwm.headerRateLimit) ||
(WebSocketRateLimit > 0 && currentRequestsPerSecond > uint64(WebSocketRateLimit)) {
rateLimitResponse, err := cwm.handleRateLimitReached(msg)
if err == nil {
websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: rateLimitResponse}
Expand Down
20 changes: 19 additions & 1 deletion protocol/chainlib/jsonRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ import (
spectypes "github.com/lavanet/lava/v4/x/spec/types"
)

const SEP = "&"
const (
SEP = "&"
)

var MaximumNumberOfParallelWebsocketConnectionsPerIp int64 = 0

type JsonRPCChainParser struct {
BaseChainParser
Expand Down Expand Up @@ -308,6 +312,7 @@ type JsonRPCChainListener struct {
refererData *RefererData
consumerWsSubscriptionManager *ConsumerWSSubscriptionManager
listeningAddress string
websocketConnectionLimiter *WebsocketConnectionLimiter
}

// NewJrpcChainListener creates a new instance of JsonRPCChainListener
Expand All @@ -325,6 +330,7 @@ func NewJrpcChainListener(ctx context.Context, listenEndpoint *lavasession.RPCEn
logger: rpcConsumerLogs,
refererData: refererData,
consumerWsSubscriptionManager: consumerWsSubscriptionManager,
websocketConnectionLimiter: &WebsocketConnectionLimiter{ipToNumberOfActiveConnections: make(map[string]int64)},
}

return chainListener
Expand All @@ -341,6 +347,8 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
app := createAndSetupBaseAppListener(cmdFlags, apil.endpoint.HealthCheckPath, apil.healthReporter)

app.Use("/ws", func(c *fiber.Ctx) error {
apil.websocketConnectionLimiter.handleFiberRateLimitFlags(c)

// IsWebSocketUpgrade returns true if the client
// requested upgrade to the WebSocket protocol.
if websocket.IsWebSocketUpgrade(c) {
Expand All @@ -354,6 +362,15 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
apiInterface := apil.endpoint.ApiInterface

webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) {
if !apil.websocketConnectionLimiter.canOpenConnection(websocketConn) {
return
}
rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader)
rateLimit, assertionSuccessful := rateLimitInf.(int64)
if !assertionSuccessful || rateLimit < 0 {
rateLimit = 0
}

utils.LavaFormatDebug("jsonrpc websocket opened", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String()))
defer utils.LavaFormatDebug("jsonrpc websocket closed", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String()))

Expand All @@ -370,6 +387,7 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
RelaySender: apil.relaySender,
ConsumerWsSubscriptionManager: apil.consumerWsSubscriptionManager,
WebsocketConnectionUID: strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10),
headerRateLimit: uint64(rateLimit),
})

consumerWebsocketManager.ListenToMessages()
Expand Down
14 changes: 14 additions & 0 deletions protocol/chainlib/tendermintRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ type TendermintRpcChainListener struct {
refererData *RefererData
consumerWsSubscriptionManager *ConsumerWSSubscriptionManager
listeningAddress string
websocketConnectionLimiter *WebsocketConnectionLimiter
}

// NewTendermintRpcChainListener creates a new instance of TendermintRpcChainListener
Expand All @@ -351,6 +352,7 @@ func NewTendermintRpcChainListener(ctx context.Context, listenEndpoint *lavasess
logger: rpcConsumerLogs,
refererData: refererData,
consumerWsSubscriptionManager: consumerWsSubscriptionManager,
websocketConnectionLimiter: &WebsocketConnectionLimiter{ipToNumberOfActiveConnections: make(map[string]int64)},
}

return chainListener
Expand All @@ -369,6 +371,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm
apiInterface := apil.endpoint.ApiInterface

app.Use("/ws", func(c *fiber.Ctx) error {
apil.websocketConnectionLimiter.handleFiberRateLimitFlags(c)
// IsWebSocketUpgrade returns true if the client
// requested upgrade to the WebSocket protocol.
if websocket.IsWebSocketUpgrade(c) {
Expand All @@ -378,6 +381,16 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm
return fiber.ErrUpgradeRequired
})
webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) {
if !apil.websocketConnectionLimiter.canOpenConnection(websocketConn) {
return
}

rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader)
rateLimit, assertionSuccessful := rateLimitInf.(int64)
if !assertionSuccessful || rateLimit < 0 {
rateLimit = 0
}

utils.LavaFormatDebug("tendermintrpc websocket opened", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String()))
defer utils.LavaFormatDebug("tendermintrpc websocket closed", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String()))

Expand All @@ -394,6 +407,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm
RelaySender: apil.relaySender,
ConsumerWsSubscriptionManager: apil.consumerWsSubscriptionManager,
WebsocketConnectionUID: strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10),
headerRateLimit: uint64(rateLimit),
})

consumerWebsocketManager.ListenToMessages()
Expand Down
122 changes: 122 additions & 0 deletions protocol/chainlib/websocket_connection_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package chainlib

import (
"fmt"
"net"
"strconv"
"strings"
"sync"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/lavanet/lava/v4/protocol/common"
"github.com/lavanet/lava/v4/utils"
)

// Will limit a certain amount of connections per IP
type WebsocketConnectionLimiter struct {
ipToNumberOfActiveConnections map[string]int64
lock sync.RWMutex
}

func (wcl *WebsocketConnectionLimiter) handleFiberRateLimitFlags(c *fiber.Ctx) {
omerlavanet marked this conversation as resolved.
Show resolved Hide resolved
forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME)
if forwardedFor == "" {
// If not present, fallback to c.IP() which retrieves the real IP
forwardedFor = c.IP()
}
// Store the X-Forwarded-For or real IP in the context
c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor)

rateLimitString := c.Get(WebSocketRateLimitHeader)
rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64)
if err != nil {
rateLimit = 0
}
c.Locals(WebSocketRateLimitHeader, rateLimit)

connectionLimitString := c.Get(WebSocketOpenConnectionsLimitHeader)
connectionLimit, err := strconv.ParseInt(connectionLimitString, 10, 64)
if err != nil {
connectionLimit = 0
}
c.Locals(WebSocketOpenConnectionsLimitHeader, connectionLimit)
}

func (wcl *WebsocketConnectionLimiter) getConnectionLimit(websocketConn *websocket.Conn) int64 {
connectionLimitHeaderValue, ok := websocketConn.Locals(WebSocketOpenConnectionsLimitHeader).(int64)
if !ok || connectionLimitHeaderValue < 0 {
connectionLimitHeaderValue = 0
}
// Do not allow header to overwrite flag value if its set.
if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 && connectionLimitHeaderValue > MaximumNumberOfParallelWebsocketConnectionsPerIp {
return MaximumNumberOfParallelWebsocketConnectionsPerIp
}
// Return the larger of the global limit (if set) or the header value
return utils.Max(MaximumNumberOfParallelWebsocketConnectionsPerIp, connectionLimitHeaderValue)
}

func (wcl *WebsocketConnectionLimiter) canOpenConnection(websocketConn *websocket.Conn) bool {
// Check which connection limit is higher and use that.
connectionLimit := wcl.getConnectionLimit(websocketConn)
if connectionLimit > 0 { // 0 is disabled.
ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME)
ipForwarded, assertionSuccessful := ipForwardedInterface.(string)
if !assertionSuccessful {
ipForwarded = ""
}
ip := websocketConn.RemoteAddr().String()
key := wcl.getKey(ip, ipForwarded)
ranlavanet marked this conversation as resolved.
Show resolved Hide resolved
numberOfActiveConnections := wcl.addIpConnectionAndGetCurrentAmount(key)
ranlavanet marked this conversation as resolved.
Show resolved Hide resolved
defer wcl.decreaseIpConnectionAndGetCurrentAmount(key)
if numberOfActiveConnections > connectionLimit {
websocketConn.WriteMessage(1, []byte(fmt.Sprintf("Too Many Open Connections, limited to %d", connectionLimit)))
return false
}
}
return true
}

func (wcl *WebsocketConnectionLimiter) addIpConnectionAndGetCurrentAmount(ip string) int64 {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] += 1
return wcl.ipToNumberOfActiveConnections[ip]
}

func (wcl *WebsocketConnectionLimiter) decreaseIpConnectionAndGetCurrentAmount(ip string) {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] -= 1
if wcl.ipToNumberOfActiveConnections[ip] == 0 {
delete(wcl.ipToNumberOfActiveConnections, ip)
}
}

func (wcl *WebsocketConnectionLimiter) getKey(ip string, forwardedIp string) string {
returnedKey := ""
ipOriginal := net.ParseIP(ip)
if ipOriginal != nil {
returnedKey = ipOriginal.String()
} else {
ipPart, _, err := net.SplitHostPort(ip)
if err == nil {
returnedKey = ipPart
}
}
ips := strings.Split(forwardedIp, ",")
for _, ipStr := range ips {
ipParsed := net.ParseIP(strings.TrimSpace(ipStr))
if ipParsed != nil {
returnedKey += SEP + ipParsed.String()
} else {
ipPart, _, err := net.SplitHostPort(ipStr)
if err == nil {
returnedKey += SEP + ipPart
}
}
}
return returnedKey
}
1 change: 1 addition & 0 deletions protocol/common/cobra_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const (
// websocket flags
RateLimitWebSocketFlag = "rate-limit-websocket-requests-per-connection"
BanDurationForWebsocketRateLimitExceededFlag = "ban-duration-for-websocket-rate-limit-exceeded"
LimitParallelWebsocketConnectionsPerIpFlag = "limit-parallel-websocket-connections-per-ip"
)

const (
Expand Down
1 change: 1 addition & 0 deletions protocol/rpcconsumer/rpcconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77
cmdRPCConsumer.Flags().DurationVar(&metrics.OptimizerQosServerPushInterval, common.OptimizerQosServerPushIntervalFlag, time.Minute*5, "interval to push optimizer qos reports")
cmdRPCConsumer.Flags().DurationVar(&metrics.OptimizerQosServerSamplingInterval, common.OptimizerQosServerSamplingIntervalFlag, time.Second*1, "interval to sample optimizer qos reports")
cmdRPCConsumer.Flags().IntVar(&chainlib.WebSocketRateLimit, common.RateLimitWebSocketFlag, chainlib.WebSocketRateLimit, "rate limit (per second) websocket requests per user connection, default is unlimited")
cmdRPCConsumer.Flags().Int64Var(&chainlib.MaximumNumberOfParallelWebsocketConnectionsPerIp, common.LimitParallelWebsocketConnectionsPerIpFlag, chainlib.MaximumNumberOfParallelWebsocketConnectionsPerIp, "limit number of parallel connections to websocket, per ip, default is unlimited (0)")
cmdRPCConsumer.Flags().DurationVar(&chainlib.WebSocketBanDuration, common.BanDurationForWebsocketRateLimitExceededFlag, chainlib.WebSocketBanDuration, "once websocket rate limit is reached, user will be banned Xfor a duration, default no ban")
common.AddRollingLogConfig(cmdRPCConsumer)
return cmdRPCConsumer
Expand Down
2 changes: 1 addition & 1 deletion scripts/pre_setups/init_lava_only_with_node.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ wait_next_block

screen -d -m -S consumers bash -c "source ~/.bashrc; lavap rpcconsumer \
127.0.0.1:3360 LAV1 rest 127.0.0.1:3361 LAV1 tendermintrpc 127.0.0.1:3362 LAV1 grpc \
$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level trace --from user1 --chain-id lava --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25
$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level trace --from user1 --chain-id lava --limit-parallel-websocket-connections-per-ip 1 --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25

echo "--- setting up screens done ---"
screen -ls
Loading