From a92728630efddc8335dcb86c3ce5b954fa6b0b41 Mon Sep 17 00:00:00 2001 From: Lukasz Klimek <842586+lklimek@users.noreply.github.com> Date: Thu, 16 May 2024 18:10:09 +0200 Subject: [PATCH] test(p2pclient): rate limits unit tests --- internal/p2p/client/client_test.go | 10 +- internal/p2p/client/consumer_test.go | 115 -------------- internal/p2p/client/ratelimit_test.go | 219 ++++++++++++++++++++++++++ 3 files changed, 224 insertions(+), 120 deletions(-) create mode 100644 internal/p2p/client/ratelimit_test.go diff --git a/internal/p2p/client/client_test.go b/internal/p2p/client/client_test.go index 50a88c4bb..ee353b0f6 100644 --- a/internal/p2p/client/client_test.go +++ b/internal/p2p/client/client_test.go @@ -54,11 +54,11 @@ func (suite *ChannelTestSuite) SetupTest() { suite.fakeClock = clockwork.NewFakeClock() suite.client = New( suite.descriptors, - func(ctx context.Context, descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) { + func(_ctx context.Context, _descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) { return suite.p2pChannel, nil }, WithClock(suite.fakeClock), - WithChanIDResolver(func(msg proto.Message) p2p.ChannelID { + WithChanIDResolver(func(_msg proto.Message) p2p.ChannelID { return testChannelID }), ) @@ -185,7 +185,7 @@ func (suite *ChannelTestSuite) TestConsumeHandle() { suite.p2pChannel. On("Receive", ctx). Once(). - Return(func(ctx context.Context) p2p.ChannelIterator { + Return(func(_ctx context.Context) p2p.ChannelIterator { return p2p.NewChannelIterator(outCh) }) consumer := newMockConsumer(suite.T()) @@ -226,7 +226,7 @@ func (suite *ChannelTestSuite) TestConsumeResolve() { suite.p2pChannel. On("Receive", ctx). Once(). - Return(func(ctx context.Context) p2p.ChannelIterator { + Return(func(_ctx context.Context) p2p.ChannelIterator { return p2p.NewChannelIterator(outCh) }) resCh := suite.client.addPending(reqID) @@ -278,7 +278,7 @@ func (suite *ChannelTestSuite) TestConsumeError() { suite.p2pChannel. On("Receive", ctx). Once(). - Return(func(ctx context.Context) p2p.ChannelIterator { + Return(func(_ctx context.Context) p2p.ChannelIterator { return p2p.NewChannelIterator(outCh) }) consumer := newMockConsumer(suite.T()) diff --git a/internal/p2p/client/consumer_test.go b/internal/p2p/client/consumer_test.go index dc729463a..31b1347cd 100644 --- a/internal/p2p/client/consumer_test.go +++ b/internal/p2p/client/consumer_test.go @@ -4,24 +4,16 @@ import ( "context" "errors" "fmt" - "math" "regexp" - "strconv" - "sync" - "sync/atomic" "testing" - "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "github.com/dashpay/tenderdash/internal/p2p" tmrequire "github.com/dashpay/tenderdash/internal/test/require" "github.com/dashpay/tenderdash/libs/log" bcproto "github.com/dashpay/tenderdash/proto/tendermint/blocksync" - "github.com/dashpay/tenderdash/types" ) func TestErrorLoggerP2PMessageHandler(t *testing.T) { @@ -151,110 +143,3 @@ func TestValidateMessageHandler(t *testing.T) { }) } } - -// TestRateLimitHandler tests the rate limit middleware. -// -// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4, -// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds, -// THEN: -// * peer 1 and 2 receive all messages, -// * other peers receive 2 messages per second plus 4 burst messages. -func TestRateLimitHandler(t *testing.T) { - const ( - Peers = 5 - RateLimit = 2.0 - Burst = 4 - TestTimeSeconds = 3 - ) - - // don't run this if we are in short mode - if testing.Short() { - t.Skip("skipping test in short mode.") - } - - fakeHandler := newMockConsumer(t) - - // we cancel manually to control race conditions - ctx, cancel := context.WithCancel(context.Background()) - - logger := log.NewTestingLogger(t) - client := &Client{} - - mw := WithRecvRateLimitPerPeerHandler(ctx, RateLimit, func(*p2p.Envelope) uint { return 1 }, false, logger)(fakeHandler).(*recvRateLimitPerPeerHandler) - mw.burst = Burst - - start := sync.RWMutex{} - start.Lock() - - sent := make([]atomic.Uint32, Peers) - - for peer := 1; peer <= Peers; peer++ { - counter := &sent[peer-1] - peerID := types.NodeID(strconv.Itoa(peer)) - fakeHandler.On("Handle", mock.Anything, mock.Anything, mock.MatchedBy( - func(e *p2p.Envelope) bool { - return e.From == peerID - }, - )).Return(nil).Run(func(_args mock.Arguments) { - counter.Add(1) - }) - - go func(peerID types.NodeID, rate int) { - start.RLock() - defer start.RUnlock() - - for s := 0; s < TestTimeSeconds; s++ { - until := time.NewTimer(time.Second) - defer until.Stop() - - for i := 0; i < rate; i++ { - select { - case <-ctx.Done(): - return - default: - } - - envelope := &p2p.Envelope{ - From: peerID, - } - - err := mw.Handle(ctx, client, envelope) - require.NoError(t, err) - } - - select { - case <-until.C: - // noop, we just sleep - case <-ctx.Done(): - return - } - } - }(peerID, peer) - } - - // start the test - startTime := time.Now() - start.Unlock() - time.Sleep(TestTimeSeconds * time.Second) - cancel() - // wait for all goroutines to finish, that is - drop RLocks - start.Lock() - - // Check assertions - - // we floor with 1 decimal place, as elapsed will be slightly more than TestTimeSeconds - elapsed := math.Floor(time.Since(startTime).Seconds()*10) / 10 - assert.Equal(t, float64(TestTimeSeconds), elapsed, "test should run for %d seconds", TestTimeSeconds) - - for peer := 1; peer <= Peers; peer++ { - expected := int(RateLimit)*TestTimeSeconds + Burst - if expected > peer*TestTimeSeconds { - expected = peer * TestTimeSeconds - } - - assert.Equal(t, expected, int(sent[peer-1].Load()), "peer %d should receive %d messages", peer, expected) - } - // require.Equal(t, uint32(1*TestTimeSeconds), sent[0].Load(), "peer 0 should receive 1 message per second") - // require.Equal(t, uint32(2*TestTimeSeconds), sent[1].Load(), "peer 1 should receive 2 messages per second") - // require.Equal(t, uint32(2*TestTimeSeconds+Burst), sent[2].Load(), "peer 2 should receive 2 messages per second") -} diff --git a/internal/p2p/client/ratelimit_test.go b/internal/p2p/client/ratelimit_test.go new file mode 100644 index 000000000..004a33004 --- /dev/null +++ b/internal/p2p/client/ratelimit_test.go @@ -0,0 +1,219 @@ +package client + +import ( + "context" + "errors" + "math" + "runtime" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/dashpay/tenderdash/internal/p2p" + "github.com/dashpay/tenderdash/internal/p2p/conn" + "github.com/dashpay/tenderdash/libs/log" + "github.com/dashpay/tenderdash/types" +) + +// TestRecvRateLimitHandler tests the rate limit middleware when receiving messages from peers. +// It tests that the rate limit is applied per peer. +// +// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4, +// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds, +// THEN: +// * peer 1 and 2 receive all messages, +// * other peers receive 2 messages per second plus 4 burst messages. +// +// Reuses testRateLimit from client_test.go +func TestRecvRateLimitHandler(t *testing.T) { + // don't run this if we are in short mode + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + const ( + Limit = 2.0 + Burst = 4 + Peers = 5 + TestTimeSeconds = 3 + ) + + sent := make([]atomic.Uint32, Peers) + + fakeHandler := newMockConsumer(t) + fakeHandler.On("Handle", mock.Anything, mock.Anything, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + peerID := args.Get(2).(*p2p.Envelope).From + peerNum, err := strconv.Atoi(string(peerID)) + require.NoError(t, err) + sent[peerNum-1].Add(1) + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewTestingLogger(t) + client := &Client{} + + mw := WithRecvRateLimitPerPeerHandler(ctx, + Limit, + func(*p2p.Envelope) uint { return 1 }, + false, + logger, + )(fakeHandler).(*recvRateLimitPerPeerHandler) + + mw.burst = Burst + + sendFn := func(peerID types.NodeID) error { + envelope := p2p.Envelope{ + From: peerID, + ChannelID: testChannelID, + } + return mw.Handle(ctx, client, &envelope) + } + + parallelSendWithLimit(t, ctx, sendFn, Peers, TestTimeSeconds) + assertRateLimits(t, sent, Limit, Burst, TestTimeSeconds) +} + +// TestSendRateLimit tests the rate limit for sending messages using p2p.client. +// +// Each peer should have his own, independent rate limit. +// +// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4, +// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds, +// THEN: +// * peer 1 and 2 receive all messages, +// * other peers receive 2 messages per second plus 4 burst messages. +func (suite *ChannelTestSuite) TestSendRateLimit() { + if testing.Short() { + suite.T().Skip("skipping test in short mode.") + } + + const ( + Limit = 2.0 + Burst = 4 + Peers = 5 + TestTimeSeconds = 3 + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := suite.client + + limiter := NewRateLimit(ctx, Limit, false, suite.client.logger) + limiter.burst = Burst + suite.client.rateLimit = map[conn.ChannelID]*RateLimit{ + testChannelID: limiter, + } + + sendFn := func(peerID types.NodeID) error { + envelope := p2p.Envelope{ + To: peerID, + ChannelID: testChannelID, + } + return client.Send(ctx, envelope) + + } + sent := make([]atomic.Uint32, Peers) + + suite.p2pChannel.On("Send", mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + peerID := args.Get(1).(p2p.Envelope).To + peerNum, err := strconv.Atoi(string(peerID)) + suite.NoError(err) + sent[peerNum-1].Add(1) + }). + Return(nil) + + parallelSendWithLimit(suite.T(), ctx, sendFn, Peers, TestTimeSeconds) + assertRateLimits(suite.T(), sent, Limit, Burst, TestTimeSeconds) +} + +// parallelSendWithLimit sends messages to peers in parallel with a rate limit. +// +// The function sends messages to peers. Each peer gets its number, starting from 1. +// Rate limit is equal to the peer number, eg. peer 1 sends 1 msg/s, peeer 2 sends 2 msg/s etc. +func parallelSendWithLimit(t *testing.T, ctx context.Context, sendFn func(peerID types.NodeID) error, + peers int, testTimeSeconds int) { + t.Helper() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // all goroutines will wait for the start signal + start := sync.RWMutex{} + start.Lock() + + for peer := 1; peer <= peers; peer++ { + peerID := types.NodeID(strconv.Itoa(peer)) + // peer number is the rate limit + msgsPerSec := peer + + go func(peerID types.NodeID, rate int) { + start.RLock() + defer start.RUnlock() + + for s := 0; s < testTimeSeconds; s++ { + until := time.NewTimer(time.Second) + defer until.Stop() + + for i := 0; i < rate; i++ { + select { + case <-ctx.Done(): + return + default: + } + + if err := sendFn(peerID); !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + } + + select { + case <-until.C: + // noop, we just sleep until the end of the second + case <-ctx.Done(): + return + } + } + + }(peerID, msgsPerSec) + } + + // start the test + startTime := time.Now() + start.Unlock() + runtime.Gosched() + time.Sleep(time.Duration(testTimeSeconds) * time.Second) + cancel() + // wait for all goroutines to finish, that is - drop RLocks + start.Lock() + defer start.Unlock() + + // check if test ran for the expected time + // note we ignore up to 99 ms to account for any processing time + elapsed := math.Floor(time.Since(startTime).Seconds()*10) / 10 + assert.Equal(t, float64(testTimeSeconds), elapsed, "test should run for %d seconds", testTimeSeconds) +} + +// assertRateLimits checks if the rate limits were applied correctly +// We assume that index of each item in `sent` is the peer number, as described in parallelSendWithLimit. +func assertRateLimits(t *testing.T, sent []atomic.Uint32, limit float64, burst int, seconds int) { + for peer := 1; peer <= len(sent); peer++ { + expected := int(limit)*seconds + burst + if expected > peer*seconds { + expected = peer * seconds + } + + assert.Equal(t, expected, int(sent[peer-1].Load()), "peer %d should receive %d messages", peer, expected) + } +}