diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist.go b/core/services/gateway/handlers/functions/allowlist/allowlist.go index 20dc92ced70..020de2359c2 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "math/big" + "regexp" "sync" "sync/atomic" "time" @@ -12,6 +13,7 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + "golang.org/x/mod/semver" "github.com/smartcontractkit/chainlink-common/pkg/services" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -23,9 +25,10 @@ import ( ) const ( - defaultStoredAllowlistBatchSize = 1000 - defaultOnchainAllowlistBatchSize = 100 - defaultFetchingDelayInRangeSec = 1 + defaultStoredAllowlistBatchSize = 1000 + defaultOnchainAllowlistBatchSize = 100 + defaultFetchingDelayInRangeSec = 1 + tosContractMinBatchProcessingVersion = "v1.1.0" ) type OnchainAllowlistConfig struct { @@ -38,8 +41,6 @@ type OnchainAllowlistConfig struct { UpdateTimeoutSec uint `json:"updateTimeoutSec"` StoredAllowlistBatchSize uint `json:"storedAllowlistBatchSize"` OnchainAllowlistBatchSize uint `json:"onchainAllowlistBatchSize"` - // StoreAllowedSendersEnabled is a feature flag that enables storing in db a copy of the allowlist. - StoreAllowedSendersEnabled bool `json:"storeAllowedSendersEnabled"` // FetchingDelayInRangeSec prevents RPC client being rate limited when fetching the allowlist in ranges. FetchingDelayInRangeSec uint `json:"fetchingDelayInRangeSec"` } @@ -210,7 +211,31 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b } var allowedSenderList []common.Address - if !a.config.StoreAllowedSendersEnabled { + typeAndVersion, err := tosContract.TypeAndVersion(&bind.CallOpts{ + Pending: false, + BlockNumber: blockNum, + Context: ctx, + }) + if err != nil { + return errors.Wrap(err, "failed to fetch the tos contract type and version") + } + + currentVersion, err := ExtractContractVersion(typeAndVersion) + if err != nil { + return fmt.Errorf("failed to extract version: %w", err) + } + + if semver.Compare(tosContractMinBatchProcessingVersion, currentVersion) <= 0 { + err = a.syncBlockedSenders(ctx, tosContract, blockNum) + if err != nil { + return errors.Wrap(err, "failed to sync the stored allowed and blocked senders") + } + + allowedSenderList, err = a.getAllowedSendersBatched(ctx, tosContract, blockNum) + if err != nil { + return errors.Wrap(err, "failed to get allowed senders in rage") + } + } else { allowedSenderList, err = tosContract.GetAllAllowedSenders(&bind.CallOpts{ Pending: false, BlockNumber: blockNum, @@ -219,15 +244,15 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b if err != nil { return errors.Wrap(err, "error calling GetAllAllowedSenders") } - } else { - err = a.syncBlockedSenders(ctx, tosContract, blockNum) + + err = a.orm.PurgeAllowedSenders() if err != nil { - return errors.Wrap(err, "failed to sync the stored allowed and blocked senders") + a.lggr.Errorf("failed to purge allowedSenderList: %w", err) } - allowedSenderList, err = a.getAllowedSendersBatched(ctx, tosContract, blockNum) + err = a.orm.CreateAllowedSenders(allowedSenderList) if err != nil { - return errors.Wrap(err, "failed to get allowed senders in rage") + a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) } } @@ -344,3 +369,14 @@ func (a *onchainAllowlist) loadStoredAllowedSenderList() { a.update(allowedList) } + +func ExtractContractVersion(str string) (string, error) { + pattern := `v(\d+).(\d+).(\d+)` + re := regexp.MustCompile(pattern) + + match := re.FindStringSubmatch(str) + if len(match) != 4 { + return "", fmt.Errorf("version not found in string: %s", str) + } + return fmt.Sprintf("v%s.%s.%s", match[1], match[2], match[3]), nil +} diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go index e8cbca80b94..735c0bff7dc 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go @@ -3,11 +3,14 @@ package allowlist_test import ( "context" "encoding/hex" + "fmt" "math/big" "testing" "time" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/onsi/gomega" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -18,55 +21,105 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/allowlist" amocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/allowlist/mocks" + "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" + "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" ) const ( - addr1 = "9ed925d8206a4f88a2f643b28b3035b315753cd6" - addr2 = "ea6721ac65bced841b8ec3fc5fedea6141a0ade4" - addr3 = "84689acc87ff22841b8ec378300da5e141a99911" + addr1 = "9ed925d8206a4f88a2f643b28b3035b315753cd6" + addr2 = "ea6721ac65bced841b8ec3fc5fedea6141a0ade4" + addr3 = "84689acc87ff22841b8ec378300da5e141a99911" + ToSContractV100 = "Functions Terms of Service Allow List v1.0.0" + ToSContractV110 = "Functions Terms of Service Allow List v1.1.0" ) -func sampleEncodedAllowlist(t *testing.T) []byte { - abiEncodedAddresses := - "0000000000000000000000000000000000000000000000000000000000000020" + - "0000000000000000000000000000000000000000000000000000000000000002" + - "000000000000000000000000" + addr1 + - "000000000000000000000000" + addr2 - rawData, err := hex.DecodeString(abiEncodedAddresses) - require.NoError(t, err) - return rawData -} - -func TestAllowlist_UpdateAndCheck(t *testing.T) { +func TestUpdateAndCheck(t *testing.T) { t.Parallel() - client := mocks.NewClient(t) - client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) - client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(sampleEncodedAllowlist(t), nil) - config := allowlist.OnchainAllowlistConfig{ - ContractVersion: 1, - ContractAddress: common.Address{}, - BlockConfirmations: 1, - } + t.Run("OK-with_ToS_V1.0.0", func(t *testing.T) { + client := mocks.NewClient(t) + client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) - orm := amocks.NewORM(t) - allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) - require.NoError(t, err) + addr := common.HexToAddress("0x0000000000000000000000000000000000000020") + typeAndVersionResponse, err := encodeTypeAndVersionResponse(ToSContractV100) + require.NoError(t, err) - err = allowlist.Start(testutils.Context(t)) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, allowlist.Close()) + client.On("CallContract", mock.Anything, ethereum.CallMsg{ // typeAndVersion + To: &addr, + Data: hexutil.MustDecode("0x181f5a77"), + }, mock.Anything).Return(typeAndVersionResponse, nil) + + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(sampleEncodedAllowlist(t), nil) + + config := allowlist.OnchainAllowlistConfig{ + ContractVersion: 1, + ContractAddress: common.Address{}, + BlockConfirmations: 1, + } + + orm := amocks.NewORM(t) + orm.On("PurgeAllowedSenders").Times(1).Return(nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + + allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) + require.NoError(t, err) + + err = allowlist.Start(testutils.Context(t)) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, allowlist.Close()) + }) + + require.NoError(t, allowlist.UpdateFromContract(testutils.Context(t))) + require.False(t, allowlist.Allow(common.Address{})) + require.True(t, allowlist.Allow(common.HexToAddress(addr1))) + require.True(t, allowlist.Allow(common.HexToAddress(addr2))) + require.False(t, allowlist.Allow(common.HexToAddress(addr3))) }) - require.NoError(t, allowlist.UpdateFromContract(testutils.Context(t))) - require.False(t, allowlist.Allow(common.Address{})) - require.True(t, allowlist.Allow(common.HexToAddress(addr1))) - require.True(t, allowlist.Allow(common.HexToAddress(addr2))) - require.False(t, allowlist.Allow(common.HexToAddress(addr3))) + t.Run("OK-with_ToS_V1.1.0", func(t *testing.T) { + client := mocks.NewClient(t) + client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) + + typeAndVersionResponse, err := encodeTypeAndVersionResponse(ToSContractV110) + require.NoError(t, err) + + addr := common.HexToAddress("0x0000000000000000000000000000000000000020") + client.On("CallContract", mock.Anything, ethereum.CallMsg{ // typeAndVersion + To: &addr, + Data: hexutil.MustDecode("0x181f5a77"), + }, mock.Anything).Return(typeAndVersionResponse, nil) + + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(sampleEncodedAllowlist(t), nil) + + config := allowlist.OnchainAllowlistConfig{ + ContractVersion: 1, + ContractAddress: common.Address{}, + BlockConfirmations: 1, + } + + orm := amocks.NewORM(t) + orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + + allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) + require.NoError(t, err) + + err = allowlist.Start(testutils.Context(t)) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, allowlist.Close()) + }) + + require.NoError(t, allowlist.UpdateFromContract(testutils.Context(t))) + require.False(t, allowlist.Allow(common.Address{})) + require.True(t, allowlist.Allow(common.HexToAddress(addr1))) + require.True(t, allowlist.Allow(common.HexToAddress(addr2))) + require.False(t, allowlist.Allow(common.HexToAddress(addr3))) + }) } -func TestAllowlist_UnsupportedVersion(t *testing.T) { +func TestUnsupportedVersion(t *testing.T) { t.Parallel() client := mocks.NewClient(t) @@ -81,64 +134,132 @@ func TestAllowlist_UnsupportedVersion(t *testing.T) { require.Error(t, err) } -func TestAllowlist_UpdatePeriodically(t *testing.T) { +func TestUpdatePeriodically(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(testutils.Context(t)) - client := mocks.NewClient(t) - client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) - client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - cancel() - }).Return(sampleEncodedAllowlist(t), nil) - config := allowlist.OnchainAllowlistConfig{ - ContractAddress: common.Address{}, - ContractVersion: 1, - BlockConfirmations: 1, - UpdateFrequencySec: 2, - UpdateTimeoutSec: 1, - } + t.Run("OK-with_ToS_V1.0.0", func(t *testing.T) { + ctx, cancel := context.WithCancel(testutils.Context(t)) + client := mocks.NewClient(t) + client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) - orm := amocks.NewORM(t) - orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) + addr := common.HexToAddress("0x0000000000000000000000000000000000000020") + typeAndVersionResponse, err := encodeTypeAndVersionResponse(ToSContractV100) + require.NoError(t, err) - allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) - require.NoError(t, err) + client.On("CallContract", mock.Anything, ethereum.CallMsg{ // typeAndVersion + To: &addr, + Data: hexutil.MustDecode("0x181f5a77"), + }, mock.Anything).Return(typeAndVersionResponse, nil) - err = allowlist.Start(ctx) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, allowlist.Close()) + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + cancel() + }).Return(sampleEncodedAllowlist(t), nil) + config := allowlist.OnchainAllowlistConfig{ + ContractAddress: common.Address{}, + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + } + + orm := amocks.NewORM(t) + orm.On("PurgeAllowedSenders").Times(1).Return(nil) + orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + + allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) + require.NoError(t, err) + + err = allowlist.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, allowlist.Close()) + }) + + gomega.NewGomegaWithT(t).Eventually(func() bool { + return allowlist.Allow(common.HexToAddress(addr1)) && !allowlist.Allow(common.HexToAddress(addr3)) + }, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue()) }) - gomega.NewGomegaWithT(t).Eventually(func() bool { - return allowlist.Allow(common.HexToAddress(addr1)) && !allowlist.Allow(common.HexToAddress(addr3)) - }, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue()) + t.Run("OK-with_ToS_V1.1.0", func(t *testing.T) { + ctx, cancel := context.WithCancel(testutils.Context(t)) + client := mocks.NewClient(t) + client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) + + addr := common.HexToAddress("0x0000000000000000000000000000000000000020") + typeAndVersionResponse, err := encodeTypeAndVersionResponse(ToSContractV110) + require.NoError(t, err) + + client.On("CallContract", mock.Anything, ethereum.CallMsg{ // typeAndVersion + To: &addr, + Data: hexutil.MustDecode("0x181f5a77"), + }, mock.Anything).Return(typeAndVersionResponse, nil) + + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + cancel() + }).Return(sampleEncodedAllowlist(t), nil) + config := allowlist.OnchainAllowlistConfig{ + ContractAddress: common.Address{}, + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + } + + orm := amocks.NewORM(t) + orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + + allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) + require.NoError(t, err) + + err = allowlist.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, allowlist.Close()) + }) + + gomega.NewGomegaWithT(t).Eventually(func() bool { + return allowlist.Allow(common.HexToAddress(addr1)) && !allowlist.Allow(common.HexToAddress(addr3)) + }, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue()) + }) } -func TestAllowlist_UpdateFromContract(t *testing.T) { + +func TestUpdateFromContract(t *testing.T) { t.Parallel() - t.Run("OK-iterate_over_list_of_allowed_senders", func(t *testing.T) { + t.Run("OK-fetch_complete_list_of_allowed_senders", func(t *testing.T) { ctx, cancel := context.WithCancel(testutils.Context(t)) client := mocks.NewClient(t) client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) + + addr := common.HexToAddress("0x0000000000000000000000000000000000000020") + typeAndVersionResponse, err := encodeTypeAndVersionResponse(ToSContractV100) + require.NoError(t, err) + + client.On("CallContract", mock.Anything, ethereum.CallMsg{ // typeAndVersion + To: &addr, + Data: hexutil.MustDecode("0x181f5a77"), + }, mock.Anything).Return(typeAndVersionResponse, nil) + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(sampleEncodedAllowlist(t), nil) config := allowlist.OnchainAllowlistConfig{ - ContractAddress: common.HexToAddress(addr3), - ContractVersion: 1, - BlockConfirmations: 1, - UpdateFrequencySec: 2, - UpdateTimeoutSec: 1, - StoredAllowlistBatchSize: 2, - OnchainAllowlistBatchSize: 16, - StoreAllowedSendersEnabled: true, - FetchingDelayInRangeSec: 0, + ContractAddress: common.HexToAddress(addr3), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 16, + FetchingDelayInRangeSec: 0, } orm := amocks.NewORM(t) - orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) + orm.On("PurgeAllowedSenders").Times(1).Return(nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -151,26 +272,38 @@ func TestAllowlist_UpdateFromContract(t *testing.T) { }, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue()) }) - t.Run("OK-fetch_complete_list_of_allowed_senders_without_storing", func(t *testing.T) { + t.Run("OK-iterate_over_list_of_allowed_senders", func(t *testing.T) { ctx, cancel := context.WithCancel(testutils.Context(t)) client := mocks.NewClient(t) client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) + + addr := common.HexToAddress("0x0000000000000000000000000000000000000020") + typeAndVersionResponse, err := encodeTypeAndVersionResponse(ToSContractV110) + require.NoError(t, err) + + client.On("CallContract", mock.Anything, ethereum.CallMsg{ // typeAndVersion + To: &addr, + Data: hexutil.MustDecode("0x181f5a77"), + }, mock.Anything).Return(typeAndVersionResponse, nil) + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(sampleEncodedAllowlist(t), nil) config := allowlist.OnchainAllowlistConfig{ - ContractAddress: common.HexToAddress(addr3), - ContractVersion: 1, - BlockConfirmations: 1, - UpdateFrequencySec: 2, - UpdateTimeoutSec: 1, - StoredAllowlistBatchSize: 2, - OnchainAllowlistBatchSize: 16, - StoreAllowedSendersEnabled: false, - FetchingDelayInRangeSec: 0, + ContractAddress: common.HexToAddress(addr3), + ContractVersion: 1, + BlockConfirmations: 1, + UpdateFrequencySec: 2, + UpdateTimeoutSec: 1, + StoredAllowlistBatchSize: 2, + OnchainAllowlistBatchSize: 16, + FetchingDelayInRangeSec: 0, } orm := amocks.NewORM(t) + orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) + orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) + allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -181,4 +314,93 @@ func TestAllowlist_UpdateFromContract(t *testing.T) { return allowlist.Allow(common.HexToAddress(addr1)) && !allowlist.Allow(common.HexToAddress(addr3)) }, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue()) }) + +} + +func TestExtractContractVersion(t *testing.T) { + + type tc struct { + name string + versionStr string + expectedResult string + expectedError *string + } + + var errInvalidVersion = func(v string) *string { + ev := fmt.Sprintf("version not found in string: %s", v) + return &ev + } + + tcs := []tc{ + { + name: "OK-Tos_type_and_version", + versionStr: "Functions Terms of Service Allow List v1.1.0", + expectedResult: "v1.1.0", + expectedError: nil, + }, + { + name: "OK-double_digits_minor", + versionStr: "Functions Terms of Service Allow List v1.20.0", + expectedResult: "v1.20.0", + expectedError: nil, + }, + { + name: "NOK-invalid_version", + versionStr: "invalid_version", + expectedResult: "", + expectedError: errInvalidVersion("invalid_version"), + }, + { + name: "NOK-incomplete_version", + versionStr: "v2.0", + expectedResult: "", + expectedError: errInvalidVersion("v2.0"), + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + actualResult, actualError := allowlist.ExtractContractVersion(tc.versionStr) + require.Equal(t, tc.expectedResult, actualResult) + + if tc.expectedError != nil { + require.EqualError(t, actualError, *tc.expectedError) + } else { + require.NoError(t, actualError) + } + }) + } +} + +func encodeTypeAndVersionResponse(typeAndVersion string) ([]byte, error) { + codecName := "my_codec" + evmEncoderConfig := `[{"Name":"typeAndVersion","Type":"string"}]` + codecConfig := types.CodecConfig{Configs: map[string]types.ChainCodecConfig{ + codecName: {TypeABI: evmEncoderConfig}, + }} + encoder, err := evm.NewCodec(codecConfig) + if err != nil { + return nil, err + } + + input := map[string]any{ + "typeAndVersion": typeAndVersion, + } + typeAndVersionResponse, err := encoder.Encode(context.Background(), input, codecName) + if err != nil { + return nil, err + } + + return typeAndVersionResponse, nil +} + +func sampleEncodedAllowlist(t *testing.T) []byte { + abiEncodedAddresses := + "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "000000000000000000000000" + addr1 + + "000000000000000000000000" + addr2 + rawData, err := hex.DecodeString(abiEncodedAddresses) + require.NoError(t, err) + return rawData } diff --git a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go index c2ba27c3a24..daff33d8902 100644 --- a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go @@ -101,6 +101,30 @@ func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]c return r0, r1 } +// PurgeAllowedSenders provides a mock function with given fields: qopts +func (_m *ORM) PurgeAllowedSenders(qopts ...pg.QOpt) error { + _va := make([]interface{}, len(qopts)) + for _i := range qopts { + _va[_i] = qopts[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for PurgeAllowedSenders") + } + + var r0 error + if rf, ok := ret.Get(0).(func(...pg.QOpt) error); ok { + r0 = rf(qopts...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // NewORM creates a new instance of ORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewORM(t interface { diff --git a/core/services/gateway/handlers/functions/allowlist/orm.go b/core/services/gateway/handlers/functions/allowlist/orm.go index 07ee1ea3b3b..ccacec81a43 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/orm.go @@ -18,6 +18,7 @@ type ORM interface { GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common.Address, error) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error + PurgeAllowedSenders(qopts ...pg.QOpt) error } type orm struct { @@ -91,6 +92,8 @@ func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg. return nil } +// DeleteAllowedSenders is used to remove blocked senders from the functions_allowlist table. +// This is achieved by specifying a list of blockedSenders to remove. func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error { var valuesPlaceholder []string for i := 1; i <= len(blockedSenders); i++ { @@ -121,3 +124,24 @@ func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg. return nil } + +// PurgeAllowedSenders will remove all the allowed senders for the configured orm routerContractAddress +func (o *orm) PurgeAllowedSenders(qopts ...pg.QOpt) error { + stmt := fmt.Sprintf(` + DELETE FROM %s + WHERE router_contract_address = $1;`, tableName) + + res, err := o.q.WithOpts(qopts...).Exec(stmt, o.routerContractAddress) + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + o.lggr.Debugf("Successfully purged allowed senders for routerContractAddress: %s. rowsAffected: %d", o.routerContractAddress, rowsAffected) + + return nil +} diff --git a/core/services/gateway/handlers/functions/allowlist/orm_test.go b/core/services/gateway/handlers/functions/allowlist/orm_test.go index 0f63e83cd5f..1d357616fab 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm_test.go +++ b/core/services/gateway/handlers/functions/allowlist/orm_test.go @@ -174,6 +174,71 @@ func TestORM_DeleteAllowedSenders(t *testing.T) { }) } +func TestORM_PurgeAllowedSenders(t *testing.T) { + t.Parallel() + + t.Run("OK-purge_allowed_list", func(t *testing.T) { + orm, err := setupORM(t) + require.NoError(t, err) + add1 := testutils.NewAddress() + add2 := testutils.NewAddress() + add3 := testutils.NewAddress() + err = orm.CreateAllowedSenders([]common.Address{add1, add2, add3}) + require.NoError(t, err) + + results, err := orm.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 3, len(results), "incorrect results length") + require.Equal(t, add1, results[0]) + + err = orm.PurgeAllowedSenders() + require.NoError(t, err) + + results, err = orm.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 0, len(results), "incorrect results length") + }) + + t.Run("OK-purge_allowed_list_for_contract_address", func(t *testing.T) { + orm1, err := setupORM(t) + require.NoError(t, err) + add1 := testutils.NewAddress() + add2 := testutils.NewAddress() + err = orm1.CreateAllowedSenders([]common.Address{add1, add2}) + require.NoError(t, err) + + results, err := orm1.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 2, len(results), "incorrect results length") + require.Equal(t, add1, results[0]) + + orm2, err := setupORM(t) + require.NoError(t, err) + add3 := testutils.NewAddress() + add4 := testutils.NewAddress() + err = orm2.CreateAllowedSenders([]common.Address{add3, add4}) + require.NoError(t, err) + + results, err = orm2.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 2, len(results), "incorrect results length") + require.Equal(t, add3, results[0]) + + err = orm2.PurgeAllowedSenders() + require.NoError(t, err) + + results, err = orm2.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 0, len(results), "incorrect results length") + + results, err = orm1.GetAllowedSenders(0, 10) + require.NoError(t, err) + require.Equal(t, 2, len(results), "incorrect results length") + require.Equal(t, add1, results[0]) + require.Equal(t, add2, results[1]) + }) +} + func Test_NewORM(t *testing.T) { t.Run("OK-create_ORM", func(t *testing.T) { _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), testutils.NewAddress()) diff --git a/go.mod b/go.mod index 2291388def2..1e0c09b06bf 100644 --- a/go.mod +++ b/go.mod @@ -98,6 +98,7 @@ require ( go.uber.org/zap v1.26.0 golang.org/x/crypto v0.19.0 golang.org/x/exp v0.0.0-20240213143201-ec583247a57a + golang.org/x/mod v0.15.0 golang.org/x/sync v0.6.0 golang.org/x/term v0.17.0 golang.org/x/text v0.14.0 @@ -316,7 +317,6 @@ require ( go.opentelemetry.io/proto/otlp v1.0.0 // indirect go.uber.org/ratelimit v0.2.0 // indirect golang.org/x/arch v0.7.0 // indirect - golang.org/x/mod v0.15.0 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/oauth2 v0.17.0 // indirect golang.org/x/sys v0.17.0 // indirect