Skip to content

Commit

Permalink
remove chainid
Browse files Browse the repository at this point in the history
  • Loading branch information
augustbleeds committed Mar 11, 2024
1 parent a546236 commit 9bd3814
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 67 deletions.
47 changes: 20 additions & 27 deletions relayer/pkg/chainlink/txm/nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ type NonceManagerClient interface {

type NonceManager interface {
services.Service
Register(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error
NextSequence(address *felt.Felt, chainID string) (*felt.Felt, error)
IncrementNextSequence(address *felt.Felt, chainID string, currentNonce *felt.Felt) error
Register(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error
NextSequence(address *felt.Felt) (*felt.Felt, error)
IncrementNextSequence(address *felt.Felt, currentNonce *felt.Felt) error
// Resets local account nonce to on-chain account nonce
Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error
Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error
}

var _ NonceManager = (*nonceManager)(nil)
Expand All @@ -33,14 +33,14 @@ type nonceManager struct {
starter utils.StartStopOnce
lggr logger.Logger

n map[string]map[string]*felt.Felt // map address + chain ID to nonce
n map[string]*felt.Felt // map public key to nonce
lock sync.RWMutex
}

func NewNonceManager(lggr logger.Logger) *nonceManager {
return &nonceManager{
lggr: logger.Named(lggr, "NonceManager"),
n: map[string]map[string]*felt.Felt{},
n: map[string]*felt.Felt{},
}
}

Expand All @@ -64,8 +64,8 @@ func (nm *nonceManager) HealthReport() map[string]error {
return map[string]error{nm.Name(): nm.starter.Healthy()}
}

func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error {
if err := nm.validate(address, chainId); err != nil {
func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error {
if err := nm.validate(address); err != nil {
return err
}
nm.lock.Lock()
Expand All @@ -76,65 +76,58 @@ func (nm *nonceManager) Sync(ctx context.Context, address *felt.Felt, publicKey
return err
}

nm.n[publicKey.String()][chainId] = n
nm.n[publicKey.String()] = n

return nil
}

// Register is used because we cannot pre-fetch nonces. the pubkey is known before hand, but the account address is not known until a job is started and sends a tx
func (nm *nonceManager) Register(ctx context.Context, addr *felt.Felt, publicKey *felt.Felt, chainId string, client NonceManagerClient) error {
func (nm *nonceManager) Register(ctx context.Context, addr *felt.Felt, publicKey *felt.Felt, client NonceManagerClient) error {
nm.lock.Lock()
defer nm.lock.Unlock()
addressNonces, exists := nm.n[publicKey.String()]
if !exists {
nm.n[publicKey.String()] = map[string]*felt.Felt{}
}
_, exists = addressNonces[chainId]
_, exists := nm.n[publicKey.String()]
if !exists {
n, err := client.AccountNonce(ctx, addr)
if err != nil {
return err
}
nm.n[publicKey.String()][chainId] = n
nm.n[publicKey.String()] = n
}

return nil
}

func (nm *nonceManager) NextSequence(publicKey *felt.Felt, chainId string) (*felt.Felt, error) {
if err := nm.validate(publicKey, chainId); err != nil {
func (nm *nonceManager) NextSequence(publicKey *felt.Felt) (*felt.Felt, error) {
if err := nm.validate(publicKey); err != nil {
return nil, err
}

nm.lock.RLock()
defer nm.lock.RUnlock()
return nm.n[publicKey.String()][chainId], nil
return nm.n[publicKey.String()], nil
}

func (nm *nonceManager) IncrementNextSequence(publicKey *felt.Felt, chainId string, currentNonce *felt.Felt) error {
if err := nm.validate(publicKey, chainId); err != nil {
func (nm *nonceManager) IncrementNextSequence(publicKey *felt.Felt, currentNonce *felt.Felt) error {
if err := nm.validate(publicKey); err != nil {
return err
}

nm.lock.Lock()
defer nm.lock.Unlock()
n := nm.n[publicKey.String()][chainId]
n := nm.n[publicKey.String()]
if n.Cmp(currentNonce) != 0 {
return fmt.Errorf("mismatched nonce for %s: %s (expected) != %s (got)", publicKey, n, currentNonce)
}
one := new(felt.Felt).SetUint64(1)
nm.n[publicKey.String()][chainId] = new(felt.Felt).Add(n, one)
nm.n[publicKey.String()] = new(felt.Felt).Add(n, one)
return nil
}

func (nm *nonceManager) validate(publicKey *felt.Felt, chainId string) error {
func (nm *nonceManager) validate(publicKey *felt.Felt) error {
nm.lock.RLock()
defer nm.lock.RUnlock()
if _, exists := nm.n[publicKey.String()]; !exists {
return fmt.Errorf("nonce tracking does not exist for key: %s", publicKey.String())
}
if _, exists := nm.n[publicKey.String()][chainId]; !exists {
return fmt.Errorf("nonce does not exist for key: %s and chain: %s", publicKey.String(), chainId)
}
return nil
}
31 changes: 12 additions & 19 deletions relayer/pkg/chainlink/txm/nonce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"github.com/smartcontractkit/chainlink-starknet/relayer/pkg/chainlink/txm/mocks"
)

func newTestNonceManager(t *testing.T, chainID string, initNonce *felt.Felt) (txm.NonceManager, *felt.Felt, func()) {
func newTestNonceManager(t *testing.T, initNonce *felt.Felt) (txm.NonceManager, *felt.Felt, func()) {
// setup
c := mocks.NewNonceManagerClient(t)
lggr := logger.Test(t)
Expand All @@ -31,73 +31,66 @@ func newTestNonceManager(t *testing.T, chainID string, initNonce *felt.Felt) (tx
c.On("AccountNonce", mock.Anything, mock.Anything).Return(initNonce, nil).Once()

require.NoError(t, nm.Start(tests.Context(t)))
require.NoError(t, nm.Register(tests.Context(t), keyHash, keyHash, chainID, c))
require.NoError(t, nm.Register(tests.Context(t), keyHash, keyHash, c))

return nm, keyHash, func() { require.NoError(t, nm.Close()) }
}

func TestNonceManager_NextSequence(t *testing.T) {
t.Parallel()

chainId := "test_nextSequence"
initNonce := new(felt.Felt).SetUint64(10)
nm, k, stop := newTestNonceManager(t, chainId, initNonce)
nm, k, stop := newTestNonceManager(t, initNonce)
defer stop()

// get with proper inputs
nonce, err := nm.NextSequence(k, chainId)
nonce, err := nm.NextSequence(k)
require.NoError(t, err)
assert.Equal(t, initNonce, nonce)

// should fail with invalid chain id
_, err = nm.NextSequence(k, "invalid_chainId")
require.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("nonce does not exist for key: %s and chain: %s", k.String(), "invalid_chainId"))

// should fail with invalid address
randAddr1 := starknetutils.BigIntToFelt(big.NewInt(1))
_, err = nm.NextSequence(randAddr1, chainId)
_, err = nm.NextSequence(randAddr1)
require.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("nonce tracking does not exist for key: %s", randAddr1.String()))
}

func TestNonceManager_IncrementNextSequence(t *testing.T) {
t.Parallel()

chainId := "test_nextSequence"
initNonce := new(felt.Felt).SetUint64(10)
nm, k, stop := newTestNonceManager(t, chainId, initNonce)
nm, k, stop := newTestNonceManager(t, initNonce)
defer stop()

one := new(felt.Felt).SetUint64(1)
initMinusOne := new(felt.Felt).Sub(initNonce, one)
initPlusOne := new(felt.Felt).Add(initNonce, one)

// should fail if nonce is lower then expected
err := nm.IncrementNextSequence(k, chainId, initMinusOne)
err := nm.IncrementNextSequence(k, initMinusOne)
require.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("mismatched nonce for %s: %s (expected) != %s (got)", k, initNonce, initMinusOne))

// increment with proper inputs
err = nm.IncrementNextSequence(k, chainId, initNonce)
err = nm.IncrementNextSequence(k, initNonce)
require.NoError(t, err)
next, err := nm.NextSequence(k, chainId)
next, err := nm.NextSequence(k)
require.NoError(t, err)
assert.Equal(t, initPlusOne, next)

// should fail with invalid chain id
err = nm.IncrementNextSequence(k, "invalid_chainId", initPlusOne)
err = nm.IncrementNextSequence(k, initPlusOne)
require.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("nonce does not exist for key: %s and chain: %s", k.String(), "invalid_chainId"))

// should fail with invalid address
randAddr1 := starknetutils.BigIntToFelt(big.NewInt(1))
err = nm.IncrementNextSequence(randAddr1, chainId, initPlusOne)
err = nm.IncrementNextSequence(randAddr1, initPlusOne)
require.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("nonce tracking does not exist for key: %s", randAddr1.String()))

// verify it didnt get changed by any erroring calls
next, err = nm.NextSequence(k, chainId)
next, err = nm.NextSequence(k)
require.NoError(t, err)
assert.Equal(t, initPlusOne, next)
}
27 changes: 6 additions & 21 deletions relayer/pkg/chainlink/txm/txm.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,15 @@ func (txm *starktxm) handleNonceErr(ctx context.Context, accountAddress *felt.Fe
return err
}

chainId, err := client.Provider.ChainID(ctx)
if err != nil {
return err
}

// get current nonce before syncing (for logging purposes)
oldVal, err := txm.nonce.NextSequence(publicKey, chainId)
oldVal, err := txm.nonce.NextSequence(publicKey)
if err != nil {
return err
}

txm.nonce.Sync(ctx, accountAddress, publicKey, chainId, client)
txm.nonce.Sync(ctx, accountAddress, publicKey, client)

getVal, err := txm.nonce.NextSequence(publicKey, chainId)
getVal, err := txm.nonce.NextSequence(publicKey)
if err != nil {
return err
}
Expand Down Expand Up @@ -226,12 +221,7 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun
return txhash, fmt.Errorf("failed to create new account: %+w", err)
}

chainID, err := client.Provider.ChainID(ctx)
if err != nil {
return txhash, fmt.Errorf("failed to get chainID: %+w", err)
}

nonce, err := txm.nonce.NextSequence(publicKey, chainID)
nonce, err := txm.nonce.NextSequence(publicKey)
if err != nil {
return txhash, fmt.Errorf("failed to get nonce: %+w", err)
}
Expand Down Expand Up @@ -365,7 +355,7 @@ func (txm *starktxm) broadcast(ctx context.Context, publicKey *felt.Felt, accoun
// update nonce if transaction is successful
txhash = res.TransactionHash.String()
err = errors.Join(
txm.nonce.IncrementNextSequence(publicKey, chainID, nonce),
txm.nonce.IncrementNextSequence(publicKey, nonce),
txm.txStore.Save(accountAddress, nonce, txhash, &call, publicKey),
)
return txhash, err
Expand Down Expand Up @@ -501,13 +491,8 @@ func (txm *starktxm) Enqueue(accountAddress, publicKey *felt.Felt, tx starknetrp
return fmt.Errorf("broadcast: failed to fetch client: %+w", err)
}

chainID, err := client.Provider.ChainID(context.TODO())
if err != nil {
return fmt.Errorf("failed to get chainID: %+w", err)
}

// register account for nonce manager
if err := txm.nonce.Register(context.TODO(), accountAddress, publicKey, chainID, client); err != nil {
if err := txm.nonce.Register(context.TODO(), accountAddress, publicKey, client); err != nil {
return fmt.Errorf("failed to register nonce: %+w", err)
}

Expand Down

0 comments on commit 9bd3814

Please sign in to comment.