diff --git a/action/protocol/staking/contractstake_indexer.go b/action/protocol/staking/contractstake_indexer.go index 1a3225f2ef..a81a28df16 100644 --- a/action/protocol/staking/contractstake_indexer.go +++ b/action/protocol/staking/contractstake_indexer.go @@ -8,6 +8,7 @@ package staking import ( _ "embed" "strings" + "time" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/iotexproject/iotex-address/address" @@ -25,6 +26,7 @@ type ( // ContractStakingIndexer defines the interface of contract staking reader ContractStakingIndexer interface { + Height() (uint64, error) // Buckets returns active buckets Buckets(height uint64) ([]*VoteBucket, error) // BucketsByIndices returns active buckets by indices @@ -42,6 +44,17 @@ type ( // BucketTypes returns the active bucket types BucketTypes(height uint64) ([]*ContractStakingBucketType, error) } + + delayTolerantIndexer struct { + ContractStakingIndexer + duration time.Duration + startHeight uint64 + } + + delayTolerantIndexerWithBucketType struct { + *delayTolerantIndexer + indexer ContractStakingIndexerWithBucketType + } ) func init() { @@ -51,3 +64,93 @@ func init() { panic(err) } } + +// NewDelayTolerantIndexer creates a delay tolerant indexer +func NewDelayTolerantIndexer(indexer ContractStakingIndexer, duration time.Duration) ContractStakingIndexer { + d := &delayTolerantIndexer{ContractStakingIndexer: indexer, duration: duration} + if indexWithStart, ok := indexer.(interface{ StartHeight() uint64 }); ok { + d.startHeight = indexWithStart.StartHeight() + } + return d +} + +// NewDelayTolerantIndexerWithBucketType creates a delay tolerant indexer with bucket type +func NewDelayTolerantIndexerWithBucketType(indexer ContractStakingIndexerWithBucketType, duration time.Duration) ContractStakingIndexerWithBucketType { + return &delayTolerantIndexerWithBucketType{ + NewDelayTolerantIndexer(indexer, duration).(*delayTolerantIndexer), + indexer, + } +} + +func (c *delayTolerantIndexer) wait(height uint64) (bool, error) { + // first check if the height is already reached + if c.startHeight >= height { + return false, nil + } + h, err := c.Height() + if err != nil { + return false, err + } + if h >= height { + return true, nil + } + // wait for the height to be reached + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + timer := time.NewTimer(c.duration) + defer timer.Stop() + for { + select { + case <-ticker.C: + h, err := c.Height() + if err != nil { + return false, err + } + if h >= height { + return true, nil + } + case <-timer.C: + return false, nil + } + } +} + +func (c *delayTolerantIndexer) Buckets(height uint64) ([]*VoteBucket, error) { + _, err := c.wait(height) + if err != nil { + return nil, err + } + return c.ContractStakingIndexer.Buckets(height) +} + +func (c *delayTolerantIndexer) BucketsByIndices(indices []uint64, height uint64) ([]*VoteBucket, error) { + _, err := c.wait(height) + if err != nil { + return nil, err + } + return c.ContractStakingIndexer.BucketsByIndices(indices, height) +} + +func (c *delayTolerantIndexer) BucketsByCandidate(ownerAddr address.Address, height uint64) ([]*VoteBucket, error) { + _, err := c.wait(height) + if err != nil { + return nil, err + } + return c.ContractStakingIndexer.BucketsByCandidate(ownerAddr, height) +} + +func (c *delayTolerantIndexer) TotalBucketCount(height uint64) (uint64, error) { + _, err := c.wait(height) + if err != nil { + return 0, err + } + return c.ContractStakingIndexer.TotalBucketCount(height) +} + +func (c *delayTolerantIndexerWithBucketType) BucketTypes(height uint64) ([]*ContractStakingBucketType, error) { + _, err := c.wait(height) + if err != nil { + return nil, err + } + return c.indexer.BucketTypes(height) +} diff --git a/action/protocol/staking/contractstake_indexer_mock.go b/action/protocol/staking/contractstake_indexer_mock.go index bc45f43c7f..b050f69cf2 100644 --- a/action/protocol/staking/contractstake_indexer_mock.go +++ b/action/protocol/staking/contractstake_indexer_mock.go @@ -93,6 +93,21 @@ func (mr *MockContractStakingIndexerMockRecorder) ContractAddress() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractAddress", reflect.TypeOf((*MockContractStakingIndexer)(nil).ContractAddress)) } +// Height mocks base method. +func (m *MockContractStakingIndexer) Height() (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Height") + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Height indicates an expected call of Height. +func (mr *MockContractStakingIndexerMockRecorder) Height() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Height", reflect.TypeOf((*MockContractStakingIndexer)(nil).Height)) +} + // TotalBucketCount mocks base method. func (m *MockContractStakingIndexer) TotalBucketCount(height uint64) (uint64, error) { m.ctrl.T.Helper() @@ -205,6 +220,21 @@ func (mr *MockContractStakingIndexerWithBucketTypeMockRecorder) ContractAddress( return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractAddress", reflect.TypeOf((*MockContractStakingIndexerWithBucketType)(nil).ContractAddress)) } +// Height mocks base method. +func (m *MockContractStakingIndexerWithBucketType) Height() (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Height") + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Height indicates an expected call of Height. +func (mr *MockContractStakingIndexerWithBucketTypeMockRecorder) Height() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Height", reflect.TypeOf((*MockContractStakingIndexerWithBucketType)(nil).Height)) +} + // TotalBucketCount mocks base method. func (m *MockContractStakingIndexerWithBucketType) TotalBucketCount(height uint64) (uint64, error) { m.ctrl.T.Helper() diff --git a/action/protocol/staking/contractstake_indexer_test.go b/action/protocol/staking/contractstake_indexer_test.go new file mode 100644 index 0000000000..b55faec79f --- /dev/null +++ b/action/protocol/staking/contractstake_indexer_test.go @@ -0,0 +1,100 @@ +package staking + +import ( + "errors" + "math/big" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/iotexproject/iotex-address/address" + "github.com/stretchr/testify/require" +) + +func TestDelayTolerantIndexer(t *testing.T) { + r := require.New(t) + ctrl := gomock.NewController(t) + indexer := NewMockContractStakingIndexerWithBucketType(ctrl) + delayIndexer := NewDelayTolerantIndexerWithBucketType(indexer, time.Second) + + var ( + indexerHeight = uint64(10) + indexerBuckets = []*VoteBucket{{Index: 1}} + indexerAddress = address.ZeroAddress + indexerBucketTypes = []*ContractStakingBucketType{{Amount: big.NewInt(1)}} + ) + + // Height + indexer.EXPECT().Height().DoAndReturn(func() (uint64, error) { + return atomic.LoadUint64(&indexerHeight), nil + }).AnyTimes() + + height, err := delayIndexer.Height() + r.NoError(err) + r.Equal(atomic.LoadUint64(&indexerHeight), height) + // Buckets + indexer.EXPECT().Buckets(gomock.Any()).DoAndReturn(func(height uint64) ([]*VoteBucket, error) { + if height <= atomic.LoadUint64(&indexerHeight) { + return indexerBuckets, nil + } + return nil, errors.New("invalid height") + }).AnyTimes() + noDelayHeight, delayHeight := atomic.LoadUint64(&indexerHeight), atomic.LoadUint64(&indexerHeight)+1 + bkts, err := delayIndexer.Buckets(noDelayHeight) + r.NoError(err) + r.Equal(indexerBuckets, bkts) + bkts, err = delayIndexer.Buckets(delayHeight) + r.ErrorContains(err, "invalid height") + go func() { + time.Sleep(100 * time.Millisecond) + atomic.StoreUint64(&indexerHeight, delayHeight) + }() + bkts, err = delayIndexer.Buckets(delayHeight) + r.NoError(err) + r.Equal(indexerBuckets, bkts) + // BucketsByIndices + indexer.EXPECT().BucketsByIndices(gomock.Any(), gomock.Any()).DoAndReturn(func(indices []uint64, height uint64) ([]*VoteBucket, error) { + if height <= atomic.LoadUint64(&indexerHeight) { + return indexerBuckets, nil + } + return nil, errors.New("invalid height") + }).AnyTimes() + bkts, err = delayIndexer.BucketsByIndices(nil, delayHeight) + r.NoError(err) + r.Equal(indexerBuckets, bkts) + // BucketsByCandidate + indexer.EXPECT().BucketsByCandidate(gomock.Any(), gomock.Any()).DoAndReturn(func(ownerAddr address.Address, height uint64) ([]*VoteBucket, error) { + if height <= atomic.LoadUint64(&indexerHeight) { + return indexerBuckets, nil + } + return nil, errors.New("invalid height") + }).AnyTimes() + bkts, err = delayIndexer.BucketsByCandidate(nil, delayHeight) + r.NoError(err) + r.Equal(indexerBuckets, bkts) + // TotalBucketCount + indexer.EXPECT().TotalBucketCount(gomock.Any()).DoAndReturn(func(height uint64) (uint64, error) { + if height <= atomic.LoadUint64(&indexerHeight) { + return uint64(len(indexerBuckets)), nil + } + return 0, errors.New("invalid height") + }).AnyTimes() + count, err := delayIndexer.TotalBucketCount(delayHeight) + r.NoError(err) + r.Equal(uint64(len(indexerBuckets)), count) + // ContractAddress + indexer.EXPECT().ContractAddress().Return(indexerAddress).AnyTimes() + ca := delayIndexer.ContractAddress() + r.Equal(indexerAddress, ca) + // BucketTypes + indexer.EXPECT().BucketTypes(gomock.Any()).DoAndReturn(func(height uint64) ([]*ContractStakingBucketType, error) { + if height <= atomic.LoadUint64(&indexerHeight) { + return indexerBucketTypes, nil + } + return nil, errors.New("invalid height") + }) + bucketTypes, err := delayIndexer.BucketTypes(delayHeight) + r.NoError(err) + r.Equal(indexerBucketTypes, bucketTypes) +} diff --git a/action/protocol/staking/protocol.go b/action/protocol/staking/protocol.go index d3c82227c5..953b66d286 100644 --- a/action/protocol/staking/protocol.go +++ b/action/protocol/staking/protocol.go @@ -581,10 +581,10 @@ func (p *Protocol) ReadState(ctx context.Context, sr protocol.StateReader, metho // stakeSR is the stake state reader including native and contract staking indexers := []ContractStakingIndexer{} if p.contractStakingIndexer != nil { - indexers = append(indexers, p.contractStakingIndexer) + indexers = append(indexers, NewDelayTolerantIndexerWithBucketType(p.contractStakingIndexer, time.Second)) } if p.contractStakingIndexerV2 != nil { - indexers = append(indexers, p.contractStakingIndexerV2) + indexers = append(indexers, NewDelayTolerantIndexer(p.contractStakingIndexerV2, time.Second)) } stakeSR, err := newCompositeStakingStateReader(p.candBucketsIndexer, sr, p.calculateVoteWeight, indexers...) if err != nil { diff --git a/action/protocol/staking/protocol_test.go b/action/protocol/staking/protocol_test.go index 11b4d3e4f0..8aff5445c4 100644 --- a/action/protocol/staking/protocol_test.go +++ b/action/protocol/staking/protocol_test.go @@ -469,6 +469,7 @@ func TestProtocol_ActiveCandidates(t *testing.T) { require.NoError(err) var csIndexerHeight, csVotes uint64 + csIndexer.EXPECT().Height().Return(uint64(0), nil).AnyTimes() csIndexer.EXPECT().BucketsByCandidate(gomock.Any(), gomock.Any()).DoAndReturn(func(ownerAddr address.Address, height uint64) ([]*VoteBucket, error) { if height != csIndexerHeight { return nil, errors.Errorf("invalid height %d", height)