diff --git a/.changeset/real-numbers-taste.md b/.changeset/real-numbers-taste.md new file mode 100644 index 00000000000..d9f545444c2 --- /dev/null +++ b/.changeset/real-numbers-taste.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +core/services/functions: switch to sqlutil.DataStore #internal diff --git a/core/services/chainlink/application.go b/core/services/chainlink/application.go index edc613e25dd..6c373846205 100644 --- a/core/services/chainlink/application.go +++ b/core/services/chainlink/application.go @@ -437,6 +437,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { delegates[job.OffchainReporting2] = ocr2.NewDelegate( sqlxDB, + opts.DB, jobORM, bridgeORM, mercuryORM, diff --git a/core/services/functions/listener.go b/core/services/functions/listener.go index ff4e268573a..d2033ff74de 100644 --- a/core/services/functions/listener.go +++ b/core/services/functions/listener.go @@ -23,7 +23,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/threshold" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" evmrelayTypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" "github.com/smartcontractkit/chainlink/v2/core/services/s4" "github.com/smartcontractkit/chainlink/v2/core/services/synchronization/telem" @@ -270,7 +269,7 @@ func (l *functionsListener) setError(ctx context.Context, requestId RequestID, e promRequestComputationError.WithLabelValues(l.contractAddressHex).Inc() } readyForProcessing := errType != INTERNAL_ERROR - if err := l.pluginORM.SetError(requestId, errType, errBytes, time.Now(), readyForProcessing, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.SetError(ctx, requestId, errType, errBytes, time.Now(), readyForProcessing); err != nil { l.logger.Errorw("call to SetError failed", "requestID", formatRequestId(requestId), "err", err) } } @@ -321,7 +320,7 @@ func (l *functionsListener) HandleOffchainRequest(ctx context.Context, request * CoordinatorContractAddress: &senderAddr, OnchainMetadata: []byte(OffchainRequestMarker), } - if err := l.pluginORM.CreateRequest(newReq, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.CreateRequest(ctx, newReq); err != nil { if errors.Is(err, ErrDuplicateRequestID) { l.logger.Warnw("HandleOffchainRequest: received duplicate request ID", "requestID", formatRequestId(requestId), "err", err) } else { @@ -348,7 +347,7 @@ func (l *functionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleR CoordinatorContractAddress: &request.CoordinatorContract, OnchainMetadata: request.OnchainMetadata, } - if err := l.pluginORM.CreateRequest(newReq, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.CreateRequest(ctx, newReq); err != nil { if errors.Is(err, ErrDuplicateRequestID) { l.logger.Warnw("handleOracleRequestV1: received a log with duplicate request ID", "requestID", formatRequestId(request.RequestId), "err", err) } else { @@ -450,7 +449,7 @@ func (l *functionsListener) handleRequest(ctx context.Context, requestID Request promRequestComputationSuccess.WithLabelValues(l.contractAddressHex).Inc() promComputationResultSize.WithLabelValues(l.contractAddressHex).Set(float64(len(computationResult))) l.logger.Debugw("saving computation result", "requestID", requestIDStr) - if err2 := l.pluginORM.SetResult(requestID, computationResult, time.Now(), pg.WithParentCtx(ctx)); err2 != nil { + if err2 := l.pluginORM.SetResult(ctx, requestID, computationResult, time.Now()); err2 != nil { l.logger.Errorw("call to SetResult failed", "requestID", requestIDStr, "err", err2) return err2 } @@ -464,7 +463,7 @@ func (l *functionsListener) handleOracleResponseV1(response *evmrelayTypes.Oracl ctx, cancel := l.getNewHandlerContext() defer cancel() - if err := l.pluginORM.SetConfirmed(response.RequestId, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.SetConfirmed(ctx, response.RequestId); err != nil { l.logger.Errorw("setting CONFIRMED state failed", "requestID", formatRequestId(response.RequestId), "err", err) } promRequestConfirmed.WithLabelValues(l.contractAddressHex).Inc() @@ -486,7 +485,7 @@ func (l *functionsListener) timeoutRequests() { case <-ticker.C: cutoff := time.Now().Add(-(time.Duration(timeoutSec) * time.Second)) ctx, cancel := l.getNewHandlerContext() - ids, err := l.pluginORM.TimeoutExpiredResults(cutoff, batchSize, pg.WithParentCtx(ctx)) + ids, err := l.pluginORM.TimeoutExpiredResults(ctx, cutoff, batchSize) cancel() if err != nil { l.logger.Errorw("error when calling FindExpiredResults", "err", err) @@ -531,7 +530,7 @@ func (l *functionsListener) pruneRequests() { case <-ticker.C: ctx, cancel := l.getNewHandlerContext() startTime := time.Now() - nTotal, nPruned, err := l.pluginORM.PruneOldestRequests(maxStoredRequests, batchSize, pg.WithParentCtx(ctx)) + nTotal, nPruned, err := l.pluginORM.PruneOldestRequests(ctx, maxStoredRequests, batchSize) cancel() elapsedMillis := time.Since(startTime).Milliseconds() if err != nil { diff --git a/core/services/functions/listener_test.go b/core/services/functions/listener_test.go index 090ced7c91d..d6cd9aa23d6 100644 --- a/core/services/functions/listener_test.go +++ b/core/services/functions/listener_test.go @@ -1,6 +1,7 @@ package functions_test import ( + "context" "encoding/json" "errors" "fmt" @@ -35,7 +36,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" threshold_mocks "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/threshold/mocks" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" @@ -172,7 +172,7 @@ func TestFunctionsListener_HandleOracleRequestV1_Success(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, nil, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -189,7 +189,7 @@ func TestFunctionsListener_HandleOffchainRequest_Success(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, nil, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Return(nil) + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Return(nil) request := &functions_service.OffchainRequest{ RequestId: RequestID[:], @@ -233,7 +233,7 @@ func TestFunctionsListener_HandleOffchainRequest_InternalError(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil, nil, errors.New("error")) - uni.pluginORM.On("SetError", RequestID, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + uni.pluginORM.On("SetError", mock.Anything, RequestID, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) request := &functions_service.OffchainRequest{ RequestId: RequestID[:], @@ -266,7 +266,7 @@ func TestFunctionsListener_HandleOracleRequestV1_ComputationError(t *testing.T) uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrorBytes, nil, nil) - uni.pluginORM.On("SetError", RequestID, mock.Anything, ErrorBytes, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetError", mock.Anything, RequestID, mock.Anything, ErrorBytes, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -307,7 +307,7 @@ func TestFunctionsListener_HandleOracleRequestV1_ThresholdDecryptedSecrets(t *te uni.eaClient.On("FetchEncryptedSecrets", mock.Anything, mock.Anything, RequestIDStr, mock.Anything, mock.Anything).Return(EncryptedSecrets, nil, nil) uni.decryptor.On("Decrypt", mock.Anything, decryptionPlugin.CiphertextId(RequestID[:]), EncryptedSecrets).Return(DecryptedSecrets, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, nil, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -333,7 +333,7 @@ func TestFunctionsListener_HandleOracleRequestV1_CBORTooBig(t *testing.T) { uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return([]types.OracleRequest{request}, nil, nil).Once() uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) - uni.pluginORM.On("SetError", RequestID, functions_service.USER_ERROR, []byte("request too big (max 10 bytes)"), mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetError", mock.Anything, RequestID, functions_service.USER_ERROR, []byte("request too big (max 10 bytes)"), mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -361,7 +361,7 @@ func TestFunctionsListener_ReportSourceCodeDomains(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, Domains, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) var sentMessage []byte @@ -388,7 +388,7 @@ func TestFunctionsListener_PruneRequests(t *testing.T) { uni := NewFunctionsListenerUniverse(t, 0, 1) doneCh := make(chan bool) uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) - uni.pluginORM.On("PruneOldestRequests", functions_service.DefaultPruneMaxStoredRequests, functions_service.DefaultPruneBatchSize, mock.Anything).Return(uint32(0), uint32(0), nil).Run(func(args mock.Arguments) { + uni.pluginORM.On("PruneOldestRequests", mock.Anything, functions_service.DefaultPruneMaxStoredRequests, functions_service.DefaultPruneBatchSize, mock.Anything).Return(uint32(0), uint32(0), nil).Run(func(args mock.Arguments) { doneCh <- true }) @@ -403,7 +403,7 @@ func TestFunctionsListener_TimeoutRequests(t *testing.T) { uni := NewFunctionsListenerUniverse(t, 1, 0) doneCh := make(chan bool) uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) - uni.pluginORM.On("TimeoutExpiredResults", mock.Anything, uint32(1), mock.Anything).Return([]functions_service.RequestID{}, nil).Run(func(args mock.Arguments) { + uni.pluginORM.On("TimeoutExpiredResults", mock.Anything, mock.Anything, uint32(1), mock.Anything).Return([]functions_service.RequestID{}, nil).Run(func(args mock.Arguments) { doneCh <- true }) @@ -423,9 +423,7 @@ func TestFunctionsListener_ORMDoesNotFreezeHandlersForever(t *testing.T) { uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return([]types.OracleRequest{request}, nil, nil).Once() uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - var queryerWrapper pg.Q - args.Get(1).(pg.QOpt)(&queryerWrapper) - <-queryerWrapper.ParentCtx.Done() + <-args.Get(0).(context.Context).Done() ormCallExited.Done() }).Return(errors.New("timeout")) diff --git a/core/services/functions/mocks/orm.go b/core/services/functions/mocks/orm.go index 90055fe6286..ff72916171b 100644 --- a/core/services/functions/mocks/orm.go +++ b/core/services/functions/mocks/orm.go @@ -3,11 +3,11 @@ package mocks import ( + context "context" + functions "github.com/smartcontractkit/chainlink/v2/core/services/functions" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - time "time" ) @@ -16,24 +16,17 @@ type ORM struct { mock.Mock } -// CreateRequest provides a mock function with given fields: request, qopts -func (_m *ORM) CreateRequest(request *functions.Request, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, request) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateRequest provides a mock function with given fields: ctx, request +func (_m *ORM) CreateRequest(ctx context.Context, request *functions.Request) error { + ret := _m.Called(ctx, request) if len(ret) == 0 { panic("no return value specified for CreateRequest") } var r0 error - if rf, ok := ret.Get(0).(func(*functions.Request, ...pg.QOpt) error); ok { - r0 = rf(request, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *functions.Request) error); ok { + r0 = rf(ctx, request) } else { r0 = ret.Error(0) } @@ -41,16 +34,9 @@ func (_m *ORM) CreateRequest(request *functions.Request, qopts ...pg.QOpt) error return r0 } -// FindById provides a mock function with given fields: requestID, qopts -func (_m *ORM) FindById(requestID functions.RequestID, qopts ...pg.QOpt) (*functions.Request, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindById provides a mock function with given fields: ctx, requestID +func (_m *ORM) FindById(ctx context.Context, requestID functions.RequestID) (*functions.Request, error) { + ret := _m.Called(ctx, requestID) if len(ret) == 0 { panic("no return value specified for FindById") @@ -58,19 +44,19 @@ func (_m *ORM) FindById(requestID functions.RequestID, qopts ...pg.QOpt) (*funct var r0 *functions.Request var r1 error - if rf, ok := ret.Get(0).(func(functions.RequestID, ...pg.QOpt) (*functions.Request, error)); ok { - return rf(requestID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID) (*functions.Request, error)); ok { + return rf(ctx, requestID) } - if rf, ok := ret.Get(0).(func(functions.RequestID, ...pg.QOpt) *functions.Request); ok { - r0 = rf(requestID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID) *functions.Request); ok { + r0 = rf(ctx, requestID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*functions.Request) } } - if rf, ok := ret.Get(1).(func(functions.RequestID, ...pg.QOpt) error); ok { - r1 = rf(requestID, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, functions.RequestID) error); ok { + r1 = rf(ctx, requestID) } else { r1 = ret.Error(1) } @@ -78,16 +64,9 @@ func (_m *ORM) FindById(requestID functions.RequestID, qopts ...pg.QOpt) (*funct return r0, r1 } -// FindOldestEntriesByState provides a mock function with given fields: state, limit, qopts -func (_m *ORM) FindOldestEntriesByState(state functions.RequestState, limit uint32, qopts ...pg.QOpt) ([]functions.Request, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, state, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindOldestEntriesByState provides a mock function with given fields: ctx, state, limit +func (_m *ORM) FindOldestEntriesByState(ctx context.Context, state functions.RequestState, limit uint32) ([]functions.Request, error) { + ret := _m.Called(ctx, state, limit) if len(ret) == 0 { panic("no return value specified for FindOldestEntriesByState") @@ -95,19 +74,19 @@ func (_m *ORM) FindOldestEntriesByState(state functions.RequestState, limit uint var r0 []functions.Request var r1 error - if rf, ok := ret.Get(0).(func(functions.RequestState, uint32, ...pg.QOpt) ([]functions.Request, error)); ok { - return rf(state, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestState, uint32) ([]functions.Request, error)); ok { + return rf(ctx, state, limit) } - if rf, ok := ret.Get(0).(func(functions.RequestState, uint32, ...pg.QOpt) []functions.Request); ok { - r0 = rf(state, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestState, uint32) []functions.Request); ok { + r0 = rf(ctx, state, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]functions.Request) } } - if rf, ok := ret.Get(1).(func(functions.RequestState, uint32, ...pg.QOpt) error); ok { - r1 = rf(state, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, functions.RequestState, uint32) error); ok { + r1 = rf(ctx, state, limit) } else { r1 = ret.Error(1) } @@ -115,16 +94,9 @@ func (_m *ORM) FindOldestEntriesByState(state functions.RequestState, limit uint return r0, r1 } -// PruneOldestRequests provides a mock function with given fields: maxRequestsInDB, batchSize, qopts -func (_m *ORM) PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qopts ...pg.QOpt) (uint32, uint32, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, maxRequestsInDB, batchSize) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// PruneOldestRequests provides a mock function with given fields: ctx, maxRequestsInDB, batchSize +func (_m *ORM) PruneOldestRequests(ctx context.Context, maxRequestsInDB uint32, batchSize uint32) (uint32, uint32, error) { + ret := _m.Called(ctx, maxRequestsInDB, batchSize) if len(ret) == 0 { panic("no return value specified for PruneOldestRequests") @@ -133,23 +105,23 @@ func (_m *ORM) PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qop var r0 uint32 var r1 uint32 var r2 error - if rf, ok := ret.Get(0).(func(uint32, uint32, ...pg.QOpt) (uint32, uint32, error)); ok { - return rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint32, uint32) (uint32, uint32, error)); ok { + return rf(ctx, maxRequestsInDB, batchSize) } - if rf, ok := ret.Get(0).(func(uint32, uint32, ...pg.QOpt) uint32); ok { - r0 = rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint32, uint32) uint32); ok { + r0 = rf(ctx, maxRequestsInDB, batchSize) } else { r0 = ret.Get(0).(uint32) } - if rf, ok := ret.Get(1).(func(uint32, uint32, ...pg.QOpt) uint32); ok { - r1 = rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint32, uint32) uint32); ok { + r1 = rf(ctx, maxRequestsInDB, batchSize) } else { r1 = ret.Get(1).(uint32) } - if rf, ok := ret.Get(2).(func(uint32, uint32, ...pg.QOpt) error); ok { - r2 = rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(2).(func(context.Context, uint32, uint32) error); ok { + r2 = rf(ctx, maxRequestsInDB, batchSize) } else { r2 = ret.Error(2) } @@ -157,24 +129,17 @@ func (_m *ORM) PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qop return r0, r1, r2 } -// SetConfirmed provides a mock function with given fields: requestID, qopts -func (_m *ORM) SetConfirmed(requestID functions.RequestID, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetConfirmed provides a mock function with given fields: ctx, requestID +func (_m *ORM) SetConfirmed(ctx context.Context, requestID functions.RequestID) error { + ret := _m.Called(ctx, requestID) if len(ret) == 0 { panic("no return value specified for SetConfirmed") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, ...pg.QOpt) error); ok { - r0 = rf(requestID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID) error); ok { + r0 = rf(ctx, requestID) } else { r0 = ret.Error(0) } @@ -182,24 +147,17 @@ func (_m *ORM) SetConfirmed(requestID functions.RequestID, qopts ...pg.QOpt) err return r0 } -// SetError provides a mock function with given fields: requestID, errorType, computationError, readyAt, readyForProcessing, qopts -func (_m *ORM) SetError(requestID functions.RequestID, errorType functions.ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID, errorType, computationError, readyAt, readyForProcessing) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetError provides a mock function with given fields: ctx, requestID, errorType, computationError, readyAt, readyForProcessing +func (_m *ORM) SetError(ctx context.Context, requestID functions.RequestID, errorType functions.ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool) error { + ret := _m.Called(ctx, requestID, errorType, computationError, readyAt, readyForProcessing) if len(ret) == 0 { panic("no return value specified for SetError") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, functions.ErrType, []byte, time.Time, bool, ...pg.QOpt) error); ok { - r0 = rf(requestID, errorType, computationError, readyAt, readyForProcessing, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID, functions.ErrType, []byte, time.Time, bool) error); ok { + r0 = rf(ctx, requestID, errorType, computationError, readyAt, readyForProcessing) } else { r0 = ret.Error(0) } @@ -207,24 +165,17 @@ func (_m *ORM) SetError(requestID functions.RequestID, errorType functions.ErrTy return r0 } -// SetFinalized provides a mock function with given fields: requestID, reportedResult, reportedError, qopts -func (_m *ORM) SetFinalized(requestID functions.RequestID, reportedResult []byte, reportedError []byte, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID, reportedResult, reportedError) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetFinalized provides a mock function with given fields: ctx, requestID, reportedResult, reportedError +func (_m *ORM) SetFinalized(ctx context.Context, requestID functions.RequestID, reportedResult []byte, reportedError []byte) error { + ret := _m.Called(ctx, requestID, reportedResult, reportedError) if len(ret) == 0 { panic("no return value specified for SetFinalized") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, []byte, []byte, ...pg.QOpt) error); ok { - r0 = rf(requestID, reportedResult, reportedError, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID, []byte, []byte) error); ok { + r0 = rf(ctx, requestID, reportedResult, reportedError) } else { r0 = ret.Error(0) } @@ -232,24 +183,17 @@ func (_m *ORM) SetFinalized(requestID functions.RequestID, reportedResult []byte return r0 } -// SetResult provides a mock function with given fields: requestID, computationResult, readyAt, qopts -func (_m *ORM) SetResult(requestID functions.RequestID, computationResult []byte, readyAt time.Time, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID, computationResult, readyAt) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetResult provides a mock function with given fields: ctx, requestID, computationResult, readyAt +func (_m *ORM) SetResult(ctx context.Context, requestID functions.RequestID, computationResult []byte, readyAt time.Time) error { + ret := _m.Called(ctx, requestID, computationResult, readyAt) if len(ret) == 0 { panic("no return value specified for SetResult") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, []byte, time.Time, ...pg.QOpt) error); ok { - r0 = rf(requestID, computationResult, readyAt, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID, []byte, time.Time) error); ok { + r0 = rf(ctx, requestID, computationResult, readyAt) } else { r0 = ret.Error(0) } @@ -257,16 +201,9 @@ func (_m *ORM) SetResult(requestID functions.RequestID, computationResult []byte return r0 } -// TimeoutExpiredResults provides a mock function with given fields: cutoff, limit, qopts -func (_m *ORM) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg.QOpt) ([]functions.RequestID, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, cutoff, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// TimeoutExpiredResults provides a mock function with given fields: ctx, cutoff, limit +func (_m *ORM) TimeoutExpiredResults(ctx context.Context, cutoff time.Time, limit uint32) ([]functions.RequestID, error) { + ret := _m.Called(ctx, cutoff, limit) if len(ret) == 0 { panic("no return value specified for TimeoutExpiredResults") @@ -274,19 +211,19 @@ func (_m *ORM) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg var r0 []functions.RequestID var r1 error - if rf, ok := ret.Get(0).(func(time.Time, uint32, ...pg.QOpt) ([]functions.RequestID, error)); ok { - return rf(cutoff, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint32) ([]functions.RequestID, error)); ok { + return rf(ctx, cutoff, limit) } - if rf, ok := ret.Get(0).(func(time.Time, uint32, ...pg.QOpt) []functions.RequestID); ok { - r0 = rf(cutoff, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint32) []functions.RequestID); ok { + r0 = rf(ctx, cutoff, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]functions.RequestID) } } - if rf, ok := ret.Get(1).(func(time.Time, uint32, ...pg.QOpt) error); ok { - r1 = rf(cutoff, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, time.Time, uint32) error); ok { + r1 = rf(ctx, cutoff, limit) } else { r1 = ret.Error(1) } diff --git a/core/services/functions/orm.go b/core/services/functions/orm.go index 7838c700858..f45effa9354 100644 --- a/core/services/functions/orm.go +++ b/core/services/functions/orm.go @@ -1,38 +1,37 @@ package functions import ( + "context" "fmt" "time" "github.com/ethereum/go-ethereum/common" - "github.com/pkg/errors" - "github.com/jmoiron/sqlx" + "github.com/pkg/errors" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - CreateRequest(request *Request, qopts ...pg.QOpt) error + CreateRequest(ctx context.Context, request *Request) error - SetResult(requestID RequestID, computationResult []byte, readyAt time.Time, qopts ...pg.QOpt) error - SetError(requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool, qopts ...pg.QOpt) error - SetFinalized(requestID RequestID, reportedResult []byte, reportedError []byte, qopts ...pg.QOpt) error - SetConfirmed(requestID RequestID, qopts ...pg.QOpt) error + SetResult(ctx context.Context, requestID RequestID, computationResult []byte, readyAt time.Time) error + SetError(ctx context.Context, requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool) error + SetFinalized(ctx context.Context, requestID RequestID, reportedResult []byte, reportedError []byte) error + SetConfirmed(ctx context.Context, requestID RequestID) error - TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg.QOpt) ([]RequestID, error) + TimeoutExpiredResults(ctx context.Context, cutoff time.Time, limit uint32) ([]RequestID, error) - FindOldestEntriesByState(state RequestState, limit uint32, qopts ...pg.QOpt) ([]Request, error) - FindById(requestID RequestID, qopts ...pg.QOpt) (*Request, error) + FindOldestEntriesByState(ctx context.Context, state RequestState, limit uint32) ([]Request, error) + FindById(ctx context.Context, requestID RequestID) (*Request, error) - PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qopts ...pg.QOpt) (total uint32, pruned uint32, err error) + PruneOldestRequests(ctx context.Context, maxRequestsInDB uint32, batchSize uint32) (total uint32, pruned uint32, err error) } type orm struct { - q pg.Q + ds sqlutil.DataSource contractAddress common.Address } @@ -49,19 +48,20 @@ const ( "callback_gas_limit, coordinator_contract_address, onchain_metadata, processing_metadata" ) -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, contractAddress common.Address) ORM { +func NewORM(ds sqlutil.DataSource, contractAddress common.Address) ORM { return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, contractAddress: contractAddress, } } -func (o *orm) CreateRequest(request *Request, qopts ...pg.QOpt) error { +func (o *orm) CreateRequest(ctx context.Context, request *Request) error { stmt := fmt.Sprintf(` INSERT INTO %s (request_id, contract_address, received_at, request_tx_hash, state, flags, aggregation_method, callback_gas_limit, coordinator_contract_address, onchain_metadata) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) ON CONFLICT (request_id) DO NOTHING; `, tableName) - result, err := o.q.WithOpts(qopts...).Exec( + result, err := o.ds.ExecContext( + ctx, stmt, request.RequestID, o.contractAddress, @@ -86,11 +86,11 @@ func (o *orm) CreateRequest(request *Request, qopts ...pg.QOpt) error { return nil } -func (o *orm) setWithStateTransitionCheck(requestID RequestID, newState RequestState, setter func(pg.Queryer) error, qopts ...pg.QOpt) error { - err := o.q.WithOpts(qopts...).Transaction(func(tx pg.Queryer) error { +func (o *orm) setWithStateTransitionCheck(ctx context.Context, requestID RequestID, newState RequestState, setter func(sqlutil.DataSource) error) error { + err := sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { prevState := defaultInitialState stmt := fmt.Sprintf(`SELECT state FROM %s WHERE request_id=$1 AND contract_address=$2;`, tableName) - if err2 := tx.Get(&prevState, stmt, requestID, o.contractAddress); err2 != nil { + if err2 := tx.GetContext(ctx, &prevState, stmt, requestID, o.contractAddress); err2 != nil { return err2 } if err2 := CheckStateTransition(prevState, newState); err2 != nil { @@ -102,64 +102,64 @@ func (o *orm) setWithStateTransitionCheck(requestID RequestID, newState RequestS return err } -func (o *orm) SetResult(requestID RequestID, computationResult []byte, readyAt time.Time, qopts ...pg.QOpt) error { +func (o *orm) SetResult(ctx context.Context, requestID RequestID, computationResult []byte, readyAt time.Time) error { newState := RESULT_READY - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(` UPDATE %s SET result=$3, result_ready_at=$4, state=$5 WHERE request_id=$1 AND contract_address=$2; `, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, computationResult, readyAt, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, computationResult, readyAt, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) SetError(requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool, qopts ...pg.QOpt) error { +func (o *orm) SetError(ctx context.Context, requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool) error { var newState RequestState if readyForProcessing { newState = RESULT_READY } else { newState = IN_PROGRESS } - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(` UPDATE %s SET error=$3, error_type=$4, result_ready_at=$5, state=$6 WHERE request_id=$1 AND contract_address=$2; `, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, computationError, errorType, readyAt, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, computationError, errorType, readyAt, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) SetFinalized(requestID RequestID, reportedResult []byte, reportedError []byte, qopts ...pg.QOpt) error { +func (o *orm) SetFinalized(ctx context.Context, requestID RequestID, reportedResult []byte, reportedError []byte) error { newState := FINALIZED - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(` UPDATE %s SET transmitted_result=$3, transmitted_error=$4, state=$5 WHERE request_id=$1 AND contract_address=$2; `, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, reportedResult, reportedError, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, reportedResult, reportedError, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) SetConfirmed(requestID RequestID, qopts ...pg.QOpt) error { +func (o *orm) SetConfirmed(ctx context.Context, requestID RequestID) error { newState := CONFIRMED - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(`UPDATE %s SET state=$3 WHERE request_id=$1 AND contract_address=$2;`, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg.QOpt) ([]RequestID, error) { +func (o *orm) TimeoutExpiredResults(ctx context.Context, cutoff time.Time, limit uint32) ([]RequestID, error) { var ids []RequestID allowedPrevStates := []RequestState{IN_PROGRESS, RESULT_READY, FINALIZED} nextState := TIMED_OUT @@ -169,14 +169,14 @@ func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg. return ids, err } } - err := o.q.WithOpts(qopts...).Transaction(func(tx pg.Queryer) error { + err := sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { selectStmt := fmt.Sprintf(` SELECT request_id FROM %s WHERE (state=$1 OR state=$2 OR state=$3) AND contract_address=$4 AND received_at < ($5) ORDER BY received_at LIMIT $6;`, tableName) - if err2 := tx.Select(&ids, selectStmt, allowedPrevStates[0], allowedPrevStates[1], allowedPrevStates[2], o.contractAddress, cutoff, limit); err2 != nil { + if err2 := tx.SelectContext(ctx, &ids, selectStmt, allowedPrevStates[0], allowedPrevStates[1], allowedPrevStates[2], o.contractAddress, cutoff, limit); err2 != nil { return err2 } if len(ids) == 0 { @@ -200,7 +200,7 @@ func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg. return err2 } updateStmt = tx.Rebind(updateStmt) - if _, err2 := tx.Exec(updateStmt, args...); err2 != nil { + if _, err2 := tx.ExecContext(ctx, updateStmt, args...); err2 != nil { return err2 } return nil @@ -209,28 +209,28 @@ func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg. return ids, err } -func (o *orm) FindOldestEntriesByState(state RequestState, limit uint32, qopts ...pg.QOpt) ([]Request, error) { +func (o *orm) FindOldestEntriesByState(ctx context.Context, state RequestState, limit uint32) ([]Request, error) { var requests []Request stmt := fmt.Sprintf(`SELECT %s FROM %s WHERE state=$1 AND contract_address=$2 ORDER BY received_at LIMIT $3;`, requestFields, tableName) - if err := o.q.WithOpts(qopts...).Select(&requests, stmt, state, o.contractAddress, limit); err != nil { + if err := o.ds.SelectContext(ctx, &requests, stmt, state, o.contractAddress, limit); err != nil { return nil, err } return requests, nil } -func (o *orm) FindById(requestID RequestID, qopts ...pg.QOpt) (*Request, error) { +func (o *orm) FindById(ctx context.Context, requestID RequestID) (*Request, error) { var request Request stmt := fmt.Sprintf(`SELECT %s FROM %s WHERE request_id=$1 AND contract_address=$2;`, requestFields, tableName) - if err := o.q.WithOpts(qopts...).Get(&request, stmt, requestID, o.contractAddress); err != nil { + if err := o.ds.GetContext(ctx, &request, stmt, requestID, o.contractAddress); err != nil { return nil, err } return &request, nil } -func (o *orm) PruneOldestRequests(maxStoredRequests uint32, batchSize uint32, qopts ...pg.QOpt) (total uint32, pruned uint32, err error) { - err = o.q.WithOpts(qopts...).Transaction(func(tx pg.Queryer) error { +func (o *orm) PruneOldestRequests(ctx context.Context, maxStoredRequests uint32, batchSize uint32) (total uint32, pruned uint32, err error) { + err = sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE contract_address=$1`, tableName) - if err2 := tx.Get(&total, stmt, o.contractAddress); err2 != nil { + if err2 := tx.GetContext(ctx, &total, stmt, o.contractAddress); err2 != nil { return errors.Wrap(err, "failed to get request count") } @@ -246,7 +246,7 @@ func (o *orm) PruneOldestRequests(maxStoredRequests uint32, batchSize uint32, qo with := fmt.Sprintf(`WITH ids AS (SELECT request_id FROM %s WHERE contract_address = $1 ORDER BY received_at LIMIT $2)`, tableName) deleteStmt := fmt.Sprintf(`%s DELETE FROM %s WHERE contract_address = $1 AND request_id IN (SELECT request_id FROM ids);`, with, tableName) - res, err2 := tx.Exec(deleteStmt, o.contractAddress, pruneLimit) + res, err2 := tx.ExecContext(ctx, deleteStmt, o.contractAddress, pruneLimit) if err2 != nil { return err2 } diff --git a/core/services/functions/orm_test.go b/core/services/functions/orm_test.go index ca92aafcb0e..37b3a28256f 100644 --- a/core/services/functions/orm_test.go +++ b/core/services/functions/orm_test.go @@ -11,7 +11,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" - "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/functions" ) @@ -28,9 +27,8 @@ func setupORM(t *testing.T) functions.ORM { var ( db = pgtest.NewSqlxDB(t) - lggr = logger.TestLogger(t) contract = testutils.NewAddress() - orm = functions.NewORM(db, lggr, pgtest.NewQConfig(true), contract) + orm = functions.NewORM(db, contract) ) return orm @@ -47,6 +45,7 @@ func createRequest(t *testing.T, orm functions.ORM) (functions.RequestID, common } func createRequestWithTimestamp(t *testing.T, orm functions.ORM, ts time.Time) (functions.RequestID, common.Hash) { + ctx := testutils.Context(t) id := newRequestID() txHash := utils.RandomHash() newReq := &functions.Request{ @@ -59,19 +58,20 @@ func createRequestWithTimestamp(t *testing.T, orm functions.ORM, ts time.Time) ( CoordinatorContractAddress: &defaultCoordinatorContract, OnchainMetadata: defaultMetadata, } - err := orm.CreateRequest(newReq) + err := orm.CreateRequest(ctx, newReq) require.NoError(t, err) return id, txHash } func TestORM_CreateRequestsAndFindByID(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id1, txHash1, ts1 := createRequest(t, orm) id2, txHash2, ts2 := createRequest(t, orm) - req1, err := orm.FindById(id1) + req1, err := orm.FindById(ctx, id1) require.NoError(t, err) require.Equal(t, id1, req1.RequestID) require.Equal(t, &txHash1, req1.RequestTxHash) @@ -83,7 +83,7 @@ func TestORM_CreateRequestsAndFindByID(t *testing.T) { require.Equal(t, defaultCoordinatorContract, *req1.CoordinatorContractAddress) require.Equal(t, defaultMetadata, req1.OnchainMetadata) - req2, err := orm.FindById(id2) + req2, err := orm.FindById(ctx, id2) require.NoError(t, err) require.Equal(t, id2, req2.RequestID) require.Equal(t, &txHash2, req2.RequestTxHash) @@ -91,14 +91,14 @@ func TestORM_CreateRequestsAndFindByID(t *testing.T) { require.Equal(t, functions.IN_PROGRESS, req2.State) t.Run("missing ID", func(t *testing.T) { - req, err := orm.FindById(newRequestID()) + req, err := orm.FindById(testutils.Context(t), newRequestID()) require.Error(t, err) require.Nil(t, req) }) t.Run("duplicated", func(t *testing.T) { newReq := &functions.Request{RequestID: id1, RequestTxHash: &txHash1, ReceivedAt: ts1} - err := orm.CreateRequest(newReq) + err := orm.CreateRequest(testutils.Context(t), newReq) require.Error(t, err) require.True(t, errors.Is(err, functions.ErrDuplicateRequestID)) }) @@ -106,15 +106,16 @@ func TestORM_CreateRequestsAndFindByID(t *testing.T) { func TestORM_SetResult(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, ts := createRequest(t, orm) rdts := time.Now().Round(time.Second) - err := orm.SetResult(id, []byte("result"), rdts) + err := orm.SetResult(ctx, id, []byte("result"), rdts) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, id, req.RequestID) require.Equal(t, ts, req.ReceivedAt) @@ -126,15 +127,16 @@ func TestORM_SetResult(t *testing.T) { func TestORM_SetError(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, ts := createRequest(t, orm) rdts := time.Now().Round(time.Second) - err := orm.SetError(id, functions.USER_ERROR, []byte("error"), rdts, true) + err := orm.SetError(ctx, id, functions.USER_ERROR, []byte("error"), rdts, true) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, id, req.RequestID) require.Equal(t, ts, req.ReceivedAt) @@ -148,15 +150,16 @@ func TestORM_SetError(t *testing.T) { func TestORM_SetError_Internal(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, ts := createRequest(t, orm) rdts := time.Now().Round(time.Second) - err := orm.SetError(id, functions.INTERNAL_ERROR, []byte("error"), rdts, false) + err := orm.SetError(ctx, id, functions.INTERNAL_ERROR, []byte("error"), rdts, false) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, id, req.RequestID) require.Equal(t, ts, req.ReceivedAt) @@ -167,14 +170,15 @@ func TestORM_SetError_Internal(t *testing.T) { func TestORM_SetFinalized(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, _ := createRequest(t, orm) - err := orm.SetFinalized(id, []byte("result"), []byte("error")) + err := orm.SetFinalized(ctx, id, []byte("result"), []byte("error")) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, []byte("result"), req.TransmittedResult) require.Equal(t, []byte("error"), req.TransmittedError) @@ -183,49 +187,51 @@ func TestORM_SetFinalized(t *testing.T) { func TestORM_SetConfirmed(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, _ := createRequest(t, orm) - err := orm.SetConfirmed(id) + err := orm.SetConfirmed(ctx, id) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.CONFIRMED, req.State) } func TestORM_StateTransitions(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() id, _ := createRequestWithTimestamp(t, orm, now) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.IN_PROGRESS, req.State) - err = orm.SetResult(id, []byte{}, now) + err = orm.SetResult(ctx, id, []byte{}, now) require.NoError(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.RESULT_READY, req.State) - _, err = orm.TimeoutExpiredResults(now.Add(time.Minute), 1) + _, err = orm.TimeoutExpiredResults(ctx, now.Add(time.Minute), 1) require.NoError(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.TIMED_OUT, req.State) - err = orm.SetFinalized(id, nil, nil) + err = orm.SetFinalized(ctx, id, nil, nil) require.Error(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.TIMED_OUT, req.State) - err = orm.SetConfirmed(id) + err = orm.SetConfirmed(ctx, id) require.NoError(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.CONFIRMED, req.State) } @@ -240,7 +246,8 @@ func TestORM_FindOldestEntriesByState(t *testing.T) { id1, _ := createRequestWithTimestamp(t, orm, now.Add(1*time.Minute)) t.Run("with limit", func(t *testing.T) { - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 2) + ctx := testutils.Context(t) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 2) require.NoError(t, err) require.Equal(t, 2, len(result), "incorrect results length") require.Equal(t, id1, result[0].RequestID, "incorrect results order") @@ -255,13 +262,15 @@ func TestORM_FindOldestEntriesByState(t *testing.T) { }) t.Run("with no limit", func(t *testing.T) { - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 20) + ctx := testutils.Context(t) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 20) require.NoError(t, err) require.Equal(t, 3, len(result), "incorrect results length") }) t.Run("no matching entries", func(t *testing.T) { - result, err := orm.FindOldestEntriesByState(functions.RESULT_READY, 10) + ctx := testutils.Context(t) + result, err := orm.FindOldestEntriesByState(ctx, functions.RESULT_READY, 10) require.NoError(t, err) require.Equal(t, 0, len(result), "incorrect results length") }) @@ -269,6 +278,7 @@ func TestORM_FindOldestEntriesByState(t *testing.T) { func TestORM_TimeoutExpiredResults(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() @@ -278,26 +288,26 @@ func TestORM_TimeoutExpiredResults(t *testing.T) { ids = append(ids, id) } // can time out IN_PROGRESS, RESULT_READY or FINALIZED - err := orm.SetResult(ids[0], []byte("result"), now) + err := orm.SetResult(ctx, ids[0], []byte("result"), now) require.NoError(t, err) - err = orm.SetFinalized(ids[1], []byte("result"), []byte("")) + err = orm.SetFinalized(ctx, ids[1], []byte("result"), []byte("")) require.NoError(t, err) // can't time out CONFIRMED - err = orm.SetConfirmed(ids[2]) + err = orm.SetConfirmed(ctx, ids[2]) require.NoError(t, err) - results, err := orm.TimeoutExpiredResults(now.Add(-35*time.Minute), 1) + results, err := orm.TimeoutExpiredResults(ctx, now.Add(-35*time.Minute), 1) require.NoError(t, err) require.Equal(t, 1, len(results), "not respecting limit") require.Equal(t, ids[0], results[0], "incorrect results order") - results, err = orm.TimeoutExpiredResults(now.Add(-15*time.Minute), 10) + results, err = orm.TimeoutExpiredResults(ctx, now.Add(-15*time.Minute), 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, ids[1], results[0], "incorrect results order") require.Equal(t, ids[3], results[1], "incorrect results order") - results, err = orm.TimeoutExpiredResults(now.Add(-15*time.Minute), 10) + results, err = orm.TimeoutExpiredResults(ctx, now.Add(-15*time.Minute), 10) require.NoError(t, err) require.Equal(t, 0, len(results), "not idempotent") @@ -309,7 +319,7 @@ func TestORM_TimeoutExpiredResults(t *testing.T) { functions.IN_PROGRESS, } for i, expectedState := range expectedFinalStates { - req, err := orm.FindById(ids[i]) + req, err := orm.FindById(ctx, ids[i]) require.NoError(t, err) require.Equal(t, req.State, expectedState, "incorrect state") } @@ -317,6 +327,7 @@ func TestORM_TimeoutExpiredResults(t *testing.T) { func TestORM_PruneOldestRequests(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() @@ -328,31 +339,31 @@ func TestORM_PruneOldestRequests(t *testing.T) { } // don't prune if max not hit - total, pruned, err := orm.PruneOldestRequests(6, 3) + total, pruned, err := orm.PruneOldestRequests(ctx, 6, 3) require.NoError(t, err) require.Equal(t, uint32(5), total) require.Equal(t, uint32(0), pruned) // prune up to max batch size - total, pruned, err = orm.PruneOldestRequests(1, 2) + total, pruned, err = orm.PruneOldestRequests(ctx, 1, 2) require.NoError(t, err) require.Equal(t, uint32(5), total) require.Equal(t, uint32(2), pruned) // prune all above the limit - total, pruned, err = orm.PruneOldestRequests(1, 20) + total, pruned, err = orm.PruneOldestRequests(ctx, 1, 20) require.NoError(t, err) require.Equal(t, uint32(3), total) require.Equal(t, uint32(2), pruned) // no pruning needed any more - total, pruned, err = orm.PruneOldestRequests(1, 20) + total, pruned, err = orm.PruneOldestRequests(ctx, 1, 20) require.NoError(t, err) require.Equal(t, uint32(1), total) require.Equal(t, uint32(0), pruned) // verify only the newest one is left after pruning - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 20) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 20) require.NoError(t, err) require.Equal(t, 1, len(result), "incorrect results length") require.Equal(t, ids[4], result[0].RequestID, "incorrect results order") @@ -360,6 +371,7 @@ func TestORM_PruneOldestRequests(t *testing.T) { func TestORM_PruneOldestRequests_Large(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() @@ -369,13 +381,13 @@ func TestORM_PruneOldestRequests_Large(t *testing.T) { } // prune 900/1000 - total, pruned, err := orm.PruneOldestRequests(100, 1000) + total, pruned, err := orm.PruneOldestRequests(ctx, 100, 1000) require.NoError(t, err) require.Equal(t, uint32(1000), total) require.Equal(t, uint32(900), pruned) // verify there's 100 left - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 200) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 200) require.NoError(t, err) require.Equal(t, 100, len(result), "incorrect results length") } diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist.go b/core/services/gateway/handlers/functions/allowlist/allowlist.go index 020de2359c2..f0fe5c8c829 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist.go @@ -128,7 +128,7 @@ func (a *onchainAllowlist) Start(ctx context.Context) error { return nil } - a.loadStoredAllowedSenderList() + a.loadStoredAllowedSenderList(ctx) updateOnce := func() { timeoutCtx, cancel := utils.ContextFromChanWithTimeout(a.stopCh, time.Duration(a.config.UpdateTimeoutSec)*time.Second) @@ -245,12 +245,12 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b return errors.Wrap(err, "error calling GetAllAllowedSenders") } - err = a.orm.PurgeAllowedSenders() + err = a.orm.PurgeAllowedSenders(ctx) if err != nil { a.lggr.Errorf("failed to purge allowedSenderList: %w", err) } - err = a.orm.CreateAllowedSenders(allowedSenderList) + err = a.orm.CreateAllowedSenders(ctx, allowedSenderList) if err != nil { a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) } @@ -290,7 +290,7 @@ func (a *onchainAllowlist) getAllowedSendersBatched(ctx context.Context, tosCont } allowedSenderList = append(allowedSenderList, allowedSendersBatch...) - err = a.orm.CreateAllowedSenders(allowedSendersBatch) + err = a.orm.CreateAllowedSenders(ctx, allowedSendersBatch) if err != nil { a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) } @@ -330,7 +330,7 @@ func (a *onchainAllowlist) syncBlockedSenders(ctx context.Context, tosContract * return errors.Wrap(err, "error calling GetAllowedSendersInRange") } - err = a.orm.DeleteAllowedSenders(blockedSendersBatch) + err = a.orm.DeleteAllowedSenders(ctx, blockedSendersBatch) if err != nil { a.lggr.Errorf("failed to delete blocked address from allowed list in storage: %w", err) } @@ -349,11 +349,11 @@ func (a *onchainAllowlist) update(addrList []common.Address) { a.lggr.Infow("allowlist updated successfully", "len", len(addrList)) } -func (a *onchainAllowlist) loadStoredAllowedSenderList() { +func (a *onchainAllowlist) loadStoredAllowedSenderList(ctx context.Context) { allowedList := make([]common.Address, 0) offset := uint(0) for { - asBatch, err := a.orm.GetAllowedSenders(offset, a.config.StoredAllowlistBatchSize) + asBatch, err := a.orm.GetAllowedSenders(ctx, offset, a.config.StoredAllowlistBatchSize) if err != nil { a.lggr.Errorf("failed to get stored allowed senders: %w", err) break diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go index 735c0bff7dc..d4900627bdb 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go @@ -58,8 +58,8 @@ func TestUpdateAndCheck(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("PurgeAllowedSenders").Times(1).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("PurgeAllowedSenders", mock.Anything).Times(1).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -99,8 +99,8 @@ func TestUpdateAndCheck(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("DeleteAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -163,9 +163,9 @@ func TestUpdatePeriodically(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("PurgeAllowedSenders").Times(1).Return(nil) - orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("PurgeAllowedSenders", mock.Anything).Times(1).Return(nil) + orm.On("GetAllowedSenders", mock.Anything, uint(0), uint(1000)).Return([]common.Address{}, nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -207,9 +207,9 @@ func TestUpdatePeriodically(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) - orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("DeleteAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("GetAllowedSenders", mock.Anything, uint(0), uint(1000)).Return([]common.Address{}, nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -258,8 +258,8 @@ func TestUpdateFromContract(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("PurgeAllowedSenders").Times(1).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("PurgeAllowedSenders", mock.Anything).Times(1).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -301,8 +301,8 @@ func TestUpdateFromContract(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) + orm.On("DeleteAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) diff --git a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go index daff33d8902..76121270518 100644 --- a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go @@ -3,10 +3,11 @@ package mocks import ( + context "context" + common "github.com/ethereum/go-ethereum/common" - mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + mock "github.com/stretchr/testify/mock" ) // ORM is an autogenerated mock type for the ORM type @@ -14,24 +15,17 @@ type ORM struct { mock.Mock } -// CreateAllowedSenders provides a mock function with given fields: allowedSenders, qopts -func (_m *ORM) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, allowedSenders) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateAllowedSenders provides a mock function with given fields: ctx, allowedSenders +func (_m *ORM) CreateAllowedSenders(ctx context.Context, allowedSenders []common.Address) error { + ret := _m.Called(ctx, allowedSenders) if len(ret) == 0 { panic("no return value specified for CreateAllowedSenders") } var r0 error - if rf, ok := ret.Get(0).(func([]common.Address, ...pg.QOpt) error); ok { - r0 = rf(allowedSenders, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) error); ok { + r0 = rf(ctx, allowedSenders) } else { r0 = ret.Error(0) } @@ -39,24 +33,17 @@ func (_m *ORM) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg return r0 } -// DeleteAllowedSenders provides a mock function with given fields: blockedSenders, qopts -func (_m *ORM) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, blockedSenders) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// DeleteAllowedSenders provides a mock function with given fields: ctx, blockedSenders +func (_m *ORM) DeleteAllowedSenders(ctx context.Context, blockedSenders []common.Address) error { + ret := _m.Called(ctx, blockedSenders) if len(ret) == 0 { panic("no return value specified for DeleteAllowedSenders") } var r0 error - if rf, ok := ret.Get(0).(func([]common.Address, ...pg.QOpt) error); ok { - r0 = rf(blockedSenders, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) error); ok { + r0 = rf(ctx, blockedSenders) } else { r0 = ret.Error(0) } @@ -64,16 +51,9 @@ func (_m *ORM) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg return r0 } -// GetAllowedSenders provides a mock function with given fields: offset, limit, qopts -func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]common.Address, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, offset, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetAllowedSenders provides a mock function with given fields: ctx, offset, limit +func (_m *ORM) GetAllowedSenders(ctx context.Context, offset uint, limit uint) ([]common.Address, error) { + ret := _m.Called(ctx, offset, limit) if len(ret) == 0 { panic("no return value specified for GetAllowedSenders") @@ -81,19 +61,19 @@ func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]c var r0 []common.Address var r1 error - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) ([]common.Address, error)); ok { - return rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) ([]common.Address, error)); ok { + return rf(ctx, offset, limit) } - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) []common.Address); ok { - r0 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) []common.Address); ok { + r0 = rf(ctx, offset, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]common.Address) } } - if rf, ok := ret.Get(1).(func(uint, uint, ...pg.QOpt) error); ok { - r1 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint, uint) error); ok { + r1 = rf(ctx, offset, limit) } else { r1 = ret.Error(1) } @@ -101,23 +81,17 @@ func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]c return r0, r1 } -// PurgeAllowedSenders provides a mock function with given fields: qopts -func (_m *ORM) PurgeAllowedSenders(qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// PurgeAllowedSenders provides a mock function with given fields: ctx +func (_m *ORM) PurgeAllowedSenders(ctx context.Context) error { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for PurgeAllowedSenders") } var r0 error - if rf, ok := ret.Get(0).(func(...pg.QOpt) error); ok { - r0 = rf(qopts...) + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) } else { r0 = ret.Error(0) } diff --git a/core/services/gateway/handlers/functions/allowlist/orm.go b/core/services/gateway/handlers/functions/allowlist/orm.go index ccacec81a43..7867c06d5d4 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/orm.go @@ -1,28 +1,27 @@ package allowlist import ( + "context" "fmt" "strings" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common.Address, error) - CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error - DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error - PurgeAllowedSenders(qopts ...pg.QOpt) error + GetAllowedSenders(ctx context.Context, offset, limit uint) ([]common.Address, error) + CreateAllowedSenders(ctx context.Context, allowedSenders []common.Address) error + DeleteAllowedSenders(ctx context.Context, blockedSenders []common.Address) error + PurgeAllowedSenders(ctx context.Context) error } type orm struct { - q pg.Q + ds sqlutil.DataSource lggr logger.Logger routerContractAddress common.Address } @@ -36,19 +35,19 @@ const ( tableName = "functions_allowlist" ) -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, routerContractAddress common.Address) (ORM, error) { - if db == nil || cfg == nil || lggr == nil || routerContractAddress == (common.Address{}) { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, routerContractAddress common.Address) (ORM, error) { + if ds == nil || lggr == nil || routerContractAddress == (common.Address{}) { return nil, ErrInvalidParameters } return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, lggr: lggr, routerContractAddress: routerContractAddress, }, nil } -func (o *orm) GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common.Address, error) { +func (o *orm) GetAllowedSenders(ctx context.Context, offset, limit uint) ([]common.Address, error) { var addresses []common.Address stmt := fmt.Sprintf(` SELECT allowed_address @@ -58,7 +57,7 @@ func (o *orm) GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common. OFFSET $2 LIMIT $3; `, tableName) - err := o.q.WithOpts(qopts...).Select(&addresses, stmt, o.routerContractAddress, offset, limit) + err := o.ds.SelectContext(ctx, &addresses, stmt, o.routerContractAddress, offset, limit) if err != nil { return addresses, err } @@ -67,7 +66,7 @@ func (o *orm) GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common. return addresses, nil } -func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error { +func (o *orm) CreateAllowedSenders(ctx context.Context, allowedSenders []common.Address) error { var valuesPlaceholder []string for i := 1; i <= len(allowedSenders)*2; i += 2 { valuesPlaceholder = append(valuesPlaceholder, fmt.Sprintf("($%d, $%d)", i, i+1)) @@ -82,7 +81,7 @@ func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg. args = append(args, as, o.routerContractAddress) } - _, err := o.q.WithOpts(qopts...).Exec(stmt, args...) + _, err := o.ds.ExecContext(ctx, stmt, args...) if err != nil { return err } @@ -94,7 +93,7 @@ func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg. // DeleteAllowedSenders is used to remove blocked senders from the functions_allowlist table. // This is achieved by specifying a list of blockedSenders to remove. -func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error { +func (o *orm) DeleteAllowedSenders(ctx context.Context, blockedSenders []common.Address) error { var valuesPlaceholder []string for i := 1; i <= len(blockedSenders); i++ { valuesPlaceholder = append(valuesPlaceholder, fmt.Sprintf("$%d", i+1)) @@ -110,7 +109,7 @@ func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg. args = append(args, bs) } - res, err := o.q.WithOpts(qopts...).Exec(stmt, args...) + res, err := o.ds.ExecContext(ctx, stmt, args...) if err != nil { return err } @@ -126,12 +125,12 @@ func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg. } // PurgeAllowedSenders will remove all the allowed senders for the configured orm routerContractAddress -func (o *orm) PurgeAllowedSenders(qopts ...pg.QOpt) error { +func (o *orm) PurgeAllowedSenders(ctx context.Context) error { stmt := fmt.Sprintf(` DELETE FROM %s WHERE router_contract_address = $1;`, tableName) - res, err := o.q.WithOpts(qopts...).Exec(stmt, o.routerContractAddress) + res, err := o.ds.ExecContext(ctx, stmt, o.routerContractAddress) if err != nil { return err } diff --git a/core/services/gateway/handlers/functions/allowlist/orm_test.go b/core/services/gateway/handlers/functions/allowlist/orm_test.go index 1d357616fab..2584e131968 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm_test.go +++ b/core/services/gateway/handlers/functions/allowlist/orm_test.go @@ -20,17 +20,18 @@ func setupORM(t *testing.T) (allowlist.ORM, error) { lggr = logger.TestLogger(t) ) - return allowlist.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + return allowlist.NewORM(db, lggr, testutils.NewAddress()) } func seedAllowedSenders(t *testing.T, orm allowlist.ORM, amount int) []common.Address { + ctx := testutils.Context(t) storedAllowedSenders := make([]common.Address, amount) for i := 0; i < amount; i++ { address := testutils.NewAddress() storedAllowedSenders[i] = address } - err := orm.CreateAllowedSenders(storedAllowedSenders) + err := orm.CreateAllowedSenders(ctx, storedAllowedSenders) require.NoError(t, err) return storedAllowedSenders @@ -38,20 +39,22 @@ func seedAllowedSenders(t *testing.T, orm allowlist.ORM, amount int) []common.Ad func TestORM_GetAllowedSenders(t *testing.T) { t.Parallel() t.Run("fetch first page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedAllowedSenders := seedAllowedSenders(t, orm, 2) - results, err := orm.GetAllowedSenders(0, 1) + results, err := orm.GetAllowedSenders(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedAllowedSenders[0], results[0]) }) t.Run("fetch second page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedAllowedSenders := seedAllowedSenders(t, orm, 2) - results, err := orm.GetAllowedSenders(1, 5) + results, err := orm.GetAllowedSenders(ctx, 1, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedAllowedSenders[1], results[0]) @@ -62,42 +65,45 @@ func TestORM_CreateAllowedSenders(t *testing.T) { t.Parallel() t.Run("OK-create_an_allowed_sender", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{expected}) + err = orm.CreateAllowedSenders(ctx, []common.Address{expected}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 1) + results, err := orm.GetAllowedSenders(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, expected, results[0]) }) t.Run("OK-create_an_existing_allowed_sender", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{expected}) + err = orm.CreateAllowedSenders(ctx, []common.Address{expected}) require.NoError(t, err) - err = orm.CreateAllowedSenders([]common.Address{expected}) + err = orm.CreateAllowedSenders(ctx, []common.Address{expected}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 5) + results, err := orm.GetAllowedSenders(ctx, 0, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, expected, results[0]) }) t.Run("OK-create_multiple_allowed_senders_in_one_query", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := []common.Address{testutils.NewAddress(), testutils.NewAddress()} - err = orm.CreateAllowedSenders(expected) + err = orm.CreateAllowedSenders(ctx, expected) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 2) + results, err := orm.GetAllowedSenders(ctx, 0, 2) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, expected[0], results[0]) @@ -105,6 +111,7 @@ func TestORM_CreateAllowedSenders(t *testing.T) { }) t.Run("OK-create_multiple_allowed_senders_with_duplicates", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) addr1 := testutils.NewAddress() @@ -112,10 +119,10 @@ func TestORM_CreateAllowedSenders(t *testing.T) { expected := []common.Address{addr1, addr2} duplicatedAddressInput := []common.Address{addr1, addr1, addr1, addr2} - err = orm.CreateAllowedSenders(duplicatedAddressInput) + err = orm.CreateAllowedSenders(ctx, duplicatedAddressInput) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, expected[0], results[0]) @@ -127,46 +134,48 @@ func TestORM_DeleteAllowedSenders(t *testing.T) { t.Parallel() t.Run("OK-delete_blocked_sender_from_allowed_list", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() add3 := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{add1, add2, add3}) + err = orm.CreateAllowedSenders(ctx, []common.Address{add1, add2, add3}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 3, len(results), "incorrect results length") require.Equal(t, add1, results[0]) - err = orm.DeleteAllowedSenders([]common.Address{add1, add3}) + err = orm.DeleteAllowedSenders(ctx, []common.Address{add1, add3}) require.NoError(t, err) - results, err = orm.GetAllowedSenders(0, 10) + results, err = orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, add2, results[0]) }) t.Run("OK-delete_non_existing_blocked_sender_from_allowed_list", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{add1, add2}) + err = orm.CreateAllowedSenders(ctx, []common.Address{add1, add2}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) add3 := testutils.NewAddress() - err = orm.DeleteAllowedSenders([]common.Address{add3}) + err = orm.DeleteAllowedSenders(ctx, []common.Address{add3}) require.NoError(t, err) - results, err = orm.GetAllowedSenders(0, 10) + results, err = orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) @@ -178,36 +187,38 @@ func TestORM_PurgeAllowedSenders(t *testing.T) { t.Parallel() t.Run("OK-purge_allowed_list", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() add3 := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{add1, add2, add3}) + err = orm.CreateAllowedSenders(ctx, []common.Address{add1, add2, add3}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 3, len(results), "incorrect results length") require.Equal(t, add1, results[0]) - err = orm.PurgeAllowedSenders() + err = orm.PurgeAllowedSenders(ctx) require.NoError(t, err) - results, err = orm.GetAllowedSenders(0, 10) + results, err = orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 0, len(results), "incorrect results length") }) t.Run("OK-purge_allowed_list_for_contract_address", func(t *testing.T) { + ctx := testutils.Context(t) orm1, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() - err = orm1.CreateAllowedSenders([]common.Address{add1, add2}) + err = orm1.CreateAllowedSenders(ctx, []common.Address{add1, add2}) require.NoError(t, err) - results, err := orm1.GetAllowedSenders(0, 10) + results, err := orm1.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) @@ -216,22 +227,22 @@ func TestORM_PurgeAllowedSenders(t *testing.T) { require.NoError(t, err) add3 := testutils.NewAddress() add4 := testutils.NewAddress() - err = orm2.CreateAllowedSenders([]common.Address{add3, add4}) + err = orm2.CreateAllowedSenders(ctx, []common.Address{add3, add4}) require.NoError(t, err) - results, err = orm2.GetAllowedSenders(0, 10) + results, err = orm2.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add3, results[0]) - err = orm2.PurgeAllowedSenders() + err = orm2.PurgeAllowedSenders(ctx) require.NoError(t, err) - results, err = orm2.GetAllowedSenders(0, 10) + results, err = orm2.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 0, len(results), "incorrect results length") - results, err = orm1.GetAllowedSenders(0, 10) + results, err = orm1.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) @@ -241,15 +252,15 @@ func TestORM_PurgeAllowedSenders(t *testing.T) { func Test_NewORM(t *testing.T) { t.Run("OK-create_ORM", func(t *testing.T) { - _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), testutils.NewAddress()) + _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), testutils.NewAddress()) require.NoError(t, err) }) t.Run("NOK-create_ORM_with_nil_fields", func(t *testing.T) { - _, err := allowlist.NewORM(nil, nil, nil, common.Address{}) + _, err := allowlist.NewORM(nil, nil, common.Address{}) require.Error(t, err) }) t.Run("NOK-create_ORM_with_empty_address", func(t *testing.T) { - _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), common.Address{}) + _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), common.Address{}) require.Error(t, err) }) } diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index ff272e4e577..692534db598 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -114,7 +114,7 @@ func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *con return nil, err2 } - orm, err2 := fallow.NewORM(db, lggr, qcfg, cfg.OnchainAllowlist.ContractAddress) + orm, err2 := fallow.NewORM(db, lggr, cfg.OnchainAllowlist.ContractAddress) if err2 != nil { return nil, err2 } @@ -143,7 +143,7 @@ func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *con return nil, err2 } - orm, err2 := fsub.NewORM(db, lggr, qcfg, cfg.OnchainSubscriptions.ContractAddress) + orm, err2 := fsub.NewORM(db, lggr, cfg.OnchainSubscriptions.ContractAddress) if err2 != nil { return nil, err2 } diff --git a/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go b/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go index 0f278aa49b0..16a82a488b4 100644 --- a/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go +++ b/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go @@ -3,8 +3,9 @@ package mocks import ( + context "context" + subscriptions "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/subscriptions" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" mock "github.com/stretchr/testify/mock" ) @@ -13,16 +14,9 @@ type ORM struct { mock.Mock } -// GetSubscriptions provides a mock function with given fields: offset, limit, qopts -func (_m *ORM) GetSubscriptions(offset uint, limit uint, qopts ...pg.QOpt) ([]subscriptions.StoredSubscription, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, offset, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetSubscriptions provides a mock function with given fields: ctx, offset, limit +func (_m *ORM) GetSubscriptions(ctx context.Context, offset uint, limit uint) ([]subscriptions.StoredSubscription, error) { + ret := _m.Called(ctx, offset, limit) if len(ret) == 0 { panic("no return value specified for GetSubscriptions") @@ -30,19 +24,19 @@ func (_m *ORM) GetSubscriptions(offset uint, limit uint, qopts ...pg.QOpt) ([]su var r0 []subscriptions.StoredSubscription var r1 error - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) ([]subscriptions.StoredSubscription, error)); ok { - return rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) ([]subscriptions.StoredSubscription, error)); ok { + return rf(ctx, offset, limit) } - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) []subscriptions.StoredSubscription); ok { - r0 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) []subscriptions.StoredSubscription); ok { + r0 = rf(ctx, offset, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]subscriptions.StoredSubscription) } } - if rf, ok := ret.Get(1).(func(uint, uint, ...pg.QOpt) error); ok { - r1 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint, uint) error); ok { + r1 = rf(ctx, offset, limit) } else { r1 = ret.Error(1) } @@ -50,24 +44,17 @@ func (_m *ORM) GetSubscriptions(offset uint, limit uint, qopts ...pg.QOpt) ([]su return r0, r1 } -// UpsertSubscription provides a mock function with given fields: subscription, qopts -func (_m *ORM) UpsertSubscription(subscription subscriptions.StoredSubscription, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, subscription) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpsertSubscription provides a mock function with given fields: ctx, subscription +func (_m *ORM) UpsertSubscription(ctx context.Context, subscription subscriptions.StoredSubscription) error { + ret := _m.Called(ctx, subscription) if len(ret) == 0 { panic("no return value specified for UpsertSubscription") } var r0 error - if rf, ok := ret.Get(0).(func(subscriptions.StoredSubscription, ...pg.QOpt) error); ok { - r0 = rf(subscription, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, subscriptions.StoredSubscription) error); ok { + r0 = rf(ctx, subscription) } else { r0 = ret.Error(0) } diff --git a/core/services/gateway/handlers/functions/subscriptions/orm.go b/core/services/gateway/handlers/functions/subscriptions/orm.go index 369291ace54..d97437a39dc 100644 --- a/core/services/gateway/handlers/functions/subscriptions/orm.go +++ b/core/services/gateway/handlers/functions/subscriptions/orm.go @@ -1,6 +1,7 @@ package subscriptions import ( + "context" "fmt" "math/big" @@ -8,21 +9,19 @@ import ( "github.com/lib/pq" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/functions/generated/functions_router" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSubscription, error) - UpsertSubscription(subscription StoredSubscription, qopts ...pg.QOpt) error + GetSubscriptions(ctx context.Context, offset, limit uint) ([]StoredSubscription, error) + UpsertSubscription(ctx context.Context, subscription StoredSubscription) error } type orm struct { - q pg.Q + ds sqlutil.DataSource lggr logger.Logger routerContractAddress common.Address } @@ -47,19 +46,19 @@ type storedSubscriptionRow struct { RouterContractAddress common.Address } -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, routerContractAddress common.Address) (ORM, error) { - if db == nil || cfg == nil || lggr == nil || routerContractAddress == (common.Address{}) { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, routerContractAddress common.Address) (ORM, error) { + if ds == nil || lggr == nil || routerContractAddress == (common.Address{}) { return nil, ErrInvalidParameters } return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, lggr: lggr, routerContractAddress: routerContractAddress, }, nil } -func (o *orm) GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSubscription, error) { +func (o *orm) GetSubscriptions(ctx context.Context, offset, limit uint) ([]StoredSubscription, error) { var storedSubscriptions []StoredSubscription var storedSubscriptionRows []storedSubscriptionRow stmt := fmt.Sprintf(` @@ -70,7 +69,7 @@ func (o *orm) GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSu OFFSET $2 LIMIT $3; `, tableName) - err := o.q.WithOpts(qopts...).Select(&storedSubscriptionRows, stmt, o.routerContractAddress, offset, limit) + err := o.ds.SelectContext(ctx, &storedSubscriptionRows, stmt, o.routerContractAddress, offset, limit) if err != nil { return storedSubscriptions, err } @@ -84,7 +83,7 @@ func (o *orm) GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSu // UpsertSubscription will update if a subscription exists or create if it does not. // In case a subscription gets deleted we will update it with an owner address equal to 0x0. -func (o *orm) UpsertSubscription(subscription StoredSubscription, qopts ...pg.QOpt) error { +func (o *orm) UpsertSubscription(ctx context.Context, subscription StoredSubscription) error { stmt := fmt.Sprintf(` INSERT INTO %s (subscription_id, owner, balance, blocked_balance, proposed_owner, consumers, flags, router_contract_address) VALUES ($1,$2,$3,$4,$5,$6,$7,$8) ON CONFLICT (subscription_id, router_contract_address) DO UPDATE @@ -103,7 +102,8 @@ func (o *orm) UpsertSubscription(subscription StoredSubscription, qopts ...pg.QO consumers = append(consumers, c.Bytes()) } - _, err := o.q.WithOpts(qopts...).Exec( + _, err := o.ds.ExecContext( + ctx, stmt, subscription.SubscriptionID, subscription.Owner, diff --git a/core/services/gateway/handlers/functions/subscriptions/orm_test.go b/core/services/gateway/handlers/functions/subscriptions/orm_test.go index 6cb1146f03c..f75ab0b98c1 100644 --- a/core/services/gateway/handlers/functions/subscriptions/orm_test.go +++ b/core/services/gateway/handlers/functions/subscriptions/orm_test.go @@ -27,10 +27,11 @@ func setupORM(t *testing.T) (subscriptions.ORM, error) { lggr = logger.TestLogger(t) ) - return subscriptions.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + return subscriptions.NewORM(db, lggr, testutils.NewAddress()) } func seedSubscriptions(t *testing.T, orm subscriptions.ORM, amount int) []subscriptions.StoredSubscription { + ctx := testutils.Context(t) storedSubscriptions := make([]subscriptions.StoredSubscription, 0) for i := amount; i > 0; i-- { cs := subscriptions.StoredSubscription{ @@ -45,7 +46,7 @@ func seedSubscriptions(t *testing.T, orm subscriptions.ORM, amount int) []subscr }, } storedSubscriptions = append(storedSubscriptions, cs) - err := orm.UpsertSubscription(cs) + err := orm.UpsertSubscription(ctx, cs) require.NoError(t, err) } return storedSubscriptions @@ -54,20 +55,22 @@ func seedSubscriptions(t *testing.T, orm subscriptions.ORM, amount int) []subscr func TestORM_GetSubscriptions(t *testing.T) { t.Parallel() t.Run("fetch first page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedSubscriptions := seedSubscriptions(t, orm, 2) - results, err := orm.GetSubscriptions(0, 1) + results, err := orm.GetSubscriptions(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedSubscriptions[1], results[0]) }) t.Run("fetch second page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedSubscriptions := seedSubscriptions(t, orm, 2) - results, err := orm.GetSubscriptions(1, 5) + results, err := orm.GetSubscriptions(ctx, 1, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedSubscriptions[0], results[0]) @@ -78,6 +81,7 @@ func TestORM_UpsertSubscription(t *testing.T) { t.Parallel() t.Run("create a subscription", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := subscriptions.StoredSubscription{ @@ -91,16 +95,17 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(expected) + err = orm.UpsertSubscription(ctx, expected) require.NoError(t, err) - results, err := orm.GetSubscriptions(0, 1) + results, err := orm.GetSubscriptions(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, expected, results[0]) }) t.Run("update a subscription", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) @@ -115,7 +120,7 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(expectedUpdated) + err = orm.UpsertSubscription(ctx, expectedUpdated) require.NoError(t, err) expectedNotUpdated := subscriptions.StoredSubscription{ @@ -129,15 +134,15 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(expectedNotUpdated) + err = orm.UpsertSubscription(ctx, expectedNotUpdated) require.NoError(t, err) // update the balance value expectedUpdated.Balance = big.NewInt(20) - err = orm.UpsertSubscription(expectedUpdated) + err = orm.UpsertSubscription(ctx, expectedUpdated) require.NoError(t, err) - results, err := orm.GetSubscriptions(0, 5) + results, err := orm.GetSubscriptions(ctx, 0, 5) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, expectedNotUpdated, results[1]) @@ -145,6 +150,7 @@ func TestORM_UpsertSubscription(t *testing.T) { }) t.Run("update a deleted subscription", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) @@ -159,7 +165,7 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(subscription) + err = orm.UpsertSubscription(ctx, subscription) require.NoError(t, err) // empty subscription @@ -172,24 +178,25 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: [32]byte{}, } - err = orm.UpsertSubscription(subscription) + err = orm.UpsertSubscription(ctx, subscription) require.NoError(t, err) - results, err := orm.GetSubscriptions(0, 5) + results, err := orm.GetSubscriptions(ctx, 0, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, subscription, results[0]) }) t.Run("create a subscription with same id but different router address", func(t *testing.T) { + ctx := testutils.Context(t) var ( db = pgtest.NewSqlxDB(t) lggr = logger.TestLogger(t) ) - orm1, err := subscriptions.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + orm1, err := subscriptions.NewORM(db, lggr, testutils.NewAddress()) require.NoError(t, err) - orm2, err := subscriptions.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + orm2, err := subscriptions.NewORM(db, lggr, testutils.NewAddress()) require.NoError(t, err) subscription := subscriptions.StoredSubscription{ @@ -204,42 +211,42 @@ func TestORM_UpsertSubscription(t *testing.T) { }, } - err = orm1.UpsertSubscription(subscription) + err = orm1.UpsertSubscription(ctx, subscription) require.NoError(t, err) // should update the existing subscription subscription.Balance = assets.Ether(12).ToInt() - err = orm1.UpsertSubscription(subscription) + err = orm1.UpsertSubscription(ctx, subscription) require.NoError(t, err) - results, err := orm1.GetSubscriptions(0, 10) + results, err := orm1.GetSubscriptions(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") // should create a new subscription because it comes from a different router contract - err = orm2.UpsertSubscription(subscription) + err = orm2.UpsertSubscription(ctx, subscription) require.NoError(t, err) - results, err = orm1.GetSubscriptions(0, 10) + results, err = orm1.GetSubscriptions(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") - results, err = orm2.GetSubscriptions(0, 10) + results, err = orm2.GetSubscriptions(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") }) } func Test_NewORM(t *testing.T) { t.Run("OK-create_ORM", func(t *testing.T) { - _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), testutils.NewAddress()) + _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), testutils.NewAddress()) require.NoError(t, err) }) t.Run("NOK-create_ORM_with_nil_fields", func(t *testing.T) { - _, err := subscriptions.NewORM(nil, nil, nil, common.Address{}) + _, err := subscriptions.NewORM(nil, nil, common.Address{}) require.Error(t, err) }) t.Run("NOK-create_ORM_with_empty_address", func(t *testing.T) { - _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), common.Address{}) + _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), common.Address{}) require.Error(t, err) }) } diff --git a/core/services/gateway/handlers/functions/subscriptions/subscriptions.go b/core/services/gateway/handlers/functions/subscriptions/subscriptions.go index e90201a31a9..d481ecf12ed 100644 --- a/core/services/gateway/handlers/functions/subscriptions/subscriptions.go +++ b/core/services/gateway/handlers/functions/subscriptions/subscriptions.go @@ -99,7 +99,7 @@ func (s *onchainSubscriptions) Start(ctx context.Context) error { return errors.New("OnchainSubscriptionsConfig.UpdateRangeSize must be greater than 0") } - s.loadStoredSubscriptions() + s.loadStoredSubscriptions(ctx) s.closeWait.Add(1) go s.queryLoop() @@ -206,7 +206,7 @@ func (s *onchainSubscriptions) querySubscriptionsRange(ctx context.Context, bloc subscription := subscription updated := s.subscriptions.UpdateSubscription(subscriptionId, &subscription) if updated { - if err = s.orm.UpsertSubscription(StoredSubscription{ + if err = s.orm.UpsertSubscription(ctx, StoredSubscription{ SubscriptionID: subscriptionId, IFunctionsSubscriptionsSubscription: subscription, }); err != nil { @@ -226,10 +226,10 @@ func (s *onchainSubscriptions) getSubscriptionsCount(ctx context.Context, blockN }) } -func (s *onchainSubscriptions) loadStoredSubscriptions() { +func (s *onchainSubscriptions) loadStoredSubscriptions(ctx context.Context) { offset := uint(0) for { - csBatch, err := s.orm.GetSubscriptions(offset, s.config.StoreBatchSize) + csBatch, err := s.orm.GetSubscriptions(ctx, offset, s.config.StoreBatchSize) if err != nil { break } diff --git a/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go b/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go index be1d2520434..212029b73f7 100644 --- a/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go +++ b/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go @@ -51,8 +51,8 @@ func TestSubscriptions_OnePass(t *testing.T) { UpdateRangeSize: 3, } orm := smocks.NewORM(t) - orm.On("GetSubscriptions", uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) - orm.On("UpsertSubscription", mock.Anything).Return(nil) + orm.On("GetSubscriptions", mock.Anything, uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) + orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil) subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -102,8 +102,8 @@ func TestSubscriptions_MultiPass(t *testing.T) { UpdateRangeSize: 3, } orm := smocks.NewORM(t) - orm.On("GetSubscriptions", uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) - orm.On("UpsertSubscription", mock.Anything).Return(nil) + orm.On("GetSubscriptions", mock.Anything, uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) + orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil) subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -144,7 +144,7 @@ func TestSubscriptions_Stored(t *testing.T) { expectedBalance := big.NewInt(5) orm := smocks.NewORM(t) - orm.On("GetSubscriptions", uint(0), uint(1)).Return([]subscriptions.StoredSubscription{ + orm.On("GetSubscriptions", mock.Anything, uint(0), uint(1)).Return([]subscriptions.StoredSubscription{ { SubscriptionID: 1, IFunctionsSubscriptionsSubscription: functions_router.IFunctionsSubscriptionsSubscription{ @@ -154,8 +154,8 @@ func TestSubscriptions_Stored(t *testing.T) { }, }, }, nil) - orm.On("GetSubscriptions", uint(1), uint(1)).Return([]subscriptions.StoredSubscription{}, nil) - orm.On("UpsertSubscription", mock.Anything).Return(nil) + orm.On("GetSubscriptions", mock.Anything, uint(1), uint(1)).Return([]subscriptions.StoredSubscription{}, nil) + orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil) subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) diff --git a/core/services/job/spawner_test.go b/core/services/job/spawner_test.go index 3ac32309775..d2e7a80d5d4 100644 --- a/core/services/job/spawner_test.go +++ b/core/services/job/spawner_test.go @@ -306,7 +306,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { processConfig := plugins.NewRegistrarConfig(loop.GRPCOpts{}, func(name string) (*plugins.RegisteredLoop, error) { return nil, nil }, func(loopId string) {}) ocr2DelegateConfig := ocr2.NewDelegateConfig(config.OCR2(), config.Mercury(), config.Threshold(), config.Insecure(), config.JobPipeline(), config.Database(), processConfig) - d := ocr2.NewDelegate(nil, orm, nil, nil, nil, nil, nil, monitoringEndpoint, legacyChains, lggr, ocr2DelegateConfig, + d := ocr2.NewDelegate(nil, nil, orm, nil, nil, nil, nil, nil, monitoringEndpoint, legacyChains, lggr, ocr2DelegateConfig, keyStore.OCR2(), keyStore.DKGSign(), keyStore.DKGEncrypt(), ethKeyStore, testRelayGetter, mailMon, capabilities.NewRegistry(lggr)) delegateOCR2 := &delegate{jobOCR2VRF.Type, []job.ServiceCtx{}, 0, nil, d} diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 7b4200efd68..a00ed195903 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -27,6 +27,7 @@ import ( ocr2keepers21config "github.com/smartcontractkit/chainlink-automation/pkg/v3/config" ocr2keepers21 "github.com/smartcontractkit/chainlink-automation/pkg/v3/plugin" "github.com/smartcontractkit/chainlink-common/pkg/loop/reportingplugins/ocr3" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/config/env" "github.com/smartcontractkit/chainlink-vrf/altbn_128" @@ -109,7 +110,8 @@ type RelayGetter interface { Get(id relay.ID) (loop.Relayer, error) } type Delegate struct { - db *sqlx.DB + db *sqlx.DB // legacy: prefer to use ds instead + ds sqlutil.DataSource jobORM job.ORM bridgeORM bridges.ORM mercuryORM evmmercury.ORM @@ -223,6 +225,7 @@ var _ job.Delegate = (*Delegate)(nil) func NewDelegate( db *sqlx.DB, + ds sqlutil.DataSource, jobORM job.ORM, bridgeORM bridges.ORM, mercuryORM evmmercury.ORM, @@ -243,6 +246,7 @@ func NewDelegate( ) *Delegate { return &Delegate{ db: db, + ds: ds, jobORM: jobORM, bridgeORM: bridgeORM, mercuryORM: mercuryORM, @@ -1669,8 +1673,7 @@ func (d *Delegate) newServicesOCR2Functions( Job: jb, JobORM: d.jobORM, BridgeORM: d.bridgeORM, - QConfig: d.cfg.Database(), - DB: d.db, + DS: d.ds, Chain: chain, ContractID: spec.ContractID, Logger: lggr, diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index 92b15892885..d6ffa1a3f06 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -8,13 +8,13 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" "github.com/jonboulle/clockwork" "github.com/pkg/errors" "github.com/smartcontractkit/libocr/commontypes" libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/bridges" @@ -31,7 +31,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" s4_plugin "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/s4" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/threshold" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" evmrelayTypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" "github.com/smartcontractkit/chainlink/v2/core/services/s4" ) @@ -40,8 +39,7 @@ type FunctionsServicesConfig struct { Job job.Job JobORM job.ORM BridgeORM bridges.ORM - QConfig pg.QConfig - DB *sqlx.DB + DS sqlutil.DataSource Chain legacyevm.Chain ContractID string Logger logger.Logger @@ -63,8 +61,8 @@ const ( // Create all OCR2 plugin Oracles and all extra services needed to run a Functions job. func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOracleArgs, s4OracleArgs *libocr2.OCR2OracleArgs, conf *FunctionsServicesConfig) ([]job.ServiceCtx, error) { - pluginORM := functions.NewORM(conf.DB, conf.Logger, conf.QConfig, common.HexToAddress(conf.ContractID)) - s4ORM := s4.NewCachedORMWrapper(s4.NewPostgresORM(conf.DB, conf.Logger, conf.QConfig, s4.SharedTableName, FunctionsS4Namespace), conf.Logger) + pluginORM := functions.NewORM(conf.DS, common.HexToAddress(conf.ContractID)) + s4ORM := s4.NewCachedORMWrapper(s4.NewPostgresORM(conf.DS, s4.SharedTableName, FunctionsS4Namespace), conf.Logger) var pluginConfig config.PluginConfig if err := json.Unmarshal(conf.Job.OCR2OracleSpec.PluginConfig.Bytes(), &pluginConfig); err != nil { @@ -155,7 +153,7 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra allServices = append(allServices, job.NewServiceAdapter(functionsReportingPluginOracle)) if pluginConfig.GatewayConnectorConfig != nil && s4Storage != nil && pluginConfig.OnchainAllowlist != nil && pluginConfig.RateLimiter != nil && pluginConfig.OnchainSubscriptions != nil { - allowlistORM, err := gwAllowlist.NewORM(conf.DB, conf.Logger, conf.QConfig, pluginConfig.OnchainAllowlist.ContractAddress) + allowlistORM, err := gwAllowlist.NewORM(conf.DS, conf.Logger, pluginConfig.OnchainAllowlist.ContractAddress) if err != nil { return nil, errors.Wrap(err, "failed to create allowlist ORM") } @@ -167,7 +165,7 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra if err2 != nil { return nil, errors.Wrap(err, "failed to create a RateLimiter") } - subscriptionsORM, err := gwSubscriptions.NewORM(conf.DB, conf.Logger, conf.QConfig, pluginConfig.OnchainSubscriptions.ContractAddress) + subscriptionsORM, err := gwSubscriptions.NewORM(conf.DS, conf.Logger, pluginConfig.OnchainSubscriptions.ContractAddress) if err != nil { return nil, errors.Wrap(err, "failed to create subscriptions ORM") } diff --git a/core/services/ocr2/plugins/functions/reporting.go b/core/services/ocr2/plugins/functions/reporting.go index 36e8a882734..d9d68ec9097 100644 --- a/core/services/ocr2/plugins/functions/reporting.go +++ b/core/services/ocr2/plugins/functions/reporting.go @@ -18,7 +18,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/functions" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/encoding" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type FunctionsReportingPluginFactory struct { @@ -151,7 +150,7 @@ func (r *functionsReporting) Query(ctx context.Context, ts types.ReportTimestamp "oracleID": r.genericConfig.OracleID, }) maxBatchSize := r.specificConfig.Config.GetMaxRequestBatchSize() - results, err := r.pluginORM.FindOldestEntriesByState(functions.RESULT_READY, maxBatchSize, pg.WithParentCtx(ctx)) + results, err := r.pluginORM.FindOldestEntriesByState(ctx, functions.RESULT_READY, maxBatchSize) if err != nil { return nil, err } @@ -222,7 +221,7 @@ func (r *functionsReporting) Observation(ctx context.Context, ts types.ReportTim continue } processedIds[id] = true - localResult, err2 := r.pluginORM.FindById(id, pg.WithParentCtx(ctx)) + localResult, err2 := r.pluginORM.FindById(ctx, id) if err2 != nil { r.logger.Debug("FunctionsReporting Observation can't find request from query", commontypes.LogFields{ "requestID": formatRequestId(id[:]), @@ -429,14 +428,14 @@ func (r *functionsReporting) ShouldAcceptFinalizedReport(ctx context.Context, ts r.logger.Error("FunctionsReporting ShouldAcceptFinalizedReport: invalid ID", commontypes.LogFields{"requestID": reqIdStr, "err": err}) continue } - _, err = r.pluginORM.FindById(id, pg.WithParentCtx(ctx)) + _, err = r.pluginORM.FindById(ctx, id) if err != nil { // TODO: Differentiate between ID not found and other ORM errors (https://smartcontract-it.atlassian.net/browse/DRO-215) r.logger.Warn("FunctionsReporting ShouldAcceptFinalizedReport: request doesn't exist locally! Accepting anyway.", commontypes.LogFields{"requestID": reqIdStr}) needTransmissionIds = append(needTransmissionIds, reqIdStr) continue } - err = r.pluginORM.SetFinalized(id, item.Result, item.Error, pg.WithParentCtx(ctx)) // validates state transition + err = r.pluginORM.SetFinalized(ctx, id, item.Result, item.Error) // validates state transition if err != nil { r.logger.Debug("FunctionsReporting ShouldAcceptFinalizedReport: state couldn't be changed to FINALIZED. Not transmitting.", commontypes.LogFields{"requestID": reqIdStr, "err": err}) continue @@ -490,7 +489,7 @@ func (r *functionsReporting) ShouldTransmitAcceptedReport(ctx context.Context, t r.logger.Error("FunctionsReporting ShouldAcceptFinalizedReport: invalid ID", commontypes.LogFields{"requestID": reqIdStr, "err": err}) continue } - request, err := r.pluginORM.FindById(id, pg.WithParentCtx(ctx)) + request, err := r.pluginORM.FindById(ctx, id) if err != nil { r.logger.Warn("FunctionsReporting ShouldTransmitAcceptedReport: request doesn't exist locally! Transmitting anyway.", commontypes.LogFields{"requestID": reqIdStr, "err": err}) needTransmissionIds = append(needTransmissionIds, reqIdStr) diff --git a/core/services/ocr2/plugins/functions/reporting_test.go b/core/services/ocr2/plugins/functions/reporting_test.go index 5b9f59ccb23..7d6686a0b4f 100644 --- a/core/services/ocr2/plugins/functions/reporting_test.go +++ b/core/services/ocr2/plugins/functions/reporting_test.go @@ -134,7 +134,7 @@ func TestFunctionsReporting_Query(t *testing.T) { const batchSize = 10 plugin, orm, _, _ := preparePlugin(t, batchSize, 0) reqs := []functions_srv.Request{newRequest(), newRequest()} - orm.On("FindOldestEntriesByState", functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) + orm.On("FindOldestEntriesByState", mock.Anything, functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) q, err := plugin.Query(testutils.Context(t), types.ReportTimestamp{}) require.NoError(t, err) @@ -154,7 +154,7 @@ func TestFunctionsReporting_Query_HandleCoordinatorMismatch(t *testing.T) { reqs := []functions_srv.Request{newRequest(), newRequest()} reqs[0].CoordinatorContractAddress = &common.Address{1} reqs[1].CoordinatorContractAddress = &common.Address{2} - orm.On("FindOldestEntriesByState", functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) + orm.On("FindOldestEntriesByState", mock.Anything, functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) q, err := plugin.Query(testutils.Context(t), types.ReportTimestamp{}) require.NoError(t, err) @@ -177,11 +177,11 @@ func TestFunctionsReporting_Observation(t *testing.T) { req4 := newRequestTimedOut() nonexistentId := newRequestID() - orm.On("FindById", req1.RequestID, mock.Anything).Return(&req1, nil) - orm.On("FindById", req2.RequestID, mock.Anything).Return(&req2, nil) - orm.On("FindById", req3.RequestID, mock.Anything).Return(&req3, nil) - orm.On("FindById", req4.RequestID, mock.Anything).Return(&req4, nil) - orm.On("FindById", nonexistentId, mock.Anything).Return(nil, errors.New("nonexistent ID")) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(&req1, nil) + orm.On("FindById", mock.Anything, req2.RequestID, mock.Anything).Return(&req2, nil) + orm.On("FindById", mock.Anything, req3.RequestID, mock.Anything).Return(&req3, nil) + orm.On("FindById", mock.Anything, req4.RequestID, mock.Anything).Return(&req4, nil) + orm.On("FindById", mock.Anything, nonexistentId, mock.Anything).Return(nil, errors.New("nonexistent ID")) // Query asking for 5 requests (with duplicates), out of which: // - two are ready @@ -209,7 +209,7 @@ func TestFunctionsReporting_Observation_IncorrectQuery(t *testing.T) { req1 := newRequestWithResult([]byte("abc")) invalidId := []byte("invalid") - orm.On("FindById", req1.RequestID, mock.Anything).Return(&req1, nil) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(&req1, nil) // Query asking for 3 requests (with duplicates), out of which: // - two are invalid @@ -441,13 +441,13 @@ func TestFunctionsReporting_ShouldAcceptFinalizedReport(t *testing.T) { req3 := newRequestFinalized() req4 := newRequestTimedOut() - orm.On("FindById", req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) - orm.On("FindById", req2.RequestID, mock.Anything).Return(&req2, nil) - orm.On("SetFinalized", req2.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) - orm.On("FindById", req3.RequestID, mock.Anything).Return(&req3, nil) - orm.On("SetFinalized", req3.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("same state")) - orm.On("FindById", req4.RequestID, mock.Anything).Return(&req4, nil) - orm.On("SetFinalized", req4.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("already timed out")) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) + orm.On("FindById", mock.Anything, req2.RequestID, mock.Anything).Return(&req2, nil) + orm.On("SetFinalized", mock.Anything, req2.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) + orm.On("FindById", mock.Anything, req3.RequestID, mock.Anything).Return(&req3, nil) + orm.On("SetFinalized", mock.Anything, req3.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("same state")) + orm.On("FindById", mock.Anything, req4.RequestID, mock.Anything).Return(&req4, nil) + orm.On("SetFinalized", mock.Anything, req4.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("already timed out")) // Attempting to transmit 2 requests, out of which: // - one was already accepted for transmission earlier @@ -477,8 +477,8 @@ func TestFunctionsReporting_ShouldAcceptFinalizedReport_OffchainTransmission(t * req1 := newRequestWithResult([]byte("abc")) req1.OnchainMetadata = []byte(functions_srv.OffchainRequestMarker) - orm.On("FindById", req1.RequestID, mock.Anything).Return(&req1, nil) - orm.On("SetFinalized", req1.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(&req1, nil) + orm.On("SetFinalized", mock.Anything, req1.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) offchainTransmitter.On("TransmitReport", mock.Anything, mock.Anything).Return(nil) should, err := plugin.ShouldAcceptFinalizedReport(testutils.Context(t), types.ReportTimestamp{}, getReportBytes(t, codec, req1)) @@ -496,11 +496,11 @@ func TestFunctionsReporting_ShouldTransmitAcceptedReport(t *testing.T) { req4 := newRequestTimedOut() req5 := newRequestConfirmed() - orm.On("FindById", req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) - orm.On("FindById", req2.RequestID, mock.Anything).Return(&req2, nil) - orm.On("FindById", req3.RequestID, mock.Anything).Return(&req3, nil) - orm.On("FindById", req4.RequestID, mock.Anything).Return(&req4, nil) - orm.On("FindById", req5.RequestID, mock.Anything).Return(&req5, nil) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) + orm.On("FindById", mock.Anything, req2.RequestID, mock.Anything).Return(&req2, nil) + orm.On("FindById", mock.Anything, req3.RequestID, mock.Anything).Return(&req3, nil) + orm.On("FindById", mock.Anything, req4.RequestID, mock.Anything).Return(&req4, nil) + orm.On("FindById", mock.Anything, req5.RequestID, mock.Anything).Return(&req5, nil) // Attempting to transmit 2 requests, out of which: // - one was already confirmed on chain diff --git a/core/services/ocr2/plugins/s4/integration_test.go b/core/services/ocr2/plugins/s4/integration_test.go index 8efe38f8e2d..5148ea6e26d 100644 --- a/core/services/ocr2/plugins/s4/integration_test.go +++ b/core/services/ocr2/plugins/s4/integration_test.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/s4" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" s4_svc "github.com/smartcontractkit/chainlink/v2/core/services/s4" commonlogger "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -53,7 +52,7 @@ func newDON(t *testing.T, size int, config *s4.PluginConfig) *don { for i := 0; i < size; i++ { ns := fmt.Sprintf("s4_int_test_%d", i) - orm := s4_svc.NewPostgresORM(db, logger, pgtest.NewQConfig(false), s4_svc.SharedTableName, ns) + orm := s4_svc.NewPostgresORM(db, s4_svc.SharedTableName, ns) orms[i] = orm ocrLogger := commonlogger.NewOCRWrapper(logger, true, func(msg string) {}) @@ -149,7 +148,7 @@ func checkNoErrors(t *testing.T, errors []error) { func checkNoUnconfirmedRows(ctx context.Context, t *testing.T, orm s4_svc.ORM, limit uint) { t.Helper() - rows, err := orm.GetUnconfirmedRows(limit, pg.WithParentCtx(ctx)) + rows, err := orm.GetUnconfirmedRows(ctx, limit) assert.NoError(t, err) assert.Empty(t, rows) } @@ -161,10 +160,10 @@ func TestS4Integration_HappyDON(t *testing.T) { // injecting new records rows := generateTestOrmRows(t, 10, time.Minute) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) // S4 to propagate all records in one OCR round @@ -172,7 +171,7 @@ func TestS4Integration_HappyDON(t *testing.T) { checkNoErrors(t, errors) for i := 0; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -188,7 +187,7 @@ func TestS4Integration_HappyDON_4X(t *testing.T) { for o := 0; o < don.size; o++ { rows := generateTestOrmRows(t, 10, time.Minute) for _, row := range rows { - err := don.orms[o].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[o].Update(ctx, row) require.NoError(t, err) } } @@ -197,11 +196,11 @@ func TestS4Integration_HappyDON_4X(t *testing.T) { errors := don.simulateOCR(ctx, 1) checkNoErrors(t, errors) - firstSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + firstSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(firstSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -217,10 +216,10 @@ func TestS4Integration_WrongSignature(t *testing.T) { rows := generateTestOrmRows(t, 10, time.Minute) rows[0].Signature = rows[1].Signature for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) originSnapshot = filter(originSnapshot, func(row *s4_svc.SnapshotRow) bool { return row.Address.Cmp(rows[0].Address) != 0 || row.SlotId != rows[0].SlotId @@ -232,14 +231,14 @@ func TestS4Integration_WrongSignature(t *testing.T) { checkNoErrors(t, errors) for i := 1; i < don.size; i++ { - snapshot, err2 := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err2 := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err2) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) } // record with a wrong signature must remain unconfirmed - ur, err := don.orms[0].GetUnconfirmedRows(10, pg.WithParentCtx(ctx)) + ur, err := don.orms[0].GetUnconfirmedRows(ctx, 10) require.NoError(t, err) require.Len(t, ur, 1) } @@ -253,10 +252,10 @@ func TestS4Integration_MaxObservations(t *testing.T) { // injecting new records rows := generateTestOrmRows(t, 10, time.Minute) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) // It requires at least two rounds due to MaxObservationEntries = rows / 2 @@ -264,7 +263,7 @@ func TestS4Integration_MaxObservations(t *testing.T) { checkNoErrors(t, errors) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -280,7 +279,7 @@ func TestS4Integration_Expired(t *testing.T) { // injecting expiring records rows := generateTestOrmRows(t, 10, time.Millisecond) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } @@ -290,7 +289,7 @@ func TestS4Integration_Expired(t *testing.T) { checkNoErrors(t, errors) for i := 0; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) require.Len(t, snapshot, 0) } @@ -305,10 +304,10 @@ func TestS4Integration_NSnapshotShards(t *testing.T) { // injecting lots of new records (to be close to normal address distribution) rows := generateTestOrmRows(t, 1000, time.Minute) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) // this still requires one round, because Observation takes all unconfirmed rows @@ -316,7 +315,7 @@ func TestS4Integration_NSnapshotShards(t *testing.T) { checkNoErrors(t, errors) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -332,7 +331,7 @@ func TestS4Integration_OneNodeOutOfSync(t *testing.T) { rows := generateConfirmedTestOrmRows(t, 10, time.Minute) for o := 0; o < don.size-1; o++ { for _, row := range rows { - err := don.orms[o].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[o].Update(ctx, row) require.NoError(t, err) } } @@ -342,9 +341,9 @@ func TestS4Integration_OneNodeOutOfSync(t *testing.T) { errors := don.simulateOCR(ctx, 4) checkNoErrors(t, errors) - firstSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + firstSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) - lastSnapshot, err := don.orms[don.size-1].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + lastSnapshot, err := don.orms[don.size-1].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(firstSnapshot, lastSnapshot) assert.True(t, equal) @@ -389,7 +388,7 @@ func TestS4Integration_RandomState(t *testing.T) { sig, err := env.Sign(user.privateKey) require.NoError(t, err) row.Signature = sig - err = don.orms[o].Update(row, pg.WithParentCtx(ctx)) + err = don.orms[o].Update(ctx, row) require.NoError(t, err) } } @@ -398,13 +397,13 @@ func TestS4Integration_RandomState(t *testing.T) { errors := don.simulateOCR(ctx, 4) checkNoErrors(t, errors) - firstSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + firstSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) require.NotEmpty(t, firstSnapshot) checkNoUnconfirmedRows(ctx, t, don.orms[0], 1000) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(firstSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) diff --git a/core/services/ocr2/plugins/s4/plugin.go b/core/services/ocr2/plugins/s4/plugin.go index 2b55ebf3cc5..6976c606045 100644 --- a/core/services/ocr2/plugins/s4/plugin.go +++ b/core/services/ocr2/plugins/s4/plugin.go @@ -12,7 +12,6 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/s4" ) @@ -69,7 +68,7 @@ func NewReportingPlugin(logger commontypes.Logger, config *PluginConfig, orm s4. func (c *plugin) Query(ctx context.Context, ts types.ReportTimestamp) (types.Query, error) { promReportingPluginQuery.WithLabelValues(c.config.ProductName).Inc() - snapshot, err := c.orm.GetSnapshot(c.addressRange, pg.WithParentCtx(ctx)) + snapshot, err := c.orm.GetSnapshot(ctx, c.addressRange) if err != nil { return nil, errors.Wrap(err, "failed to GetVersions in Query()") } @@ -111,7 +110,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer promReportingPluginObservation.WithLabelValues(c.config.ProductName).Inc() now := time.Now().UTC() - count, err := c.orm.DeleteExpired(c.config.MaxDeleteExpiredEntries, now, pg.WithParentCtx(ctx)) + count, err := c.orm.DeleteExpired(ctx, c.config.MaxDeleteExpiredEntries, now) if err != nil { return nil, errors.Wrap(err, "failed to DeleteExpired in Observation()") } @@ -122,7 +121,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer return MarshalRows(convertRows(rows)) } - unconfirmedRows, err := c.orm.GetUnconfirmedRows(c.config.MaxObservationEntries, pg.WithParentCtx(ctx)) + unconfirmedRows, err := c.orm.GetUnconfirmedRows(ctx, c.config.MaxObservationEntries) if err != nil { return nil, errors.Wrap(err, "failed to GetUnconfirmedRows in Observation()") } @@ -138,7 +137,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer if err != nil { c.logger.Error("Failed to unmarshal query (likely malformed)", commontypes.LogFields{"err": err}) } else { - snapshot, err := c.orm.GetSnapshot(addressRange, pg.WithParentCtx(ctx)) + snapshot, err := c.orm.GetSnapshot(ctx, addressRange) if err != nil { c.logger.Error("ORM GetSnapshot error", commontypes.LogFields{"err": err}) } else { @@ -178,7 +177,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer } for _, k := range toBeAdded { - row, err := c.orm.Get(k.address, k.slotID, pg.WithParentCtx(ctx)) + row, err := c.orm.Get(ctx, k.address, k.slotID) if err == nil { remainingRows = append(remainingRows, row) } else if !errors.Is(err, s4.ErrNotFound) { @@ -283,7 +282,7 @@ func (c *plugin) ShouldAcceptFinalizedReport(ctx context.Context, ts types.Repor continue } - err = c.orm.Update(ormRow, pg.WithParentCtx(ctx)) + err = c.orm.Update(ctx, ormRow) if err != nil && !errors.Is(err, s4.ErrVersionTooLow) { c.logger.Error("Failed to Update a row in ShouldAcceptFinalizedReport()", commontypes.LogFields{"err": err}) continue diff --git a/core/services/ocr2/plugins/s4/plugin_test.go b/core/services/ocr2/plugins/s4/plugin_test.go index b53ab40bfcb..6321b8ce867 100644 --- a/core/services/ocr2/plugins/s4/plugin_test.go +++ b/core/services/ocr2/plugins/s4/plugin_test.go @@ -205,7 +205,7 @@ func TestPlugin_ShouldAcceptFinalizedReport(t *testing.T) { ormRows := make([]*s4_svc.Row, 0) rows := generateTestRows(t, 10, time.Minute) orm.On("Update", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - updateRow := args.Get(0).(*s4_svc.Row) + updateRow := args.Get(1).(*s4_svc.Row) ormRows = append(ormRows, updateRow) }).Return(nil).Times(10) @@ -344,8 +344,8 @@ func TestPlugin_Observation(t *testing.T) { for _, or := range ormRows { or.Confirmed = false } - orm.On("DeleteExpired", uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() - orm.On("GetUnconfirmedRows", config.MaxObservationEntries, mock.Anything).Return(ormRows, nil).Once() + orm.On("DeleteExpired", mock.Anything, uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() + orm.On("GetUnconfirmedRows", mock.Anything, config.MaxObservationEntries).Return(ormRows, nil).Once() observation, err := plugin.Observation(testutils.Context(t), types.ReportTimestamp{}, []byte{}) assert.NoError(t, err) @@ -370,8 +370,8 @@ func TestPlugin_Observation(t *testing.T) { Confirmed: or.Confirmed, } } - orm.On("DeleteExpired", uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() - orm.On("GetUnconfirmedRows", config.MaxObservationEntries, mock.Anything).Return(ormRows[numUnconfirmed:], nil).Once() + orm.On("DeleteExpired", mock.Anything, uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() + orm.On("GetUnconfirmedRows", mock.Anything, config.MaxObservationEntries).Return(ormRows[numUnconfirmed:], nil).Once() orm.On("GetSnapshot", mock.Anything, mock.Anything).Return(snapshot, nil).Once() snapshotRows := rowsToShapshotRows(ormRows) @@ -388,7 +388,7 @@ func TestPlugin_Observation(t *testing.T) { if i < numHigherVersion { ormRows[i].Version++ snapshot[i].Version++ - orm.On("Get", v.Address, v.SlotId, mock.Anything).Return(ormRows[i], nil).Once() + orm.On("Get", mock.Anything, v.Address, v.SlotId).Return(ormRows[i], nil).Once() } } queryBytes, err := proto.Marshal(query) @@ -447,11 +447,11 @@ func TestPlugin_Observation(t *testing.T) { queryBytes, err := proto.Marshal(query) assert.NoError(t, err) - orm.On("DeleteExpired", uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() - orm.On("GetUnconfirmedRows", config.MaxObservationEntries, mock.Anything).Return([]*s4_svc.Row{}, nil).Once() + orm.On("DeleteExpired", mock.Anything, uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() + orm.On("GetUnconfirmedRows", mock.Anything, config.MaxObservationEntries).Return([]*s4_svc.Row{}, nil).Once() orm.On("GetSnapshot", mock.Anything, mock.Anything).Return(snapshot, nil).Once() - orm.On("Get", snapshot[1].Address, snapshot[1].SlotId, mock.Anything).Return(ormRows[1], nil).Once() - orm.On("Get", snapshot[2].Address, snapshot[2].SlotId, mock.Anything).Return(ormRows[2], nil).Once() + orm.On("Get", mock.Anything, snapshot[1].Address, snapshot[1].SlotId).Return(ormRows[1], nil).Once() + orm.On("Get", mock.Anything, snapshot[2].Address, snapshot[2].SlotId).Return(ormRows[2], nil).Once() observation, err := plugin.Observation(testutils.Context(t), types.ReportTimestamp{}, queryBytes) assert.NoError(t, err) diff --git a/core/services/s4/cached_orm_wrapper.go b/core/services/s4/cached_orm_wrapper.go index 38b9ecba1ca..fe6cb20e3cd 100644 --- a/core/services/s4/cached_orm_wrapper.go +++ b/core/services/s4/cached_orm_wrapper.go @@ -1,6 +1,7 @@ package s4 import ( + "context" "fmt" "math/big" "strings" @@ -10,7 +11,6 @@ import ( ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) const ( @@ -40,18 +40,18 @@ func NewCachedORMWrapper(orm ORM, lggr logger.Logger) *CachedORM { } } -func (c CachedORM) Get(address *ubig.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) { - return c.underlayingORM.Get(address, slotId, qopts...) +func (c CachedORM) Get(ctx context.Context, address *ubig.Big, slotId uint) (*Row, error) { + return c.underlayingORM.Get(ctx, address, slotId) } -func (c CachedORM) Update(row *Row, qopts ...pg.QOpt) error { +func (c CachedORM) Update(ctx context.Context, row *Row) error { c.deleteRowFromSnapshotCache(row) - return c.underlayingORM.Update(row, qopts...) + return c.underlayingORM.Update(ctx, row) } -func (c CachedORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) { - deletedRows, err := c.underlayingORM.DeleteExpired(limit, utcNow, qopts...) +func (c CachedORM) DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) { + deletedRows, err := c.underlayingORM.DeleteExpired(ctx, limit, utcNow) if err != nil { return 0, err } @@ -63,7 +63,7 @@ func (c CachedORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) return deletedRows, nil } -func (c CachedORM) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) { +func (c CachedORM) GetSnapshot(ctx context.Context, addressRange *AddressRange) ([]*SnapshotRow, error) { key := fmt.Sprintf("%s_%s_%s", getSnapshotCachePrefix, addressRange.MinAddress.String(), addressRange.MaxAddress.String()) cached, found := c.cache.Get(key) @@ -72,7 +72,7 @@ func (c CachedORM) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([] } c.lggr.Debug("Snapshot not found in cache, fetching it from underlaying implementation") - data, err := c.underlayingORM.GetSnapshot(addressRange, qopts...) + data, err := c.underlayingORM.GetSnapshot(ctx, addressRange) if err != nil { return nil, err } @@ -81,8 +81,8 @@ func (c CachedORM) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([] return data, nil } -func (c CachedORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) { - return c.underlayingORM.GetUnconfirmedRows(limit, qopts...) +func (c CachedORM) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) { + return c.underlayingORM.GetUnconfirmedRows(ctx, limit) } // deleteRowFromSnapshotCache will clean the cache for every snapshot that would involve a given row diff --git a/core/services/s4/cached_orm_wrapper_test.go b/core/services/s4/cached_orm_wrapper_test.go index 6f6ac298557..5b94ce3b253 100644 --- a/core/services/s4/cached_orm_wrapper_test.go +++ b/core/services/s4/cached_orm_wrapper_test.go @@ -8,6 +8,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -21,11 +22,12 @@ import ( func TestGetSnapshotEmpty(t *testing.T) { t.Run("OK-no_rows", func(t *testing.T) { + ctx := testutils.Context(t) psqlORM := setupORM(t, "test") lggr := logger.TestLogger(t) orm := s4.NewCachedORMWrapper(psqlORM, lggr) - rows, err := orm.GetSnapshot(s4.NewFullAddressRange()) + rows, err := orm.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Empty(t, rows) }) @@ -33,23 +35,24 @@ func TestGetSnapshotEmpty(t *testing.T) { func TestGetSnapshotCacheFilled(t *testing.T) { t.Run("OK_with_rows_already_cached", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestSnapshotRows(t, 100) fullAddressRange := s4.NewFullAddressRange() lggr := logger.TestLogger(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() + underlayingORM.On("GetSnapshot", mock.Anything, fullAddressRange).Return(rows, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) // first call will go to the underlaying orm implementation to fill the cache - first_snapshot, err := orm.GetSnapshot(fullAddressRange) + first_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(first_snapshot)) // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() - cache_snapshot, err := orm.GetSnapshot(fullAddressRange) + cache_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(cache_snapshot)) @@ -75,23 +78,24 @@ func TestGetSnapshotCacheFilled(t *testing.T) { func TestUpdateInvalidatesSnapshotCache(t *testing.T) { t.Run("OK-GetSnapshot_cache_invalidated_after_update", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestSnapshotRows(t, 100) fullAddressRange := s4.NewFullAddressRange() lggr := logger.TestLogger(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() + underlayingORM.On("GetSnapshot", mock.Anything, fullAddressRange).Return(rows, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) // first call will go to the underlaying orm implementation to fill the cache - first_snapshot, err := orm.GetSnapshot(fullAddressRange) + first_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(first_snapshot)) // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() - cache_snapshot, err := orm.GetSnapshot(fullAddressRange) + cache_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(cache_snapshot)) @@ -105,18 +109,19 @@ func TestUpdateInvalidatesSnapshotCache(t *testing.T) { Confirmed: true, Signature: cltest.MustRandomBytes(t, 32), } - underlayingORM.On("Update", row).Return(nil).Once() - err = orm.Update(row) + underlayingORM.On("Update", mock.Anything, row).Return(nil).Once() + err = orm.Update(ctx, row) assert.NoError(t, err) // given the cache was invalidated this request will reach the underlaying orm implementation - underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() - third_snapshot, err := orm.GetSnapshot(fullAddressRange) + underlayingORM.On("GetSnapshot", mock.Anything, fullAddressRange).Return(rows, nil).Once() + third_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(third_snapshot)) }) t.Run("OK-GetSnapshot_cache_not_invalidated_after_update", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestSnapshotRows(t, 5) addressRange := &s4.AddressRange{ @@ -126,17 +131,17 @@ func TestUpdateInvalidatesSnapshotCache(t *testing.T) { lggr := logger.TestLogger(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetSnapshot", addressRange).Return(rows, nil).Once() + underlayingORM.On("GetSnapshot", mock.Anything, addressRange).Return(rows, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) // first call will go to the underlaying orm implementation to fill the cache - first_snapshot, err := orm.GetSnapshot(addressRange) + first_snapshot, err := orm.GetSnapshot(ctx, addressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(first_snapshot)) // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() - cache_snapshot, err := orm.GetSnapshot(addressRange) + cache_snapshot, err := orm.GetSnapshot(ctx, addressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(cache_snapshot)) @@ -151,12 +156,12 @@ func TestUpdateInvalidatesSnapshotCache(t *testing.T) { Confirmed: true, Signature: cltest.MustRandomBytes(t, 32), } - underlayingORM.On("Update", row).Return(nil).Once() - err = orm.Update(row) + underlayingORM.On("Update", mock.Anything, row).Return(nil).Once() + err = orm.Update(ctx, row) assert.NoError(t, err) // given the cache was not invalidated this request wont reach the underlaying orm implementation - third_snapshot, err := orm.GetSnapshot(addressRange) + third_snapshot, err := orm.GetSnapshot(ctx, addressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(third_snapshot)) }) @@ -169,24 +174,26 @@ func TestGet(t *testing.T) { lggr := logger.TestLogger(t) t.Run("OK-Get_underlaying_ORM_returns_a_row", func(t *testing.T) { + ctx := testutils.Context(t) underlayingORM := mocks.NewORM(t) expectedRow := &s4.Row{ Address: address, SlotId: slotID, } - underlayingORM.On("Get", address, slotID).Return(expectedRow, nil).Once() + underlayingORM.On("Get", mock.Anything, address, slotID).Return(expectedRow, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - row, err := orm.Get(address, slotID) + row, err := orm.Get(ctx, address, slotID) require.NoError(t, err) require.Equal(t, expectedRow, row) }) t.Run("NOK-Get_underlaying_ORM_returns_an_error", func(t *testing.T) { + ctx := testutils.Context(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("Get", address, slotID).Return(nil, fmt.Errorf("some_error")).Once() + underlayingORM.On("Get", mock.Anything, address, slotID).Return(nil, fmt.Errorf("some_error")).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - row, err := orm.Get(address, slotID) + row, err := orm.Get(ctx, address, slotID) require.Nil(t, row) require.EqualError(t, err, "some_error") }) @@ -199,22 +206,24 @@ func TestDeletedExpired(t *testing.T) { lggr := logger.TestLogger(t) t.Run("OK-DeletedExpired_underlaying_ORM_returns_a_row", func(t *testing.T) { + ctx := testutils.Context(t) var expectedDeleted int64 = 10 underlayingORM := mocks.NewORM(t) - underlayingORM.On("DeleteExpired", limit, now).Return(expectedDeleted, nil).Once() + underlayingORM.On("DeleteExpired", mock.Anything, limit, now).Return(expectedDeleted, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualDeleted, err := orm.DeleteExpired(limit, now) + actualDeleted, err := orm.DeleteExpired(ctx, limit, now) require.NoError(t, err) require.Equal(t, expectedDeleted, actualDeleted) }) t.Run("NOK-DeletedExpired_underlaying_ORM_returns_an_error", func(t *testing.T) { + ctx := testutils.Context(t) var expectedDeleted int64 underlayingORM := mocks.NewORM(t) - underlayingORM.On("DeleteExpired", limit, now).Return(expectedDeleted, fmt.Errorf("some_error")).Once() + underlayingORM.On("DeleteExpired", mock.Anything, limit, now).Return(expectedDeleted, fmt.Errorf("some_error")).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualDeleted, err := orm.DeleteExpired(limit, now) + actualDeleted, err := orm.DeleteExpired(ctx, limit, now) require.EqualError(t, err, "some_error") require.Equal(t, expectedDeleted, actualDeleted) }) @@ -226,6 +235,7 @@ func TestGetUnconfirmedRows(t *testing.T) { lggr := logger.TestLogger(t) t.Run("OK-GetUnconfirmedRows_underlaying_ORM_returns_a_row", func(t *testing.T) { + ctx := testutils.Context(t) address := big.New(testutils.NewAddress().Big()) var slotID uint = 1 @@ -234,19 +244,20 @@ func TestGetUnconfirmedRows(t *testing.T) { SlotId: slotID, }} underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetUnconfirmedRows", limit).Return(expectedRow, nil).Once() + underlayingORM.On("GetUnconfirmedRows", mock.Anything, limit).Return(expectedRow, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualRow, err := orm.GetUnconfirmedRows(limit) + actualRow, err := orm.GetUnconfirmedRows(ctx, limit) require.NoError(t, err) require.Equal(t, expectedRow, actualRow) }) t.Run("NOK-GetUnconfirmedRows_underlaying_ORM_returns_an_error", func(t *testing.T) { + ctx := testutils.Context(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetUnconfirmedRows", limit).Return(nil, fmt.Errorf("some_error")).Once() + underlayingORM.On("GetUnconfirmedRows", mock.Anything, limit).Return(nil, fmt.Errorf("some_error")).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualRow, err := orm.GetUnconfirmedRows(limit) + actualRow, err := orm.GetUnconfirmedRows(ctx, limit) require.Nil(t, actualRow) require.EqualError(t, err, "some_error") }) diff --git a/core/services/s4/in_memory_orm.go b/core/services/s4/in_memory_orm.go index 28b50ce430c..723f8820999 100644 --- a/core/services/s4/in_memory_orm.go +++ b/core/services/s4/in_memory_orm.go @@ -1,12 +1,12 @@ package s4 import ( + "context" "sort" "sync" "time" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type key struct { @@ -32,7 +32,7 @@ func NewInMemoryORM() ORM { } } -func (o *inMemoryOrm) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) { +func (o *inMemoryOrm) Get(ctx context.Context, address *big.Big, slotId uint) (*Row, error) { o.mu.RLock() defer o.mu.RUnlock() @@ -47,7 +47,7 @@ func (o *inMemoryOrm) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row return mrow.Row.Clone(), nil } -func (o *inMemoryOrm) Update(row *Row, qopts ...pg.QOpt) error { +func (o *inMemoryOrm) Update(ctx context.Context, row *Row) error { o.mu.Lock() defer o.mu.Unlock() @@ -74,7 +74,7 @@ func (o *inMemoryOrm) Update(row *Row, qopts ...pg.QOpt) error { return nil } -func (o *inMemoryOrm) DeleteExpired(limit uint, now time.Time, qopts ...pg.QOpt) (int64, error) { +func (o *inMemoryOrm) DeleteExpired(ctx context.Context, limit uint, now time.Time) (int64, error) { o.mu.Lock() defer o.mu.Unlock() @@ -94,7 +94,7 @@ func (o *inMemoryOrm) DeleteExpired(limit uint, now time.Time, qopts ...pg.QOpt) return int64(len(queue)), nil } -func (o *inMemoryOrm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) { +func (o *inMemoryOrm) GetSnapshot(ctx context.Context, _ *AddressRange) ([]*SnapshotRow, error) { o.mu.RLock() defer o.mu.RUnlock() @@ -115,7 +115,7 @@ func (o *inMemoryOrm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) return rows, nil } -func (o *inMemoryOrm) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) { +func (o *inMemoryOrm) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) { o.mu.RLock() defer o.mu.RUnlock() diff --git a/core/services/s4/in_memory_orm_test.go b/core/services/s4/in_memory_orm_test.go index 318db5f1a44..db4f73ba1ef 100644 --- a/core/services/s4/in_memory_orm_test.go +++ b/core/services/s4/in_memory_orm_test.go @@ -33,33 +33,36 @@ func TestInMemoryORM(t *testing.T) { orm := s4.NewInMemoryORM() t.Run("row not found", func(t *testing.T) { - _, err := orm.Get(big.New(address.Big()), slotId) + ctx := testutils.Context(t) + _, err := orm.Get(ctx, big.New(address.Big()), slotId) assert.ErrorIs(t, err, s4.ErrNotFound) }) t.Run("insert and get", func(t *testing.T) { - err := orm.Update(row) + ctx := testutils.Context(t) + err := orm.Update(ctx, row) assert.NoError(t, err) - e, err := orm.Get(big.New(address.Big()), slotId) + e, err := orm.Get(ctx, big.New(address.Big()), slotId) assert.NoError(t, err) assert.Equal(t, row, e) }) t.Run("update and get", func(t *testing.T) { + ctx := testutils.Context(t) row.Version = 5 - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) // unconfirmed row requires greater version - err = orm.Update(row) + err = orm.Update(ctx, row) assert.ErrorIs(t, err, s4.ErrVersionTooLow) row.Confirmed = true - err = orm.Update(row) + err = orm.Update(ctx, row) assert.NoError(t, err) - e, err := orm.Get(big.New(address.Big()), slotId) + e, err := orm.Get(ctx, big.New(address.Big()), slotId) assert.NoError(t, err) assert.Equal(t, row, e) }) @@ -67,6 +70,7 @@ func TestInMemoryORM(t *testing.T) { func TestInMemoryORM_DeleteExpired(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := s4.NewInMemoryORM() baseTime := time.Now().Add(time.Minute).UTC() @@ -84,22 +88,23 @@ func TestInMemoryORM_DeleteExpired(t *testing.T) { Confirmed: false, Signature: []byte{}, } - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } deadline := baseTime.Add(100 * time.Second) - count, err := orm.DeleteExpired(200, deadline) + count, err := orm.DeleteExpired(ctx, 200, deadline) assert.NoError(t, err) assert.Equal(t, int64(100), count) - rows, err := orm.GetUnconfirmedRows(200) + rows, err := orm.GetUnconfirmedRows(ctx, 200) assert.NoError(t, err) assert.Len(t, rows, 156) } func TestInMemoryORM_GetUnconfirmedRows(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := s4.NewInMemoryORM() expiration := time.Now().Add(100 * time.Second).UnixMilli() @@ -117,18 +122,19 @@ func TestInMemoryORM_GetUnconfirmedRows(t *testing.T) { Confirmed: i >= 100, Signature: []byte{}, } - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) time.Sleep(time.Millisecond) } - rows, err := orm.GetUnconfirmedRows(100) + rows, err := orm.GetUnconfirmedRows(ctx, 100) assert.NoError(t, err) assert.Len(t, rows, 100) } func TestInMemoryORM_GetSnapshot(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := s4.NewInMemoryORM() expiration := time.Now().Add(100 * time.Second).UnixMilli() @@ -147,11 +153,11 @@ func TestInMemoryORM_GetSnapshot(t *testing.T) { Confirmed: i >= 100, Signature: []byte{}, } - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } - rows, err := orm.GetSnapshot(s4.NewFullAddressRange()) + rows, err := orm.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Len(t, rows, n) diff --git a/core/services/s4/mocks/orm.go b/core/services/s4/mocks/orm.go index 3b8cac8e76d..4a5d7fa992d 100644 --- a/core/services/s4/mocks/orm.go +++ b/core/services/s4/mocks/orm.go @@ -3,10 +3,11 @@ package mocks import ( + context "context" + big "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + mock "github.com/stretchr/testify/mock" s4 "github.com/smartcontractkit/chainlink/v2/core/services/s4" @@ -18,16 +19,9 @@ type ORM struct { mock.Mock } -// DeleteExpired provides a mock function with given fields: limit, utcNow, qopts -func (_m *ORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, limit, utcNow) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// DeleteExpired provides a mock function with given fields: ctx, limit, utcNow +func (_m *ORM) DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) { + ret := _m.Called(ctx, limit, utcNow) if len(ret) == 0 { panic("no return value specified for DeleteExpired") @@ -35,17 +29,17 @@ func (_m *ORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (in var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(uint, time.Time, ...pg.QOpt) (int64, error)); ok { - return rf(limit, utcNow, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, time.Time) (int64, error)); ok { + return rf(ctx, limit, utcNow) } - if rf, ok := ret.Get(0).(func(uint, time.Time, ...pg.QOpt) int64); ok { - r0 = rf(limit, utcNow, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, time.Time) int64); ok { + r0 = rf(ctx, limit, utcNow) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(uint, time.Time, ...pg.QOpt) error); ok { - r1 = rf(limit, utcNow, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint, time.Time) error); ok { + r1 = rf(ctx, limit, utcNow) } else { r1 = ret.Error(1) } @@ -53,16 +47,9 @@ func (_m *ORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (in return r0, r1 } -// Get provides a mock function with given fields: address, slotId, qopts -func (_m *ORM) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*s4.Row, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, address, slotId) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// Get provides a mock function with given fields: ctx, address, slotId +func (_m *ORM) Get(ctx context.Context, address *big.Big, slotId uint) (*s4.Row, error) { + ret := _m.Called(ctx, address, slotId) if len(ret) == 0 { panic("no return value specified for Get") @@ -70,19 +57,19 @@ func (_m *ORM) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*s4.Row, er var r0 *s4.Row var r1 error - if rf, ok := ret.Get(0).(func(*big.Big, uint, ...pg.QOpt) (*s4.Row, error)); ok { - return rf(address, slotId, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Big, uint) (*s4.Row, error)); ok { + return rf(ctx, address, slotId) } - if rf, ok := ret.Get(0).(func(*big.Big, uint, ...pg.QOpt) *s4.Row); ok { - r0 = rf(address, slotId, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Big, uint) *s4.Row); ok { + r0 = rf(ctx, address, slotId) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*s4.Row) } } - if rf, ok := ret.Get(1).(func(*big.Big, uint, ...pg.QOpt) error); ok { - r1 = rf(address, slotId, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *big.Big, uint) error); ok { + r1 = rf(ctx, address, slotId) } else { r1 = ret.Error(1) } @@ -90,16 +77,9 @@ func (_m *ORM) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*s4.Row, er return r0, r1 } -// GetSnapshot provides a mock function with given fields: addressRange, qopts -func (_m *ORM) GetSnapshot(addressRange *s4.AddressRange, qopts ...pg.QOpt) ([]*s4.SnapshotRow, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, addressRange) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetSnapshot provides a mock function with given fields: ctx, addressRange +func (_m *ORM) GetSnapshot(ctx context.Context, addressRange *s4.AddressRange) ([]*s4.SnapshotRow, error) { + ret := _m.Called(ctx, addressRange) if len(ret) == 0 { panic("no return value specified for GetSnapshot") @@ -107,19 +87,19 @@ func (_m *ORM) GetSnapshot(addressRange *s4.AddressRange, qopts ...pg.QOpt) ([]* var r0 []*s4.SnapshotRow var r1 error - if rf, ok := ret.Get(0).(func(*s4.AddressRange, ...pg.QOpt) ([]*s4.SnapshotRow, error)); ok { - return rf(addressRange, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *s4.AddressRange) ([]*s4.SnapshotRow, error)); ok { + return rf(ctx, addressRange) } - if rf, ok := ret.Get(0).(func(*s4.AddressRange, ...pg.QOpt) []*s4.SnapshotRow); ok { - r0 = rf(addressRange, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *s4.AddressRange) []*s4.SnapshotRow); ok { + r0 = rf(ctx, addressRange) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*s4.SnapshotRow) } } - if rf, ok := ret.Get(1).(func(*s4.AddressRange, ...pg.QOpt) error); ok { - r1 = rf(addressRange, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *s4.AddressRange) error); ok { + r1 = rf(ctx, addressRange) } else { r1 = ret.Error(1) } @@ -127,16 +107,9 @@ func (_m *ORM) GetSnapshot(addressRange *s4.AddressRange, qopts ...pg.QOpt) ([]* return r0, r1 } -// GetUnconfirmedRows provides a mock function with given fields: limit, qopts -func (_m *ORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*s4.Row, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetUnconfirmedRows provides a mock function with given fields: ctx, limit +func (_m *ORM) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*s4.Row, error) { + ret := _m.Called(ctx, limit) if len(ret) == 0 { panic("no return value specified for GetUnconfirmedRows") @@ -144,19 +117,19 @@ func (_m *ORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*s4.Row, erro var r0 []*s4.Row var r1 error - if rf, ok := ret.Get(0).(func(uint, ...pg.QOpt) ([]*s4.Row, error)); ok { - return rf(limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint) ([]*s4.Row, error)); ok { + return rf(ctx, limit) } - if rf, ok := ret.Get(0).(func(uint, ...pg.QOpt) []*s4.Row); ok { - r0 = rf(limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint) []*s4.Row); ok { + r0 = rf(ctx, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*s4.Row) } } - if rf, ok := ret.Get(1).(func(uint, ...pg.QOpt) error); ok { - r1 = rf(limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint) error); ok { + r1 = rf(ctx, limit) } else { r1 = ret.Error(1) } @@ -164,24 +137,17 @@ func (_m *ORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*s4.Row, erro return r0, r1 } -// Update provides a mock function with given fields: row, qopts -func (_m *ORM) Update(row *s4.Row, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, row) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// Update provides a mock function with given fields: ctx, row +func (_m *ORM) Update(ctx context.Context, row *s4.Row) error { + ret := _m.Called(ctx, row) if len(ret) == 0 { panic("no return value specified for Update") } var r0 error - if rf, ok := ret.Get(0).(func(*s4.Row, ...pg.QOpt) error); ok { - r0 = rf(row, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *s4.Row) error); ok { + r0 = rf(ctx, row) } else { r0 = ret.Error(0) } diff --git a/core/services/s4/orm.go b/core/services/s4/orm.go index 4d3cee9312a..952d8a33b24 100644 --- a/core/services/s4/orm.go +++ b/core/services/s4/orm.go @@ -1,10 +1,10 @@ package s4 import ( + "context" "time" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // Row represents a data row persisted by ORM. @@ -36,26 +36,26 @@ type ORM interface { // Get reads a row for the given address and slotId combination. // If such row does not exist, ErrNotFound is returned. // There is no filter on Expiration. - Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) + Get(ctx context.Context, address *big.Big, slotId uint) (*Row, error) // Update inserts or updates the row identified by (Address, SlotId) pair. // When updating, the new row must have greater or equal version, // otherwise ErrVersionTooLow is returned. // UpdatedAt field value is ignored. - Update(row *Row, qopts ...pg.QOpt) error + Update(ctx context.Context, row *Row) error // DeleteExpired deletes any entries having Expiration < utcNow, // up to the given limit. // Returns the number of deleted rows. - DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) + DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) // GetSnapshot selects all non-expired row versions for the given addresses range. // For the full address range, use NewFullAddressRange(). - GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) + GetSnapshot(ctx context.Context, addressRange *AddressRange) ([]*SnapshotRow, error) // GetUnconfirmedRows selects all non-expired, non-confirmed rows ordered by UpdatedAt. // The number of returned rows is limited to the given limit. - GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) + GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) } func (r Row) Clone() *Row { diff --git a/core/services/s4/postgres_orm.go b/core/services/s4/postgres_orm.go index 1f92f2e1281..3d271e543d7 100644 --- a/core/services/s4/postgres_orm.go +++ b/core/services/s4/postgres_orm.go @@ -1,16 +1,15 @@ package s4 import ( + "context" "database/sql" "fmt" "time" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" - - "github.com/jmoiron/sqlx" "github.com/pkg/errors" + + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" ) const ( @@ -19,28 +18,27 @@ const ( ) type orm struct { - q pg.Q + ds sqlutil.DataSource tableName string namespace string } var _ ORM = (*orm)(nil) -func NewPostgresORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, tableName, namespace string) ORM { +func NewPostgresORM(ds sqlutil.DataSource, tableName, namespace string) ORM { return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, tableName: fmt.Sprintf(`"%s".%s`, s4PostgresSchema, tableName), namespace: namespace, } } -func (o orm) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) { +func (o *orm) Get(ctx context.Context, address *big.Big, slotId uint) (*Row, error) { row := &Row{} - q := o.q.WithOpts(qopts...) stmt := fmt.Sprintf(`SELECT address, slot_id, version, expiration, confirmed, payload, signature FROM %s WHERE namespace=$1 AND address=$2 AND slot_id=$3;`, o.tableName) - if err := q.Get(row, stmt, o.namespace, address, slotId); err != nil { + if err := o.ds.GetContext(ctx, row, stmt, o.namespace, address, slotId); err != nil { if errors.Is(err, sql.ErrNoRows) { err = ErrNotFound } @@ -49,9 +47,7 @@ WHERE namespace=$1 AND address=$2 AND slot_id=$3;`, o.tableName) return row, nil } -func (o orm) Update(row *Row, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - +func (o *orm) Update(ctx context.Context, row *Row) error { // This query inserts or updates a row, depending on whether the version is higher than the existing one. // We only allow the same version when the row is confirmed. // We never transition back from unconfirmed to confirmed state. @@ -67,31 +63,28 @@ updated_at = NOW() WHERE (t.version < EXCLUDED.version) OR (t.version <= EXCLUDED.version AND EXCLUDED.confirmed IS TRUE) RETURNING id;`, o.tableName) var id uint64 - err := q.Get(&id, stmt, o.namespace, row.Address, row.SlotId, row.Version, row.Expiration, row.Confirmed, row.Payload, row.Signature) + err := o.ds.GetContext(ctx, &id, stmt, o.namespace, row.Address, row.SlotId, row.Version, row.Expiration, row.Confirmed, row.Payload, row.Signature) if errors.Is(err, sql.ErrNoRows) { return ErrVersionTooLow } return err } -func (o orm) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) { - q := o.q.WithOpts(qopts...) - +func (o *orm) DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) { with := fmt.Sprintf(`WITH rows AS (SELECT id FROM %s WHERE namespace = $1 AND expiration < $2 LIMIT $3)`, o.tableName) stmt := fmt.Sprintf(`%s DELETE FROM %s WHERE id IN (SELECT id FROM rows);`, with, o.tableName) - result, err := q.Exec(stmt, o.namespace, utcNow.UnixMilli(), limit) + result, err := o.ds.ExecContext(ctx, stmt, o.namespace, utcNow.UnixMilli(), limit) if err != nil { return 0, err } return result.RowsAffected() } -func (o orm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) { - q := o.q.WithOpts(qopts...) +func (o *orm) GetSnapshot(ctx context.Context, addressRange *AddressRange) ([]*SnapshotRow, error) { rows := make([]*SnapshotRow, 0) stmt := fmt.Sprintf(`SELECT address, slot_id, version, expiration, confirmed, octet_length(payload) AS payload_size FROM %s WHERE namespace = $1 AND address >= $2 AND address <= $3;`, o.tableName) - if err := q.Select(&rows, stmt, o.namespace, addressRange.MinAddress, addressRange.MaxAddress); err != nil { + if err := o.ds.SelectContext(ctx, &rows, stmt, o.namespace, addressRange.MinAddress, addressRange.MaxAddress); err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err } @@ -99,13 +92,12 @@ func (o orm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*Snaps return rows, nil } -func (o orm) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) { - q := o.q.WithOpts(qopts...) +func (o *orm) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) { rows := make([]*Row, 0) stmt := fmt.Sprintf(`SELECT address, slot_id, version, expiration, confirmed, payload, signature FROM %s WHERE namespace = $1 AND confirmed IS FALSE ORDER BY updated_at LIMIT $2;`, o.tableName) - if err := q.Select(&rows, stmt, o.namespace, limit); err != nil { + if err := o.ds.SelectContext(ctx, &rows, stmt, o.namespace, limit); err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err } diff --git a/core/services/s4/postgres_orm_test.go b/core/services/s4/postgres_orm_test.go index d26f082ce5b..660002a2e3b 100644 --- a/core/services/s4/postgres_orm_test.go +++ b/core/services/s4/postgres_orm_test.go @@ -10,7 +10,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" - "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/s4" "github.com/stretchr/testify/assert" @@ -20,8 +19,7 @@ func setupORM(t *testing.T, namespace string) s4.ORM { t.Helper() db := pgtest.NewSqlxDB(t) - lggr := logger.TestLogger(t) - orm := s4.NewPostgresORM(db, lggr, pgtest.NewQConfig(true), s4.SharedTableName, namespace) + orm := s4.NewPostgresORM(db, s4.SharedTableName, namespace) t.Cleanup(func() { assert.NoError(t, db.Close()) @@ -59,64 +57,67 @@ func TestNewPostgresOrm(t *testing.T) { func TestPostgresORM_UpdateAndGet(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") rows := generateTestRows(t, 10) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) row.Version++ - err = orm.Update(row) + err = orm.Update(ctx, row) assert.NoError(t, err) - err = orm.Update(row) + err = orm.Update(ctx, row) if !row.Confirmed { assert.ErrorIs(t, err, s4.ErrVersionTooLow) } } for _, row := range rows { - gotRow, err := orm.Get(row.Address, row.SlotId) + gotRow, err := orm.Get(ctx, row.Address, row.SlotId) assert.NoError(t, err) assert.Equal(t, row, gotRow) } rows = generateTestRows(t, 1) - _, err := orm.Get(rows[0].Address, rows[0].SlotId) + _, err := orm.Get(ctx, rows[0].Address, rows[0].SlotId) assert.ErrorIs(t, err, s4.ErrNotFound) } func TestPostgresORM_UpdateSimpleFlow(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") row := generateTestRows(t, 1)[0] // user sends a new version - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // OCR round confirms it row.Confirmed = true - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // user sends a higher version (unconfirmed) row.Version++ row.Confirmed = false - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // and again, before OCR has a chance to confirm row.Version++ - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // user tries to send a lower version row.Version-- - assert.Error(t, orm.Update(row)) + assert.Error(t, orm.Update(ctx, row)) } func TestPostgresORM_DeleteExpired(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") @@ -125,17 +126,17 @@ func TestPostgresORM_DeleteExpired(t *testing.T) { rows := generateTestRows(t, total) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } - deleted, err := orm.DeleteExpired(expired, time.Now().Add(2*time.Hour).UTC()) + deleted, err := orm.DeleteExpired(ctx, expired, time.Now().Add(2*time.Hour).UTC()) assert.NoError(t, err) assert.Equal(t, int64(expired), deleted) count := 0 for _, row := range rows { - _, err := orm.Get(row.Address, row.SlotId) + _, err := orm.Get(ctx, row.Address, row.SlotId) if !errors.Is(err, s4.ErrNotFound) { count++ } @@ -149,21 +150,23 @@ func TestPostgresORM_GetSnapshot(t *testing.T) { orm := setupORM(t, "test") t.Run("no rows", func(t *testing.T) { - rows, err := orm.GetSnapshot(s4.NewFullAddressRange()) + ctx := testutils.Context(t) + rows, err := orm.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Empty(t, rows) }) t.Run("with rows", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestRows(t, 100) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } t.Run("full range", func(t *testing.T) { - snapshot, err := orm.GetSnapshot(s4.NewFullAddressRange()) + snapshot, err := orm.GetSnapshot(testutils.Context(t), s4.NewFullAddressRange()) assert.NoError(t, err) assert.Equal(t, len(rows), len(snapshot)) @@ -188,7 +191,7 @@ func TestPostgresORM_GetSnapshot(t *testing.T) { t.Run("half range", func(t *testing.T) { ar, err := s4.NewInitialAddressRangeForIntervals(2) assert.NoError(t, err) - snapshot, err := orm.GetSnapshot(ar) + snapshot, err := orm.GetSnapshot(testutils.Context(t), ar) assert.NoError(t, err) for _, sr := range snapshot { assert.True(t, ar.Contains(sr.Address)) @@ -203,21 +206,23 @@ func TestPostgresORM_GetUnconfirmedRows(t *testing.T) { orm := setupORM(t, "test") t.Run("no rows", func(t *testing.T) { - rows, err := orm.GetUnconfirmedRows(5) + ctx := testutils.Context(t) + rows, err := orm.GetUnconfirmedRows(ctx, 5) assert.NoError(t, err) assert.Empty(t, rows) }) t.Run("with rows", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestRows(t, 10) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) time.Sleep(testutils.TestInterval / 10) } - gotRows, err := orm.GetUnconfirmedRows(5) + gotRows, err := orm.GetUnconfirmedRows(ctx, 5) assert.NoError(t, err) assert.Len(t, gotRows, 5) @@ -229,6 +234,7 @@ func TestPostgresORM_GetUnconfirmedRows(t *testing.T) { func TestPostgresORM_Namespace(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) ormA := setupORM(t, "a") ormB := setupORM(t, "b") @@ -237,44 +243,45 @@ func TestPostgresORM_Namespace(t *testing.T) { rowsA := generateTestRows(t, n) rowsB := generateTestRows(t, n) for i := 0; i < n; i++ { - err := ormA.Update(rowsA[i]) + err := ormA.Update(ctx, rowsA[i]) assert.NoError(t, err) - err = ormB.Update(rowsB[i]) + err = ormB.Update(ctx, rowsB[i]) assert.NoError(t, err) } - urowsA, err := ormA.GetUnconfirmedRows(n) + urowsA, err := ormA.GetUnconfirmedRows(ctx, n) assert.NoError(t, err) assert.Len(t, urowsA, n/2) - urowsB, err := ormB.GetUnconfirmedRows(n) + urowsB, err := ormB.GetUnconfirmedRows(ctx, n) assert.NoError(t, err) assert.Len(t, urowsB, n/2) - _, err = ormB.DeleteExpired(n, time.Now().UTC()) + _, err = ormB.DeleteExpired(ctx, n, time.Now().UTC()) assert.NoError(t, err) - snapshotA, err := ormA.GetSnapshot(s4.NewFullAddressRange()) + snapshotA, err := ormA.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Len(t, snapshotA, n) } func TestPostgresORM_BigIntVersion(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") row := generateTestRows(t, 1)[0] row.Version = math.MaxUint64 - 10 - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) row.Version++ - err = orm.Update(row) + err = orm.Update(ctx, row) assert.NoError(t, err) - gotRow, err := orm.Get(row.Address, row.SlotId) + gotRow, err := orm.Get(ctx, row.Address, row.SlotId) assert.NoError(t, err) assert.Equal(t, row, gotRow) } diff --git a/core/services/s4/storage.go b/core/services/s4/storage.go index 02ba9c7bd50..1af14ec269f 100644 --- a/core/services/s4/storage.go +++ b/core/services/s4/storage.go @@ -5,11 +5,10 @@ import ( "github.com/jonboulle/clockwork" + "github.com/ethereum/go-ethereum/common" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" - - "github.com/ethereum/go-ethereum/common" ) // Constraints specifies the global storage constraints. @@ -95,7 +94,7 @@ func (s *storage) Get(ctx context.Context, key *Key) (*Record, *Metadata, error) } bigAddress := big.New(key.Address.Big()) - row, err := s.orm.Get(bigAddress, key.SlotId, pg.WithParentCtx(ctx)) + row, err := s.orm.Get(ctx, bigAddress, key.SlotId) if err != nil { return nil, nil, err } @@ -125,7 +124,7 @@ func (s *storage) List(ctx context.Context, address common.Address) ([]*Snapshot if err != nil { return nil, err } - return s.orm.GetSnapshot(sar, pg.WithParentCtx(ctx)) + return s.orm.GetSnapshot(ctx, sar) } func (s *storage) Put(ctx context.Context, key *Key, record *Record, signature []byte) error { @@ -161,5 +160,5 @@ func (s *storage) Put(ctx context.Context, key *Key, record *Record, signature [ copy(row.Payload, record.Payload) copy(row.Signature, signature) - return s.orm.Update(row, pg.WithParentCtx(ctx)) + return s.orm.Update(ctx, row) } diff --git a/core/services/s4/storage_test.go b/core/services/s4/storage_test.go index b643609f449..8deb23bb979 100644 --- a/core/services/s4/storage_test.go +++ b/core/services/s4/storage_test.go @@ -53,7 +53,7 @@ func TestStorage_Errors(t *testing.T) { SlotId: 1, Version: 0, } - ormMock.On("Get", big.New(key.Address.Big()), key.SlotId, mock.Anything).Return(nil, s4.ErrNotFound) + ormMock.On("Get", mock.Anything, big.New(key.Address.Big()), key.SlotId).Return(nil, s4.ErrNotFound) _, _, err := storage.Get(testutils.Context(t), key) assert.ErrorIs(t, err, s4.ErrNotFound) }) @@ -181,7 +181,7 @@ func TestStorage_PutAndGet(t *testing.T) { assert.NoError(t, err) ormMock.On("Update", mock.Anything, mock.Anything).Return(nil) - ormMock.On("Get", big.New(key.Address.Big()), uint(2), mock.Anything).Return(&s4.Row{ + ormMock.On("Get", mock.Anything, big.New(key.Address.Big()), uint(2)).Return(&s4.Row{ Address: big.New(key.Address.Big()), SlotId: key.SlotId, Version: key.Version, @@ -221,7 +221,7 @@ func TestStorage_List(t *testing.T) { addressRange, err := s4.NewSingleAddressRange(big.New(address.Big())) assert.NoError(t, err) - ormMock.On("GetSnapshot", addressRange, mock.Anything).Return(ormRows, nil) + ormMock.On("GetSnapshot", mock.Anything, addressRange).Return(ormRows, nil) rows, err := storage.List(testutils.Context(t), address) require.NoError(t, err)