Skip to content

Commit

Permalink
adding unitest
Browse files Browse the repository at this point in the history
  • Loading branch information
ranlavanet committed Dec 4, 2024
1 parent c907e59 commit 673d433
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 14 deletions.
118 changes: 118 additions & 0 deletions protocol/chainlib/consumer_websocket_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package chainlib

import (
"net"

"testing"

"github.com/golang/mock/gomock"
"github.com/lavanet/lava/v4/protocol/common"
"github.com/stretchr/testify/assert"
)

func TestWebsocketConnectionLimiter(t *testing.T) {
tests := []struct {
name string
connectionLimit int64
headerLimit int64
ipAddress string
forwardedIP string
userAgent string
expectSuccess []bool
}{
{
name: "Single connection allowed",
connectionLimit: 1,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true},
},
{
name: "Single connection allowed",
connectionLimit: 1,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, false},
},
{
name: "Multiple connections allowed",
connectionLimit: 2,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true},
},
{
name: "Multiple connections allowed",
connectionLimit: 2,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true, false},
},
{
name: "Header limit overrides global limit succeed",
connectionLimit: 3,
headerLimit: 2,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true},
},
{
name: "Header limit overrides global limit fail",
connectionLimit: 0,
headerLimit: 2,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true, false},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Create a new connection limiter
wcl := &WebsocketConnectionLimiter{
ipToNumberOfActiveConnections: make(map[string]int64),
}

// Set global connection limit for testing
MaximumNumberOfParallelWebsocketConnectionsPerIp = tt.connectionLimit

// Create mock websocket connection
mockWsConn := NewMockWebsocketConnection(ctrl)

// Set up expectations
mockWsConn.EXPECT().Locals(WebSocketOpenConnectionsLimitHeader).Return(tt.headerLimit).AnyTimes()
mockWsConn.EXPECT().Locals(common.IP_FORWARDING_HEADER_NAME).Return(tt.forwardedIP).AnyTimes()
mockWsConn.EXPECT().Locals("User-Agent").Return(tt.userAgent).AnyTimes()
mockWsConn.EXPECT().RemoteAddr().Return(&net.TCPAddr{
IP: net.ParseIP(tt.ipAddress),
Port: 8080,
}).AnyTimes()
mockWsConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Do(func(messageType int, data []byte) {
t.Logf("WriteMessage called with messageType: %d, data: %s", messageType, string(data))
}).AnyTimes()

// Test the connection
for _, expectSuccess := range tt.expectSuccess {
canOpen, _ := wcl.canOpenConnection(mockWsConn)
if expectSuccess {
assert.True(t, canOpen, "Expected connection to be allowed")
} else {
assert.False(t, canOpen, "Expected connection to be denied")
}
}
})
}
}
77 changes: 77 additions & 0 deletions protocol/chainlib/mock_websocket.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 33 additions & 14 deletions protocol/chainlib/websocket_connection_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ import (
"github.com/lavanet/lava/v4/utils"
)

// WebsocketConnection defines the interface for websocket connections
type WebsocketConnection interface {
// Add only the methods you need to mock
RemoteAddr() net.Addr
Locals(key string) interface{}
WriteMessage(messageType int, data []byte) error
}

// Will limit a certain amount of connections per IP
type WebsocketConnectionLimiter struct {
ipToNumberOfActiveConnections map[string]int64
Expand Down Expand Up @@ -47,7 +55,7 @@ func (wcl *WebsocketConnectionLimiter) handleFiberRateLimitFlags(c *fiber.Ctx) {
c.Locals(WebSocketOpenConnectionsLimitHeader, connectionLimit)
}

func (wcl *WebsocketConnectionLimiter) getConnectionLimit(websocketConn *websocket.Conn) int64 {
func (wcl *WebsocketConnectionLimiter) getConnectionLimit(websocketConn WebsocketConnection) int64 {
connectionLimitHeaderValue, ok := websocketConn.Locals(WebSocketOpenConnectionsLimitHeader).(int64)
if !ok || connectionLimitHeaderValue < 0 {
connectionLimitHeaderValue = 0
Expand All @@ -60,9 +68,10 @@ func (wcl *WebsocketConnectionLimiter) getConnectionLimit(websocketConn *websock
return utils.Max(MaximumNumberOfParallelWebsocketConnectionsPerIp, connectionLimitHeaderValue)
}

func (wcl *WebsocketConnectionLimiter) canOpenConnection(websocketConn *websocket.Conn) (bool, func()) {
func (wcl *WebsocketConnectionLimiter) canOpenConnection(websocketConn WebsocketConnection) (bool, func()) {
// Check which connection limit is higher and use that.
connectionLimit := wcl.getConnectionLimit(websocketConn)
decreaseIpConnectionCallback := func() {}
if connectionLimit > 0 { // 0 is disabled.
ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME)
ipForwarded, assertionSuccessful := ipForwardedInterface.(string)
Expand All @@ -75,31 +84,41 @@ func (wcl *WebsocketConnectionLimiter) canOpenConnection(websocketConn *websocke
userAgent = ""
}
key := wcl.getKey(ip, ipForwarded, userAgent)
numberOfActiveConnections := wcl.addIpConnectionAndGetCurrentAmount(key)

if numberOfActiveConnections > connectionLimit {
// Check current connections before incrementing
currentConnections := wcl.getCurrentAmountOfConnections(key)
// If already at or exceeding limit, deny the connection
if currentConnections >= connectionLimit {
websocketConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("Too Many Open Connections, limited to %d", connectionLimit)))
return false, func() { wcl.decreaseIpConnection(key) }
return false, decreaseIpConnectionCallback
}
// If under limit, increment and return cleanup function
wcl.addIpConnectionAndGetCurrentAmount(key)
decreaseIpConnectionCallback = func() { wcl.decreaseIpConnection(key) }
}
return true, func() {}
return true, decreaseIpConnectionCallback
}

func (wcl *WebsocketConnectionLimiter) addIpConnectionAndGetCurrentAmount(ip string) int64 {
func (wcl *WebsocketConnectionLimiter) getCurrentAmountOfConnections(key string) int64 {
wcl.lock.RLock()
defer wcl.lock.RUnlock()
return wcl.ipToNumberOfActiveConnections[key]
}

func (wcl *WebsocketConnectionLimiter) addIpConnectionAndGetCurrentAmount(key string) {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] += 1
return wcl.ipToNumberOfActiveConnections[ip]
wcl.ipToNumberOfActiveConnections[key] += 1
}

func (wcl *WebsocketConnectionLimiter) decreaseIpConnection(ip string) {
func (wcl *WebsocketConnectionLimiter) decreaseIpConnection(key string) {
wcl.lock.Lock()
defer wcl.lock.Unlock()
// wether it exists or not we add 1.
wcl.ipToNumberOfActiveConnections[ip] -= 1
if wcl.ipToNumberOfActiveConnections[ip] == 0 {
delete(wcl.ipToNumberOfActiveConnections, ip)
// it must exist as we dont get here without adding it prior
wcl.ipToNumberOfActiveConnections[key] -= 1
if wcl.ipToNumberOfActiveConnections[key] == 0 {
delete(wcl.ipToNumberOfActiveConnections, key)
}
}

Expand Down

0 comments on commit 673d433

Please sign in to comment.