diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index 27835127d0d..5a7a152d950 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -62,7 +62,7 @@ const ( // Create all OCR2 plugin Oracles and all extra services needed to run a Functions job. func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOracleArgs, s4OracleArgs *libocr2.OCR2OracleArgs, conf *FunctionsServicesConfig) ([]job.ServiceCtx, error) { pluginORM := functions.NewORM(conf.DB, conf.Logger, conf.QConfig, common.HexToAddress(conf.ContractID)) - s4ORM := s4.NewPostgresORM(conf.DB, conf.Logger, conf.QConfig, s4.SharedTableName, FunctionsS4Namespace) + s4ORM := s4.NewCachedORMWrapper(s4.NewPostgresORM(conf.DB, conf.Logger, conf.QConfig, s4.SharedTableName, FunctionsS4Namespace), conf.Logger) var pluginConfig config.PluginConfig if err := json.Unmarshal(conf.Job.OCR2OracleSpec.PluginConfig.Bytes(), &pluginConfig); err != nil { diff --git a/core/services/s4/cached_orm_wrapper.go b/core/services/s4/cached_orm_wrapper.go new file mode 100644 index 00000000000..38b9ecba1ca --- /dev/null +++ b/core/services/s4/cached_orm_wrapper.go @@ -0,0 +1,119 @@ +package s4 + +import ( + "fmt" + "math/big" + "strings" + "time" + + "github.com/patrickmn/go-cache" + + ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/pg" +) + +const ( + // defaultExpiration decides how long info will be valid for. + defaultExpiration = 10 * time.Minute + // cleanupInterval decides when the expired items in cache will be deleted. + cleanupInterval = 5 * time.Minute + + getSnapshotCachePrefix = "GetSnapshot" +) + +// CachedORM is a cached orm wrapper that implements the ORM interface. +// It adds a cache layer in order to remove unnecessary pressure to the underlaying implementation +type CachedORM struct { + underlayingORM ORM + cache *cache.Cache + lggr logger.Logger +} + +var _ ORM = (*CachedORM)(nil) + +func NewCachedORMWrapper(orm ORM, lggr logger.Logger) *CachedORM { + return &CachedORM{ + underlayingORM: orm, + cache: cache.New(defaultExpiration, cleanupInterval), + lggr: lggr, + } +} + +func (c CachedORM) Get(address *ubig.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) { + return c.underlayingORM.Get(address, slotId, qopts...) +} + +func (c CachedORM) Update(row *Row, qopts ...pg.QOpt) error { + c.deleteRowFromSnapshotCache(row) + + return c.underlayingORM.Update(row, qopts...) +} + +func (c CachedORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) { + deletedRows, err := c.underlayingORM.DeleteExpired(limit, utcNow, qopts...) + if err != nil { + return 0, err + } + + if deletedRows > 0 { + c.cache.Flush() + } + + return deletedRows, nil +} + +func (c CachedORM) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) { + key := fmt.Sprintf("%s_%s_%s", getSnapshotCachePrefix, addressRange.MinAddress.String(), addressRange.MaxAddress.String()) + + cached, found := c.cache.Get(key) + if found { + return cached.([]*SnapshotRow), nil + } + + c.lggr.Debug("Snapshot not found in cache, fetching it from underlaying implementation") + data, err := c.underlayingORM.GetSnapshot(addressRange, qopts...) + if err != nil { + return nil, err + } + c.cache.Set(key, data, defaultExpiration) + + return data, nil +} + +func (c CachedORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) { + return c.underlayingORM.GetUnconfirmedRows(limit, qopts...) +} + +// deleteRowFromSnapshotCache will clean the cache for every snapshot that would involve a given row +// in case of an error parsing a key it will also delete the key from the cache +func (c CachedORM) deleteRowFromSnapshotCache(row *Row) { + for key := range c.cache.Items() { + keyParts := strings.Split(key, "_") + if len(keyParts) != 3 { + continue + } + + if keyParts[0] != getSnapshotCachePrefix { + continue + } + + minAddress, ok := new(big.Int).SetString(keyParts[1], 10) + if !ok { + c.lggr.Errorf("error while converting minAddress string: %s to big.Int, deleting key %q", keyParts[1], key) + c.cache.Delete(key) + continue + } + + maxAddress, ok := new(big.Int).SetString(keyParts[2], 10) + if !ok { + c.lggr.Errorf("error while converting minAddress string: %s to big.Int, deleting key %q ", keyParts[2], key) + c.cache.Delete(key) + continue + } + + if row.Address.ToInt().Cmp(minAddress) >= 0 && row.Address.ToInt().Cmp(maxAddress) <= 0 { + c.cache.Delete(key) + } + } +} diff --git a/core/services/s4/cached_orm_wrapper_test.go b/core/services/s4/cached_orm_wrapper_test.go new file mode 100644 index 00000000000..6f6ac298557 --- /dev/null +++ b/core/services/s4/cached_orm_wrapper_test.go @@ -0,0 +1,272 @@ +package s4_test + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" + ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" + "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/s4" + "github.com/smartcontractkit/chainlink/v2/core/services/s4/mocks" +) + +func TestGetSnapshotEmpty(t *testing.T) { + t.Run("OK-no_rows", func(t *testing.T) { + psqlORM := setupORM(t, "test") + lggr := logger.TestLogger(t) + orm := s4.NewCachedORMWrapper(psqlORM, lggr) + + rows, err := orm.GetSnapshot(s4.NewFullAddressRange()) + assert.NoError(t, err) + assert.Empty(t, rows) + }) +} + +func TestGetSnapshotCacheFilled(t *testing.T) { + t.Run("OK_with_rows_already_cached", func(t *testing.T) { + rows := generateTestSnapshotRows(t, 100) + + fullAddressRange := s4.NewFullAddressRange() + + lggr := logger.TestLogger(t) + underlayingORM := mocks.NewORM(t) + underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() + + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + // first call will go to the underlaying orm implementation to fill the cache + first_snapshot, err := orm.GetSnapshot(fullAddressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(first_snapshot)) + + // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() + cache_snapshot, err := orm.GetSnapshot(fullAddressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(cache_snapshot)) + + snapshotRowMap := make(map[string]*s4.SnapshotRow) + for i, sr := range cache_snapshot { + // assuming unique addresses + snapshotRowMap[sr.Address.String()] = cache_snapshot[i] + } + + for _, sr := range rows { + snapshotRow, ok := snapshotRowMap[sr.Address.String()] + assert.True(t, ok) + assert.NotNil(t, snapshotRow) + assert.Equal(t, snapshotRow.Address, sr.Address) + assert.Equal(t, snapshotRow.SlotId, sr.SlotId) + assert.Equal(t, snapshotRow.Version, sr.Version) + assert.Equal(t, snapshotRow.Expiration, sr.Expiration) + assert.Equal(t, snapshotRow.Confirmed, sr.Confirmed) + assert.Equal(t, snapshotRow.PayloadSize, sr.PayloadSize) + } + }) +} + +func TestUpdateInvalidatesSnapshotCache(t *testing.T) { + t.Run("OK-GetSnapshot_cache_invalidated_after_update", func(t *testing.T) { + rows := generateTestSnapshotRows(t, 100) + + fullAddressRange := s4.NewFullAddressRange() + + lggr := logger.TestLogger(t) + underlayingORM := mocks.NewORM(t) + underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() + + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + // first call will go to the underlaying orm implementation to fill the cache + first_snapshot, err := orm.GetSnapshot(fullAddressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(first_snapshot)) + + // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() + cache_snapshot, err := orm.GetSnapshot(fullAddressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(cache_snapshot)) + + // this update call will invalidate the cache + row := &s4.Row{ + Address: big.New(common.HexToAddress("0x0000000000000000000000000000000000000000000000000000000000000005").Big()), + SlotId: 1, + Payload: cltest.MustRandomBytes(t, 32), + Version: 1, + Expiration: time.Now().Add(time.Hour).UnixMilli(), + Confirmed: true, + Signature: cltest.MustRandomBytes(t, 32), + } + underlayingORM.On("Update", row).Return(nil).Once() + err = orm.Update(row) + assert.NoError(t, err) + + // given the cache was invalidated this request will reach the underlaying orm implementation + underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() + third_snapshot, err := orm.GetSnapshot(fullAddressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(third_snapshot)) + }) + + t.Run("OK-GetSnapshot_cache_not_invalidated_after_update", func(t *testing.T) { + rows := generateTestSnapshotRows(t, 5) + + addressRange := &s4.AddressRange{ + MinAddress: ubig.New(common.BytesToAddress(bytes.Repeat([]byte{0x00}, common.AddressLength)).Big()), + MaxAddress: ubig.New(common.BytesToAddress(append(bytes.Repeat([]byte{0x00}, common.AddressLength-1), 3)).Big()), + } + + lggr := logger.TestLogger(t) + underlayingORM := mocks.NewORM(t) + underlayingORM.On("GetSnapshot", addressRange).Return(rows, nil).Once() + + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + // first call will go to the underlaying orm implementation to fill the cache + first_snapshot, err := orm.GetSnapshot(addressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(first_snapshot)) + + // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() + cache_snapshot, err := orm.GetSnapshot(addressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(cache_snapshot)) + + // this update call wont invalidate the cache because the address is out of the cache address range + outOfCachedRangeAddress := ubig.New(common.BytesToAddress(append(bytes.Repeat([]byte{0x00}, common.AddressLength-1), 5)).Big()) + row := &s4.Row{ + Address: outOfCachedRangeAddress, + SlotId: 1, + Payload: cltest.MustRandomBytes(t, 32), + Version: 1, + Expiration: time.Now().Add(time.Hour).UnixMilli(), + Confirmed: true, + Signature: cltest.MustRandomBytes(t, 32), + } + underlayingORM.On("Update", row).Return(nil).Once() + err = orm.Update(row) + assert.NoError(t, err) + + // given the cache was not invalidated this request wont reach the underlaying orm implementation + third_snapshot, err := orm.GetSnapshot(addressRange) + assert.NoError(t, err) + assert.Equal(t, len(rows), len(third_snapshot)) + }) +} + +func TestGet(t *testing.T) { + address := big.New(testutils.NewAddress().Big()) + var slotID uint = 1 + + lggr := logger.TestLogger(t) + + t.Run("OK-Get_underlaying_ORM_returns_a_row", func(t *testing.T) { + underlayingORM := mocks.NewORM(t) + expectedRow := &s4.Row{ + Address: address, + SlotId: slotID, + } + underlayingORM.On("Get", address, slotID).Return(expectedRow, nil).Once() + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + row, err := orm.Get(address, slotID) + require.NoError(t, err) + require.Equal(t, expectedRow, row) + }) + t.Run("NOK-Get_underlaying_ORM_returns_an_error", func(t *testing.T) { + underlayingORM := mocks.NewORM(t) + underlayingORM.On("Get", address, slotID).Return(nil, fmt.Errorf("some_error")).Once() + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + row, err := orm.Get(address, slotID) + require.Nil(t, row) + require.EqualError(t, err, "some_error") + }) +} + +func TestDeletedExpired(t *testing.T) { + var limit uint = 1 + now := time.Now() + + lggr := logger.TestLogger(t) + + t.Run("OK-DeletedExpired_underlaying_ORM_returns_a_row", func(t *testing.T) { + var expectedDeleted int64 = 10 + underlayingORM := mocks.NewORM(t) + underlayingORM.On("DeleteExpired", limit, now).Return(expectedDeleted, nil).Once() + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + actualDeleted, err := orm.DeleteExpired(limit, now) + require.NoError(t, err) + require.Equal(t, expectedDeleted, actualDeleted) + }) + t.Run("NOK-DeletedExpired_underlaying_ORM_returns_an_error", func(t *testing.T) { + var expectedDeleted int64 + underlayingORM := mocks.NewORM(t) + underlayingORM.On("DeleteExpired", limit, now).Return(expectedDeleted, fmt.Errorf("some_error")).Once() + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + actualDeleted, err := orm.DeleteExpired(limit, now) + require.EqualError(t, err, "some_error") + require.Equal(t, expectedDeleted, actualDeleted) + }) +} + +// GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) +func TestGetUnconfirmedRows(t *testing.T) { + var limit uint = 1 + lggr := logger.TestLogger(t) + + t.Run("OK-GetUnconfirmedRows_underlaying_ORM_returns_a_row", func(t *testing.T) { + address := big.New(testutils.NewAddress().Big()) + var slotID uint = 1 + + expectedRow := []*s4.Row{{ + Address: address, + SlotId: slotID, + }} + underlayingORM := mocks.NewORM(t) + underlayingORM.On("GetUnconfirmedRows", limit).Return(expectedRow, nil).Once() + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + actualRow, err := orm.GetUnconfirmedRows(limit) + require.NoError(t, err) + require.Equal(t, expectedRow, actualRow) + }) + t.Run("NOK-GetUnconfirmedRows_underlaying_ORM_returns_an_error", func(t *testing.T) { + underlayingORM := mocks.NewORM(t) + underlayingORM.On("GetUnconfirmedRows", limit).Return(nil, fmt.Errorf("some_error")).Once() + orm := s4.NewCachedORMWrapper(underlayingORM, lggr) + + actualRow, err := orm.GetUnconfirmedRows(limit) + require.Nil(t, actualRow) + require.EqualError(t, err, "some_error") + }) +} + +func generateTestSnapshotRows(t *testing.T, n int) []*s4.SnapshotRow { + t.Helper() + + rows := make([]*s4.SnapshotRow, n) + for i := 0; i < n; i++ { + row := &s4.SnapshotRow{ + Address: big.New(testutils.NewAddress().Big()), + SlotId: 1, + PayloadSize: 32, + Version: 1 + uint64(i), + Expiration: time.Now().Add(time.Hour).UnixMilli(), + Confirmed: i%2 == 0, + } + rows[i] = row + } + + return rows +}