diff --git a/relayer/lavasession/consumer_session_manager.go b/relayer/lavasession/consumer_session_manager.go index b2671fe146..70f6df83f9 100644 --- a/relayer/lavasession/consumer_session_manager.go +++ b/relayer/lavasession/consumer_session_manager.go @@ -17,10 +17,10 @@ import ( ) type ConsumerSessionManager struct { - lock sync.RWMutex - pairing map[string]*ConsumerSessionsWithProvider // key == provider address - currentEpoch uint64 - + lock sync.RWMutex + pairing map[string]*ConsumerSessionsWithProvider // key == provider address + currentEpoch uint64 + numberOfResets uint64 // pairingAddresses for Data reliability pairingAddresses []string // contains all addresses from the initial pairing. pairingAddressesLength uint64 @@ -47,10 +47,11 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(ctx context.Context, epoch csm.atomicWriteCurrentEpoch(epoch) // Reset States - csm.validAddresses = make([]string, pairingListLength) + // csm.validAddresses length is reset in setValidAddressesToDefaultValue csm.pairingAddresses = make([]string, pairingListLength) csm.addedToPurgeAndReport = make(map[string]struct{}, 0) csm.pairingAddressesLength = uint64(pairingListLength) + csm.numberOfResets = 0 // Reset the pairingPurge. // This happens only after an entire epoch. so its impossible to have session connected to the old purged list @@ -60,11 +61,16 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(ctx context.Context, epoch csm.pairingAddresses[idx] = provider.Acc csm.pairing[provider.Acc] = provider } - copy(csm.validAddresses, csm.pairingAddresses) // the starting point is that valid addresses are equal to pairing addresses. + csm.setValidAddressesToDefaultValue() // the starting point is that valid addresses are equal to pairing addresses. return nil } +func (csm *ConsumerSessionManager) setValidAddressesToDefaultValue() { + csm.validAddresses = make([]string, len(csm.pairingAddresses)) + copy(csm.validAddresses, csm.pairingAddresses) +} + // reads cs.currentEpoch atomically func (csm *ConsumerSessionManager) atomicWriteCurrentEpoch(epoch uint64) { atomic.StoreUint64(&csm.currentEpoch, epoch) @@ -75,11 +81,46 @@ func (csm *ConsumerSessionManager) atomicReadCurrentEpoch() (epoch uint64) { return atomic.LoadUint64(&csm.currentEpoch) } +// validate if reset is needed for valid addresses list. +func (csm *ConsumerSessionManager) shouldResetValidAddresses() (reset bool, numberOfResets uint64) { + csm.lock.RLock() // lock read to validate length + defer csm.lock.RUnlock() + numberOfResets = csm.numberOfResets + if len(csm.validAddresses) == 0 { + reset = true + } + return +} + +// reset the valid addresses list and increase numberOfResets +func (csm *ConsumerSessionManager) resetValidAddresses() uint64 { + csm.lock.Lock() // lock write + defer csm.lock.Unlock() + if len(csm.validAddresses) == 0 { // re verify it didn't change while waiting for lock. + utils.LavaFormatWarning("Provider pairing list is empty, resetting state.", nil, nil) + csm.setValidAddressesToDefaultValue() + csm.numberOfResets += 1 + } + // if len(csm.validAddresses) != 0 meaning we had a reset (or an epoch change), so we need to return the numberOfResets which is currently in csm + return csm.numberOfResets +} + +// validating we still have providers, otherwise reset valid addresses list +func (csm *ConsumerSessionManager) validatePairingListNotEmpty() uint64 { + reset, numberOfResets := csm.shouldResetValidAddresses() + if reset { + numberOfResets = csm.resetValidAddresses() + } + return numberOfResets +} + // GetSession will return a ConsumerSession, given cu needed for that session. // The user can also request specific providers to not be included in the search for a session. func (csm *ConsumerSessionManager) GetSession(ctx context.Context, cuNeededForSession uint64, initUnwantedProviders map[string]struct{}) ( consumerSession *SingleConsumerSession, epoch uint64, providerPublicAddress string, reportedProviders []byte, errRet error, ) { + numberOfResets := csm.validatePairingListNotEmpty() // if pairing list is empty we reset the state. + if initUnwantedProviders == nil { // verify initUnwantedProviders is not nil initUnwantedProviders = make(map[string]struct{}) } @@ -135,7 +176,7 @@ func (csm *ConsumerSessionManager) GetSession(ctx context.Context, cuNeededForSe } // Get session from endpoint or create new or continue. if more than 10 connections are open. - consumerSession, pairingEpoch, err := consumerSessionWithProvider.getConsumerSessionInstanceFromEndpoint(endpoint) + consumerSession, pairingEpoch, err := consumerSessionWithProvider.getConsumerSessionInstanceFromEndpoint(endpoint, numberOfResets) if err != nil { utils.LavaFormatDebug("Error on consumerSessionWithProvider.getConsumerSessionInstanceFromEndpoint", &map[string]string{"Error": err.Error()}) if MaximumNumberOfSessionsExceededError.Is(err) { diff --git a/relayer/lavasession/consumer_session_manager_test.go b/relayer/lavasession/consumer_session_manager_test.go index 6507f9a623..2fa13a9837 100644 --- a/relayer/lavasession/consumer_session_manager_test.go +++ b/relayer/lavasession/consumer_session_manager_test.go @@ -18,6 +18,7 @@ import ( const ( parallelGoRoutines = 40 numberOfProviders = 10 + numberOfResetsToTest = 10 numberOfAllowedSessionsPerConsumer = 10 firstEpochHeight = 20 secondEpochHeight = 40 @@ -83,6 +84,92 @@ func TestHappyFlow(t *testing.T) { require.Equal(t, cs.LatestBlock, servicedBlockNumber) } +func TestPairingReset(t *testing.T) { + s := createGRPCServer(t) // create a grpcServer so we can connect to its endpoint and validate everything works. + defer s.Stop() // stop the server when finished. + ctx := context.Background() + csm := CreateConsumerSessionManager() + pairingList := createPairingList() + err := csm.UpdateAllProviders(ctx, firstEpochHeight, pairingList) // update the providers. + require.Nil(t, err) + csm.validAddresses = []string{} // set valid addresses to zero + cs, epoch, _, _, err := csm.GetSession(ctx, cuForFirstRequest, nil) // get a session + require.Nil(t, err) + require.Equal(t, len(csm.validAddresses), len(csm.pairingAddresses)) + require.NotNil(t, cs) + require.Equal(t, epoch, csm.currentEpoch) + require.Equal(t, cs.LatestRelayCu, uint64(cuForFirstRequest)) + err = csm.OnSessionDone(cs, firstEpochHeight, servicedBlockNumber, cuForFirstRequest, time.Duration(time.Millisecond), (servicedBlockNumber - 1), numberOfProviders, numberOfProviders) + require.Nil(t, err) + require.Equal(t, cs.CuSum, cuForFirstRequest) + require.Equal(t, cs.LatestRelayCu, latestRelayCuAfterDone) + require.Equal(t, cs.RelayNum, relayNumberAfterFirstCall) + require.Equal(t, cs.LatestBlock, servicedBlockNumber) + require.Equal(t, csm.numberOfResets, uint64(0x1)) // verify we had one reset only +} + +func TestPairingResetWithFailures(t *testing.T) { + s := createGRPCServer(t) // create a grpcServer so we can connect to its endpoint and validate everything works. + defer s.Stop() // stop the server when finished. + ctx := context.Background() + csm := CreateConsumerSessionManager() + pairingList := createPairingList() + err := csm.UpdateAllProviders(ctx, firstEpochHeight, pairingList) // update the providers. + require.Nil(t, err) + for { + fmt.Printf("%v", len(csm.validAddresses)) + cs, _, _, _, err := csm.GetSession(ctx, cuForFirstRequest, nil) // get a session + if err != nil { + if len(csm.validAddresses) == 0 { // wait for all pairings to be blocked. + break + } + require.True(t, false) // fail test. + } + err = csm.OnSessionFailure(cs, nil) + + } + require.Equal(t, len(csm.validAddresses), 0) + cs, epoch, _, _, err := csm.GetSession(ctx, cuForFirstRequest, nil) // get a session + require.Nil(t, err) + require.Equal(t, len(csm.validAddresses), len(csm.pairingAddresses)) + require.NotNil(t, cs) + require.Equal(t, epoch, csm.currentEpoch) + require.Equal(t, cs.LatestRelayCu, uint64(cuForFirstRequest)) + require.Equal(t, csm.numberOfResets, uint64(0x1)) // verify we had one reset only +} + +func TestPairingResetWithMultipleFailures(t *testing.T) { + s := createGRPCServer(t) // create a grpcServer so we can connect to its endpoint and validate everything works. + defer s.Stop() // stop the server when finished. + ctx := context.Background() + csm := CreateConsumerSessionManager() + pairingList := createPairingList() + err := csm.UpdateAllProviders(ctx, firstEpochHeight, pairingList) // update the providers. + require.Nil(t, err) + for numberOfResets := 0; numberOfResets < numberOfResetsToTest; numberOfResets++ { + for { + fmt.Printf("%v", len(csm.validAddresses)) + cs, _, _, _, err := csm.GetSession(ctx, cuForFirstRequest, nil) // get a session + if err != nil { + if len(csm.validAddresses) == 0 { // wait for all pairings to be blocked. + break + } + require.True(t, false) // fail test. + } + err = csm.OnSessionFailure(cs, nil) + } + require.Equal(t, len(csm.validAddresses), 0) + cs, epoch, _, _, err := csm.GetSession(ctx, cuForFirstRequest, nil) // get a session + require.Nil(t, err) + require.Equal(t, len(csm.validAddresses), len(csm.pairingAddresses)) + require.NotNil(t, cs) + require.Equal(t, epoch, csm.currentEpoch) + require.Equal(t, cs.LatestRelayCu, uint64(cuForFirstRequest)) + require.Equal(t, csm.numberOfResets, uint64(numberOfResets+1)) // verify we had one reset only + } + +} + // Test the basic functionality of the consumerSessionManager func TestSuccessAndFailureOfSessionWithUpdatePairingsInTheMiddle(t *testing.T) { s := createGRPCServer(t) // create a grpcServer so we can connect to its endpoint and validate everything works. diff --git a/relayer/lavasession/consumer_types.go b/relayer/lavasession/consumer_types.go index 861400843f..825ad83780 100644 --- a/relayer/lavasession/consumer_types.go +++ b/relayer/lavasession/consumer_types.go @@ -149,14 +149,16 @@ func (cswp *ConsumerSessionsWithProvider) connectRawClientWithTimeout(ctx contex return &c, nil } -func (cswp *ConsumerSessionsWithProvider) getConsumerSessionInstanceFromEndpoint(endpoint *Endpoint) (singleConsumerSession *SingleConsumerSession, pairingEpoch uint64, err error) { +func (cswp *ConsumerSessionsWithProvider) getConsumerSessionInstanceFromEndpoint(endpoint *Endpoint, numberOfResets uint64) (singleConsumerSession *SingleConsumerSession, pairingEpoch uint64, err error) { // TODO: validate that the endpoint even belongs to the ConsumerSessionsWithProvider and is enabled. + // Multiply numberOfReset +1 by MaxAllowedBlockListedSessionPerProvider as every reset needs to allow more blocked sessions allowed. + maximumBlockedSessionsAllowed := MaxAllowedBlockListedSessionPerProvider * (numberOfResets + 1) // +1 as we start from 0 cswp.Lock.Lock() defer cswp.Lock.Unlock() // try to lock an existing session, if can't create a new one - numberOfBlockedSessions := 0 + var numberOfBlockedSessions uint64 = 0 for sessionID, session := range cswp.Sessions { if sessionID == DataReliabilitySessionId { continue // we cant use the data reliability session. which is located at key DataReliabilitySessionId @@ -165,7 +167,7 @@ func (cswp *ConsumerSessionsWithProvider) getConsumerSessionInstanceFromEndpoint // skip sessions that don't belong to the active connection continue } - if numberOfBlockedSessions >= MaxAllowedBlockListedSessionPerProvider { + if numberOfBlockedSessions >= maximumBlockedSessionsAllowed { return nil, 0, MaximumNumberOfBlockListedSessionsError }