Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: audit fixes #1672

Merged
merged 10 commits into from
Sep 25, 2024
12 changes: 12 additions & 0 deletions x/conflict/keeper/conflict.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ func (k Keeper) ValidateFinalizationConflict(ctx sdk.Context, conflictData *type
}

func (k Keeper) ValidateResponseConflict(ctx sdk.Context, conflictData *types.ResponseConflict, clientAddr sdk.AccAddress) error {
// 0. validate conflictData is not nil
if conflictData.IsDataNil() {
Yaroms marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("ValidateResponseConflict: conflict data is nil")
}

// 1. validate mismatching data
chainID := conflictData.ConflictRelayData0.Request.RelaySession.SpecId
if chainID != conflictData.ConflictRelayData1.Request.RelaySession.SpecId {
Expand Down Expand Up @@ -279,6 +284,10 @@ func (k Keeper) ValidateSameProviderConflict(ctx sdk.Context, conflictData *type

func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization, spec *spectypes.Spec) (finalizedBlocksMarshalled map[int64]string, earliestFinalizedBlock int64, latestFinalizedBlock int64, err error) {
EMPTY_MAP := map[int64]string{}
// verify spec is not nil
if spec == nil {
return EMPTY_MAP, 0, 0, fmt.Errorf("validateBlockHeights: spec is nil")
}

// Unmarshall finalized blocks
finalizedBlocks := map[int64]string{}
Expand Down Expand Up @@ -312,6 +321,9 @@ func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization,
}

func (k Keeper) validateFinalizedBlock(relayFinalization *types.RelayFinalization, latestFinalizedBlock int64, spec *spectypes.Spec) error {
if spec == nil {
return fmt.Errorf("validateFinalizedBlock: spec is nil")
}
latestBlock := relayFinalization.GetLatestBlock()
blockDistanceToFinalization := int64(spec.BlockDistanceForFinalizedData)

Expand Down
2 changes: 2 additions & 0 deletions x/conflict/keeper/msg_server_detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"golang.org/x/exp/slices"
)

// DetectionIndex creates an index for detection instances.
// WARNING: the detection index should not be used for prefixed iteration.
omerlavanet marked this conversation as resolved.
Show resolved Hide resolved
func DetectionIndex(creatorAddr string, conflict *types.ResponseConflict, epochStart uint64) string {
return creatorAddr + conflict.ConflictRelayData0.Request.RelaySession.Provider + conflict.ConflictRelayData1.Request.RelaySession.Provider + strconv.FormatUint(epochStart, 10)
}
Expand Down
24 changes: 24 additions & 0 deletions x/conflict/types/conflict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package types

func (c *ResponseConflict) IsDataNil() bool {
if c == nil {
return true
}
if c.ConflictRelayData0 == nil || c.ConflictRelayData1 == nil {
return true
}
if c.ConflictRelayData0.Request == nil || c.ConflictRelayData1.Request == nil {
return true
}
if c.ConflictRelayData0.Request.RelayData == nil || c.ConflictRelayData1.Request.RelayData == nil {
return true
}
if c.ConflictRelayData0.Request.RelaySession == nil || c.ConflictRelayData1.Request.RelaySession == nil {
return true
}
if c.ConflictRelayData0.Reply == nil || c.ConflictRelayData1.Reply == nil {
return true
}

return false
}
8 changes: 0 additions & 8 deletions x/pairing/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"

storetypes "github.com/cosmos/cosmos-sdk/store/types"
epochstoragetypes "github.com/lavanet/lava/v3/x/epochstorage/types"
timerstoretypes "github.com/lavanet/lava/v3/x/timerstore/types"

"github.com/cometbft/cometbft/libs/log"
Expand Down Expand Up @@ -35,8 +34,6 @@ type (
downtimeKeeper types.DowntimeKeeper
dualstakingKeeper types.DualstakingKeeper
stakingKeeper types.StakingKeeper

pairingQueryCache *map[string][]epochstoragetypes.StakeEntry
}
)

Expand Down Expand Up @@ -74,8 +71,6 @@ func NewKeeper(
ps = ps.WithKeyTable(types.ParamKeyTable())
}

emptypairingQueryCache := map[string][]epochstoragetypes.StakeEntry{}

keeper := &Keeper{
cdc: cdc,
storeKey: storeKey,
Expand All @@ -91,7 +86,6 @@ func NewKeeper(
downtimeKeeper: downtimeKeeper,
dualstakingKeeper: dualstakingKeeper,
stakingKeeper: stakingKeeper,
pairingQueryCache: &emptypairingQueryCache,
}

// note that the timer and badgeUsedCu keys are the same (so we can use only the second arg)
Expand All @@ -113,8 +107,6 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger {

func (k Keeper) BeginBlock(ctx sdk.Context) {
if k.epochStorageKeeper.IsEpochStart(ctx) {
// reset pairing query cache every epoch
*k.pairingQueryCache = map[string][]epochstoragetypes.StakeEntry{}
// remove old session payments
k.RemoveOldEpochPayments(ctx)
// unstake/jail unresponsive providers
Expand Down
30 changes: 5 additions & 25 deletions x/pairing/keeper/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (k Keeper) GetPairingForClient(ctx sdk.Context, chainID string, clientAddre
return nil, fmt.Errorf("invalid user for pairing: %s", err.Error())
}

providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false, true)
providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false)

return providers, err
}
Expand All @@ -90,7 +90,7 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID
totalScore := cosmosmath.ZeroUint()
providerScore := cosmosmath.ZeroUint()

_, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true, false)
_, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true)
if err != nil {
return cosmosmath.LegacyZeroDec(), err
}
Expand All @@ -117,22 +117,12 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID

// function used to get a new pairing from provider and client
// first argument has all metadata, second argument is only the addresses
// useCache is a boolean argument that is used to determine whether pairing cache should be used
// Note: useCache should only be true for queries! functions that write to the state and use this function should never put useCache=true
func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool, useCache bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) {
func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) {
epoch, providersType, err := k.VerifyPairingData(ctx, chainID, block)
if err != nil {
return nil, 0, nil, fmt.Errorf("invalid pairing data: %s", err)
}

// to be used only in queries as this changes gas calculations, and therefore must not be part of consensus
if useCache {
providers, found := k.GetPairingQueryCache(projectIndex, chainID, epoch)
if found {
return providers, policy.EpochCuLimit, nil, nil
}
}

stakeEntries := k.epochStorageKeeper.GetAllStakeEntriesForEpochChainId(ctx, epoch, chainID)
if len(stakeEntries) == 0 {
return nil, 0, nil, fmt.Errorf("did not find providers for pairing: epoch:%d, chainID: %s", block, chainID)
Expand All @@ -149,9 +139,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6
stakeEntriesFiltered = append(stakeEntriesFiltered, stakeEntries[i])
}
}
if useCache {
k.SetPairingQueryCache(projectIndex, chainID, epoch, stakeEntriesFiltered)
}
return stakeEntriesFiltered, policy.EpochCuLimit, nil, nil
}

Expand All @@ -171,9 +158,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6
for _, score := range providerScores {
filteredEntries = append(filteredEntries, *score.Provider)
}
if useCache {
k.SetPairingQueryCache(projectIndex, chainID, epoch, filteredEntries)
}
return filteredEntries, policy.EpochCuLimit, nil, nil
}

Expand All @@ -194,10 +178,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6
prevGroupSlot = group
}

if useCache {
k.SetPairingQueryCache(projectIndex, chainID, epoch, providers)
}

return providers, policy.EpochCuLimit, providerScores, nil
}

Expand Down Expand Up @@ -350,7 +330,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid
return false, allowedCU, []epochstoragetypes.StakeEntry{}, fmt.Errorf("invalid user for pairing: %s", err.Error())
}

validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false, false)
validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false)
if err != nil {
return false, allowedCU, []epochstoragetypes.StakeEntry{}, err
}
Expand All @@ -363,7 +343,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid
utils.LavaFormatPanic("critical: invalid provider address for payment", err,
utils.Attribute{Key: "chainID", Value: chainID},
utils.Attribute{Key: "client", Value: project.Subscription},
utils.Attribute{Key: "provider", Value: providerAccAddr.String()},
utils.Attribute{Key: "provider", Value: possibleAddr.Address},
utils.Attribute{Key: "epochBlock", Value: strconv.FormatUint(epoch, 10)},
)
}
Expand Down
24 changes: 0 additions & 24 deletions x/pairing/keeper/pairing_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,3 @@ func (k Keeper) ResetPairingRelayCache(ctx sdk.Context) {
store.Delete(iterator.Key())
}
}

// the cache used for the query, does not write into state
func (k Keeper) SetPairingQueryCache(project string, chainID string, epoch uint64, pairedProviders []epochstoragetypes.StakeEntry) {
if k.pairingQueryCache == nil {
// pairing cache is not initialized, will be in next epoch so simply skip
return
}
key := types.NewPairingCacheKey(project, chainID, epoch)

(*k.pairingQueryCache)[key] = pairedProviders
}

func (k Keeper) GetPairingQueryCache(project string, chainID string, epoch uint64) ([]epochstoragetypes.StakeEntry, bool) {
if k.pairingQueryCache == nil {
// pairing cache is not initialized, will be in next epoch so simply skip
return nil, false
}
key := types.NewPairingCacheKey(project, chainID, epoch)
if providers, ok := (*k.pairingQueryCache)[key]; ok {
return providers, true
}

return nil, false
}
39 changes: 1 addition & 38 deletions x/pairing/keeper/pairing_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,7 @@ import (
"github.com/stretchr/testify/require"
)

// TestPairingQueryCache tests the following:
// 1. The pairing query cache is reset every epoch
// 2. Getting pairing with a query using an existent cache entry consumes fewer gas than without one
func TestPairingQueryCache(t *testing.T) {
ts := newTester(t)
ts.setupForPayments(1, 1, 0) // 1 provider, 1 client, default providers-to-pair

_, consumer := ts.GetAccount(common.CONSUMER, 0)

getPairingGas := func(ts *tester) uint64 {
gm := ts.Ctx.GasMeter()
before := gm.GasConsumed()
_, err := ts.QueryPairingGetPairing(ts.spec.Index, consumer)
require.NoError(t, err)
return gm.GasConsumed() - before
}

// query for pairing for the first time - empty cache
emptyCacheGas := getPairingGas(ts)

// query for pairing for the second time - non-empty cache
filledCacheGas := getPairingGas(ts)

// second time gas should be smaller than first time
require.Less(t, filledCacheGas, emptyCacheGas)

// advance block to test it stays the same (should still be less than empty cache gas)
ts.AdvanceBlock()
filledAfterBlockCacheGas := getPairingGas(ts)
require.Less(t, filledAfterBlockCacheGas, emptyCacheGas)

// advance epoch to reset the cache
ts.AdvanceEpoch()
emptyCacheAgainGas := getPairingGas(ts)
require.Equal(t, emptyCacheGas, emptyCacheAgainGas)
}

// TestPairingQueryCache tests the following:
// TestPairingRelayCache tests the following:
// 1. The pairing relay cache is reset every block
// 2. Getting pairing in relay payment using an existent cache entry consumes fewer gas than without one
func TestPairingRelayCache(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion x/pairing/keeper/scores/pairing_slot.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (psg PairingSlotGroup) Subtract(other *PairingSlotGroup) *PairingSlot {
otherReq, found := other.Reqs[key]
if !found {
reqsDiff[key] = req
} else if !req.Equal(otherReq) {
} else if req != nil && !req.Equal(otherReq) {
reqsDiff[key] = req
}
}
Expand Down
3 changes: 3 additions & 0 deletions x/pairing/keeper/scores/stake_req.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ func (sr *StakeReq) Init(policy planstypes.Policy) bool {

// Score calculates the the provider score as the normalized stake
func (sr *StakeReq) Score(score PairingScore) math.Uint {
if sr == nil {
return math.OneUint()
}
effectiveStake := score.Provider.TotalStake()
if !effectiveStake.IsPositive() {
return math.OneUint()
Expand Down
5 changes: 2 additions & 3 deletions x/projects/keeper/creation.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (k Keeper) registerKey(ctx sdk.Context, key types.ProjectKey, project *type

// check that the developer key is valid, and that it does not already
// belong to a different project.
if found && devkeyData.ProjectID != project.Index {
if found && devkeyData.ProjectID != project.GetIndex() {
return utils.LavaFormatWarning("failed to register key",
fmt.Errorf("key already exists"),
utils.Attribute{Key: "key", Value: key.Key},
Expand Down Expand Up @@ -254,10 +254,9 @@ func (k Keeper) unregisterKey(ctx sdk.Context, key types.ProjectKey, project *ty
// the developer key belongs to a different project
if devkeyData.ProjectID != project.GetIndex() {
return utils.LavaFormatWarning("failed to unregister key", legacyerrors.ErrNotFound,
utils.Attribute{Key: "projectID", Value: project.Index},
utils.Attribute{Key: "projectID", Value: project.GetIndex()},
utils.Attribute{Key: "key", Value: key.Key},
utils.Attribute{Key: "keyTypes", Value: key.Kinds},
utils.Attribute{Key: "projectID", Value: project.GetIndex()},
utils.Attribute{Key: "otherID", Value: devkeyData.ProjectID},
)
}
Expand Down
3 changes: 3 additions & 0 deletions x/projects/types/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ func (project *Project) GetKey(key string) ProjectKey {
}

func (project *Project) AppendKey(key ProjectKey) bool {
if project == nil {
oren-lava marked this conversation as resolved.
Show resolved Hide resolved
return false
}
for i, projectKey := range project.ProjectKeys {
if projectKey.Key == key.Key {
project.ProjectKeys[i].Kinds |= key.Kinds
Expand Down
6 changes: 6 additions & 0 deletions x/spec/keeper/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ func (k Keeper) RefreshSpec(ctx sdk.Context, spec types.Spec, ancestors []types.
}

if details, err := spec.ValidateSpec(k.MaxCU(ctx)); err != nil {
if details != nil {
details = map[string]string{}
}
details["invalidates"] = spec.Index
attrs := utils.StringMapToAttributes(details)
return nil, utils.LavaFormatWarning("spec refresh failed (invalidate)", err, attrs...)
Expand Down Expand Up @@ -137,6 +140,9 @@ func (k Keeper) doExpandSpec(
inherit *map[string]bool,
details string,
) (string, error) {
if spec == nil {
return "", fmt.Errorf("doExpandSpec: spec is nil")
}
parentsCollections := map[types.CollectionData][]*types.ApiCollection{}

if len(spec.Imports) != 0 {
Expand Down
3 changes: 3 additions & 0 deletions x/spec/types/api_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ func (apic *ApiCollection) InheritAllFields(myCollections map[CollectionData]*Ap
// changes in place inside the apic
// nil merge maps means not to combine that field
func (apic *ApiCollection) CombineWithOthers(others []*ApiCollection, combineWithDisabled, allowOverwrite bool) (err error) {
if apic == nil {
return fmt.Errorf("CombineWithOthers: API collection is nil")
}
mergedApis := map[string]interface{}{}
mergedHeaders := map[string]interface{}{}
mergedParsers := map[string]interface{}{}
Expand Down
2 changes: 1 addition & 1 deletion x/spec/types/combinable.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func CombineUnique[T Combinable](appendFrom, appendTo []T, currentMap map[string
} else {
// overwriting the inherited field might need Overwrite actions
if overwritten, isOverwritten := current.currentCombinable.Overwrite(combinable); isOverwritten {
if appendTo[current.index].Differeniator() != combinable.Differeniator() {
if len(appendTo) <= current.index || appendTo[current.index].Differeniator() != combinable.Differeniator() {
return nil, fmt.Errorf("differentiator mismatch in overwrite %s vs %s", combinable.Differeniator(), appendTo[current.index].Differeniator())
}
overwrittenT, ok := overwritten.(T)
Expand Down
5 changes: 4 additions & 1 deletion x/spec/types/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ func (spec Spec) ValidateSpec(maxCU uint64) (map[string]string, error) {
}

func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*ApiCollection) error {
if spec == nil {
return fmt.Errorf("CombineCollections: spec is nil")
}
collectionDataList := make([]CollectionData, 0)
// Populate the keys slice with the map keys
for key := range parentsCollections {
Expand All @@ -225,7 +228,7 @@ func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*Ap
break
}
}
if !combined.Enabled {
if combined == nil || !combined.Enabled {
// no collections enabled to combine, we skip this
continue
}
Expand Down
Loading
Loading