From a5206e752de5fb558943caf12f7840c4e0236e3e Mon Sep 17 00:00:00 2001 From: oren-lava <111131399+oren-lava@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:27:33 +0300 Subject: [PATCH] fix: audit fixes (#1672) * nilaway fixes * added warning comment on detection index * remove redundant conflict data checks * reverted qos req and geo req turning pointers * removed pairing query cache * fix comment --- x/conflict/keeper/conflict.go | 12 +++++++ x/conflict/keeper/msg_server_detection.go | 3 ++ x/conflict/types/conflict.go | 24 ++++++++++++++ x/pairing/keeper/keeper.go | 8 ----- x/pairing/keeper/pairing.go | 30 +++-------------- x/pairing/keeper/pairing_cache.go | 24 -------------- x/pairing/keeper/pairing_cache_test.go | 39 +---------------------- x/pairing/keeper/scores/pairing_slot.go | 2 +- x/pairing/keeper/scores/stake_req.go | 3 ++ x/projects/keeper/creation.go | 5 ++- x/projects/types/project.go | 3 ++ x/spec/keeper/spec.go | 6 ++++ x/spec/types/api_collection.go | 3 ++ x/spec/types/combinable.go | 2 +- x/spec/types/spec.go | 5 ++- x/subscription/keeper/subscription.go | 10 ++++++ 16 files changed, 78 insertions(+), 101 deletions(-) create mode 100644 x/conflict/types/conflict.go diff --git a/x/conflict/keeper/conflict.go b/x/conflict/keeper/conflict.go index a3e6e31a1a..b1f5c19ae9 100644 --- a/x/conflict/keeper/conflict.go +++ b/x/conflict/keeper/conflict.go @@ -20,6 +20,11 @@ func (k Keeper) ValidateFinalizationConflict(ctx sdk.Context, conflictData *type } func (k Keeper) ValidateResponseConflict(ctx sdk.Context, conflictData *types.ResponseConflict, clientAddr sdk.AccAddress) error { + // 0. validate conflictData is not nil + if conflictData.IsDataNil() { + return fmt.Errorf("ValidateResponseConflict: conflict data is nil") + } + // 1. validate mismatching data chainID := conflictData.ConflictRelayData0.Request.RelaySession.SpecId if chainID != conflictData.ConflictRelayData1.Request.RelaySession.SpecId { @@ -279,6 +284,10 @@ func (k Keeper) ValidateSameProviderConflict(ctx sdk.Context, conflictData *type func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization, spec *spectypes.Spec) (finalizedBlocksMarshalled map[int64]string, earliestFinalizedBlock int64, latestFinalizedBlock int64, err error) { EMPTY_MAP := map[int64]string{} + // verify spec is not nil + if spec == nil { + return EMPTY_MAP, 0, 0, fmt.Errorf("validateBlockHeights: spec is nil") + } // Unmarshall finalized blocks finalizedBlocks := map[int64]string{} @@ -312,6 +321,9 @@ func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization, } func (k Keeper) validateFinalizedBlock(relayFinalization *types.RelayFinalization, latestFinalizedBlock int64, spec *spectypes.Spec) error { + if spec == nil { + return fmt.Errorf("validateFinalizedBlock: spec is nil") + } latestBlock := relayFinalization.GetLatestBlock() blockDistanceToFinalization := int64(spec.BlockDistanceForFinalizedData) diff --git a/x/conflict/keeper/msg_server_detection.go b/x/conflict/keeper/msg_server_detection.go index e9d262062d..5a2dbd3727 100644 --- a/x/conflict/keeper/msg_server_detection.go +++ b/x/conflict/keeper/msg_server_detection.go @@ -14,6 +14,9 @@ import ( "golang.org/x/exp/slices" ) +// DetectionIndex creates an index for detection instances. +// WARNING: the detection index should not be used for prefixed iteration since it doesn't contain delimeters +// thus it's not sanitized for such iterations and could cause issues in the future as the codebase evolves. func DetectionIndex(creatorAddr string, conflict *types.ResponseConflict, epochStart uint64) string { return creatorAddr + conflict.ConflictRelayData0.Request.RelaySession.Provider + conflict.ConflictRelayData1.Request.RelaySession.Provider + strconv.FormatUint(epochStart, 10) } diff --git a/x/conflict/types/conflict.go b/x/conflict/types/conflict.go new file mode 100644 index 0000000000..3b9a7c72bb --- /dev/null +++ b/x/conflict/types/conflict.go @@ -0,0 +1,24 @@ +package types + +func (c *ResponseConflict) IsDataNil() bool { + if c == nil { + return true + } + if c.ConflictRelayData0 == nil || c.ConflictRelayData1 == nil { + return true + } + if c.ConflictRelayData0.Request == nil || c.ConflictRelayData1.Request == nil { + return true + } + if c.ConflictRelayData0.Request.RelayData == nil || c.ConflictRelayData1.Request.RelayData == nil { + return true + } + if c.ConflictRelayData0.Request.RelaySession == nil || c.ConflictRelayData1.Request.RelaySession == nil { + return true + } + if c.ConflictRelayData0.Reply == nil || c.ConflictRelayData1.Reply == nil { + return true + } + + return false +} diff --git a/x/pairing/keeper/keeper.go b/x/pairing/keeper/keeper.go index 5451e3da5d..88642b77c0 100644 --- a/x/pairing/keeper/keeper.go +++ b/x/pairing/keeper/keeper.go @@ -4,7 +4,6 @@ import ( "fmt" storetypes "github.com/cosmos/cosmos-sdk/store/types" - epochstoragetypes "github.com/lavanet/lava/v3/x/epochstorage/types" timerstoretypes "github.com/lavanet/lava/v3/x/timerstore/types" "github.com/cometbft/cometbft/libs/log" @@ -35,8 +34,6 @@ type ( downtimeKeeper types.DowntimeKeeper dualstakingKeeper types.DualstakingKeeper stakingKeeper types.StakingKeeper - - pairingQueryCache *map[string][]epochstoragetypes.StakeEntry } ) @@ -74,8 +71,6 @@ func NewKeeper( ps = ps.WithKeyTable(types.ParamKeyTable()) } - emptypairingQueryCache := map[string][]epochstoragetypes.StakeEntry{} - keeper := &Keeper{ cdc: cdc, storeKey: storeKey, @@ -91,7 +86,6 @@ func NewKeeper( downtimeKeeper: downtimeKeeper, dualstakingKeeper: dualstakingKeeper, stakingKeeper: stakingKeeper, - pairingQueryCache: &emptypairingQueryCache, } // note that the timer and badgeUsedCu keys are the same (so we can use only the second arg) @@ -113,8 +107,6 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger { func (k Keeper) BeginBlock(ctx sdk.Context) { if k.epochStorageKeeper.IsEpochStart(ctx) { - // reset pairing query cache every epoch - *k.pairingQueryCache = map[string][]epochstoragetypes.StakeEntry{} // remove old session payments k.RemoveOldEpochPayments(ctx) // unstake/jail unresponsive providers diff --git a/x/pairing/keeper/pairing.go b/x/pairing/keeper/pairing.go index 023eae047d..638c01648a 100644 --- a/x/pairing/keeper/pairing.go +++ b/x/pairing/keeper/pairing.go @@ -80,7 +80,7 @@ func (k Keeper) GetPairingForClient(ctx sdk.Context, chainID string, clientAddre return nil, fmt.Errorf("invalid user for pairing: %s", err.Error()) } - providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false, true) + providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false) return providers, err } @@ -90,7 +90,7 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID totalScore := cosmosmath.ZeroUint() providerScore := cosmosmath.ZeroUint() - _, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true, false) + _, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true) if err != nil { return cosmosmath.LegacyZeroDec(), err } @@ -117,22 +117,12 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID // function used to get a new pairing from provider and client // first argument has all metadata, second argument is only the addresses -// useCache is a boolean argument that is used to determine whether pairing cache should be used -// Note: useCache should only be true for queries! functions that write to the state and use this function should never put useCache=true -func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool, useCache bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) { +func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) { epoch, providersType, err := k.VerifyPairingData(ctx, chainID, block) if err != nil { return nil, 0, nil, fmt.Errorf("invalid pairing data: %s", err) } - // to be used only in queries as this changes gas calculations, and therefore must not be part of consensus - if useCache { - providers, found := k.GetPairingQueryCache(projectIndex, chainID, epoch) - if found { - return providers, policy.EpochCuLimit, nil, nil - } - } - stakeEntries := k.epochStorageKeeper.GetAllStakeEntriesForEpochChainId(ctx, epoch, chainID) if len(stakeEntries) == 0 { return nil, 0, nil, fmt.Errorf("did not find providers for pairing: epoch:%d, chainID: %s", block, chainID) @@ -149,9 +139,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6 stakeEntriesFiltered = append(stakeEntriesFiltered, stakeEntries[i]) } } - if useCache { - k.SetPairingQueryCache(projectIndex, chainID, epoch, stakeEntriesFiltered) - } return stakeEntriesFiltered, policy.EpochCuLimit, nil, nil } @@ -171,9 +158,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6 for _, score := range providerScores { filteredEntries = append(filteredEntries, *score.Provider) } - if useCache { - k.SetPairingQueryCache(projectIndex, chainID, epoch, filteredEntries) - } return filteredEntries, policy.EpochCuLimit, nil, nil } @@ -194,10 +178,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6 prevGroupSlot = group } - if useCache { - k.SetPairingQueryCache(projectIndex, chainID, epoch, providers) - } - return providers, policy.EpochCuLimit, providerScores, nil } @@ -350,7 +330,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid return false, allowedCU, []epochstoragetypes.StakeEntry{}, fmt.Errorf("invalid user for pairing: %s", err.Error()) } - validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false, false) + validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false) if err != nil { return false, allowedCU, []epochstoragetypes.StakeEntry{}, err } @@ -363,7 +343,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid utils.LavaFormatPanic("critical: invalid provider address for payment", err, utils.Attribute{Key: "chainID", Value: chainID}, utils.Attribute{Key: "client", Value: project.Subscription}, - utils.Attribute{Key: "provider", Value: providerAccAddr.String()}, + utils.Attribute{Key: "provider", Value: possibleAddr.Address}, utils.Attribute{Key: "epochBlock", Value: strconv.FormatUint(epoch, 10)}, ) } diff --git a/x/pairing/keeper/pairing_cache.go b/x/pairing/keeper/pairing_cache.go index 5cad220cd6..a3652babc8 100644 --- a/x/pairing/keeper/pairing_cache.go +++ b/x/pairing/keeper/pairing_cache.go @@ -38,27 +38,3 @@ func (k Keeper) ResetPairingRelayCache(ctx sdk.Context) { store.Delete(iterator.Key()) } } - -// the cache used for the query, does not write into state -func (k Keeper) SetPairingQueryCache(project string, chainID string, epoch uint64, pairedProviders []epochstoragetypes.StakeEntry) { - if k.pairingQueryCache == nil { - // pairing cache is not initialized, will be in next epoch so simply skip - return - } - key := types.NewPairingCacheKey(project, chainID, epoch) - - (*k.pairingQueryCache)[key] = pairedProviders -} - -func (k Keeper) GetPairingQueryCache(project string, chainID string, epoch uint64) ([]epochstoragetypes.StakeEntry, bool) { - if k.pairingQueryCache == nil { - // pairing cache is not initialized, will be in next epoch so simply skip - return nil, false - } - key := types.NewPairingCacheKey(project, chainID, epoch) - if providers, ok := (*k.pairingQueryCache)[key]; ok { - return providers, true - } - - return nil, false -} diff --git a/x/pairing/keeper/pairing_cache_test.go b/x/pairing/keeper/pairing_cache_test.go index 31b6669192..fb1f0829ef 100644 --- a/x/pairing/keeper/pairing_cache_test.go +++ b/x/pairing/keeper/pairing_cache_test.go @@ -7,44 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -// TestPairingQueryCache tests the following: -// 1. The pairing query cache is reset every epoch -// 2. Getting pairing with a query using an existent cache entry consumes fewer gas than without one -func TestPairingQueryCache(t *testing.T) { - ts := newTester(t) - ts.setupForPayments(1, 1, 0) // 1 provider, 1 client, default providers-to-pair - - _, consumer := ts.GetAccount(common.CONSUMER, 0) - - getPairingGas := func(ts *tester) uint64 { - gm := ts.Ctx.GasMeter() - before := gm.GasConsumed() - _, err := ts.QueryPairingGetPairing(ts.spec.Index, consumer) - require.NoError(t, err) - return gm.GasConsumed() - before - } - - // query for pairing for the first time - empty cache - emptyCacheGas := getPairingGas(ts) - - // query for pairing for the second time - non-empty cache - filledCacheGas := getPairingGas(ts) - - // second time gas should be smaller than first time - require.Less(t, filledCacheGas, emptyCacheGas) - - // advance block to test it stays the same (should still be less than empty cache gas) - ts.AdvanceBlock() - filledAfterBlockCacheGas := getPairingGas(ts) - require.Less(t, filledAfterBlockCacheGas, emptyCacheGas) - - // advance epoch to reset the cache - ts.AdvanceEpoch() - emptyCacheAgainGas := getPairingGas(ts) - require.Equal(t, emptyCacheGas, emptyCacheAgainGas) -} - -// TestPairingQueryCache tests the following: +// TestPairingRelayCache tests the following: // 1. The pairing relay cache is reset every block // 2. Getting pairing in relay payment using an existent cache entry consumes fewer gas than without one func TestPairingRelayCache(t *testing.T) { diff --git a/x/pairing/keeper/scores/pairing_slot.go b/x/pairing/keeper/scores/pairing_slot.go index 5d6935f56e..fe13eaff63 100644 --- a/x/pairing/keeper/scores/pairing_slot.go +++ b/x/pairing/keeper/scores/pairing_slot.go @@ -36,7 +36,7 @@ func (psg PairingSlotGroup) Subtract(other *PairingSlotGroup) *PairingSlot { otherReq, found := other.Reqs[key] if !found { reqsDiff[key] = req - } else if !req.Equal(otherReq) { + } else if req != nil && !req.Equal(otherReq) { reqsDiff[key] = req } } diff --git a/x/pairing/keeper/scores/stake_req.go b/x/pairing/keeper/scores/stake_req.go index 44eb3ff5c0..74af1713fb 100644 --- a/x/pairing/keeper/scores/stake_req.go +++ b/x/pairing/keeper/scores/stake_req.go @@ -17,6 +17,9 @@ func (sr *StakeReq) Init(policy planstypes.Policy) bool { // Score calculates the the provider score as the normalized stake func (sr *StakeReq) Score(score PairingScore) math.Uint { + if sr == nil { + return math.OneUint() + } effectiveStake := score.Provider.TotalStake() if !effectiveStake.IsPositive() { return math.OneUint() diff --git a/x/projects/keeper/creation.go b/x/projects/keeper/creation.go index 423aed7227..a3f1c75741 100644 --- a/x/projects/keeper/creation.go +++ b/x/projects/keeper/creation.go @@ -191,7 +191,7 @@ func (k Keeper) registerKey(ctx sdk.Context, key types.ProjectKey, project *type // check that the developer key is valid, and that it does not already // belong to a different project. - if found && devkeyData.ProjectID != project.Index { + if found && devkeyData.ProjectID != project.GetIndex() { return utils.LavaFormatWarning("failed to register key", fmt.Errorf("key already exists"), utils.Attribute{Key: "key", Value: key.Key}, @@ -254,10 +254,9 @@ func (k Keeper) unregisterKey(ctx sdk.Context, key types.ProjectKey, project *ty // the developer key belongs to a different project if devkeyData.ProjectID != project.GetIndex() { return utils.LavaFormatWarning("failed to unregister key", legacyerrors.ErrNotFound, - utils.Attribute{Key: "projectID", Value: project.Index}, + utils.Attribute{Key: "projectID", Value: project.GetIndex()}, utils.Attribute{Key: "key", Value: key.Key}, utils.Attribute{Key: "keyTypes", Value: key.Kinds}, - utils.Attribute{Key: "projectID", Value: project.GetIndex()}, utils.Attribute{Key: "otherID", Value: devkeyData.ProjectID}, ) } diff --git a/x/projects/types/project.go b/x/projects/types/project.go index d1f7f5e3a2..9289e67257 100644 --- a/x/projects/types/project.go +++ b/x/projects/types/project.go @@ -82,6 +82,9 @@ func (project *Project) GetKey(key string) ProjectKey { } func (project *Project) AppendKey(key ProjectKey) bool { + if project == nil { + return false + } for i, projectKey := range project.ProjectKeys { if projectKey.Key == key.Key { project.ProjectKeys[i].Kinds |= key.Kinds diff --git a/x/spec/keeper/spec.go b/x/spec/keeper/spec.go index 59c0a5e6ce..b52e77fe1a 100644 --- a/x/spec/keeper/spec.go +++ b/x/spec/keeper/spec.go @@ -107,6 +107,9 @@ func (k Keeper) RefreshSpec(ctx sdk.Context, spec types.Spec, ancestors []types. } if details, err := spec.ValidateSpec(k.MaxCU(ctx)); err != nil { + if details != nil { + details = map[string]string{} + } details["invalidates"] = spec.Index attrs := utils.StringMapToAttributes(details) return nil, utils.LavaFormatWarning("spec refresh failed (invalidate)", err, attrs...) @@ -137,6 +140,9 @@ func (k Keeper) doExpandSpec( inherit *map[string]bool, details string, ) (string, error) { + if spec == nil { + return "", fmt.Errorf("doExpandSpec: spec is nil") + } parentsCollections := map[types.CollectionData][]*types.ApiCollection{} if len(spec.Imports) != 0 { diff --git a/x/spec/types/api_collection.go b/x/spec/types/api_collection.go index 06561b16bf..7df10e64ee 100644 --- a/x/spec/types/api_collection.go +++ b/x/spec/types/api_collection.go @@ -65,6 +65,9 @@ func (apic *ApiCollection) InheritAllFields(myCollections map[CollectionData]*Ap // changes in place inside the apic // nil merge maps means not to combine that field func (apic *ApiCollection) CombineWithOthers(others []*ApiCollection, combineWithDisabled, allowOverwrite bool) (err error) { + if apic == nil { + return fmt.Errorf("CombineWithOthers: API collection is nil") + } mergedApis := map[string]interface{}{} mergedHeaders := map[string]interface{}{} mergedParsers := map[string]interface{}{} diff --git a/x/spec/types/combinable.go b/x/spec/types/combinable.go index f0c9b781a5..2ee3c5a7b8 100644 --- a/x/spec/types/combinable.go +++ b/x/spec/types/combinable.go @@ -141,7 +141,7 @@ func CombineUnique[T Combinable](appendFrom, appendTo []T, currentMap map[string } else { // overwriting the inherited field might need Overwrite actions if overwritten, isOverwritten := current.currentCombinable.Overwrite(combinable); isOverwritten { - if appendTo[current.index].Differeniator() != combinable.Differeniator() { + if len(appendTo) <= current.index || appendTo[current.index].Differeniator() != combinable.Differeniator() { return nil, fmt.Errorf("differentiator mismatch in overwrite %s vs %s", combinable.Differeniator(), appendTo[current.index].Differeniator()) } overwrittenT, ok := overwritten.(T) diff --git a/x/spec/types/spec.go b/x/spec/types/spec.go index 237feb9bf4..f41ee8fda0 100644 --- a/x/spec/types/spec.go +++ b/x/spec/types/spec.go @@ -200,6 +200,9 @@ func (spec Spec) ValidateSpec(maxCU uint64) (map[string]string, error) { } func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*ApiCollection) error { + if spec == nil { + return fmt.Errorf("CombineCollections: spec is nil") + } collectionDataList := make([]CollectionData, 0) // Populate the keys slice with the map keys for key := range parentsCollections { @@ -225,7 +228,7 @@ func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*Ap break } } - if !combined.Enabled { + if combined == nil || !combined.Enabled { // no collections enabled to combine, we skip this continue } diff --git a/x/subscription/keeper/subscription.go b/x/subscription/keeper/subscription.go index 7a8a04d3f9..7cde34e7ed 100644 --- a/x/subscription/keeper/subscription.go +++ b/x/subscription/keeper/subscription.go @@ -192,6 +192,9 @@ func (k Keeper) verifySubscriptionBuyInputAndGetPlan(ctx sdk.Context, block uint func (k Keeper) createNewSubscription(ctx sdk.Context, plan *planstypes.Plan, creator, consumer string, block uint64, autoRenewalFlag bool, ) (types.Subscription, error) { + if plan == nil { + return types.Subscription{}, utils.LavaFormatError("plan is nil", fmt.Errorf("createNewSubscription: cannot create new subscription")) + } autoRenewalNextPlan := types.AUTO_RENEWAL_PLAN_NONE if autoRenewalFlag { // On subscription creation, auto renewal is set to the subscription's plan @@ -223,6 +226,13 @@ func (k Keeper) createNewSubscription(ctx sdk.Context, plan *planstypes.Plan, cr } func (k Keeper) upgradeSubscriptionPlan(ctx sdk.Context, sub *types.Subscription, newPlan *planstypes.Plan) error { + if newPlan == nil { + return utils.LavaFormatError("new plan is nil", fmt.Errorf("upgradeSubscriptionPlan: cannot upgrade subscription plan")) + } + if sub == nil { + return utils.LavaFormatError("subscription is nil", fmt.Errorf("upgradeSubscriptionPlan: cannot upgrade subscription plan")) + } + block := uint64(ctx.BlockHeight()) nextEpoch, err := k.epochstorageKeeper.GetNextEpoch(ctx, block)