diff --git a/relayer/pkg/chainlink/txm/nonce.go b/relayer/pkg/chainlink/txm/nonce.go index 437b6d9c4..332ab2830 100644 --- a/relayer/pkg/chainlink/txm/nonce.go +++ b/relayer/pkg/chainlink/txm/nonce.go @@ -20,11 +20,11 @@ type NonceManagerClient interface { type NonceManager interface { services.Service - Register(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error - NextSequence(address *felt.Felt, chainID string) (*felt.Felt, error) - IncrementNextSequence(address *felt.Felt, chainID string, currentNonce *felt.Felt) error + Register(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error + NextSequence(address *felt.Felt) (*felt.Felt, error) + IncrementNextSequence(address *felt.Felt, currentNonce *felt.Felt) error // Resets local account nonce to on-chain account nonce - Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error + Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error } var _ NonceManager = (*nonceManager)(nil) @@ -33,14 +33,14 @@ type nonceManager struct { starter utils.StartStopOnce lggr logger.Logger - n map[string]map[string]*felt.Felt // map address + chain ID to nonce + n map[string]*felt.Felt // map public key to nonce lock sync.RWMutex } func NewNonceManager(lggr logger.Logger) *nonceManager { return &nonceManager{ lggr: logger.Named(lggr, "NonceManager"), - n: map[string]map[string]*felt.Felt{}, + n: map[string]*felt.Felt{}, } } @@ -64,8 +64,8 @@ func (nm *nonceManager) HealthReport() map[string]error { return map[string]error{nm.Name(): nm.starter.Healthy()} } -func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error { - if err := nm.validate(address, chainId); err != nil { +func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error { + if err := nm.validate(address); err != nil { return err } nm.lock.Lock() @@ -76,65 +76,58 @@ func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey return err } - nm.n[publicKey.String()][chainId] = n + nm.n[publicKey.String()] = n return nil } // Register is used because we cannot pre-fetch nonces. the pubkey is known before hand, but the account address is not known until a job is started and sends a tx -func (nm *nonceManager) Register(ctx context.Context, addr *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error { +func (nm *nonceManager) Register(ctx context.Context, addr *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error { nm.lock.Lock() defer nm.lock.Unlock() - addressNonces, exists := nm.n[publicKey.String()] - if !exists { - nm.n[publicKey.String()] = map[string]*felt.Felt{} - } - _, exists = addressNonces[chainId] + _, exists := nm.n[publicKey.String()] if !exists { n, err := client.AccountNonce(ctx, addr) if err != nil { return err } - nm.n[publicKey.String()][chainId] = n + nm.n[publicKey.String()] = n } return nil } -func (nm *nonceManager) NextSequence(publicKey *felt.Felt, chainId string) (*felt.Felt, error) { - if err := nm.validate(publicKey, chainId); err != nil { +func (nm *nonceManager) NextSequence(publicKey *felt.Felt) (*felt.Felt, error) { + if err := nm.validate(publicKey); err != nil { return nil, err } nm.lock.RLock() defer nm.lock.RUnlock() - return nm.n[publicKey.String()][chainId], nil + return nm.n[publicKey.String()], nil } -func (nm *nonceManager) IncrementNextSequence(publicKey *felt.Felt, chainId string, currentNonce *felt.Felt) error { - if err := nm.validate(publicKey, chainId); err != nil { +func (nm *nonceManager) IncrementNextSequence(publicKey *felt.Felt, currentNonce *felt.Felt) error { + if err := nm.validate(publicKey); err != nil { return err } nm.lock.Lock() defer nm.lock.Unlock() - n := nm.n[publicKey.String()][chainId] + n := nm.n[publicKey.String()] if n.Cmp(currentNonce) != 0 { return fmt.Errorf("mismatched nonce for %s: %s (expected) != %s (got)", publicKey, n, currentNonce) } one := new(felt.Felt).SetUint64(1) - nm.n[publicKey.String()][chainId] = new(felt.Felt).Add(n, one) + nm.n[publicKey.String()] = new(felt.Felt).Add(n, one) return nil } -func (nm *nonceManager) validate(publicKey *felt.Felt, chainId string) error { +func (nm *nonceManager) validate(publicKey *felt.Felt) error { nm.lock.RLock() defer nm.lock.RUnlock() if _, exists := nm.n[publicKey.String()]; !exists { return fmt.Errorf("nonce tracking does not exist for key: %s", publicKey.String()) } - if _, exists := nm.n[publicKey.String()][chainId]; !exists { - return fmt.Errorf("nonce does not exist for key: %s and chain: %s", publicKey.String(), chainId) - } return nil } diff --git a/relayer/pkg/chainlink/txm/nonce_test.go b/relayer/pkg/chainlink/txm/nonce_test.go index 3ccfab00e..b7b4a6478 100644 --- a/relayer/pkg/chainlink/txm/nonce_test.go +++ b/relayer/pkg/chainlink/txm/nonce_test.go @@ -19,7 +19,7 @@ import ( "github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/txm/mocks" ) -func newTestNonceManager(t *testing.T, chainID string, initNonce *felt.Felt) (txm.NonceManager, *felt.Felt, func()) { +func newTestNonceManager(t *testing.T, initNonce *felt.Felt) (txm.NonceManager, *felt.Felt, func()) { // setup c := mocks.NewNonceManagerClient(t) lggr := logger.Test(t) @@ -31,7 +31,7 @@ func newTestNonceManager(t *testing.T, chainID string, initNonce *felt.Felt) (tx c.On("AccountNonce", mock.Anything, mock.Anything).Return(initNonce, nil).Once() require.NoError(t, nm.Start(tests.Context(t))) - require.NoError(t, nm.Register(tests.Context(t), keyHash, keyHash, chainID, c)) + require.NoError(t, nm.Register(tests.Context(t), keyHash, keyHash, c)) return nm, keyHash, func() { require.NoError(t, nm.Close()) } } @@ -39,24 +39,18 @@ func newTestNonceManager(t *testing.T, chainID string, initNonce *felt.Felt) (tx func TestNonceManager_NextSequence(t *testing.T) { t.Parallel() - chainId := "test_nextSequence" initNonce := new(felt.Felt).SetUint64(10) - nm, k, stop := newTestNonceManager(t, chainId, initNonce) + nm, k, stop := newTestNonceManager(t, initNonce) defer stop() // get with proper inputs - nonce, err := nm.NextSequence(k, chainId) + nonce, err := nm.NextSequence(k) require.NoError(t, err) assert.Equal(t, initNonce, nonce) - // should fail with invalid chain id - _, err = nm.NextSequence(k, "invalid_chainId") - require.Error(t, err) - assert.Contains(t, err.Error(), fmt.Sprintf("nonce does not exist for key: %s and chain: %s", k.String(), "invalid_chainId")) - // should fail with invalid address randAddr1 := starknetutils.BigIntToFelt(big.NewInt(1)) - _, err = nm.NextSequence(randAddr1, chainId) + _, err = nm.NextSequence(randAddr1) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("nonce tracking does not exist for key: %s", randAddr1.String())) } @@ -64,9 +58,8 @@ func TestNonceManager_NextSequence(t *testing.T) { func TestNonceManager_IncrementNextSequence(t *testing.T) { t.Parallel() - chainId := "test_nextSequence" initNonce := new(felt.Felt).SetUint64(10) - nm, k, stop := newTestNonceManager(t, chainId, initNonce) + nm, k, stop := newTestNonceManager(t, initNonce) defer stop() one := new(felt.Felt).SetUint64(1) @@ -74,30 +67,30 @@ func TestNonceManager_IncrementNextSequence(t *testing.T) { initPlusOne := new(felt.Felt).Add(initNonce, one) // should fail if nonce is lower then expected - err := nm.IncrementNextSequence(k, chainId, initMinusOne) + err := nm.IncrementNextSequence(k, initMinusOne) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("mismatched nonce for %s: %s (expected) != %s (got)", k, initNonce, initMinusOne)) // increment with proper inputs - err = nm.IncrementNextSequence(k, chainId, initNonce) + err = nm.IncrementNextSequence(k, initNonce) require.NoError(t, err) - next, err := nm.NextSequence(k, chainId) + next, err := nm.NextSequence(k) require.NoError(t, err) assert.Equal(t, initPlusOne, next) // should fail with invalid chain id - err = nm.IncrementNextSequence(k, "invalid_chainId", initPlusOne) + err = nm.IncrementNextSequence(k, initPlusOne) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("nonce does not exist for key: %s and chain: %s", k.String(), "invalid_chainId")) // should fail with invalid address randAddr1 := starknetutils.BigIntToFelt(big.NewInt(1)) - err = nm.IncrementNextSequence(randAddr1, chainId, initPlusOne) + err = nm.IncrementNextSequence(randAddr1, initPlusOne) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("nonce tracking does not exist for key: %s", randAddr1.String())) // verify it didnt get changed by any erroring calls - next, err = nm.NextSequence(k, chainId) + next, err = nm.NextSequence(k) require.NoError(t, err) assert.Equal(t, initPlusOne, next) } diff --git a/relayer/pkg/chainlink/txm/txm.go b/relayer/pkg/chainlink/txm/txm.go index 4894d4009..468a6a64d 100644 --- a/relayer/pkg/chainlink/txm/txm.go +++ b/relayer/pkg/chainlink/txm/txm.go @@ -165,20 +165,15 @@ func (txm *starktxm) handleNonceErr(ctx context.Context, accountAddress *felt.Fe return err } - chainId, err := client.Provider.ChainID(ctx) - if err != nil { - return err - } - // get current nonce before syncing (for logging purposes) - oldVal, err := txm.nonce.NextSequence(publicKey, chainId) + oldVal, err := txm.nonce.NextSequence(publicKey) if err != nil { return err } - txm.nonce.Sync(ctx, accountAddress, publicKey, chainId, client) + txm.nonce.Sync(ctx, accountAddress, publicKey, client) - getVal, err := txm.nonce.NextSequence(publicKey, chainId) + getVal, err := txm.nonce.NextSequence(publicKey) if err != nil { return err } @@ -226,12 +221,7 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun return txhash, fmt.Errorf("failed to create new account: %+w", err) } - chainID, err := client.Provider.ChainID(ctx) - if err != nil { - return txhash, fmt.Errorf("failed to get chainID: %+w", err) - } - - nonce, err := txm.nonce.NextSequence(publicKey, chainID) + nonce, err := txm.nonce.NextSequence(publicKey) if err != nil { return txhash, fmt.Errorf("failed to get nonce: %+w", err) } @@ -365,7 +355,7 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun // update nonce if transaction is successful txhash = res.TransactionHash.String() err = errors.Join( - txm.nonce.IncrementNextSequence(publicKey, chainID, nonce), + txm.nonce.IncrementNextSequence(publicKey, nonce), txm.txStore.Save(accountAddress, nonce, txhash, &call, publicKey), ) return txhash, err @@ -501,13 +491,8 @@ func (txm *starktxm) Enqueue(accountAddress, publicKey *felt.Felt, tx starknetrp return fmt.Errorf("broadcast: failed to fetch client: %+w", err) } - chainID, err := client.Provider.ChainID(context.TODO()) - if err != nil { - return fmt.Errorf("failed to get chainID: %+w", err) - } - // register account for nonce manager - if err := txm.nonce.Register(context.TODO(), accountAddress, publicKey, chainID, client); err != nil { + if err := txm.nonce.Register(context.TODO(), accountAddress, publicKey, client); err != nil { return fmt.Errorf("failed to register nonce: %+w", err) }