diff --git a/relayer/pkg/chainlink/txm/nonce.go b/relayer/pkg/chainlink/txm/nonce.go deleted file mode 100644 index bd7266bff..000000000 --- a/relayer/pkg/chainlink/txm/nonce.go +++ /dev/null @@ -1,135 +0,0 @@ -package txm - -import ( - "context" - "fmt" - "sync" - - "github.com/NethermindEth/juno/core/felt" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/services" - "github.com/smartcontractkit/chainlink-common/pkg/utils" -) - -//go:generate mockery --name NonceManagerClient --output ./mocks/ --case=underscore --filename nonce_manager_client.go - -type NonceManagerClient interface { - AccountNonce(context.Context, *felt.Felt) (*felt.Felt, error) -} - -type NonceManager interface { - services.Service - Register(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error - NextSequence(address *felt.Felt) (*felt.Felt, error) - IncrementNextSequence(address *felt.Felt, currentNonce *felt.Felt) error - // Resets local account nonce to on-chain account nonce - Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error -} - -var _ NonceManager = (*nonceManager)(nil) - -type nonceManager struct { - starter utils.StartStopOnce - lggr logger.Logger - - n map[string]*felt.Felt // map public key to nonce - lock sync.RWMutex -} - -func NewNonceManager(lggr logger.Logger) *nonceManager { - return &nonceManager{ - lggr: logger.Named(lggr, "NonceManager"), - n: map[string]*felt.Felt{}, - } -} - -func (nm *nonceManager) Start(ctx context.Context) error { - return nm.starter.StartOnce(nm.Name(), func() error { return nil }) -} - -func (nm *nonceManager) Ready() error { - return nm.starter.Ready() -} - -func (nm *nonceManager) Name() string { - return nm.lggr.Name() -} - -func (nm *nonceManager) Close() error { - return nm.starter.StopOnce(nm.Name(), func() error { return nil }) -} - -func (nm *nonceManager) HealthReport() map[string]error { - return map[string]error{nm.Name(): nm.starter.Healthy()} -} - -func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error { - nm.lock.Lock() - defer nm.lock.Unlock() - - if err := nm.validate(publicKey); err != nil { - return err - } - - n, err := client.AccountNonce(ctx, address) - if err != nil { - return err - } - - nm.n[publicKey.String()] = n - - return nil -} - -// Register is used because we cannot pre-fetch nonces. the pubkey is known before hand, but the account address is not known until a job is started and sends a tx -func (nm *nonceManager) Register(ctx context.Context, addr *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error { - nm.lock.Lock() - defer nm.lock.Unlock() - - _, exists := nm.n[publicKey.String()] - if !exists { - n, err := client.AccountNonce(ctx, addr) - if err != nil { - return err - } - nm.n[publicKey.String()] = n - } - - return nil -} - -func (nm *nonceManager) NextSequence(publicKey *felt.Felt) (*felt.Felt, error) { - nm.lock.RLock() - defer nm.lock.RUnlock() - - if err := nm.validate(publicKey); err != nil { - return nil, err - } - - return nm.n[publicKey.String()], nil -} - -func (nm *nonceManager) IncrementNextSequence(publicKey *felt.Felt, currentNonce *felt.Felt) error { - nm.lock.Lock() - defer nm.lock.Unlock() - - if err := nm.validate(publicKey); err != nil { - return err - } - - n := nm.n[publicKey.String()] - if n.Cmp(currentNonce) != 0 { - return fmt.Errorf("mismatched nonce for %s: %s (expected) != %s (got)", publicKey, n, currentNonce) - } - one := new(felt.Felt).SetUint64(1) - nm.n[publicKey.String()] = new(felt.Felt).Add(n, one) - return nil -} - -func (nm *nonceManager) validate(publicKey *felt.Felt) error { - if _, exists := nm.n[publicKey.String()]; !exists { - return fmt.Errorf("nonce tracking does not exist for key: %s", publicKey.String()) - } - return nil -} diff --git a/relayer/pkg/chainlink/txm/nonce_test.go b/relayer/pkg/chainlink/txm/nonce_test.go deleted file mode 100644 index 9748b81c1..000000000 --- a/relayer/pkg/chainlink/txm/nonce_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package txm_test - -import ( - "fmt" - "math/big" - "testing" - - "github.com/NethermindEth/juno/core/felt" - starknetutils "github.com/NethermindEth/starknet.go/utils" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" - "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/txm" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/txm/mocks" -) - -func newTestNonceManager(t *testing.T, initNonce *felt.Felt) (txm.NonceManager, *felt.Felt, func()) { - // setup - c := mocks.NewNonceManagerClient(t) - lggr := logger.Test(t) - nm := txm.NewNonceManager(lggr) - - // mock returns - keyHash, err := starknetutils.HexToFelt("0x0") - require.NoError(t, err) - c.On("AccountNonce", mock.Anything, mock.Anything).Return(initNonce, nil).Once() - - require.NoError(t, nm.Start(tests.Context(t))) - require.NoError(t, nm.Register(tests.Context(t), keyHash, keyHash, c)) - - return nm, keyHash, func() { require.NoError(t, nm.Close()) } -} - -func TestNonceManager_NextSequence(t *testing.T) { - t.Parallel() - - initNonce := new(felt.Felt).SetUint64(10) - nm, k, stop := newTestNonceManager(t, initNonce) - defer stop() - - // get with proper inputs - nonce, err := nm.NextSequence(k) - require.NoError(t, err) - assert.Equal(t, initNonce, nonce) - - // should fail with invalid address - randAddr1 := starknetutils.BigIntToFelt(big.NewInt(1)) - _, err = nm.NextSequence(randAddr1) - require.Error(t, err) - assert.Contains(t, err.Error(), fmt.Sprintf("nonce tracking does not exist for key: %s", randAddr1.String())) -} - -func TestNonceManager_IncrementNextSequence(t *testing.T) { - t.Parallel() - - initNonce := new(felt.Felt).SetUint64(10) - nm, k, stop := newTestNonceManager(t, initNonce) - defer stop() - - one := new(felt.Felt).SetUint64(1) - initMinusOne := new(felt.Felt).Sub(initNonce, one) - initPlusOne := new(felt.Felt).Add(initNonce, one) - - // should fail if nonce is lower then expected - err := nm.IncrementNextSequence(k, initMinusOne) - require.Error(t, err) - assert.Contains(t, err.Error(), fmt.Sprintf("mismatched nonce for %s: %s (expected) != %s (got)", k, initNonce, initMinusOne)) - - // increment with proper inputs - err = nm.IncrementNextSequence(k, initNonce) - require.NoError(t, err) - next, err := nm.NextSequence(k) - require.NoError(t, err) - assert.Equal(t, initPlusOne, next) - - // should fail with invalid address - randAddr1 := starknetutils.BigIntToFelt(big.NewInt(1)) - err = nm.IncrementNextSequence(randAddr1, initPlusOne) - require.Error(t, err) - assert.Contains(t, err.Error(), fmt.Sprintf("nonce tracking does not exist for key: %s", randAddr1.String())) - - // verify it didnt get changed by any erroring calls - next, err = nm.NextSequence(k) - require.NoError(t, err) - assert.Equal(t, initPlusOne, next) -} diff --git a/relayer/pkg/chainlink/txm/txm.go b/relayer/pkg/chainlink/txm/txm.go index b4e431d8d..7a2a45dd6 100644 --- a/relayer/pkg/chainlink/txm/txm.go +++ b/relayer/pkg/chainlink/txm/txm.go @@ -13,7 +13,6 @@ import ( starknetaccount "github.com/NethermindEth/starknet.go/account" starknetrpc "github.com/NethermindEth/starknet.go/rpc" starknetutils "github.com/NethermindEth/starknet.go/utils" - "golang.org/x/exp/maps" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/loop" @@ -52,7 +51,6 @@ type starktxm struct { queue chan Tx ks KeystoreAdapter cfg Config - nonce NonceManager client *utils.LazyLoad[*starknet.Client] feederClient *utils.LazyLoad[*starknet.FeederClient] @@ -71,7 +69,6 @@ func New(lggr logger.Logger, keystore loop.Keystore, cfg Config, getClient func( cfg: cfg, accountStore: NewAccountStore(), } - txm.nonce = NewNonceManager(txm.lggr) return txm, nil } @@ -82,11 +79,7 @@ func (txm *starktxm) Name() string { func (txm *starktxm) Start(ctx context.Context) error { return txm.starter.StartOnce("starktxm", func() error { - if err := txm.nonce.Start(ctx); err != nil { - return err - } - - txm.done.Add(2) // waitgroup: tx sender + confirmer + txm.done.Add(2) // waitgroup: broadcast loop and confirm loop go txm.broadcastLoop() go txm.confirmLoop() @@ -183,6 +176,20 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun txm.client.Reset() return txhash, fmt.Errorf("broadcast: failed to fetch client: %+w", err) } + + txStore := txm.accountStore.GetTxStore(accountAddress) + if txStore == nil { + initialNonce, accountNonceErr := client.AccountNonce(ctx, accountAddress) + if accountNonceErr != nil { + return txhash, fmt.Errorf("failed to check account nonce during TxStore creation: %+w", accountNonceErr) + } + newTxStore, createErr := txm.accountStore.CreateTxStore(accountAddress, initialNonce) + if createErr != nil { + return txhash, fmt.Errorf("failed to create TxStore: %+w", createErr) + } + txStore = newTxStore + } + // create new account cairoVersion := 2 account, err := starknetaccount.NewAccount(client.Provider, accountAddress, publicKey.String(), txm.ks, cairoVersion) @@ -239,10 +246,8 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun txm.lggr.Infow("Set resource bounds", "L1MaxAmount", tx.ResourceBounds.L1Gas.MaxAmount, "L1MaxPricePerUnit", tx.ResourceBounds.L1Gas.MaxPricePerUnit) - nonce, err := txm.nonce.NextSequence(publicKey) - if err != nil { - return txhash, fmt.Errorf("failed to get nonce: %+w", err) - } + nonce := txStore.GetNextNonce() + tx.Nonce = nonce // Re-sign transaction now that we've determined MaxFee // TODO: SignInvokeTransaction for V3 is missing so we do it by hand @@ -279,11 +284,11 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun // update nonce if transaction is successful txhash = res.TransactionHash.String() - err = errors.Join( - txm.nonce.IncrementNextSequence(publicKey, nonce), - txm.accountStore.GetTxStore(accountAddress).Save(nonce, txhash, &call, publicKey), - ) - return txhash, err + err = txStore.AddUnconfirmed(nonce, txhash, call, publicKey) + if err != nil { + return txhash, fmt.Errorf("failed to add unconfirmed tx: %+w", err) + } + return txhash, nil } func (txm *starktxm) confirmLoop() { @@ -307,10 +312,16 @@ func (txm *starktxm) confirmLoop() { break } - hashes := txm.accountStore.GetAllUnconfirmed() - for addr := range hashes { - for i := range hashes[addr] { - hash := hashes[addr][i] + allUnconfirmedTxs := txm.accountStore.GetAllUnconfirmed() + for accountAddressStr, unconfirmedTxs := range allUnconfirmedTxs { + accountAddress, err := new(felt.Felt).SetString(accountAddressStr) + // this should never occur because the acccount address string key was created from the account address felt. + if err != nil { + txm.lggr.Errorw("could not recreate account address felt", "accountAddress", accountAddressStr) + continue + } + for _, unconfirmedTx := range unconfirmedTxs { + hash := unconfirmedTx.Hash f, err := starknetutils.HexToFelt(hash) if err != nil { txm.lggr.Errorw("invalid felt value", "hash", hash) @@ -332,8 +343,8 @@ func (txm *starktxm) confirmLoop() { // any finalityStatus other than received if finalityStatus == starknetrpc.TxnStatus_Accepted_On_L1 || finalityStatus == starknetrpc.TxnStatus_Accepted_On_L2 || finalityStatus == starknetrpc.TxnStatus_Rejected { txm.lggr.Debugw(fmt.Sprintf("tx confirmed: %s", finalityStatus), "hash", hash, "finalityStatus", finalityStatus) - if err := txm.accountStore.GetTxStore(addr).Confirm(hash); err != nil { - txm.lggr.Errorw("failed to confirm tx in TxStore", "hash", hash, "sender", addr, "error", err) + if err := txm.accountStore.GetTxStore(accountAddress).Confirm(unconfirmedTx.Nonce, hash); err != nil { + txm.lggr.Errorw("failed to confirm tx in TxStore", "hash", hash, "accountAddress", accountAddress, "error", err) } } @@ -402,17 +413,6 @@ func (txm *starktxm) Enqueue(accountAddress, publicKey *felt.Felt, tx starknetrp return fmt.Errorf("enqueue: failed to sign: %+w", err) } - client, err := txm.client.Get() - if err != nil { - txm.client.Reset() - return fmt.Errorf("broadcast: failed to fetch client: %+w", err) - } - - // register account for nonce manager - if err := txm.nonce.Register(context.TODO(), accountAddress, publicKey, client); err != nil { - return fmt.Errorf("failed to register nonce: %+w", err) - } - select { case txm.queue <- Tx{publicKey: publicKey, accountAddress: accountAddress, call: tx}: // TODO fix naming here default: @@ -423,9 +423,5 @@ func (txm *starktxm) Enqueue(accountAddress, publicKey *felt.Felt, tx starknetrp } func (txm *starktxm) InflightCount() (queue int, unconfirmed int) { - list := maps.Values(txm.accountStore.GetAllInflightCount()) - for _, count := range list { - unconfirmed += count - } - return len(txm.queue), unconfirmed + return len(txm.queue), txm.accountStore.GetTotalInflightCount() } diff --git a/relayer/pkg/chainlink/txm/txstore.go b/relayer/pkg/chainlink/txm/txstore.go index a77f852b1..8cd6893f5 100644 --- a/relayer/pkg/chainlink/txm/txstore.go +++ b/relayer/pkg/chainlink/txm/txstore.go @@ -1,7 +1,6 @@ package txm import ( - "errors" "fmt" "sort" "sync" @@ -11,192 +10,168 @@ import ( "golang.org/x/exp/maps" ) +type UnconfirmedTx struct { + Hash string + PublicKey *felt.Felt + Nonce *felt.Felt + Call starknetrpc.FunctionCall +} + // TxStore tracks broadcast & unconfirmed txs per account address per chain id type TxStore struct { - lock sync.RWMutex - nonceToHash map[felt.Felt]string // map nonce to txhash - hashToNonce map[string]felt.Felt // map hash to nonce - hashToCall map[string]*starknetrpc.FunctionCall - hashToKey map[string]felt.Felt + lock sync.RWMutex + + nextNonce *felt.Felt + unconfirmedNonces map[string]*UnconfirmedTx } -func NewTxStore() *TxStore { +func NewTxStore(initialNonce *felt.Felt) *TxStore { return &TxStore{ - nonceToHash: map[felt.Felt]string{}, - hashToNonce: map[string]felt.Felt{}, - hashToCall: map[string]*starknetrpc.FunctionCall{}, - hashToKey: map[string]felt.Felt{}, + nextNonce: new(felt.Felt).Set(initialNonce), + unconfirmedNonces: map[string]*UnconfirmedTx{}, } } -func deepCopy(nonce *felt.Felt, call *starknetrpc.FunctionCall, publicKey *felt.Felt) (newNonce *felt.Felt, newCall *starknetrpc.FunctionCall, newPublicKey *felt.Felt) { - newNonce = new(felt.Felt).Set(nonce) - newPublicKey = new(felt.Felt).Set(publicKey) - newCall = copyCall(call) - return -} +func (s *TxStore) SetNextNonce(newNextNonce *felt.Felt) { + s.lock.Lock() + defer s.lock.Unlock() -func copyCall(call *starknetrpc.FunctionCall) *starknetrpc.FunctionCall { - copyCall := starknetrpc.FunctionCall{ - ContractAddress: new(felt.Felt).Set(call.ContractAddress), - EntryPointSelector: new(felt.Felt).Set(call.EntryPointSelector), - Calldata: []*felt.Felt{}, - } - for i := 0; i < len(call.Calldata); i++ { - copyCall.Calldata = append(copyCall.Calldata, new(felt.Felt).Set(call.Calldata[i])) + s.nextNonce = new(felt.Felt).Set(newNextNonce) + + // Remove any stale transactions with nonces greater than the new next nonce. + for nonceStr, tx := range s.unconfirmedNonces { + if tx.Nonce.Cmp(s.nextNonce) >= 0 { + delete(s.unconfirmedNonces, nonceStr) + } } - return ©Call } -func (s *TxStore) Save(nonce *felt.Felt, hash string, call *starknetrpc.FunctionCall, publicKey *felt.Felt) error { +func (s *TxStore) GetNextNonce() *felt.Felt { s.lock.Lock() defer s.lock.Unlock() + return new(felt.Felt).Set(s.nextNonce) +} - if h, exists := s.nonceToHash[*nonce]; exists { - return fmt.Errorf("nonce used: tried to use nonce (%s) for tx (%s), already used by (%s)", nonce, hash, h) +func (s *TxStore) AddUnconfirmed(nonce *felt.Felt, hash string, call starknetrpc.FunctionCall, publicKey *felt.Felt) error { + s.lock.Lock() + defer s.lock.Unlock() + + if nonce.Cmp(s.nextNonce) < 0 { + return fmt.Errorf("tried to add an unconfirmed tx at an old nonce: expected %s, got %s", s.nextNonce, nonce) } - if n, exists := s.hashToNonce[hash]; exists { - return fmt.Errorf("hash used: tried to use tx (%s) for nonce (%s), already used nonce (%s)", hash, nonce, &n) + if nonce.Cmp(s.nextNonce) > 0 { + return fmt.Errorf("tried to add an unconfirmed tx at a future nonce: expected %s, got %s", s.nextNonce, nonce) } - newNonce, newCall, newPublicKey := deepCopy(nonce, call, publicKey) - - // store hash - s.nonceToHash[*newNonce] = hash + nonceStr := nonce.String() + if h, exists := s.unconfirmedNonces[nonceStr]; exists { + return fmt.Errorf("nonce used: tried to use nonce (%s) for tx (%s), already used by (%s)", nonce, h.Hash, h) + } - s.hashToNonce[hash] = *newNonce - s.hashToCall[hash] = newCall - s.hashToKey[hash] = *newPublicKey + s.unconfirmedNonces[nonceStr] = &UnconfirmedTx{ + Nonce: new(felt.Felt).Set(nonce), + PublicKey: new(felt.Felt).Set(publicKey), + Hash: hash, + Call: call, + } + s.nextNonce = new(felt.Felt).Add(s.nextNonce, new(felt.Felt).SetUint64(1)) return nil } -func (s *TxStore) Confirm(hash string) error { +func (s *TxStore) Confirm(nonce *felt.Felt, hash string) error { s.lock.Lock() defer s.lock.Unlock() - if nonce, exists := s.hashToNonce[hash]; exists { - delete(s.nonceToHash, nonce) - - delete(s.hashToNonce, hash) - delete(s.hashToCall, hash) - delete(s.hashToKey, hash) - return nil + nonceStr := nonce.String() + unconfirmed, exists := s.unconfirmedNonces[nonceStr] + if !exists { + return fmt.Errorf("no such unconfirmed nonce: %s", nonce) } - return fmt.Errorf("tx hash does not exist - it may already be confirmed: %s", hash) -} - -func (s *TxStore) GetUnconfirmed() []string { - s.lock.RLock() - defer s.lock.RUnlock() - return maps.Values(s.nonceToHash) -} - -type UnconfirmedTx struct { - PublicKey *felt.Felt - Hash string - Nonce *felt.Felt - Call *starknetrpc.FunctionCall -} - -func (s *TxStore) GetSingleUnconfirmed(hash string) (tx UnconfirmedTx, err error) { - s.lock.RLock() - defer s.lock.RUnlock() - - n, hExists := s.hashToNonce[hash] - c, cExists := s.hashToCall[hash] - k, kExists := s.hashToKey[hash] - - if !hExists || !cExists || !kExists { - return tx, errors.New("datum not found in txstore") + // sanity check that the hash matches + if unconfirmed.Hash != hash { + return fmt.Errorf("unexpected tx hash: expected %s, got %s", unconfirmed.Hash, hash) } - - newNonce, newCall, newPublicKey := deepCopy(&n, c, &k) - - tx.Call = newCall - tx.Nonce = newNonce - tx.PublicKey = newPublicKey - tx.Hash = hash - - return tx, nil + delete(s.unconfirmedNonces, nonceStr) + return nil } -// Retrieve Unconfirmed Txs in their queued order (by nonce) -func (s *TxStore) GetUnconfirmedSorted() (txs []UnconfirmedTx) { +func (s *TxStore) GetUnconfirmed() []*UnconfirmedTx { s.lock.RLock() defer s.lock.RUnlock() - nonces := maps.Values(s.hashToNonce) - sort.Slice(nonces, func(i, j int) bool { - return nonces[i].Cmp(&nonces[j]) == -1 + unconfirmed := maps.Values(s.unconfirmedNonces) + sort.Slice(unconfirmed, func(i, j int) bool { + a := unconfirmed[i] + b := unconfirmed[j] + return a.Nonce.Cmp(b.Nonce) < 0 }) - for i := 0; i < len(nonces); i++ { - n := nonces[i] - h := s.nonceToHash[n] - k := s.hashToKey[h] - c := s.hashToCall[h] - - newNonce, newCall, newPublicKey := deepCopy(&n, c, &k) - - txs = append(txs, UnconfirmedTx{Hash: h, Nonce: newNonce, Call: newCall, PublicKey: newPublicKey}) - } - - return txs + return unconfirmed } func (s *TxStore) InflightCount() int { s.lock.RLock() defer s.lock.RUnlock() - return len(s.nonceToHash) + return len(s.unconfirmedNonces) } type AccountStore struct { - store map[*felt.Felt]*TxStore // map account address to txstore + store map[string]*TxStore // map account address to txstore lock sync.RWMutex } func NewAccountStore() *AccountStore { return &AccountStore{ - store: map[*felt.Felt]*TxStore{}, + store: map[string]*TxStore{}, } } -// GetTxStore returns the TxStore for the provided account, creating it if it does not exist. +func (c *AccountStore) CreateTxStore(accountAddress *felt.Felt, initialNonce *felt.Felt) (*TxStore, error) { + c.lock.Lock() + defer c.lock.Unlock() + addressStr := accountAddress.String() + _, ok := c.store[addressStr] + if ok { + return nil, fmt.Errorf("TxStore already exists: %s", accountAddress) + } + store := NewTxStore(initialNonce) + c.store[addressStr] = store + return store, nil +} + +// GetTxStore returns the TxStore for the provided account. func (c *AccountStore) GetTxStore(accountAddress *felt.Felt) *TxStore { c.lock.Lock() defer c.lock.Unlock() - store, ok := c.store[accountAddress] + store, ok := c.store[accountAddress.String()] if !ok { - store = NewTxStore() - c.store[accountAddress] = store + return nil } return store } -func (c *AccountStore) GetAllInflightCount() map[*felt.Felt]int { +func (c *AccountStore) GetTotalInflightCount() int { // use read lock for methods that read underlying data c.lock.RLock() defer c.lock.RUnlock() - list := map[*felt.Felt]int{} - - for i := range c.store { - list[i] = c.store[i].InflightCount() + count := 0 + for _, store := range c.store { + count += store.InflightCount() } - return list + return count } -func (c *AccountStore) GetAllUnconfirmed() map[*felt.Felt][]string { +func (c *AccountStore) GetAllUnconfirmed() map[string][]*UnconfirmedTx { // use read lock for methods that read underlying data c.lock.RLock() defer c.lock.RUnlock() - list := map[*felt.Felt][]string{} - - for i := range c.store { - list[i] = c.store[i].GetUnconfirmed() + allUnconfirmed := map[string][]*UnconfirmedTx{} + for accountAddressStr, store := range c.store { + allUnconfirmed[accountAddressStr] = store.GetUnconfirmed() } - return list + return allUnconfirmed } diff --git a/relayer/pkg/chainlink/txm/txstore_test.go b/relayer/pkg/chainlink/txm/txstore_test.go index 77d73f171..0e8a84655 100644 --- a/relayer/pkg/chainlink/txm/txstore_test.go +++ b/relayer/pkg/chainlink/txm/txstore_test.go @@ -18,59 +18,54 @@ func TestTxStore(t *testing.T) { t.Run("happypath", func(t *testing.T) { t.Parallel() - call := &starknetrpc.FunctionCall{ + call := starknetrpc.FunctionCall{ ContractAddress: new(felt.Felt).SetUint64(0), EntryPointSelector: new(felt.Felt).SetUint64(0), } - feltKey := new(felt.Felt).SetUint64(7) + nonce := new(felt.Felt).SetUint64(3) + publicKey := new(felt.Felt).SetUint64(7) - s := NewTxStore() + s := NewTxStore(nonce) assert.Equal(t, 0, s.InflightCount()) - require.NoError(t, s.Save(new(felt.Felt).SetUint64(0), "0x0", call, feltKey)) + require.NoError(t, s.AddUnconfirmed(nonce, "0x42", call, publicKey)) assert.Equal(t, 1, s.InflightCount()) - assert.Equal(t, []string{"0x0"}, s.GetUnconfirmed()) - require.NoError(t, s.Confirm("0x0")) + assert.Equal(t, 1, len(s.GetUnconfirmed())) + assert.Equal(t, "0x42", s.GetUnconfirmed()[0].Hash) + require.NoError(t, s.Confirm(nonce, "0x42")) assert.Equal(t, 0, s.InflightCount()) - assert.Equal(t, []string{}, s.GetUnconfirmed()) + assert.Equal(t, 0, len(s.GetUnconfirmed())) + assert.True(t, s.GetNextNonce().Cmp(new(felt.Felt).Add(nonce, new(felt.Felt).SetUint64(1))) == 0) }) t.Run("save", func(t *testing.T) { t.Parallel() // create - s := NewTxStore() + s := NewTxStore(new(felt.Felt).SetUint64(0)) - call := &starknetrpc.FunctionCall{ + call := starknetrpc.FunctionCall{ ContractAddress: new(felt.Felt).SetUint64(0), EntryPointSelector: new(felt.Felt).SetUint64(0), } - feltKey := new(felt.Felt).SetUint64(7) + publicKey := new(felt.Felt).SetUint64(7) // accepts tx in order - require.NoError(t, s.Save(new(felt.Felt).SetUint64(0), "0x0", call, feltKey)) + require.NoError(t, s.AddUnconfirmed(new(felt.Felt).SetUint64(0), "0x0", call, publicKey)) assert.Equal(t, 1, s.InflightCount()) - // accepts tx that skips a nonce - require.NoError(t, s.Save(new(felt.Felt).SetUint64(2), "0x2", call, feltKey)) - assert.Equal(t, 2, s.InflightCount()) - - // accepts tx that fills in the missing nonce - require.NoError(t, s.Save(new(felt.Felt).SetUint64(1), "0x1", call, feltKey)) - assert.Equal(t, 3, s.InflightCount()) + // reject tx that skips a nonce + require.ErrorContains(t, s.AddUnconfirmed(new(felt.Felt).SetUint64(2), "0x2", call, publicKey), "tried to add an unconfirmed tx at a future nonce") + assert.Equal(t, 1, s.InflightCount()) - // skip a nonce for later tests - require.NoError(t, s.Save(new(felt.Felt).SetUint64(4), "0x4", call, feltKey)) - assert.Equal(t, 4, s.InflightCount()) + // accepts a subsequent tx + require.NoError(t, s.AddUnconfirmed(new(felt.Felt).SetUint64(1), "0x1", call, publicKey)) + assert.Equal(t, 2, s.InflightCount()) // reject already in use nonce - require.ErrorContains(t, s.Save(new(felt.Felt).SetUint64(4), "0xskip", call, feltKey), "nonce used: tried to use nonce (0x4) for tx (0xskip), already used by (0x4)") - assert.Equal(t, 4, s.InflightCount()) - - // reject already in use tx hash - require.ErrorContains(t, s.Save(new(felt.Felt).SetUint64(5), "0x0", call, feltKey), "hash used: tried to use tx (0x0) for nonce (0x5), already used nonce (0x0)") - assert.Equal(t, 4, s.InflightCount()) + require.ErrorContains(t, s.AddUnconfirmed(new(felt.Felt).SetUint64(1), "0xskip", call, publicKey), "tried to add an unconfirmed tx at an old nonce") + assert.Equal(t, 2, s.InflightCount()) // race save var err0 error @@ -78,111 +73,120 @@ func TestTxStore(t *testing.T) { var wg sync.WaitGroup wg.Add(2) go func() { - err0 = s.Save(new(felt.Felt).SetUint64(10), "0x10", call, feltKey) + err0 = s.AddUnconfirmed(new(felt.Felt).SetUint64(2), "0x10", call, publicKey) wg.Done() }() go func() { - err1 = s.Save(new(felt.Felt).SetUint64(10), "0x10", call, feltKey) + err1 = s.AddUnconfirmed(new(felt.Felt).SetUint64(2), "0x10", call, publicKey) wg.Done() }() wg.Wait() - assert.True(t, !errors.Is(err0, err1) && (err0 != nil || err1 != nil)) + assert.True(t, !errors.Is(err0, err1) && ((err0 != nil && err1 == nil) || (err0 == nil && err1 != nil))) + assert.Equal(t, 3, s.InflightCount()) + + // check that returned unconfirmed tx's are sorted + unconfirmed := s.GetUnconfirmed() + assert.Equal(t, 3, len(unconfirmed)) + assert.Equal(t, 0, unconfirmed[0].Nonce.Cmp(new(felt.Felt).SetUint64(0))) + assert.Equal(t, 0, unconfirmed[1].Nonce.Cmp(new(felt.Felt).SetUint64(1))) + assert.Equal(t, 0, unconfirmed[2].Nonce.Cmp(new(felt.Felt).SetUint64(2))) + }) t.Run("confirm", func(t *testing.T) { t.Parallel() - call := &starknetrpc.FunctionCall{ + call := starknetrpc.FunctionCall{ ContractAddress: new(felt.Felt).SetUint64(0), EntryPointSelector: new(felt.Felt).SetUint64(0), } - feltKey := new(felt.Felt).SetUint64(7) + publicKey := new(felt.Felt).SetUint64(7) // init store - s := NewTxStore() - for i := 0; i < 5; i++ { - require.NoError(t, s.Save(new(felt.Felt).SetUint64(uint64(i)), "0x"+fmt.Sprintf("%d", i), call, feltKey)) + s := NewTxStore(new(felt.Felt).SetUint64(0)) + for i := 0; i < 6; i++ { + require.NoError(t, s.AddUnconfirmed(new(felt.Felt).SetUint64(uint64(i)), "0x"+fmt.Sprintf("%d", i), call, publicKey)) } // confirm in order - require.NoError(t, s.Confirm("0x0")) - require.NoError(t, s.Confirm("0x1")) - assert.Equal(t, 3, s.InflightCount()) + require.NoError(t, s.Confirm(new(felt.Felt).SetUint64(0), "0x0")) + require.NoError(t, s.Confirm(new(felt.Felt).SetUint64(1), "0x1")) + assert.Equal(t, 4, s.InflightCount()) // confirm out of order - require.NoError(t, s.Confirm("0x4")) - require.NoError(t, s.Confirm("0x3")) - require.NoError(t, s.Confirm("0x2")) - assert.Equal(t, 0, s.InflightCount()) + require.NoError(t, s.Confirm(new(felt.Felt).SetUint64(4), "0x4")) + require.NoError(t, s.Confirm(new(felt.Felt).SetUint64(3), "0x3")) + require.NoError(t, s.Confirm(new(felt.Felt).SetUint64(2), "0x2")) + assert.Equal(t, 1, s.InflightCount()) // confirm unknown/duplicate - require.ErrorContains(t, s.Confirm("0x2"), "tx hash does not exist - it may already be confirmed") - require.ErrorContains(t, s.Confirm("0xNULL"), "tx hash does not exist - it may already be confirmed") + require.ErrorContains(t, s.Confirm(new(felt.Felt).SetUint64(10), "0x10"), "no such unconfirmed nonce") + // confirm with incorrect hash + require.ErrorContains(t, s.Confirm(new(felt.Felt).SetUint64(5), "0x99"), "unexpected tx hash") // race confirm - require.NoError(t, s.Save(new(felt.Felt).SetUint64(10), "0x10", call, feltKey)) var err0 error var err1 error var wg sync.WaitGroup wg.Add(2) go func() { - err0 = s.Confirm("0x10") + err0 = s.Confirm(new(felt.Felt).SetUint64(5), "0x5") wg.Done() }() go func() { - err1 = s.Confirm("0x10") + err1 = s.Confirm(new(felt.Felt).SetUint64(5), "0x5") wg.Done() }() wg.Wait() - assert.True(t, !errors.Is(err0, err1) && (err0 != nil || err1 != nil)) + assert.True(t, !errors.Is(err0, err1) && ((err0 != nil && err1 == nil) || (err0 == nil && err1 != nil))) + assert.Equal(t, 0, s.InflightCount()) }) } -func TestChainTxStore(t *testing.T) { +func TestAccountStore(t *testing.T) { t.Parallel() - c := NewChainTxStore() + c := NewAccountStore() felt0 := new(felt.Felt).SetUint64(0) felt1 := new(felt.Felt).SetUint64(1) - feltKey := new(felt.Felt).SetUint64(2) - call := &starknetrpc.FunctionCall{ + store0, err := c.CreateTxStore(felt0, felt0) + require.NoError(t, err) + + store1, err := c.CreateTxStore(felt1, felt1) + require.NoError(t, err) + + _, err = c.CreateTxStore(felt0, felt0) + require.ErrorContains(t, err, "TxStore already exists") + + assert.Equal(t, store0, c.GetTxStore(felt0)) + assert.Equal(t, store1, c.GetTxStore(felt1)) + + assert.Equal(t, c.GetTotalInflightCount(), 0) + + publicKey := new(felt.Felt).SetUint64(2) + + call := starknetrpc.FunctionCall{ ContractAddress: new(felt.Felt).SetUint64(0), EntryPointSelector: new(felt.Felt).SetUint64(0), } - // automatically save the from address - require.NoError(t, c.Save(felt0, new(felt.Felt).SetUint64(0), "0x0", call, feltKey)) - - // reject saving for existing address and reused hash & nonce - // error messages are tested within TestTxStore - assert.Error(t, c.Save(felt0, new(felt.Felt).SetUint64(0), "0x1", call, feltKey), "nonce exists") - assert.Error(t, c.Save(felt0, new(felt.Felt).SetUint64(1), "0x0", call, feltKey), "hash exists") - // inflight count - count, exists := c.GetAllInflightCount()[felt0] - require.True(t, exists) - assert.Equal(t, 1, count) - _, exists = c.GetAllInflightCount()[felt1] - require.False(t, exists) + require.NoError(t, store0.AddUnconfirmed(felt0, "0x0", call, publicKey)) + require.NoError(t, store1.AddUnconfirmed(felt1, "0x1", call, publicKey)) + assert.Equal(t, c.GetTotalInflightCount(), 2) // get unconfirmed - list := c.GetAllUnconfirmed() - assert.Equal(t, 1, len(list)) - hashes, ok := list[felt0] + m := c.GetAllUnconfirmed() + assert.Equal(t, 2, len(m)) + hashes0, ok := m[felt0.String()] + assert.True(t, ok) + assert.Equal(t, len(hashes0), 1) + assert.Equal(t, hashes0[0].Hash, "0x0") + hashes1, ok := m[felt1.String()] assert.True(t, ok) - assert.Equal(t, []string{"0x0"}, hashes) - - // confirm - assert.NoError(t, c.Confirm(felt0, "0x0")) - assert.ErrorContains(t, c.Confirm(felt1, "0x0"), "from address does not exist") - assert.Error(t, c.Confirm(felt0, "0x1")) - list = c.GetAllUnconfirmed() - assert.Equal(t, 1, len(list)) - assert.Equal(t, 0, len(list[felt0])) - count, exists = c.GetAllInflightCount()[felt0] - assert.True(t, exists) - assert.Equal(t, 0, count) + assert.Equal(t, len(hashes1), 1) + assert.Equal(t, hashes1[0].Hash, "0x1") }