Skip to content

Commit

Permalink
update forwarder ORM to be schema-specific
Browse files Browse the repository at this point in the history
  • Loading branch information
krehermann committed Jun 18, 2024
1 parent 472f44c commit 051dba7
Show file tree
Hide file tree
Showing 21 changed files with 110 additions and 122 deletions.
12 changes: 7 additions & 5 deletions core/chains/evm/forwarders/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"time"

"github.com/ethereum/go-ethereum/common"

"github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big"
)

// Forwarder is the struct for Forwarder Addresses
type Forwarder struct {
ID int64
Address common.Address
// EVMChainID big.Big
CreatedAt time.Time
UpdatedAt time.Time
ID int64
Address common.Address
EVMChainID big.Big
CreatedAt time.Time
UpdatedAt time.Time
}
2 changes: 1 addition & 1 deletion core/chains/evm/forwarders/forwarder_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func NewFwdMgr(ds sqlutil.DataSource, client evmclient.Client, logpoller evmlogp
logger: lggr,
cfg: cfg,
evmClient: client,
ORM: NewORM(ds),
ORM: NewScopedORM(ds, (*big.Big)(client.ConfiguredChainID())),
logpoller: logpoller,
sendersCache: make(map[common.Address][]common.Address),
}
Expand Down
34 changes: 20 additions & 14 deletions core/chains/evm/forwarders/forwarder_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/utils"

"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/testhelpers"
"github.com/smartcontractkit/chainlink/v2/core/store/migrate/plugins/relayer/evm"
evmtestdb "github.com/smartcontractkit/chainlink/v2/core/store/migrate/plugins/relayer/evm/testutils"

"github.com/smartcontractkit/chainlink/v2/core/chains/evm/client"
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/forwarders"
Expand All @@ -41,7 +43,8 @@ var SimpleOracleCallABI = evmtypes.MustGetABI(operator_wrapper.OperatorABI).Meth

func TestFwdMgr_MaybeForwardTransaction(t *testing.T) {
lggr := logger.Test(t)
db := pgtest.NewSqlxDB(t)
testChainId := testutils.FixtureChainID
emvPlugindb := evmtestdb.NewDB(t, evm.Cfg{Schema: "evm_" + testChainId.String(), ChainID: ubig.New(testChainId)})
cfg := configtest.NewTestGeneralConfig(t)
evmcfg := evmtest.NewChainScopedConfig(t, cfg)
owner := testutils.MustNewSimTransactor(t)
Expand All @@ -66,7 +69,7 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) {
require.NoError(t, err)
t.Log(authorized)

evmClient := client.NewSimulatedBackendClient(t, ec, testutils.FixtureChainID)
evmClient := client.NewSimulatedBackendClient(t, ec, testChainId)

lpOpts := logpoller.Opts{
PollPeriod: 100 * time.Millisecond,
Expand All @@ -75,13 +78,14 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) {
RpcBatchSize: 2,
KeepFinalizedBlocksDepth: 1000,
}
lp := logpoller.NewLogPoller(logpoller.NewORM(testutils.FixtureChainID, db, lggr), evmClient, lggr, lpOpts)
fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM())
fwdMgr.ORM = forwarders.NewORM(db)
lp := logpoller.NewLogPoller(logpoller.NewORM(testChainId, emvPlugindb, lggr), evmClient, lggr, lpOpts)
fwdMgr := forwarders.NewFwdMgr(emvPlugindb, evmClient, lp, lggr, evmcfg.EVM())
cid := ubig.Big(*testChainId)
fwdMgr.ORM = forwarders.NewScopedORM(emvPlugindb, &cid)

fwd, err := fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, ubig.Big(*testutils.FixtureChainID))
fwd, err := fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, cid)
require.NoError(t, err)
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, ubig.Big(*testutils.FixtureChainID))
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, cid)
require.NoError(t, err)
require.Equal(t, len(lst), 1)
require.Equal(t, lst[0].Address, forwarderAddr)
Expand All @@ -95,7 +99,7 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) {

cleanupCalled := false
cleanup := func(tx sqlutil.DataSource, evmChainId int64, addr common.Address) error {
require.Equal(t, testutils.FixtureChainID.Int64(), evmChainId)
require.Equal(t, testChainId.Int64(), evmChainId)
require.Equal(t, forwarderAddr, addr)
require.NotNil(t, tx)
cleanupCalled = true
Expand Down Expand Up @@ -138,11 +142,12 @@ func TestFwdMgr_AccountUnauthorizedToForward_SkipsForwarding(t *testing.T) {
}
lp := logpoller.NewLogPoller(logpoller.NewORM(testutils.FixtureChainID, db, lggr), evmClient, lggr, lpOpts)
fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM())
fwdMgr.ORM = forwarders.NewORM(db)
testChainID := ubig.Big(*testutils.FixtureChainID)
fwdMgr.ORM = forwarders.NewScopedORM(db, &testChainID)

_, err = fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, ubig.Big(*testutils.FixtureChainID))
_, err = fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, testChainID)
require.NoError(t, err)
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, ubig.Big(*testutils.FixtureChainID))
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, testChainID)
require.NoError(t, err)
require.Equal(t, len(lst), 1)
require.Equal(t, lst[0].Address, forwarderAddr)
Expand Down Expand Up @@ -203,11 +208,12 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) {
}
lp := logpoller.NewLogPoller(logpoller.NewORM(testutils.FixtureChainID, db, lggr), evmClient, lggr, lpOpts)
fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM())
fwdMgr.ORM = forwarders.NewORM(db)
testChainID := ubig.Big(*testutils.FixtureChainID)
fwdMgr.ORM = forwarders.NewScopedORM(db, &testChainID)

_, err = fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, ubig.Big(*testutils.FixtureChainID))
_, err = fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, testChainID)
require.NoError(t, err)
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, ubig.Big(*testutils.FixtureChainID))
lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, testChainID)
require.NoError(t, err)
require.Equal(t, len(lst), 1)
require.Equal(t, lst[0].Address, forwarderAddr)
Expand Down
36 changes: 20 additions & 16 deletions core/chains/evm/forwarders/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,20 @@ import (
"fmt"

"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big"

"github.com/ethereum/go-ethereum/common"
"github.com/jmoiron/sqlx"
pkgerrors "github.com/pkg/errors"

"github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big"
)

//go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore

type ORM interface {
CreateForwarder(ctx context.Context, addr common.Address, evmChainId big.Big) (fwd Forwarder, err error)
FindForwarders(ctx context.Context, offset, limit int) ([]Forwarder, int, error)
//FindForwarders(ctx context.Context, offset, limit int) ([]Forwarder, int, error)
FindForwardersByChain(ctx context.Context, evmChainId big.Big) ([]Forwarder, error)
DeleteForwarder(ctx context.Context, id int64, cleanup func(tx sqlutil.DataSource, evmChainId int64, addr common.Address) error) error
FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error)
// FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error)
}

type DSORM struct {
Expand All @@ -31,10 +29,6 @@ type DSORM struct {

var _ ORM = &DSORM{}

func NewORM(ds sqlutil.DataSource) *DSORM {
return &DSORM{ds: ds}
}

func NewScopedORM(ds sqlutil.DataSource, evmChainId *big.Big) *DSORM {
return &DSORM{ds: ds, cid: evmChainId}
}
Expand All @@ -44,7 +38,7 @@ func (o *DSORM) Transact(ctx context.Context, fn func(*DSORM) error) (err error)
}

// new returns a NewORM like o, but backed by q.
func (o *DSORM) new(q sqlutil.DataSource) *DSORM { return NewORM(q) }
func (o *DSORM) new(q sqlutil.DataSource) *DSORM { return NewScopedORM(q, o.cid) }

func (o *DSORM) schemaName() string {
if o.cid != nil {
Expand All @@ -63,7 +57,11 @@ func (o *DSORM) CreateForwarder(ctx context.Context, addr common.Address, evmCha
// sql := `INSERT INTO evm.forwarders (address, evm_chain_id, created_at, updated_at) VALUES ($1, $2, now(), now()) RETURNING *`
sql := fmt.Sprintf("INSERT INTO %s.forwarders (address, created_at, updated_at) VALUES ($1, now(), now()) RETURNING *", o.schemaName())
err = o.ds.GetContext(ctx, &fwd, sql, addr)
return fwd, err
if err != nil {
return fwd, err
}
fwd.EVMChainID = *o.cid
return fwd, nil
}

// DeleteForwarder removes a forwarder address.
Expand Down Expand Up @@ -103,12 +101,15 @@ func (o *DSORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx s

// FindForwarders returns all forwarder addresses from offset up until limit.
func (o *DSORM) FindForwarders(ctx context.Context, offset, limit int) (fwds []Forwarder, count int, err error) {
sql := `SELECT count(*) FROM evm.forwarders`
// sql := `SELECT count(*) FROM evm.forwarders`
sql := fmt.Sprintf("SELECT count(*) FROM %s.forwarders", o.schemaName())

if err = o.ds.GetContext(ctx, &count, sql); err != nil {
return
}

sql = `SELECT * FROM evm.forwarders ORDER BY created_at DESC, id DESC LIMIT $1 OFFSET $2`
// sql = `SELECT * FROM evm.forwarders ORDER BY created_at DESC, id DESC LIMIT $1 OFFSET $2`
sql = fmt.Sprintf("SELECT * FROM %s.forwarders ORDER BY created_at DESC, id DESC LIMIT $1 OFFSET $2", o.schemaName())
if err = o.ds.SelectContext(ctx, &fwds, sql, limit, offset); err != nil {
return
}
Expand All @@ -117,11 +118,13 @@ func (o *DSORM) FindForwarders(ctx context.Context, offset, limit int) (fwds []F

// FindForwardersByChain returns all forwarder addresses for a chain.
func (o *DSORM) FindForwardersByChain(ctx context.Context, evmChainId big.Big) (fwds []Forwarder, err error) {
sql := `SELECT * FROM evm.forwarders where evm_chain_id = $1 ORDER BY created_at DESC, id DESC`
err = o.ds.SelectContext(ctx, &fwds, sql, evmChainId)
// sql := `SELECT * FROM evm.forwarders where evm_chain_id = $1 ORDER BY created_at DESC, id DESC`
sql := fmt.Sprintf("SELECT * FROM %s.forwarders ORDER BY created_at DESC, id DESC", o.schemaName())
err = o.ds.SelectContext(ctx, &fwds, sql)
return
}

/*
func (o *DSORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) {
var fwdrs []Forwarder
Expand All @@ -131,7 +134,7 @@ func (o *DSORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.
}
query, args, err := sqlx.Named(`
SELECT * FROM evm.forwarders
SELECT * FROM evm.forwarders
WHERE evm_chain_id = :chainid
AND address IN (:addresses)
ORDER BY created_at DESC, id DESC`,
Expand All @@ -156,3 +159,4 @@ func (o *DSORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.
return fwdrs, nil
}
*/
2 changes: 1 addition & 1 deletion core/chains/evm/forwarders/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func Test_DeleteForwarder(t *testing.T) {
chainID := testutils.FixtureChainID
orm := forwarders.NewScopedORM(evmtestdb.NewDB(t, evm.Cfg{
Schema: "evm_" + chainID.String(),
ChainID: int(chainID.Int64()),
ChainID: big.New(chainID),
}), big.New(chainID))

addr := testutils.NewAddress()
Expand Down
8 changes: 2 additions & 6 deletions core/chains/evm/txmgr/txmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets"
evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client"
evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config"
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/forwarders"
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas"
"github.com/smartcontractkit/chainlink/v2/core/chains/evm/keystore"
ksmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/keystore/mocks"
Expand Down Expand Up @@ -304,23 +303,20 @@ func TestTxm_CreateTransaction(t *testing.T) {
})

t.Run("forwards tx when a proper forwarder is set up", func(t *testing.T) {

pgtest.MustExec(t, db, `DELETE FROM evm.txes`)
pgtest.MustExec(t, db, `DELETE FROM evm.forwarders`)
evmConfig.MaxQueued = uint64(1)

// Create mock forwarder, mock authorizedsenders call.
form := forwarders.NewORM(db)
fwdrAddr := testutils.NewAddress()
fwdr, err := form.CreateForwarder(tests.Context(t), fwdrAddr, ubig.Big(cltest.FixtureChainID))
require.NoError(t, err)
require.Equal(t, fwdr.Address, fwdrAddr)

etx, err := txm.CreateTransaction(tests.Context(t), txmgr.TxRequest{
FromAddress: fromAddress,
ToAddress: toAddress,
EncodedPayload: payload,
FeeLimit: gasLimit,
ForwarderAddress: fwdr.Address,
ForwarderAddress: fwdrAddr,
Strategy: txmgrcommon.NewSendEveryStrategy(),
})
assert.NoError(t, err)
Expand Down
5 changes: 3 additions & 2 deletions core/cmd/ocr2vrf_configure_commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,9 @@ func (s *Shell) authorizeForwarder(c *cli.Context, db *sqlx.DB, chainID int64, e
}

// Create forwarder for management in forwarder_manager.go.
orm := forwarders.NewORM(db)
_, err = orm.CreateForwarder(ctx, common.HexToAddress(forwarderAddress), *ubig.NewI(chainID))
chainId := ubig.NewI(chainID)
orm := forwarders.NewScopedORM(db, chainId)
_, err = orm.CreateForwarder(ctx, common.HexToAddress(forwarderAddress), *chainId)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion core/cmd/shell_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ func migrateDB(ctx context.Context, config dbConfig, lggr logger.Logger, opts *r
if err != nil {
return fmt.Errorf("failed to parse chain as int id: %v", err)
}
err = evmdb.Migrate(ctx, db.DB, evmdb.Cfg{Schema: opts.Relayer + "_" + opts.ChainID, ChainID: cid})
err = evmdb.Migrate(ctx, db.DB, evmdb.Cfg{Schema: opts.Relayer + "_" + opts.ChainID, ChainID: ubig.NewI(int64(cid))})
if err != nil {
return fmt.Errorf("migrateDB failed: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion core/internal/features/features_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,9 @@ func setupForwarderEnabledNode(t *testing.T, owner *bind.TransactOpts, portV2 in
b.Commit()

// add forwarder address to be tracked in db
forwarderORM := forwarders.NewORM(app.GetDB())
chainID := ubig.Big(*b.Blockchain().Config().ChainID)

forwarderORM := forwarders.NewScopedORM(app.GetDB(), &chainID)
_, err = forwarderORM.CreateForwarder(testutils.Context(t), forwarder, chainID)
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion core/internal/features/ocr2/features_ocr2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ func setupNodeOCR2(
b.Commit()

// add forwarder address to be tracked in db
forwarderORM := forwarders.NewORM(app.GetDB())
chainID := ubig.Big(*b.Blockchain().Config().ChainID)
forwarderORM := forwarders.NewScopedORM(app.GetDB(), &chainID)
_, err2 = forwarderORM.CreateForwarder(testutils.Context(t), faddr, chainID)
require.NoError(t, err2)

Expand Down
2 changes: 1 addition & 1 deletion core/services/keeper/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ func TestKeeperForwarderEthIntegration(t *testing.T) {
app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, backend.Backend(), nodeKey)
require.NoError(t, app.Start(ctx))

forwarderORM := forwarders.NewORM(db)
chainID := ubig.Big(*backend.ConfiguredChainID())
forwarderORM := forwarders.NewScopedORM(db, &chainID)
_, err = forwarderORM.CreateForwarder(ctx, fwdrAddress, chainID)
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion core/services/ocr2/plugins/ocr2keeper/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ func setupForwarderForNode(
backend.Commit()

// add forwarder address to be tracked in db
forwarderORM := forwarders.NewORM(app.GetDB())
chainID := ubig.Big(*backend.Blockchain().Config().ChainID)
forwarderORM := forwarders.NewScopedORM(app.GetDB(), &chainID)
_, err = forwarderORM.CreateForwarder(ctx, faddr, chainID)
require.NoError(t, err)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ func setupNodeOCR2(
b.Commit()

// Add the forwarder to the node's forwarder manager.
forwarderORM := forwarders.NewORM(app.GetDB())
chainID := ubig.Big(*b.Blockchain().Config().ChainID)
forwarderORM := forwarders.NewScopedORM(app.GetDB(), &chainID)
_, err = forwarderORM.CreateForwarder(testutils.Context(t), faddr, chainID)
require.NoError(t, err)
effectiveTransmitter = faddr
Expand Down
Loading

0 comments on commit 051dba7

Please sign in to comment.