Skip to content

Commit

Permalink
added shift tier chance to selection tier
Browse files Browse the repository at this point in the history
  • Loading branch information
omerlavanet committed Sep 9, 2024
1 parent f9d5846 commit 0e059af
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 7 deletions.
71 changes: 64 additions & 7 deletions protocol/provideroptimizer/selection_tier.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package provideroptimizer

import (
"github.com/lavanet/lava/v3/utils"
"github.com/lavanet/lava/v3/utils/rand"
)

Expand All @@ -15,6 +16,7 @@ type SelectionTier interface {
AddScore(entry string, score float64)
GetTier(tier int, numTiers int, minimumEntries int) []Entry
SelectTierRandomly(numTiers int, tierChances map[int]float64) int
ShiftTierChance(numTiers int, initialYierChances map[int]float64) map[int]float64
}

type SelectionTierInst struct {
Expand Down Expand Up @@ -49,15 +51,12 @@ func (st *SelectionTierInst) AddScore(entry string, score float64) {
func (st *SelectionTierInst) SelectTierRandomly(numTiers int, tierChances map[int]float64) int {
// select a tier randomly based on the chances given
// if the chances are not given, select a tier randomly based on the number of tiers
if len(tierChances) == 0 {
if len(tierChances) == 0 || len(tierChances) > numTiers {
utils.LavaFormatError("Invalid tier chances usage", nil)
return rand.Intn(numTiers)
}
// calculate the total chance
totalChance := 0.0
for _, chance := range tierChances {
totalChance += chance
}
chanceForDefaultTiers := (1 - totalChance) / float64(numTiers-len(tierChances))
chanceForDefaultTiers := st.calcChanceForDefaultTiers(tierChances, numTiers)
// select a random number between 0 and 1
randChance := rand.Float64()
// find the tier that the random chance falls into
Expand All @@ -76,6 +75,63 @@ func (st *SelectionTierInst) SelectTierRandomly(numTiers int, tierChances map[in
return 0
}

func (*SelectionTierInst) calcChanceForDefaultTiers(tierChances map[int]float64, numTiers int) float64 {
totalChance := 0.0
for _, chance := range tierChances {
totalChance += chance
}
chanceForDefaultTiers := (1 - totalChance) / float64(numTiers-len(tierChances))
return chanceForDefaultTiers
}

func (st *SelectionTierInst) averageScoreForTier(tier int, numTiers int) float64 {
// calculate the average score for the given tier and number of tiers
start, end, _, _ := getPositionsForTier(tier, numTiers, len(st.scores))
sum := 0.0
parts := 0.0
for i := start; i < end; i++ {
sum += st.scores[i].Score * st.scores[i].Part
parts += st.scores[i].Part
}
return sum / parts
}

func (st *SelectionTierInst) ShiftTierChance(numTiers int, initialTierChances map[int]float64) map[int]float64 {
chanceForDefaultTiers := st.calcChanceForDefaultTiers(initialTierChances, numTiers)

// shift the chances
shiftedTierChances := make(map[int]float64)
// shift tier chances based on the difference in the average score of each tier
reversedScores := make([]float64, numTiers)
totalScore := 0.0
for i := 0; i < numTiers; i++ {
reversedScores[i] = 1 / (st.averageScoreForTier(i, numTiers) + 0.0001) // add epsilon to avoid 0
totalScore += reversedScores[i]
}
averageChance := 1 / float64(numTiers)
for i := 0; i < numTiers; i++ {
offsetFactor := reversedScores[i] / totalScore
if _, ok := initialTierChances[i]; !ok {
if chanceForDefaultTiers > 0 {
shiftedTierChances[i] = chanceForDefaultTiers + averageChance*offsetFactor
}
} else {
if initialTierChances[i] > 0 {
shiftedTierChances[i] = initialTierChances[i] + averageChance*offsetFactor
}
}
}
// normalize the chances
totalChance := 0.0
for _, chance := range shiftedTierChances {
totalChance += chance
}
for i := 0; i < numTiers; i++ {
shiftedTierChances[i] /= totalChance
}
return shiftedTierChances
}

func (st *SelectionTierInst) GetTier(tier int, numTiers int, minimumEntries int) []Entry {
// get the tier of scores for the given tier and number of tiers
entriesLen := len(st.scores)
Expand All @@ -85,7 +141,8 @@ func (st *SelectionTierInst) GetTier(tier int, numTiers int, minimumEntries int)

start, end, fracStart, fracEnd := getPositionsForTier(tier, numTiers, entriesLen)
if end < minimumEntries {
return st.scores[:minimumEntries]
// only allow better tiers if there are not enough entries
return st.scores[:end]
}
ret := st.scores[start:end]
if len(ret) >= minimumEntries {
Expand Down
50 changes: 50 additions & 0 deletions protocol/provideroptimizer/selection_tier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,56 @@ func TestSelectionTierInstGetTierBig(t *testing.T) {
}

Check failure on line 174 in protocol/provideroptimizer/selection_tier_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `gofumpt`-ed (gofumpt)
}

Check failure on line 175 in protocol/provideroptimizer/selection_tier_test.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)

func TestSelectionTierInstShiftTierChanceEquals(t *testing.T) {
st := NewSelectionTier()
numTiers := 4
for i := 0; i < 25; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.1)
}
for i := 25; i < 50; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.1)
}
for i := 50; i < 75; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.1)
}
for i := 75; i < 100; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.1)
}
selectionTierChances := st.ShiftTierChance(numTiers, nil)
require.Equal(t, numTiers, len(selectionTierChances))
require.Equal(t, selectionTierChances[0], selectionTierChances[1])

selectionTierChances = st.ShiftTierChance(numTiers, map[int]float64{0: 0.5, 1: 0.5})
require.Equal(t, 0.0, selectionTierChances[len(selectionTierChances)-1])
require.Equal(t, 0.5, selectionTierChances[0])

selectionTierChances = st.ShiftTierChance(numTiers, map[int]float64{0: 0.5, len(selectionTierChances) - 1: 0.1})
require.Less(t, 0.5, selectionTierChances[0])
require.Greater(t, 0.25, selectionTierChances[0])
require.Greater(t, selectionTierChances[len(selectionTierChances)-1], 0.1)

st = NewSelectionTier()
for i := 0; i < 25; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.1)
}
for i := 25; i < 50; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.2)
}
for i := 50; i < 75; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.3)
}
for i := 75; i < 100; i++ {
st.AddScore("entry"+strconv.Itoa(i), 0.4)
}
selectionTierChances = st.ShiftTierChance(numTiers, nil)
require.Equal(t, numTiers, len(selectionTierChances))
require.Greater(t, selectionTierChances[0], selectionTierChances[1])
require.Greater(t, selectionTierChances[1]*2, selectionTierChances[0]) // make sure the adjustment is not that strong
require.Greater(t, selectionTierChances[1], selectionTierChances[2])

Check failure on line 223 in protocol/provideroptimizer/selection_tier_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `gofumpt`-ed (gofumpt)
}

Check failure on line 224 in protocol/provideroptimizer/selection_tier_test.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)

func TestSelectionTierInst_SelectTierRandomly(t *testing.T) {
st := NewSelectionTier()
rand.InitRandomSeed()
Expand Down

0 comments on commit 0e059af

Please sign in to comment.