diff --git a/core/chains/evm/txm/mocks/tx_store.go b/core/chains/evm/txm/mocks/tx_store.go index c164e43a6c5..cbbe48c5156 100644 --- a/core/chains/evm/txm/mocks/tx_store.go +++ b/core/chains/evm/txm/mocks/tx_store.go @@ -4,6 +4,7 @@ package mocks import ( context "context" + big "math/big" common "github.com/ethereum/go-ethereum/common" @@ -25,17 +26,17 @@ func (_m *TxStore) EXPECT() *TxStore_Expecter { return &TxStore_Expecter{mock: &_m.Mock} } -// AbandonPendingTransactions provides a mock function with given fields: _a0, _a1 -func (_m *TxStore) AbandonPendingTransactions(_a0 context.Context, _a1 common.Address) error { - ret := _m.Called(_a0, _a1) +// Abandon provides a mock function with given fields: _a0, _a1, _a2 +func (_m *TxStore) Abandon(_a0 context.Context, _a1 *big.Int, _a2 common.Address) error { + ret := _m.Called(_a0, _a1, _a2) if len(ret) == 0 { - panic("no return value specified for AbandonPendingTransactions") + panic("no return value specified for Abandon") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, common.Address) error); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int, common.Address) error); ok { + r0 = rf(_a0, _a1, _a2) } else { r0 = ret.Error(0) } @@ -43,31 +44,32 @@ func (_m *TxStore) AbandonPendingTransactions(_a0 context.Context, _a1 common.Ad return r0 } -// TxStore_AbandonPendingTransactions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AbandonPendingTransactions' -type TxStore_AbandonPendingTransactions_Call struct { +// TxStore_Abandon_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Abandon' +type TxStore_Abandon_Call struct { *mock.Call } -// AbandonPendingTransactions is a helper method to define mock.On call +// Abandon is a helper method to define mock.On call // - _a0 context.Context -// - _a1 common.Address -func (_e *TxStore_Expecter) AbandonPendingTransactions(_a0 interface{}, _a1 interface{}) *TxStore_AbandonPendingTransactions_Call { - return &TxStore_AbandonPendingTransactions_Call{Call: _e.mock.On("AbandonPendingTransactions", _a0, _a1)} +// - _a1 *big.Int +// - _a2 common.Address +func (_e *TxStore_Expecter) Abandon(_a0 interface{}, _a1 interface{}, _a2 interface{}) *TxStore_Abandon_Call { + return &TxStore_Abandon_Call{Call: _e.mock.On("Abandon", _a0, _a1, _a2)} } -func (_c *TxStore_AbandonPendingTransactions_Call) Run(run func(_a0 context.Context, _a1 common.Address)) *TxStore_AbandonPendingTransactions_Call { +func (_c *TxStore_Abandon_Call) Run(run func(_a0 context.Context, _a1 *big.Int, _a2 common.Address)) *TxStore_Abandon_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(common.Address)) + run(args[0].(context.Context), args[1].(*big.Int), args[2].(common.Address)) }) return _c } -func (_c *TxStore_AbandonPendingTransactions_Call) Return(_a0 error) *TxStore_AbandonPendingTransactions_Call { +func (_c *TxStore_Abandon_Call) Return(_a0 error) *TxStore_Abandon_Call { _c.Call.Return(_a0) return _c } -func (_c *TxStore_AbandonPendingTransactions_Call) RunAndReturn(run func(context.Context, common.Address) error) *TxStore_AbandonPendingTransactions_Call { +func (_c *TxStore_Abandon_Call) RunAndReturn(run func(context.Context, *big.Int, common.Address) error) *TxStore_Abandon_Call { _c.Call.Return(run) return _c } @@ -357,6 +359,64 @@ func (_c *TxStore_FetchUnconfirmedTransactionAtNonceWithCount_Call) RunAndReturn return _c } +// FindLatestNonce provides a mock function with given fields: _a0, _a1, _a2 +func (_m *TxStore) FindLatestNonce(_a0 context.Context, _a1 common.Address, _a2 *big.Int) (uint64, error) { + ret := _m.Called(_a0, _a1, _a2) + + if len(ret) == 0 { + panic("no return value specified for FindLatestNonce") + } + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int) (uint64, error)); ok { + return rf(_a0, _a1, _a2) + } + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int) uint64); ok { + r0 = rf(_a0, _a1, _a2) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, common.Address, *big.Int) error); ok { + r1 = rf(_a0, _a1, _a2) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// TxStore_FindLatestNonce_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FindLatestNonce' +type TxStore_FindLatestNonce_Call struct { + *mock.Call +} + +// FindLatestNonce is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 common.Address +// - _a2 *big.Int +func (_e *TxStore_Expecter) FindLatestNonce(_a0 interface{}, _a1 interface{}, _a2 interface{}) *TxStore_FindLatestNonce_Call { + return &TxStore_FindLatestNonce_Call{Call: _e.mock.On("FindLatestNonce", _a0, _a1, _a2)} +} + +func (_c *TxStore_FindLatestNonce_Call) Run(run func(_a0 context.Context, _a1 common.Address, _a2 *big.Int)) *TxStore_FindLatestNonce_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(common.Address), args[2].(*big.Int)) + }) + return _c +} + +func (_c *TxStore_FindLatestNonce_Call) Return(_a0 uint64, _a1 error) *TxStore_FindLatestNonce_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *TxStore_FindLatestNonce_Call) RunAndReturn(run func(context.Context, common.Address, *big.Int) (uint64, error)) *TxStore_FindLatestNonce_Call { + _c.Call.Return(run) + return _c +} + // MarkConfirmedAndReorgedTransactions provides a mock function with given fields: _a0, _a1, _a2 func (_m *TxStore) MarkConfirmedAndReorgedTransactions(_a0 context.Context, _a1 uint64, _a2 common.Address) ([]*types.Transaction, []uint64, error) { ret := _m.Called(_a0, _a1, _a2) diff --git a/core/chains/evm/txm/orchestrator.go b/core/chains/evm/txm/orchestrator.go index 7cf607bf4df..a3f873526ff 100644 --- a/core/chains/evm/txm/orchestrator.go +++ b/core/chains/evm/txm/orchestrator.go @@ -7,6 +7,7 @@ import ( "fmt" "math" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/google/uuid" @@ -28,14 +29,17 @@ import ( ) type OrchestratorTxStore interface { + Abandon(context.Context, *big.Int, common.Address) error Add(addresses ...common.Address) error FetchUnconfirmedTransactionAtNonceWithCount(context.Context, uint64, common.Address) (*txmtypes.Transaction, int, error) FindTxWithIdempotencyKey(context.Context, string) (*txmtypes.Transaction, error) + Remove(addresses ...common.Address) error } type OrchestratorKeystore interface { CheckEnabled(ctx context.Context, address common.Address, chainID *big.Int) error EnabledAddressesForChain(ctx context.Context, chainID *big.Int) (addresses []common.Address, err error) + SubscribeToKeyChanges(ctx context.Context) (ch chan struct{}, unsub func()) } type OrchestratorAttemptBuilder[ @@ -52,14 +56,18 @@ type Orchestrator[ HEAD types.Head[BLOCK_HASH], ] struct { services.StateMachine - lggr logger.SugaredLogger - chainID *big.Int - txm *Txm - txStore OrchestratorTxStore - fwdMgr *forwarders.FwdMgr - keystore OrchestratorKeystore - attemptBuilder OrchestratorAttemptBuilder[BLOCK_HASH, HEAD] - resumeCallback txmgr.ResumeCallback + lggr logger.SugaredLogger + chainID *big.Int + txm *Txm + txStore OrchestratorTxStore + fwdMgr *forwarders.FwdMgr + keystore OrchestratorKeystore + attemptBuilder OrchestratorAttemptBuilder[BLOCK_HASH, HEAD] + resumeCallback txmgr.ResumeCallback + enabledAddresses map[common.Address]bool + chReset chan *common.Address + chStop services.StopChan + wg *sync.WaitGroup } func NewTxmOrchestrator[BLOCK_HASH types.Hashable, HEAD types.Head[BLOCK_HASH]]( @@ -72,13 +80,17 @@ func NewTxmOrchestrator[BLOCK_HASH types.Hashable, HEAD types.Head[BLOCK_HASH]]( attemptBuilder OrchestratorAttemptBuilder[BLOCK_HASH, HEAD], ) *Orchestrator[BLOCK_HASH, HEAD] { return &Orchestrator[BLOCK_HASH, HEAD]{ - lggr: logger.Sugared(logger.Named(lggr, "Orchestrator")), - chainID: chainID, - txm: txm, - txStore: txStore, - keystore: keystore, - attemptBuilder: attemptBuilder, - fwdMgr: fwdMgr, + lggr: logger.Sugared(logger.Named(lggr, "Orchestrator")), + chainID: chainID, + txm: txm, + txStore: txStore, + keystore: keystore, + attemptBuilder: attemptBuilder, + fwdMgr: fwdMgr, + enabledAddresses: make(map[common.Address]bool), + chReset: make(chan *common.Address), + chStop: make(chan struct{}), + wg: new(sync.WaitGroup), } } @@ -93,6 +105,7 @@ func (o *Orchestrator[BLOCK_HASH, HEAD]) Start(ctx context.Context) error { return err } for _, address := range addresses { + o.enabledAddresses[address] = true err := o.txStore.Add(address) if err != nil { return err @@ -106,12 +119,19 @@ func (o *Orchestrator[BLOCK_HASH, HEAD]) Start(ctx context.Context) error { return fmt.Errorf("Orchestrator: ForwarderManager failed to start: %w", err) } } + + o.wg.Add(1) + go o.runLoop() return nil }) } func (o *Orchestrator[BLOCK_HASH, HEAD]) Close() (merr error) { return o.StopOnce("Orchestrator", func() error { + close(o.chReset) + close(o.chStop) + o.wg.Done() + if o.fwdMgr != nil { if err := o.fwdMgr.Close(); err != nil { merr = errors.Join(merr, fmt.Errorf("Orchestrator failed to stop ForwarderManager: %w", err)) @@ -127,6 +147,59 @@ func (o *Orchestrator[BLOCK_HASH, HEAD]) Close() (merr error) { }) } +func (o *Orchestrator[BLOCK_HASH, HEAD]) runLoop() { + defer o.wg.Done() + ctx, cancel := o.chStop.NewCtx() + defer cancel() + + keysChanged, unsub := o.keystore.SubscribeToKeyChanges(ctx) + defer unsub() + + for { + select { + case <-o.chStop: + return + case <-keysChanged: + updatedEnabledAddresses, err := o.keystore.EnabledAddressesForChain(ctx, o.chainID) + if err != nil { + o.lggr.Critical("Failed to reload key states after key change") + o.SvcErrBuffer.Append(err) + continue + } + o.lggr.Debugw("Keys changed, reloading", "enabledAddresses", updatedEnabledAddresses) + + // this will help with lookup + updatedEnabledAddressesMap := make(map[common.Address]bool) + for _, updatedAddress := range updatedEnabledAddresses { + updatedEnabledAddressesMap[updatedAddress] = true + if _, exists := o.enabledAddresses[updatedAddress]; !exists { + if err := o.txStore.Add(updatedAddress); err != nil { + o.lggr.Errorw("Failed to add address to InMemoryStore", "address", updatedAddress, "err", err) + continue + } + } + } + + for oldEnabledAddress := range o.enabledAddresses { + if !updatedEnabledAddressesMap[oldEnabledAddress] { + if err := o.txStore.Remove(oldEnabledAddress); err != nil { + o.lggr.Errorw("Failed to remove address from InMemoryStore", "address", oldEnabledAddress, "err", err) + continue + } + } + } + o.enabledAddresses = updatedEnabledAddressesMap + if err := o.txm.Reset(ctx, nil); err != nil { + o.lggr.Errorw("Failed to Reset TXM", "err", err) + } + case abandonAddress := <-o.chReset: + if err := o.txm.Reset(ctx, abandonAddress); err != nil { + o.lggr.Errorw("Failed to Reset TXM", "err", err) + } + } + } +} + func (o *Orchestrator[BLOCK_HASH, HEAD]) Trigger(addr common.Address) { o.txm.Trigger(addr) } @@ -143,16 +216,18 @@ func (o *Orchestrator[BLOCK_HASH, HEAD]) RegisterResumeCallback(fn txmgr.ResumeC o.resumeCallback = fn } -func (o *Orchestrator[BLOCK_HASH, HEAD]) Reset(addr common.Address, abandon bool) error { +func (o *Orchestrator[BLOCK_HASH, HEAD]) Reset(addr common.Address, abandon bool) (err error) { ok := o.IfStarted(func() { - if err := o.txm.Abandon(addr); err != nil { - o.lggr.Error(err) + if abandon { + o.chReset <- &addr + } else { + o.chReset <- nil } }) if !ok { return errors.New("Orchestrator not started yet") } - return nil + return err } func (o *Orchestrator[BLOCK_HASH, HEAD]) OnNewLongestChain(ctx context.Context, head HEAD) { diff --git a/core/chains/evm/txm/storage/inmemory_store.go b/core/chains/evm/txm/storage/inmemory_store.go index 918fa5ba740..9da36731a9d 100644 --- a/core/chains/evm/txm/storage/inmemory_store.go +++ b/core/chains/evm/txm/storage/inmemory_store.go @@ -52,7 +52,7 @@ func NewInMemoryStore(lggr logger.Logger, address common.Address, chainID *big.I } } -func (m *InMemoryStore) AbandonPendingTransactions() { +func (m *InMemoryStore) Abandon() { // TODO: append existing fatal transactions and cap the size m.Lock() defer m.Unlock() @@ -182,6 +182,24 @@ func (m *InMemoryStore) FetchUnconfirmedTransactionAtNonceWithCount(latestNonce return } +func (m *InMemoryStore) FindLatestNonce() (maxNonce uint64) { + m.RLock() + defer m.RUnlock() + + for _, tx := range m.UnconfirmedTransactions { + if tx.Nonce != nil { + maxNonce = max(*tx.Nonce, maxNonce) + } + } + + for _, tx := range m.ConfirmedTransactions { + if tx.Nonce != nil { + maxNonce = max(*tx.Nonce, maxNonce) + } + } + return +} + func (m *InMemoryStore) MarkConfirmedAndReorgedTransactions(latestNonce uint64) ([]*types.Transaction, []uint64, error) { m.Lock() defer m.Unlock() diff --git a/core/chains/evm/txm/storage/inmemory_store_manager.go b/core/chains/evm/txm/storage/inmemory_store_manager.go index 86abaf4b7cc..ffa39a1c267 100644 --- a/core/chains/evm/txm/storage/inmemory_store_manager.go +++ b/core/chains/evm/txm/storage/inmemory_store_manager.go @@ -28,9 +28,9 @@ func NewInMemoryStoreManager(lggr logger.Logger, chainID *big.Int) *InMemoryStor InMemoryStoreMap: inMemoryStoreMap} } -func (m *InMemoryStoreManager) AbandonPendingTransactions(_ context.Context, fromAddress common.Address) error { +func (m *InMemoryStoreManager) Abandon(_ context.Context, _ *big.Int, fromAddress common.Address) error { if store, exists := m.InMemoryStoreMap[fromAddress]; exists { - store.AbandonPendingTransactions() + store.Abandon() return nil } return fmt.Errorf(StoreNotFoundForAddress, fromAddress) @@ -46,6 +46,16 @@ func (m *InMemoryStoreManager) Add(addresses ...common.Address) (err error) { return } +func (m *InMemoryStoreManager) Remove(addresses ...common.Address) (err error) { + for _, address := range addresses { + if _, exists := m.InMemoryStoreMap[address]; !exists { + err = multierr.Append(err, fmt.Errorf("address %v doesn't exist in store manager", address)) + } + delete(m.InMemoryStoreMap, address) + } + return +} + func (m *InMemoryStoreManager) AppendAttemptToTransaction(_ context.Context, txNonce uint64, fromAddress common.Address, attempt *types.Attempt) error { if store, exists := m.InMemoryStoreMap[fromAddress]; exists { return store.AppendAttemptToTransaction(txNonce, attempt) @@ -82,6 +92,14 @@ func (m *InMemoryStoreManager) FetchUnconfirmedTransactionAtNonceWithCount(_ con return nil, 0, fmt.Errorf(StoreNotFoundForAddress, fromAddress) } +func (m *InMemoryStoreManager) FindLatestNonce(ctx context.Context, fromAddress common.Address, chainID *big.Int) (maxNonce uint64, err error) { + if store, exists := m.InMemoryStoreMap[fromAddress]; exists { + maxNonce = store.FindLatestNonce() + return + } + return 0, fmt.Errorf(StoreNotFoundForAddress, fromAddress) +} + func (m *InMemoryStoreManager) MarkConfirmedAndReorgedTransactions(_ context.Context, nonce uint64, fromAddress common.Address) (confirmedTxs []*types.Transaction, unconfirmedTxIDs []uint64, err error) { if store, exists := m.InMemoryStoreMap[fromAddress]; exists { confirmedTxs, unconfirmedTxIDs, err = store.MarkConfirmedAndReorgedTransactions(nonce) diff --git a/core/chains/evm/txm/storage/inmemory_store_test.go b/core/chains/evm/txm/storage/inmemory_store_test.go index 226cf284bba..163b8d6b514 100644 --- a/core/chains/evm/txm/storage/inmemory_store_test.go +++ b/core/chains/evm/txm/storage/inmemory_store_test.go @@ -34,7 +34,7 @@ func TestAbandonPendingTransactions(t *testing.T) { tx4, err := insertUnconfirmedTransaction(m, 4) require.NoError(t, err) - m.AbandonPendingTransactions() + m.Abandon() assert.Equal(t, txmgr.TxFatalError, tx1.State) assert.Equal(t, txmgr.TxFatalError, tx2.State) @@ -54,7 +54,7 @@ func TestAbandonPendingTransactions(t *testing.T) { tx4, err := insertConfirmedTransaction(m, 4) require.NoError(t, err) - m.AbandonPendingTransactions() + m.Abandon() assert.Equal(t, txmgr.TxFatalError, tx1.State) assert.Equal(t, txmgr.TxFatalError, tx2.State) diff --git a/core/chains/evm/txm/txm.go b/core/chains/evm/txm/txm.go index 683e8a45010..a0cdb825853 100644 --- a/core/chains/evm/txm/txm.go +++ b/core/chains/evm/txm/txm.go @@ -2,6 +2,7 @@ package txm import ( "context" + "errors" "fmt" "math/big" "sync" @@ -33,11 +34,12 @@ type Client interface { } type TxStore interface { - AbandonPendingTransactions(context.Context, common.Address) error + Abandon(context.Context, *big.Int, common.Address) error AppendAttemptToTransaction(context.Context, uint64, common.Address, *types.Attempt) error CreateEmptyUnconfirmedTransaction(context.Context, common.Address, uint64, uint64) (*types.Transaction, error) CreateTransaction(context.Context, *types.TxRequest) (*types.Transaction, error) FetchUnconfirmedTransactionAtNonceWithCount(context.Context, uint64, common.Address) (*types.Transaction, int, error) + FindLatestNonce(context.Context, common.Address, *big.Int) (uint64, error) MarkConfirmedAndReorgedTransactions(context.Context, uint64, common.Address) ([]*types.Transaction, []uint64, error) MarkUnconfirmedTransactionPurgeable(context.Context, uint64, common.Address) error UpdateTransactionBroadcast(context.Context, uint64, uint64, common.Hash, common.Address) error @@ -137,13 +139,39 @@ func (t *Txm) startAddress(address common.Address) { go t.backfillLoop(address) } +func (t *Txm) Reset(ctx context.Context, address *common.Address) (err error) { + if t.IfStarted(func() { + close(t.stopCh) + t.wg.Wait() + + if address != nil { + if err = t.txStore.Abandon(ctx, t.chainID, *address); err != nil { + return + } + } + + var addresses []common.Address + addresses, err = t.keystore.EnabledAddressesForChain(ctx, t.chainID) + if err != nil { + return + } + for _, address := range addresses { + t.startAddress(address) + } + }) { + return errors.New("Txm unstarted") + } + return +} + func (t *Txm) initializeNonce(ctx context.Context, address common.Address) { ctxWithTimeout, cancel := context.WithTimeout(ctx, pendingNonceDefaultTimeout) defer cancel() for { - pendingNonce, err := t.client.PendingNonceAt(ctxWithTimeout, address) - if err != nil { - t.lggr.Errorw("Error when fetching initial nonce", "address", address, "err", err) + pendingNonce, rErr := t.client.PendingNonceAt(ctxWithTimeout, address) + storedNonce, sErr := t.txStore.FindLatestNonce(ctxWithTimeout, address, t.chainID) + if rErr != nil || sErr != nil { + t.lggr.Errorw("Error when fetching initial nonce", "address", address, "requestError", rErr, "storageError", sErr) select { case <-time.After(pendingNonceRecheckInterval): case <-ctx.Done(): @@ -152,7 +180,7 @@ func (t *Txm) initializeNonce(ctx context.Context, address common.Address) { } continue } - t.setNonce(address, pendingNonce) + t.setNonce(address, max(pendingNonce, storedNonce)) t.lggr.Debugf("Set initial nonce for address: %v to %d", address, pendingNonce) return } @@ -190,12 +218,6 @@ func (t *Txm) Trigger(address common.Address) { } } -func (t *Txm) Abandon(address common.Address) error { - // TODO: restart txm - t.lggr.Infof("Dropping unstarted and unconfirmed transactions for address: %v", address) - return t.txStore.AbandonPendingTransactions(context.TODO(), address) -} - func (t *Txm) getNonce(address common.Address) uint64 { t.nonceMapMu.Lock() defer t.nonceMapMu.Unlock()