diff --git a/protocol/provideroptimizer/selection_tier.go b/protocol/provideroptimizer/selection_tier.go index 1bb2db5480..3536480b93 100644 --- a/protocol/provideroptimizer/selection_tier.go +++ b/protocol/provideroptimizer/selection_tier.go @@ -1,6 +1,7 @@ package provideroptimizer import ( + "github.com/lavanet/lava/v3/utils" "github.com/lavanet/lava/v3/utils/rand" ) @@ -15,6 +16,7 @@ type SelectionTier interface { AddScore(entry string, score float64) GetTier(tier int, numTiers int, minimumEntries int) []Entry SelectTierRandomly(numTiers int, tierChances map[int]float64) int + ShiftTierChance(numTiers int, initialYierChances map[int]float64) map[int]float64 } type SelectionTierInst struct { @@ -49,15 +51,12 @@ func (st *SelectionTierInst) AddScore(entry string, score float64) { func (st *SelectionTierInst) SelectTierRandomly(numTiers int, tierChances map[int]float64) int { // select a tier randomly based on the chances given // if the chances are not given, select a tier randomly based on the number of tiers - if len(tierChances) == 0 { + if len(tierChances) == 0 || len(tierChances) > numTiers { + utils.LavaFormatError("Invalid tier chances usage", nil) return rand.Intn(numTiers) } // calculate the total chance - totalChance := 0.0 - for _, chance := range tierChances { - totalChance += chance - } - chanceForDefaultTiers := (1 - totalChance) / float64(numTiers-len(tierChances)) + chanceForDefaultTiers := st.calcChanceForDefaultTiers(tierChances, numTiers) // select a random number between 0 and 1 randChance := rand.Float64() // find the tier that the random chance falls into @@ -76,6 +75,63 @@ func (st *SelectionTierInst) SelectTierRandomly(numTiers int, tierChances map[in return 0 } +func (*SelectionTierInst) calcChanceForDefaultTiers(tierChances map[int]float64, numTiers int) float64 { + totalChance := 0.0 + for _, chance := range tierChances { + totalChance += chance + } + chanceForDefaultTiers := (1 - totalChance) / float64(numTiers-len(tierChances)) + return chanceForDefaultTiers +} + +func (st *SelectionTierInst) averageScoreForTier(tier int, numTiers int) float64 { + // calculate the average score for the given tier and number of tiers + start, end, _, _ := getPositionsForTier(tier, numTiers, len(st.scores)) + sum := 0.0 + parts := 0.0 + for i := start; i < end; i++ { + sum += st.scores[i].Score * st.scores[i].Part + parts += st.scores[i].Part + } + return sum / parts +} + +func (st *SelectionTierInst) ShiftTierChance(numTiers int, initialTierChances map[int]float64) map[int]float64 { + chanceForDefaultTiers := st.calcChanceForDefaultTiers(initialTierChances, numTiers) + + // shift the chances + shiftedTierChances := make(map[int]float64) + // shift tier chances based on the difference in the average score of each tier + reversedScores := make([]float64, numTiers) + totalScore := 0.0 + for i := 0; i < numTiers; i++ { + reversedScores[i] = 1 / (st.averageScoreForTier(i, numTiers) + 0.0001) // add epsilon to avoid 0 + totalScore += reversedScores[i] + } + averageChance := 1 / float64(numTiers) + for i := 0; i < numTiers; i++ { + offsetFactor := reversedScores[i] / totalScore + if _, ok := initialTierChances[i]; !ok { + if chanceForDefaultTiers > 0 { + shiftedTierChances[i] = chanceForDefaultTiers + averageChance*offsetFactor + } + } else { + if initialTierChances[i] > 0 { + shiftedTierChances[i] = initialTierChances[i] + averageChance*offsetFactor + } + } + } + // normalize the chances + totalChance := 0.0 + for _, chance := range shiftedTierChances { + totalChance += chance + } + for i := 0; i < numTiers; i++ { + shiftedTierChances[i] /= totalChance + } + return shiftedTierChances +} + func (st *SelectionTierInst) GetTier(tier int, numTiers int, minimumEntries int) []Entry { // get the tier of scores for the given tier and number of tiers entriesLen := len(st.scores) @@ -85,7 +141,8 @@ func (st *SelectionTierInst) GetTier(tier int, numTiers int, minimumEntries int) start, end, fracStart, fracEnd := getPositionsForTier(tier, numTiers, entriesLen) if end < minimumEntries { - return st.scores[:minimumEntries] + // only allow better tiers if there are not enough entries + return st.scores[:end] } ret := st.scores[start:end] if len(ret) >= minimumEntries { diff --git a/protocol/provideroptimizer/selection_tier_test.go b/protocol/provideroptimizer/selection_tier_test.go index 31617577c4..25fb493db5 100644 --- a/protocol/provideroptimizer/selection_tier_test.go +++ b/protocol/provideroptimizer/selection_tier_test.go @@ -173,6 +173,56 @@ func TestSelectionTierInstGetTierBig(t *testing.T) { } } + +func TestSelectionTierInstShiftTierChanceEquals(t *testing.T) { + st := NewSelectionTier() + numTiers := 4 + for i := 0; i < 25; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 25; i < 50; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 50; i < 75; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 75; i < 100; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + selectionTierChances := st.ShiftTierChance(numTiers, nil) + require.Equal(t, numTiers, len(selectionTierChances)) + require.Equal(t, selectionTierChances[0], selectionTierChances[1]) + + selectionTierChances = st.ShiftTierChance(numTiers, map[int]float64{0: 0.5, 1: 0.5}) + require.Equal(t, 0.0, selectionTierChances[len(selectionTierChances)-1]) + require.Equal(t, 0.5, selectionTierChances[0]) + + selectionTierChances = st.ShiftTierChance(numTiers, map[int]float64{0: 0.5, len(selectionTierChances) - 1: 0.1}) + require.Less(t, 0.5, selectionTierChances[0]) + require.Greater(t, 0.25, selectionTierChances[0]) + require.Greater(t, selectionTierChances[len(selectionTierChances)-1], 0.1) + + st = NewSelectionTier() + for i := 0; i < 25; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 25; i < 50; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.2) + } + for i := 50; i < 75; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.3) + } + for i := 75; i < 100; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.4) + } + selectionTierChances = st.ShiftTierChance(numTiers, nil) + require.Equal(t, numTiers, len(selectionTierChances)) + require.Greater(t, selectionTierChances[0], selectionTierChances[1]) + require.Greater(t, selectionTierChances[1]*2, selectionTierChances[0]) // make sure the adjustment is not that strong + require.Greater(t, selectionTierChances[1], selectionTierChances[2]) + +} + func TestSelectionTierInst_SelectTierRandomly(t *testing.T) { st := NewSelectionTier() rand.InitRandomSeed()