From 7338448469d014b4b1eb5aaedf0303f687e45f62 Mon Sep 17 00:00:00 2001 From: Dylan Tinianov Date: Fri, 26 Apr 2024 13:10:57 -0400 Subject: [PATCH] Create HeadPoller for Multi-Node (#12871) * Create polling transformer * Update poller * Rename to HeadPoller * lint * update poller * Update head poller * Update poller * lint * Refactor Poller * Update poller_test.go * Update poller * Synchronize tests * Refactor with timeout * Check test logs * Update Poller * Update poller_test.go * Update poller_test.go * Simplify poller * Set logging to warn --- common/client/poller.go | 98 +++++++++++++++++ common/client/poller_test.go | 207 +++++++++++++++++++++++++++++++++++ 2 files changed, 305 insertions(+) create mode 100644 common/client/poller.go create mode 100644 common/client/poller_test.go diff --git a/common/client/poller.go b/common/client/poller.go new file mode 100644 index 00000000000..b21f28fe604 --- /dev/null +++ b/common/client/poller.go @@ -0,0 +1,98 @@ +package client + +import ( + "context" + "sync" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +// Poller is a component that polls a function at a given interval +// and delivers the result to a channel. It is used by multinode to poll +// for new heads and implements the Subscription interface. +type Poller[T any] struct { + services.StateMachine + pollingInterval time.Duration + pollingFunc func(ctx context.Context) (T, error) + pollingTimeout time.Duration + logger logger.Logger + channel chan<- T + errCh chan error + + stopCh services.StopChan + wg sync.WaitGroup +} + +// NewPoller creates a new Poller instance +func NewPoller[ + T any, +](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, channel chan<- T, logger logger.Logger) Poller[T] { + return Poller[T]{ + pollingInterval: pollingInterval, + pollingFunc: pollingFunc, + pollingTimeout: pollingTimeout, + channel: channel, + logger: logger, + errCh: make(chan error), + stopCh: make(chan struct{}), + } +} + +var _ types.Subscription = &Poller[any]{} + +func (p *Poller[T]) Start() error { + return p.StartOnce("Poller", func() error { + p.wg.Add(1) + go p.pollingLoop() + return nil + }) +} + +// Unsubscribe cancels the sending of events to the data channel +func (p *Poller[T]) Unsubscribe() { + _ = p.StopOnce("Poller", func() error { + close(p.stopCh) + p.wg.Wait() + close(p.errCh) + return nil + }) +} + +func (p *Poller[T]) Err() <-chan error { + return p.errCh +} + +func (p *Poller[T]) pollingLoop() { + defer p.wg.Done() + + ticker := time.NewTicker(p.pollingInterval) + defer ticker.Stop() + + for { + select { + case <-p.stopCh: + return + case <-ticker.C: + // Set polling timeout + pollingCtx, cancelPolling := context.WithTimeout(context.Background(), p.pollingTimeout) + p.stopCh.CtxCancel(pollingCtx, cancelPolling) + // Execute polling function + result, err := p.pollingFunc(pollingCtx) + cancelPolling() + if err != nil { + p.logger.Warnf("polling error: %v", err) + continue + } + // Send result to channel or block if channel is full + select { + case p.channel <- result: + case <-p.stopCh: + return + } + } + } +} diff --git a/common/client/poller_test.go b/common/client/poller_test.go new file mode 100644 index 00000000000..3f11c759adb --- /dev/null +++ b/common/client/poller_test.go @@ -0,0 +1,207 @@ +package client + +import ( + "context" + "fmt" + "math/big" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +func Test_Poller(t *testing.T) { + lggr := logger.Test(t) + + t.Run("Test multiple start", func(t *testing.T) { + pollFunc := func(ctx context.Context) (Head, error) { + return nil, nil + } + + channel := make(chan Head, 1) + defer close(channel) + + poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr) + err := poller.Start() + require.NoError(t, err) + + err = poller.Start() + require.Error(t, err) + poller.Unsubscribe() + }) + + t.Run("Test polling for heads", func(t *testing.T) { + // Mock polling function that returns a new value every time it's called + var pollNumber int + pollLock := sync.Mutex{} + pollFunc := func(ctx context.Context) (Head, error) { + pollLock.Lock() + defer pollLock.Unlock() + pollNumber++ + h := head{ + BlockNumber: int64(pollNumber), + BlockDifficulty: big.NewInt(int64(pollNumber)), + } + return h.ToMockHead(t), nil + } + + // data channel to receive updates from the poller + channel := make(chan Head, 1) + defer close(channel) + + // Create poller and start to receive data + poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr) + require.NoError(t, poller.Start()) + defer poller.Unsubscribe() + + // Receive updates from the poller + pollCount := 0 + pollMax := 50 + for ; pollCount < pollMax; pollCount++ { + h := <-channel + assert.Equal(t, int64(pollCount+1), h.BlockNumber()) + } + }) + + t.Run("Test polling errors", func(t *testing.T) { + // Mock polling function that returns an error + var pollNumber int + pollLock := sync.Mutex{} + pollFunc := func(ctx context.Context) (Head, error) { + pollLock.Lock() + defer pollLock.Unlock() + pollNumber++ + return nil, fmt.Errorf("polling error %d", pollNumber) + } + + // data channel to receive updates from the poller + channel := make(chan Head, 1) + defer close(channel) + + olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) + + // Create poller and subscribe to receive data + poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, olggr) + require.NoError(t, poller.Start()) + defer poller.Unsubscribe() + + // Ensure that all errors were logged as expected + logsSeen := func() bool { + for pollCount := 0; pollCount < 50; pollCount++ { + numLogs := observedLogs.FilterMessage(fmt.Sprintf("polling error: polling error %d", pollCount+1)).Len() + if numLogs != 1 { + return false + } + } + return true + } + require.Eventually(t, logsSeen, time.Second, time.Millisecond) + }) + + t.Run("Test polling timeout", func(t *testing.T) { + pollFunc := func(ctx context.Context) (Head, error) { + if <-ctx.Done(); true { + return nil, ctx.Err() + } + return nil, nil + } + + // Set instant timeout + pollingTimeout := time.Duration(0) + + // data channel to receive updates from the poller + channel := make(chan Head, 1) + defer close(channel) + + olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) + + // Create poller and subscribe to receive data + poller := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, channel, olggr) + require.NoError(t, poller.Start()) + defer poller.Unsubscribe() + + // Ensure that timeout errors were logged as expected + logsSeen := func() bool { + return observedLogs.FilterMessage("polling error: context deadline exceeded").Len() >= 1 + } + require.Eventually(t, logsSeen, time.Second, time.Millisecond) + }) + + t.Run("Test unsubscribe during polling", func(t *testing.T) { + wait := make(chan struct{}) + pollFunc := func(ctx context.Context) (Head, error) { + close(wait) + // Block in polling function until context is cancelled + if <-ctx.Done(); true { + return nil, ctx.Err() + } + return nil, nil + } + + // Set long timeout + pollingTimeout := time.Minute + + // data channel to receive updates from the poller + channel := make(chan Head, 1) + defer close(channel) + + olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel) + + // Create poller and subscribe to receive data + poller := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, channel, olggr) + require.NoError(t, poller.Start()) + + // Unsubscribe while blocked in polling function + <-wait + poller.Unsubscribe() + + // Ensure error was logged + logsSeen := func() bool { + return observedLogs.FilterMessage("polling error: context canceled").Len() >= 1 + } + require.Eventually(t, logsSeen, time.Second, time.Millisecond) + }) +} + +func Test_Poller_Unsubscribe(t *testing.T) { + lggr := logger.Test(t) + pollFunc := func(ctx context.Context) (Head, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + h := head{ + BlockNumber: 0, + BlockDifficulty: big.NewInt(0), + } + return h.ToMockHead(t), nil + } + } + + t.Run("Test multiple unsubscribe", func(t *testing.T) { + channel := make(chan Head, 1) + poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr) + err := poller.Start() + require.NoError(t, err) + + <-channel + poller.Unsubscribe() + poller.Unsubscribe() + }) + + t.Run("Test unsubscribe with closed channel", func(t *testing.T) { + channel := make(chan Head, 1) + poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr) + err := poller.Start() + require.NoError(t, err) + + <-channel + close(channel) + poller.Unsubscribe() + }) +}