From 738146e8d39eeaeef8e1479a7a2879b24cf7930d Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:11:52 -0800 Subject: [PATCH] [Functions] Minor Listener refactor (#11323) 1. Add an interface type to make Listener mockable 2. Return internal errors from handleRequest() --- core/services/functions/listener.go | 74 +++++++++++-------- core/services/functions/listener_test.go | 21 +++++- .../functions/mocks/functions_listener.go | 71 ++++++++++++++++++ 3 files changed, 134 insertions(+), 32 deletions(-) create mode 100644 core/services/functions/mocks/functions_listener.go diff --git a/core/services/functions/listener.go b/core/services/functions/listener.go index 5614c5331d4..efb40330cb4 100644 --- a/core/services/functions/listener.go +++ b/core/services/functions/listener.go @@ -29,8 +29,6 @@ import ( ) var ( - _ job.ServiceCtx = &FunctionsListener{} - sizeBuckets = []float64{ 1024, 1024 * 4, @@ -118,7 +116,14 @@ const ( FlagSecretsMaxSize uint32 = 2 ) -type FunctionsListener struct { +//go:generate mockery --quiet --name FunctionsListener --output ./mocks/ --case=underscore +type FunctionsListener interface { + job.ServiceCtx + + HandleOffchainRequest(ctx context.Context, request *OffchainRequest) error +} + +type functionsListener struct { services.StateMachine client client.Client contractAddressHex string @@ -137,11 +142,13 @@ type FunctionsListener struct { logPollerWrapper evmrelayTypes.LogPollerWrapper } -func (l *FunctionsListener) HealthReport() map[string]error { +var _ FunctionsListener = &functionsListener{} + +func (l *functionsListener) HealthReport() map[string]error { return map[string]error{l.Name(): l.Healthy()} } -func (l *FunctionsListener) Name() string { return l.logger.Name() } +func (l *functionsListener) Name() string { return l.logger.Name() } func formatRequestId(requestId [32]byte) string { return fmt.Sprintf("0x%x", requestId) @@ -159,8 +166,8 @@ func NewFunctionsListener( urlsMonEndpoint commontypes.MonitoringEndpoint, decryptor threshold.Decryptor, logPollerWrapper evmrelayTypes.LogPollerWrapper, -) *FunctionsListener { - return &FunctionsListener{ +) *functionsListener { + return &functionsListener{ client: client, contractAddressHex: contractAddressHex, job: job, @@ -177,7 +184,7 @@ func NewFunctionsListener( } // Start complies with job.Service -func (l *FunctionsListener) Start(context.Context) error { +func (l *functionsListener) Start(context.Context) error { return l.StartOnce("FunctionsListener", func() error { l.serviceContext, l.serviceCancel = context.WithCancel(context.Background()) @@ -204,7 +211,7 @@ func (l *FunctionsListener) Start(context.Context) error { } // Close complies with job.Service -func (l *FunctionsListener) Close() error { +func (l *functionsListener) Close() error { return l.StopOnce("FunctionsListener", func() error { l.serviceCancel() close(l.chStop) @@ -213,7 +220,7 @@ func (l *FunctionsListener) Close() error { }) } -func (l *FunctionsListener) processOracleEventsV1() { +func (l *functionsListener) processOracleEventsV1() { defer l.shutdownWaitGroup.Done() freqMillis := l.pluginConfig.ListenerEventsCheckFrequencyMillis if freqMillis == 0 { @@ -247,7 +254,7 @@ func (l *FunctionsListener) processOracleEventsV1() { } } -func (l *FunctionsListener) getNewHandlerContext() (context.Context, context.CancelFunc) { +func (l *functionsListener) getNewHandlerContext() (context.Context, context.CancelFunc) { timeoutSec := l.pluginConfig.ListenerEventHandlerTimeoutSec if timeoutSec == 0 { return context.WithCancel(l.serviceContext) @@ -255,7 +262,7 @@ func (l *FunctionsListener) getNewHandlerContext() (context.Context, context.Can return context.WithTimeout(l.serviceContext, time.Duration(timeoutSec)*time.Second) } -func (l *FunctionsListener) setError(ctx context.Context, requestId RequestID, errType ErrType, errBytes []byte) { +func (l *functionsListener) setError(ctx context.Context, requestId RequestID, errType ErrType, errBytes []byte) { if errType == INTERNAL_ERROR { promRequestInternalError.WithLabelValues(l.contractAddressHex).Inc() } else { @@ -267,7 +274,7 @@ func (l *FunctionsListener) setError(ctx context.Context, requestId RequestID, e } } -func (l *FunctionsListener) getMaxCBORsize(flags RequestFlags) uint32 { +func (l *functionsListener) getMaxCBORsize(flags RequestFlags) uint32 { idx := flags[FlagCBORMaxSize] if int(idx) >= len(l.pluginConfig.MaxRequestSizesList) { return l.pluginConfig.MaxRequestSizeBytes // deprecated @@ -275,7 +282,7 @@ func (l *FunctionsListener) getMaxCBORsize(flags RequestFlags) uint32 { return l.pluginConfig.MaxRequestSizesList[idx] } -func (l *FunctionsListener) getMaxSecretsSize(flags RequestFlags) uint32 { +func (l *functionsListener) getMaxSecretsSize(flags RequestFlags) uint32 { idx := flags[FlagSecretsMaxSize] if int(idx) >= len(l.pluginConfig.MaxSecretsSizesList) { return math.MaxUint32 // not enforced if not configured @@ -283,7 +290,7 @@ func (l *FunctionsListener) getMaxSecretsSize(flags RequestFlags) uint32 { return l.pluginConfig.MaxSecretsSizesList[idx] } -func (l *FunctionsListener) HandleOffchainRequest(ctx context.Context, request *OffchainRequest) error { +func (l *functionsListener) HandleOffchainRequest(ctx context.Context, request *OffchainRequest) error { if request == nil { return errors.New("HandleOffchainRequest: received nil request") } @@ -318,11 +325,10 @@ func (l *FunctionsListener) HandleOffchainRequest(ctx context.Context, request * } return err } - l.handleRequest(ctx, requestId, request.SubscriptionId, subscriptionOwner, RequestFlags{}, &request.Data) - return nil + return l.handleRequest(ctx, requestId, request.SubscriptionId, subscriptionOwner, RequestFlags{}, &request.Data) } -func (l *FunctionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleRequest) { +func (l *functionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleRequest) { defer l.shutdownWaitGroup.Done() l.logger.Infow("handleOracleRequestV1: oracle request v1 received", "requestID", formatRequestId(request.RequestId)) ctx, cancel := l.getNewHandlerContext() @@ -354,10 +360,13 @@ func (l *FunctionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleR l.setError(ctx, request.RequestId, USER_ERROR, []byte(err.Error())) return } - l.handleRequest(ctx, request.RequestId, request.SubscriptionId, request.SubscriptionOwner, request.Flags, requestData) + err = l.handleRequest(ctx, request.RequestId, request.SubscriptionId, request.SubscriptionOwner, request.Flags, requestData) + if err != nil { + l.logger.Errorw("handleOracleRequestV1: error in handleRequest()", "requestID", formatRequestId(request.RequestId), "err", err) + } } -func (l *FunctionsListener) parseCBOR(requestId RequestID, cborData []byte, maxSizeBytes uint32) (*RequestData, error) { +func (l *functionsListener) parseCBOR(requestId RequestID, cborData []byte, maxSizeBytes uint32) (*RequestData, error) { if maxSizeBytes > 0 && uint32(len(cborData)) > maxSizeBytes { l.logger.Errorw("request too big", "requestID", formatRequestId(requestId), "requestSize", len(cborData), "maxRequestSize", maxSizeBytes) return nil, fmt.Errorf("request too big (max %d bytes)", maxSizeBytes) @@ -372,7 +381,8 @@ func (l *FunctionsListener) parseCBOR(requestId RequestID, cborData []byte, maxS return &requestData, nil } -func (l *FunctionsListener) handleRequest(ctx context.Context, requestID RequestID, subscriptionId uint64, subscriptionOwner common.Address, flags RequestFlags, requestData *RequestData) { +// Handle secret fetching/decryption and functions computation. Return error only for internal errors. +func (l *functionsListener) handleRequest(ctx context.Context, requestID RequestID, subscriptionId uint64, subscriptionOwner common.Address, flags RequestFlags, requestData *RequestData) error { startTime := time.Now() defer func() { duration := time.Since(startTime) @@ -385,26 +395,26 @@ func (l *FunctionsListener) handleRequest(ctx context.Context, requestID Request if err != nil { l.logger.Errorw("failed to create ExternalAdapterClient", "requestID", requestIDStr, "err", err) l.setError(ctx, requestID, INTERNAL_ERROR, []byte(err.Error())) - return + return err } nodeProvidedSecrets, userErr, internalErr := l.getSecrets(ctx, eaClient, requestID, subscriptionOwner, requestData) if internalErr != nil { l.logger.Errorw("internal error during getSecrets", "requestID", requestIDStr, "err", internalErr) l.setError(ctx, requestID, INTERNAL_ERROR, []byte(internalErr.Error())) - return + return internalErr } if userErr != nil { l.logger.Debugw("user error during getSecrets", "requestID", requestIDStr, "err", userErr) l.setError(ctx, requestID, USER_ERROR, []byte(userErr.Error())) - return + return nil // user error } maxSecretsSize := l.getMaxSecretsSize(flags) if uint32(len(nodeProvidedSecrets)) > maxSecretsSize { l.logger.Errorw("secrets size too big", "requestID", requestIDStr, "secretsSize", len(nodeProvidedSecrets), "maxSecretsSize", maxSecretsSize) l.setError(ctx, requestID, USER_ERROR, []byte("secrets size too big")) - return + return nil // user error } computationResult, computationError, domains, err := eaClient.RunComputation(ctx, requestIDStr, l.job.Name.ValueOrZero(), subscriptionOwner.Hex(), subscriptionId, flags, nodeProvidedSecrets, requestData) @@ -412,7 +422,7 @@ func (l *FunctionsListener) handleRequest(ctx context.Context, requestID Request if err != nil { l.logger.Errorw("internal adapter error", "requestID", requestIDStr, "err", err) l.setError(ctx, requestID, INTERNAL_ERROR, []byte(err.Error())) - return + return err } if len(computationError) == 0 && len(computationResult) == 0 { @@ -438,11 +448,13 @@ func (l *FunctionsListener) handleRequest(ctx context.Context, requestID Request l.logger.Debugw("saving computation result", "requestID", requestIDStr) if err2 := l.pluginORM.SetResult(requestID, computationResult, time.Now(), pg.WithParentCtx(ctx)); err2 != nil { l.logger.Errorw("call to SetResult failed", "requestID", requestIDStr, "err", err2) + return err2 } } + return nil } -func (l *FunctionsListener) handleOracleResponseV1(response *evmrelayTypes.OracleResponse) { +func (l *functionsListener) handleOracleResponseV1(response *evmrelayTypes.OracleResponse) { defer l.shutdownWaitGroup.Done() l.logger.Infow("oracle response v1 received", "requestID", formatRequestId(response.RequestId)) @@ -454,7 +466,7 @@ func (l *FunctionsListener) handleOracleResponseV1(response *evmrelayTypes.Oracl promRequestConfirmed.WithLabelValues(l.contractAddressHex).Inc() } -func (l *FunctionsListener) timeoutRequests() { +func (l *functionsListener) timeoutRequests() { defer l.shutdownWaitGroup.Done() timeoutSec, freqSec, batchSize := l.pluginConfig.RequestTimeoutSec, l.pluginConfig.RequestTimeoutCheckFrequencySec, l.pluginConfig.RequestTimeoutBatchLookupSize if timeoutSec == 0 || freqSec == 0 || batchSize == 0 { @@ -490,7 +502,7 @@ func (l *FunctionsListener) timeoutRequests() { } } -func (l *FunctionsListener) pruneRequests() { +func (l *functionsListener) pruneRequests() { defer l.shutdownWaitGroup.Done() maxStoredRequests, freqSec, batchSize := l.pluginConfig.PruneMaxStoredRequests, l.pluginConfig.PruneCheckFrequencySec, l.pluginConfig.PruneBatchSize if maxStoredRequests == 0 { @@ -532,7 +544,7 @@ func (l *FunctionsListener) pruneRequests() { } } -func (l *FunctionsListener) reportSourceCodeDomains(requestId RequestID, domains []string) { +func (l *functionsListener) reportSourceCodeDomains(requestId RequestID, domains []string) { r := &telem.FunctionsRequest{ RequestId: formatRequestId(requestId), NodeAddress: l.job.OCR2OracleSpec.TransmitterID.ValueOrZero(), @@ -547,7 +559,7 @@ func (l *FunctionsListener) reportSourceCodeDomains(requestId RequestID, domains } } -func (l *FunctionsListener) getSecrets(ctx context.Context, eaClient ExternalAdapterClient, requestID RequestID, subscriptionOwner common.Address, requestData *RequestData) (decryptedSecrets string, userError, internalError error) { +func (l *functionsListener) getSecrets(ctx context.Context, eaClient ExternalAdapterClient, requestID RequestID, subscriptionOwner common.Address, requestData *RequestData) (decryptedSecrets string, userError, internalError error) { if l.decryptor == nil { l.logger.Warn("Decryptor not configured") return "", nil, nil diff --git a/core/services/functions/listener_test.go b/core/services/functions/listener_test.go index ac2bc64184d..ecad9e4cceb 100644 --- a/core/services/functions/listener_test.go +++ b/core/services/functions/listener_test.go @@ -46,7 +46,7 @@ import ( ) type FunctionsListenerUniverse struct { - service *functions_service.FunctionsListener + service functions_service.FunctionsListener bridgeAccessor *functions_mocks.BridgeAccessor eaClient *functions_mocks.ExternalAdapterClient pluginORM *functions_mocks.ORM @@ -219,6 +219,25 @@ func TestFunctionsListener_HandleOffchainRequest_Invalid(t *testing.T) { require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request)) } +func TestFunctionsListener_HandleOffchainRequest_InternalError(t *testing.T) { + testutils.SkipShortDB(t) + t.Parallel() + uni := NewFunctionsListenerUniverse(t, 0, 1_000_000) + uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) + uni.bridgeAccessor.On("NewExternalAdapterClient").Return(uni.eaClient, nil) + uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil, nil, errors.New("error")) + uni.pluginORM.On("SetError", RequestID, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + request := &functions_service.OffchainRequest{ + RequestId: RequestID[:], + RequestInitiator: SubscriptionOwner.Bytes(), + SubscriptionId: uint64(SubscriptionID), + SubscriptionOwner: SubscriptionOwner.Bytes(), + Data: functions_service.RequestData{}, + } + require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request)) +} + func TestFunctionsListener_HandleOracleRequestV1_ComputationError(t *testing.T) { testutils.SkipShortDB(t) t.Parallel() diff --git a/core/services/functions/mocks/functions_listener.go b/core/services/functions/mocks/functions_listener.go new file mode 100644 index 00000000000..d2aeb2ddab8 --- /dev/null +++ b/core/services/functions/mocks/functions_listener.go @@ -0,0 +1,71 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + functions "github.com/smartcontractkit/chainlink/v2/core/services/functions" + mock "github.com/stretchr/testify/mock" +) + +// FunctionsListener is an autogenerated mock type for the FunctionsListener type +type FunctionsListener struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *FunctionsListener) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// HandleOffchainRequest provides a mock function with given fields: ctx, request +func (_m *FunctionsListener) HandleOffchainRequest(ctx context.Context, request *functions.OffchainRequest) error { + ret := _m.Called(ctx, request) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *functions.OffchainRequest) error); ok { + r0 = rf(ctx, request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Start provides a mock function with given fields: _a0 +func (_m *FunctionsListener) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewFunctionsListener creates a new instance of FunctionsListener. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewFunctionsListener(t interface { + mock.TestingT + Cleanup(func()) +}) *FunctionsListener { + mock := &FunctionsListener{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}