Skip to content

Commit

Permalink
fix: audit fixes (#1672)
Browse files Browse the repository at this point in the history
* nilaway fixes

* added warning comment on detection index

* remove redundant conflict data checks

* reverted qos req and geo req turning pointers

* removed pairing query cache

* fix comment
  • Loading branch information
oren-lava authored Sep 25, 2024
1 parent ab51b73 commit a5206e7
Show file tree
Hide file tree
Showing 16 changed files with 78 additions and 101 deletions.
12 changes: 12 additions & 0 deletions x/conflict/keeper/conflict.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ func (k Keeper) ValidateFinalizationConflict(ctx sdk.Context, conflictData *type
}

func (k Keeper) ValidateResponseConflict(ctx sdk.Context, conflictData *types.ResponseConflict, clientAddr sdk.AccAddress) error {
// 0. validate conflictData is not nil
if conflictData.IsDataNil() {
return fmt.Errorf("ValidateResponseConflict: conflict data is nil")
}

// 1. validate mismatching data
chainID := conflictData.ConflictRelayData0.Request.RelaySession.SpecId
if chainID != conflictData.ConflictRelayData1.Request.RelaySession.SpecId {
Expand Down Expand Up @@ -279,6 +284,10 @@ func (k Keeper) ValidateSameProviderConflict(ctx sdk.Context, conflictData *type

func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization, spec *spectypes.Spec) (finalizedBlocksMarshalled map[int64]string, earliestFinalizedBlock int64, latestFinalizedBlock int64, err error) {
EMPTY_MAP := map[int64]string{}
// verify spec is not nil
if spec == nil {
return EMPTY_MAP, 0, 0, fmt.Errorf("validateBlockHeights: spec is nil")
}

// Unmarshall finalized blocks
finalizedBlocks := map[int64]string{}
Expand Down Expand Up @@ -312,6 +321,9 @@ func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization,
}

func (k Keeper) validateFinalizedBlock(relayFinalization *types.RelayFinalization, latestFinalizedBlock int64, spec *spectypes.Spec) error {
if spec == nil {
return fmt.Errorf("validateFinalizedBlock: spec is nil")
}
latestBlock := relayFinalization.GetLatestBlock()
blockDistanceToFinalization := int64(spec.BlockDistanceForFinalizedData)

Expand Down
3 changes: 3 additions & 0 deletions x/conflict/keeper/msg_server_detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"golang.org/x/exp/slices"
)

// DetectionIndex creates an index for detection instances.
// WARNING: the detection index should not be used for prefixed iteration since it doesn't contain delimeters
// thus it's not sanitized for such iterations and could cause issues in the future as the codebase evolves.
func DetectionIndex(creatorAddr string, conflict *types.ResponseConflict, epochStart uint64) string {
return creatorAddr + conflict.ConflictRelayData0.Request.RelaySession.Provider + conflict.ConflictRelayData1.Request.RelaySession.Provider + strconv.FormatUint(epochStart, 10)
}
Expand Down
24 changes: 24 additions & 0 deletions x/conflict/types/conflict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package types

func (c *ResponseConflict) IsDataNil() bool {
if c == nil {
return true
}
if c.ConflictRelayData0 == nil || c.ConflictRelayData1 == nil {
return true
}
if c.ConflictRelayData0.Request == nil || c.ConflictRelayData1.Request == nil {
return true
}
if c.ConflictRelayData0.Request.RelayData == nil || c.ConflictRelayData1.Request.RelayData == nil {
return true
}
if c.ConflictRelayData0.Request.RelaySession == nil || c.ConflictRelayData1.Request.RelaySession == nil {
return true
}
if c.ConflictRelayData0.Reply == nil || c.ConflictRelayData1.Reply == nil {
return true
}

return false
}
8 changes: 0 additions & 8 deletions x/pairing/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"

storetypes "github.com/cosmos/cosmos-sdk/store/types"
epochstoragetypes "github.com/lavanet/lava/v3/x/epochstorage/types"
timerstoretypes "github.com/lavanet/lava/v3/x/timerstore/types"

"github.com/cometbft/cometbft/libs/log"
Expand Down Expand Up @@ -35,8 +34,6 @@ type (
downtimeKeeper types.DowntimeKeeper
dualstakingKeeper types.DualstakingKeeper
stakingKeeper types.StakingKeeper

pairingQueryCache *map[string][]epochstoragetypes.StakeEntry
}
)

Expand Down Expand Up @@ -74,8 +71,6 @@ func NewKeeper(
ps = ps.WithKeyTable(types.ParamKeyTable())
}

emptypairingQueryCache := map[string][]epochstoragetypes.StakeEntry{}

keeper := &Keeper{
cdc: cdc,
storeKey: storeKey,
Expand All @@ -91,7 +86,6 @@ func NewKeeper(
downtimeKeeper: downtimeKeeper,
dualstakingKeeper: dualstakingKeeper,
stakingKeeper: stakingKeeper,
pairingQueryCache: &emptypairingQueryCache,
}

// note that the timer and badgeUsedCu keys are the same (so we can use only the second arg)
Expand All @@ -113,8 +107,6 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger {

func (k Keeper) BeginBlock(ctx sdk.Context) {
if k.epochStorageKeeper.IsEpochStart(ctx) {
// reset pairing query cache every epoch
*k.pairingQueryCache = map[string][]epochstoragetypes.StakeEntry{}
// remove old session payments
k.RemoveOldEpochPayments(ctx)
// unstake/jail unresponsive providers
Expand Down
30 changes: 5 additions & 25 deletions x/pairing/keeper/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (k Keeper) GetPairingForClient(ctx sdk.Context, chainID string, clientAddre
return nil, fmt.Errorf("invalid user for pairing: %s", err.Error())
}

providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false, true)
providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false)

return providers, err
}
Expand All @@ -90,7 +90,7 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID
totalScore := cosmosmath.ZeroUint()
providerScore := cosmosmath.ZeroUint()

_, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true, false)
_, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true)
if err != nil {
return cosmosmath.LegacyZeroDec(), err
}
Expand All @@ -117,22 +117,12 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID

// function used to get a new pairing from provider and client
// first argument has all metadata, second argument is only the addresses
// useCache is a boolean argument that is used to determine whether pairing cache should be used
// Note: useCache should only be true for queries! functions that write to the state and use this function should never put useCache=true
func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool, useCache bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) {
func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) {
epoch, providersType, err := k.VerifyPairingData(ctx, chainID, block)
if err != nil {
return nil, 0, nil, fmt.Errorf("invalid pairing data: %s", err)
}

// to be used only in queries as this changes gas calculations, and therefore must not be part of consensus
if useCache {
providers, found := k.GetPairingQueryCache(projectIndex, chainID, epoch)
if found {
return providers, policy.EpochCuLimit, nil, nil
}
}

stakeEntries := k.epochStorageKeeper.GetAllStakeEntriesForEpochChainId(ctx, epoch, chainID)
if len(stakeEntries) == 0 {
return nil, 0, nil, fmt.Errorf("did not find providers for pairing: epoch:%d, chainID: %s", block, chainID)
Expand All @@ -149,9 +139,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6
stakeEntriesFiltered = append(stakeEntriesFiltered, stakeEntries[i])
}
}
if useCache {
k.SetPairingQueryCache(projectIndex, chainID, epoch, stakeEntriesFiltered)
}
return stakeEntriesFiltered, policy.EpochCuLimit, nil, nil
}

Expand All @@ -171,9 +158,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6
for _, score := range providerScores {
filteredEntries = append(filteredEntries, *score.Provider)
}
if useCache {
k.SetPairingQueryCache(projectIndex, chainID, epoch, filteredEntries)
}
return filteredEntries, policy.EpochCuLimit, nil, nil
}

Expand All @@ -194,10 +178,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6
prevGroupSlot = group
}

if useCache {
k.SetPairingQueryCache(projectIndex, chainID, epoch, providers)
}

return providers, policy.EpochCuLimit, providerScores, nil
}

Expand Down Expand Up @@ -350,7 +330,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid
return false, allowedCU, []epochstoragetypes.StakeEntry{}, fmt.Errorf("invalid user for pairing: %s", err.Error())
}

validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false, false)
validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false)
if err != nil {
return false, allowedCU, []epochstoragetypes.StakeEntry{}, err
}
Expand All @@ -363,7 +343,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid
utils.LavaFormatPanic("critical: invalid provider address for payment", err,
utils.Attribute{Key: "chainID", Value: chainID},
utils.Attribute{Key: "client", Value: project.Subscription},
utils.Attribute{Key: "provider", Value: providerAccAddr.String()},
utils.Attribute{Key: "provider", Value: possibleAddr.Address},
utils.Attribute{Key: "epochBlock", Value: strconv.FormatUint(epoch, 10)},
)
}
Expand Down
24 changes: 0 additions & 24 deletions x/pairing/keeper/pairing_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,3 @@ func (k Keeper) ResetPairingRelayCache(ctx sdk.Context) {
store.Delete(iterator.Key())
}
}

// the cache used for the query, does not write into state
func (k Keeper) SetPairingQueryCache(project string, chainID string, epoch uint64, pairedProviders []epochstoragetypes.StakeEntry) {
if k.pairingQueryCache == nil {
// pairing cache is not initialized, will be in next epoch so simply skip
return
}
key := types.NewPairingCacheKey(project, chainID, epoch)

(*k.pairingQueryCache)[key] = pairedProviders
}

func (k Keeper) GetPairingQueryCache(project string, chainID string, epoch uint64) ([]epochstoragetypes.StakeEntry, bool) {
if k.pairingQueryCache == nil {
// pairing cache is not initialized, will be in next epoch so simply skip
return nil, false
}
key := types.NewPairingCacheKey(project, chainID, epoch)
if providers, ok := (*k.pairingQueryCache)[key]; ok {
return providers, true
}

return nil, false
}
39 changes: 1 addition & 38 deletions x/pairing/keeper/pairing_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,7 @@ import (
"github.com/stretchr/testify/require"
)

// TestPairingQueryCache tests the following:
// 1. The pairing query cache is reset every epoch
// 2. Getting pairing with a query using an existent cache entry consumes fewer gas than without one
func TestPairingQueryCache(t *testing.T) {
ts := newTester(t)
ts.setupForPayments(1, 1, 0) // 1 provider, 1 client, default providers-to-pair

_, consumer := ts.GetAccount(common.CONSUMER, 0)

getPairingGas := func(ts *tester) uint64 {
gm := ts.Ctx.GasMeter()
before := gm.GasConsumed()
_, err := ts.QueryPairingGetPairing(ts.spec.Index, consumer)
require.NoError(t, err)
return gm.GasConsumed() - before
}

// query for pairing for the first time - empty cache
emptyCacheGas := getPairingGas(ts)

// query for pairing for the second time - non-empty cache
filledCacheGas := getPairingGas(ts)

// second time gas should be smaller than first time
require.Less(t, filledCacheGas, emptyCacheGas)

// advance block to test it stays the same (should still be less than empty cache gas)
ts.AdvanceBlock()
filledAfterBlockCacheGas := getPairingGas(ts)
require.Less(t, filledAfterBlockCacheGas, emptyCacheGas)

// advance epoch to reset the cache
ts.AdvanceEpoch()
emptyCacheAgainGas := getPairingGas(ts)
require.Equal(t, emptyCacheGas, emptyCacheAgainGas)
}

// TestPairingQueryCache tests the following:
// TestPairingRelayCache tests the following:
// 1. The pairing relay cache is reset every block
// 2. Getting pairing in relay payment using an existent cache entry consumes fewer gas than without one
func TestPairingRelayCache(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion x/pairing/keeper/scores/pairing_slot.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (psg PairingSlotGroup) Subtract(other *PairingSlotGroup) *PairingSlot {
otherReq, found := other.Reqs[key]
if !found {
reqsDiff[key] = req
} else if !req.Equal(otherReq) {
} else if req != nil && !req.Equal(otherReq) {
reqsDiff[key] = req
}
}
Expand Down
3 changes: 3 additions & 0 deletions x/pairing/keeper/scores/stake_req.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ func (sr *StakeReq) Init(policy planstypes.Policy) bool {

// Score calculates the the provider score as the normalized stake
func (sr *StakeReq) Score(score PairingScore) math.Uint {
if sr == nil {
return math.OneUint()
}
effectiveStake := score.Provider.TotalStake()
if !effectiveStake.IsPositive() {
return math.OneUint()
Expand Down
5 changes: 2 additions & 3 deletions x/projects/keeper/creation.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (k Keeper) registerKey(ctx sdk.Context, key types.ProjectKey, project *type

// check that the developer key is valid, and that it does not already
// belong to a different project.
if found && devkeyData.ProjectID != project.Index {
if found && devkeyData.ProjectID != project.GetIndex() {
return utils.LavaFormatWarning("failed to register key",
fmt.Errorf("key already exists"),
utils.Attribute{Key: "key", Value: key.Key},
Expand Down Expand Up @@ -254,10 +254,9 @@ func (k Keeper) unregisterKey(ctx sdk.Context, key types.ProjectKey, project *ty
// the developer key belongs to a different project
if devkeyData.ProjectID != project.GetIndex() {
return utils.LavaFormatWarning("failed to unregister key", legacyerrors.ErrNotFound,
utils.Attribute{Key: "projectID", Value: project.Index},
utils.Attribute{Key: "projectID", Value: project.GetIndex()},
utils.Attribute{Key: "key", Value: key.Key},
utils.Attribute{Key: "keyTypes", Value: key.Kinds},
utils.Attribute{Key: "projectID", Value: project.GetIndex()},
utils.Attribute{Key: "otherID", Value: devkeyData.ProjectID},
)
}
Expand Down
3 changes: 3 additions & 0 deletions x/projects/types/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ func (project *Project) GetKey(key string) ProjectKey {
}

func (project *Project) AppendKey(key ProjectKey) bool {
if project == nil {
return false
}
for i, projectKey := range project.ProjectKeys {
if projectKey.Key == key.Key {
project.ProjectKeys[i].Kinds |= key.Kinds
Expand Down
6 changes: 6 additions & 0 deletions x/spec/keeper/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ func (k Keeper) RefreshSpec(ctx sdk.Context, spec types.Spec, ancestors []types.
}

if details, err := spec.ValidateSpec(k.MaxCU(ctx)); err != nil {
if details != nil {
details = map[string]string{}
}
details["invalidates"] = spec.Index
attrs := utils.StringMapToAttributes(details)
return nil, utils.LavaFormatWarning("spec refresh failed (invalidate)", err, attrs...)
Expand Down Expand Up @@ -137,6 +140,9 @@ func (k Keeper) doExpandSpec(
inherit *map[string]bool,
details string,
) (string, error) {
if spec == nil {
return "", fmt.Errorf("doExpandSpec: spec is nil")
}
parentsCollections := map[types.CollectionData][]*types.ApiCollection{}

if len(spec.Imports) != 0 {
Expand Down
3 changes: 3 additions & 0 deletions x/spec/types/api_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ func (apic *ApiCollection) InheritAllFields(myCollections map[CollectionData]*Ap
// changes in place inside the apic
// nil merge maps means not to combine that field
func (apic *ApiCollection) CombineWithOthers(others []*ApiCollection, combineWithDisabled, allowOverwrite bool) (err error) {
if apic == nil {
return fmt.Errorf("CombineWithOthers: API collection is nil")
}
mergedApis := map[string]interface{}{}
mergedHeaders := map[string]interface{}{}
mergedParsers := map[string]interface{}{}
Expand Down
2 changes: 1 addition & 1 deletion x/spec/types/combinable.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func CombineUnique[T Combinable](appendFrom, appendTo []T, currentMap map[string
} else {
// overwriting the inherited field might need Overwrite actions
if overwritten, isOverwritten := current.currentCombinable.Overwrite(combinable); isOverwritten {
if appendTo[current.index].Differeniator() != combinable.Differeniator() {
if len(appendTo) <= current.index || appendTo[current.index].Differeniator() != combinable.Differeniator() {
return nil, fmt.Errorf("differentiator mismatch in overwrite %s vs %s", combinable.Differeniator(), appendTo[current.index].Differeniator())
}
overwrittenT, ok := overwritten.(T)
Expand Down
5 changes: 4 additions & 1 deletion x/spec/types/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ func (spec Spec) ValidateSpec(maxCU uint64) (map[string]string, error) {
}

func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*ApiCollection) error {
if spec == nil {
return fmt.Errorf("CombineCollections: spec is nil")
}
collectionDataList := make([]CollectionData, 0)
// Populate the keys slice with the map keys
for key := range parentsCollections {
Expand All @@ -225,7 +228,7 @@ func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*Ap
break
}
}
if !combined.Enabled {
if combined == nil || !combined.Enabled {
// no collections enabled to combine, we skip this
continue
}
Expand Down
Loading

0 comments on commit a5206e7

Please sign in to comment.