From a7249de744c7d70f4c05a02fdfa629b3735401bb Mon Sep 17 00:00:00 2001 From: Ran Mishael Date: Wed, 4 Dec 2024 19:10:27 +0100 Subject: [PATCH] max idle duration for ws connections --- .../chainlib/consumer_websocket_manager.go | 26 +++++++++++++++++++ protocol/common/cobra_common.go | 1 + protocol/rpcconsumer/rpcconsumer.go | 1 + 3 files changed, 28 insertions(+) diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index 618491be49..bb26884fe7 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -2,6 +2,7 @@ package chainlib import ( "context" + "fmt" "strconv" "sync/atomic" "time" @@ -20,6 +21,7 @@ import ( var ( WebSocketRateLimit = -1 // rate limit requests per second on websocket connection WebSocketBanDuration = time.Duration(0) // once rate limit is reached, will not allow new incoming message for a duration + MaxIdleTimeInSeconds = int64(20 * 60) // 20 minutes of idle time will disconnect the websocket connection ) const ( @@ -178,7 +180,31 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } }() + idleFor := atomic.Int64{} + idleFor.Store(time.Now().Unix()) + go (func() { + if MaxIdleTimeInSeconds <= 0 { + return // unlimited idle time + } + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-webSocketCtx.Done(): + utils.LavaFormatDebug("ctx done in idle time checker") + return + case <-ticker.C: + utils.LavaFormatDebug("checking idle time", utils.LogAttr("idleFor", idleFor.Load()), utils.LogAttr("maxIdleTime", MaxIdleTimeInSeconds), utils.LogAttr("now", time.Now().Unix())) + idleDuration := idleFor.Load() + MaxIdleTimeInSeconds + if time.Now().Unix() > idleDuration { + websocketConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("Connection idle for too long, closing connection. Idle time: %d", idleDuration))) + } + } + } + })() + for { + idleFor.Store(time.Now().Unix()) startTime := time.Now() msgSeed := guidString + "_" + strconv.Itoa(rand.Intn(10000000000)) // use message seed with original guid and new int diff --git a/protocol/common/cobra_common.go b/protocol/common/cobra_common.go index 0fa6f464ce..05e6259ec6 100644 --- a/protocol/common/cobra_common.go +++ b/protocol/common/cobra_common.go @@ -48,6 +48,7 @@ const ( RateLimitWebSocketFlag = "rate-limit-websocket-requests-per-connection" BanDurationForWebsocketRateLimitExceededFlag = "ban-duration-for-websocket-rate-limit-exceeded" LimitParallelWebsocketConnectionsPerIpFlag = "limit-parallel-websocket-connections-per-ip" + LimitWebsocketIdleTimeFlag = "limit-websocket-connection-idle-time" RateLimitRequestPerSecondFlag = "rate-limit-requests-per-second" ) diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 97e331fb62..c69a7a069f 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -767,6 +767,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 cmdRPCConsumer.Flags().DurationVar(&metrics.OptimizerQosServerSamplingInterval, common.OptimizerQosServerSamplingIntervalFlag, time.Second*1, "interval to sample optimizer qos reports") cmdRPCConsumer.Flags().IntVar(&chainlib.WebSocketRateLimit, common.RateLimitWebSocketFlag, chainlib.WebSocketRateLimit, "rate limit (per second) websocket requests per user connection, default is unlimited") cmdRPCConsumer.Flags().Int64Var(&chainlib.MaximumNumberOfParallelWebsocketConnectionsPerIp, common.LimitParallelWebsocketConnectionsPerIpFlag, chainlib.MaximumNumberOfParallelWebsocketConnectionsPerIp, "limit number of parallel connections to websocket, per ip, default is unlimited (0)") + cmdRPCConsumer.Flags().Int64Var(&chainlib.MaxIdleTimeInSeconds, common.LimitWebsocketIdleTimeFlag, chainlib.MaxIdleTimeInSeconds, "limit the idle time in seconds for a websocket connection, default is 20 minutes ( 20 * 60 )") cmdRPCConsumer.Flags().DurationVar(&chainlib.WebSocketBanDuration, common.BanDurationForWebsocketRateLimitExceededFlag, chainlib.WebSocketBanDuration, "once websocket rate limit is reached, user will be banned Xfor a duration, default no ban") cmdRPCConsumer.Flags().Bool(LavaOverLavaBackupFlagName, true, "enable lava over lava backup to regular rpc calls") common.AddRollingLogConfig(cmdRPCConsumer)