Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: audit fixes #1672

Merged
merged 10 commits into from
Sep 25, 2024
12 changes: 12 additions & 0 deletions x/conflict/keeper/conflict.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ func (k Keeper) ValidateFinalizationConflict(ctx sdk.Context, conflictData *type
}

func (k Keeper) ValidateResponseConflict(ctx sdk.Context, conflictData *types.ResponseConflict, clientAddr sdk.AccAddress) error {
// 0. validate conflictData is not nil
if conflictData.IsDataNil() {
Yaroms marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("ValidateResponseConflict: conflict data is nil")
}

// 1. validate mismatching data
chainID := conflictData.ConflictRelayData0.Request.RelaySession.SpecId
if chainID != conflictData.ConflictRelayData1.Request.RelaySession.SpecId {
Expand Down Expand Up @@ -279,6 +284,10 @@ func (k Keeper) ValidateSameProviderConflict(ctx sdk.Context, conflictData *type

func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization, spec *spectypes.Spec) (finalizedBlocksMarshalled map[int64]string, earliestFinalizedBlock int64, latestFinalizedBlock int64, err error) {
EMPTY_MAP := map[int64]string{}
// verify spec is not nil
if spec == nil {
return EMPTY_MAP, 0, 0, fmt.Errorf("validateBlockHeights: spec is nil")
}

// Unmarshall finalized blocks
finalizedBlocks := map[int64]string{}
Expand Down Expand Up @@ -312,6 +321,9 @@ func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization,
}

func (k Keeper) validateFinalizedBlock(relayFinalization *types.RelayFinalization, latestFinalizedBlock int64, spec *spectypes.Spec) error {
if spec == nil {
return fmt.Errorf("validateFinalizedBlock: spec is nil")
}
latestBlock := relayFinalization.GetLatestBlock()
blockDistanceToFinalization := int64(spec.BlockDistanceForFinalizedData)

Expand Down
10 changes: 10 additions & 0 deletions x/conflict/keeper/msg_server_detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ import (
"golang.org/x/exp/slices"
)

// DetectionIndex creates an index for detection instances.
// WARNING: the detection index should not be used for prefixed iteration.
omerlavanet marked this conversation as resolved.
Show resolved Hide resolved
func DetectionIndex(creatorAddr string, conflict *types.ResponseConflict, epochStart uint64) string {
if conflict.IsDataNil() {
Yaroms marked this conversation as resolved.
Show resolved Hide resolved
return ""
}
return creatorAddr + conflict.ConflictRelayData0.Request.RelaySession.Provider + conflict.ConflictRelayData1.Request.RelaySession.Provider + strconv.FormatUint(epochStart, 10)
}

Expand Down Expand Up @@ -125,6 +130,11 @@ func (k msgServer) handleSameProviderFinalizationConflict(ctx sdk.Context, confl
}

func (k msgServer) handleResponseConflict(ctx sdk.Context, goCtx context.Context, conflict *types.ResponseConflict, clientAddr sdk.AccAddress) (eventData map[string]string, err error) {
if conflict.IsDataNil() {
Yaroms marked this conversation as resolved.
Show resolved Hide resolved
return nil, utils.LavaFormatWarning("conflict data is nil", fmt.Errorf("handleResponseConflict: cannot handle response conflict"),
utils.LogAttr("client", clientAddr.String()),
)
}
err = k.Keeper.ValidateResponseConflict(ctx, conflict, clientAddr)
if err != nil {
return nil, utils.LavaFormatWarning("Simulation: invalid response conflict detection", err,
Expand Down
24 changes: 24 additions & 0 deletions x/conflict/types/conflict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package types

func (c *ResponseConflict) IsDataNil() bool {
if c == nil {
return true
}
if c.ConflictRelayData0 == nil || c.ConflictRelayData1 == nil {
return true
}
if c.ConflictRelayData0.Request == nil || c.ConflictRelayData1.Request == nil {
return true
}
if c.ConflictRelayData0.Request.RelayData == nil || c.ConflictRelayData1.Request.RelayData == nil {
return true
}
if c.ConflictRelayData0.Request.RelaySession == nil || c.ConflictRelayData1.Request.RelaySession == nil {
return true
}
if c.ConflictRelayData0.Reply == nil || c.ConflictRelayData1.Reply == nil {
return true
}

return false
}
2 changes: 1 addition & 1 deletion x/pairing/keeper/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid
utils.LavaFormatPanic("critical: invalid provider address for payment", err,
utils.Attribute{Key: "chainID", Value: chainID},
utils.Attribute{Key: "client", Value: project.Subscription},
utils.Attribute{Key: "provider", Value: providerAccAddr.String()},
utils.Attribute{Key: "provider", Value: possibleAddr.Address},
utils.Attribute{Key: "epochBlock", Value: strconv.FormatUint(epoch, 10)},
)
}
Expand Down
9 changes: 5 additions & 4 deletions x/pairing/keeper/pairing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ func verifyGeoScoreForTesting(providerScores []*pairingscores.PairingScore, slot
})

geoReqObject := pairingscores.GeoReq{}
geoReq, ok := slot.Reqs[geoReqObject.GetName()].(pairingscores.GeoReq)
geoReq, ok := slot.Reqs[geoReqObject.GetName()].(*pairingscores.GeoReq)
if !ok {
return false
}
Expand Down Expand Up @@ -1331,7 +1331,8 @@ func TestNoRequiredGeo(t *testing.T) {

// TestGeoSlotCalc checks that the calculated slots always hold a single bit geo req
func TestGeoSlotCalc(t *testing.T) {
geoReqName := pairingscores.GeoReq{}.GetName()
geoReq := pairingscores.GeoReq{}
geoReqName := geoReq.GetName()

allGeos := planstypes.GetAllGeolocations()
maxGeo := lavaslices.Max(allGeos)
Expand All @@ -1347,7 +1348,7 @@ func TestGeoSlotCalc(t *testing.T) {
slots := pairingscores.CalcSlots(&policy)
for _, slot := range slots {
geoReqFromMap := slot.Reqs[geoReqName]
geoReq, ok := geoReqFromMap.(pairingscores.GeoReq)
geoReq, ok := geoReqFromMap.(*pairingscores.GeoReq)
if !ok {
require.Fail(t, "slot geo req is not of GeoReq type")
}
Expand All @@ -1366,7 +1367,7 @@ func TestGeoSlotCalc(t *testing.T) {
slots := pairingscores.CalcSlots(&policy)
for _, slot := range slots {
geoReqFromMap := slot.Reqs[geoReqName]
geoReq, ok := geoReqFromMap.(pairingscores.GeoReq)
geoReq, ok := geoReqFromMap.(*pairingscores.GeoReq)
if !ok {
require.Fail(t, "slot geo req is not of GeoReq type")
}
Expand Down
25 changes: 16 additions & 9 deletions x/pairing/keeper/scores/geo_req.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ const (
minGeoLatency = 1
)

func (gr GeoReq) Init(policy planstypes.Policy) bool {
return true
func (gr *GeoReq) Init(policy planstypes.Policy) bool {
return gr != nil
}

// Score calculates the geo score of a provider based on preset latency data
// Note: each GeoReq must have exactly a single geolocation (bit)
func (gr GeoReq) Score(score PairingScore) math.Uint {
func (gr *GeoReq) Score(score PairingScore) math.Uint {
if gr == nil {
return calculateCostFromLatency(maxGeoLatency)
}

// check if the provider supports the required geolocation
if gr.Geo&^score.Provider.Geolocation == 0 {
return calculateCostFromLatency(minGeoLatency)
Expand All @@ -38,13 +42,16 @@ func (gr GeoReq) Score(score PairingScore) math.Uint {
return cost
}

func (gr GeoReq) GetName() string {
func (gr *GeoReq) GetName() string {
Yaroms marked this conversation as resolved.
Show resolved Hide resolved
if gr == nil {
return ""
}
return geoReqName
}

// Equal() used to compare slots to determine slot groups
func (gr GeoReq) Equal(other ScoreReq) bool {
otherGeoReq, ok := other.(GeoReq)
func (gr *GeoReq) Equal(other ScoreReq) bool {
otherGeoReq, ok := other.(*GeoReq)
if !ok {
return false
}
Expand All @@ -54,17 +61,17 @@ func (gr GeoReq) Equal(other ScoreReq) bool {

// TODO: this function doesn't return the optimal geo reqs for the case
// that there are more required geos than providers to pair
func (gr GeoReq) GetReqForSlot(policy planstypes.Policy, slotIdx int) ScoreReq {
func (gr *GeoReq) GetReqForSlot(policy planstypes.Policy, slotIdx int) ScoreReq {
policyGeoEnums := planstypes.GetGeolocationsFromUint(policy.GeolocationProfile)

if len(policyGeoEnums) == 0 {
utils.LavaFormatError("length of policyGeoEnums is zero", fmt.Errorf("critical: Attempt to divide by zero"),
utils.LogAttr("policyGeoProfile", policy.GeolocationProfile),
)
return GeoReq{Geo: int32(planstypes.Geolocation_USC)}
return &GeoReq{Geo: int32(planstypes.Geolocation_USC)}
}

return GeoReq{Geo: int32(policyGeoEnums[slotIdx%len(policyGeoEnums)])}
return &GeoReq{Geo: int32(policyGeoEnums[slotIdx%len(policyGeoEnums)])}
}

// CalcGeoCost() finds the minimal latency between the required geo and the provider's supported geolocations
Expand Down
2 changes: 1 addition & 1 deletion x/pairing/keeper/scores/pairing_slot.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (psg PairingSlotGroup) Subtract(other *PairingSlotGroup) *PairingSlot {
otherReq, found := other.Reqs[key]
if !found {
reqsDiff[key] = req
} else if !req.Equal(otherReq) {
} else if req != nil && !req.Equal(otherReq) {
reqsDiff[key] = req
}
}
Expand Down
20 changes: 13 additions & 7 deletions x/pairing/keeper/scores/qos_req.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ type QosGetter interface {
// QosReq implements the ScoreReq interface for provider staking requirement(s)
type QosReq struct{}

func (qr QosReq) Init(policy planstypes.Policy) bool {
return true
func (qr *QosReq) Init(policy planstypes.Policy) bool {
return qr != nil
}

// Score calculates the the provider's qos score
func (qr QosReq) Score(score PairingScore) math.Uint {
func (qr *QosReq) Score(score PairingScore) math.Uint {
// TODO: update Qos in providerQosFS properly and uncomment this code below
// Also, the qos score should range between 0.5-2

Expand All @@ -31,19 +31,25 @@ func (qr QosReq) Score(score PairingScore) math.Uint {
// }

// return math.Uint(qosScore)
if qr == nil {
return math.NewUint(1)
}
return math.NewUint(1)
}

func (qr QosReq) GetName() string {
func (qr *QosReq) GetName() string {
if qr == nil {
return ""
}
return qosReqName
}

// Equal used to compare slots to determine slot groups.
// Equal always returns true (there are no different "types" of qos)
func (qr QosReq) Equal(other ScoreReq) bool {
return true
func (qr *QosReq) Equal(other ScoreReq) bool {
return qr != nil
}

func (qr QosReq) GetReqForSlot(policy planstypes.Policy, slotIdx int) ScoreReq {
func (qr *QosReq) GetReqForSlot(policy planstypes.Policy, slotIdx int) ScoreReq {
return qr
}
8 changes: 7 additions & 1 deletion x/pairing/keeper/scores/stake_req.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ const stakeReqName = "stake-req"
type StakeReq struct{}

func (sr *StakeReq) Init(policy planstypes.Policy) bool {
return true
return sr != nil
}

// Score calculates the the provider score as the normalized stake
func (sr *StakeReq) Score(score PairingScore) math.Uint {
if sr == nil {
return math.OneUint()
}
effectiveStake := score.Provider.EffectiveStake()
if !effectiveStake.IsPositive() {
return math.OneUint()
Expand All @@ -25,6 +28,9 @@ func (sr *StakeReq) Score(score PairingScore) math.Uint {
}

func (sr *StakeReq) GetName() string {
if sr == nil {
return ""
}
return stakeReqName
}

Expand Down
5 changes: 2 additions & 3 deletions x/projects/keeper/creation.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (k Keeper) registerKey(ctx sdk.Context, key types.ProjectKey, project *type

// check that the developer key is valid, and that it does not already
// belong to a different project.
if found && devkeyData.ProjectID != project.Index {
if found && devkeyData.ProjectID != project.GetIndex() {
return utils.LavaFormatWarning("failed to register key",
fmt.Errorf("key already exists"),
utils.Attribute{Key: "key", Value: key.Key},
Expand Down Expand Up @@ -254,10 +254,9 @@ func (k Keeper) unregisterKey(ctx sdk.Context, key types.ProjectKey, project *ty
// the developer key belongs to a different project
if devkeyData.ProjectID != project.GetIndex() {
return utils.LavaFormatWarning("failed to unregister key", legacyerrors.ErrNotFound,
utils.Attribute{Key: "projectID", Value: project.Index},
utils.Attribute{Key: "projectID", Value: project.GetIndex()},
utils.Attribute{Key: "key", Value: key.Key},
utils.Attribute{Key: "keyTypes", Value: key.Kinds},
utils.Attribute{Key: "projectID", Value: project.GetIndex()},
utils.Attribute{Key: "otherID", Value: devkeyData.ProjectID},
)
}
Expand Down
3 changes: 3 additions & 0 deletions x/projects/types/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ func (project *Project) GetKey(key string) ProjectKey {
}

func (project *Project) AppendKey(key ProjectKey) bool {
if project == nil {
oren-lava marked this conversation as resolved.
Show resolved Hide resolved
return false
}
for i, projectKey := range project.ProjectKeys {
if projectKey.Key == key.Key {
project.ProjectKeys[i].Kinds |= key.Kinds
Expand Down
6 changes: 6 additions & 0 deletions x/spec/keeper/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ func (k Keeper) RefreshSpec(ctx sdk.Context, spec types.Spec, ancestors []types.
}

if details, err := spec.ValidateSpec(k.MaxCU(ctx)); err != nil {
if details != nil {
details = map[string]string{}
}
details["invalidates"] = spec.Index
attrs := utils.StringMapToAttributes(details)
return nil, utils.LavaFormatWarning("spec refresh failed (invalidate)", err, attrs...)
Expand Down Expand Up @@ -137,6 +140,9 @@ func (k Keeper) doExpandSpec(
inherit *map[string]bool,
details string,
) (string, error) {
if spec == nil {
return "", fmt.Errorf("doExpandSpec: spec is nil")
}
parentsCollections := map[types.CollectionData][]*types.ApiCollection{}

if len(spec.Imports) != 0 {
Expand Down
3 changes: 3 additions & 0 deletions x/spec/types/api_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ func (apic *ApiCollection) InheritAllFields(myCollections map[CollectionData]*Ap
// changes in place inside the apic
// nil merge maps means not to combine that field
func (apic *ApiCollection) CombineWithOthers(others []*ApiCollection, combineWithDisabled, allowOverwrite bool) (err error) {
if apic == nil {
return fmt.Errorf("CombineWithOthers: API collection is nil")
}
mergedApis := map[string]interface{}{}
mergedHeaders := map[string]interface{}{}
mergedParsers := map[string]interface{}{}
Expand Down
2 changes: 1 addition & 1 deletion x/spec/types/combinable.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func CombineUnique[T Combinable](appendFrom, appendTo []T, currentMap map[string
} else {
// overwriting the inherited field might need Overwrite actions
if overwritten, isOverwritten := current.currentCombinable.Overwrite(combinable); isOverwritten {
if appendTo[current.index].Differeniator() != combinable.Differeniator() {
if len(appendTo) <= current.index || appendTo[current.index].Differeniator() != combinable.Differeniator() {
return nil, fmt.Errorf("differentiator mismatch in overwrite %s vs %s", combinable.Differeniator(), appendTo[current.index].Differeniator())
}
overwrittenT, ok := overwritten.(T)
Expand Down
5 changes: 4 additions & 1 deletion x/spec/types/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ func (spec Spec) ValidateSpec(maxCU uint64) (map[string]string, error) {
}

func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*ApiCollection) error {
if spec == nil {
return fmt.Errorf("CombineCollections: spec is nil")
}
collectionDataList := make([]CollectionData, 0)
// Populate the keys slice with the map keys
for key := range parentsCollections {
Expand All @@ -225,7 +228,7 @@ func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*Ap
break
}
}
if !combined.Enabled {
if combined == nil || !combined.Enabled {
// no collections enabled to combine, we skip this
continue
}
Expand Down
10 changes: 10 additions & 0 deletions x/subscription/keeper/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ func (k Keeper) verifySubscriptionBuyInputAndGetPlan(ctx sdk.Context, block uint
func (k Keeper) createNewSubscription(ctx sdk.Context, plan *planstypes.Plan, creator, consumer string,
block uint64, autoRenewalFlag bool,
) (types.Subscription, error) {
if plan == nil {
return types.Subscription{}, utils.LavaFormatError("plan is nil", fmt.Errorf("createNewSubscription: cannot create new subscription"))
}
autoRenewalNextPlan := types.AUTO_RENEWAL_PLAN_NONE
if autoRenewalFlag {
// On subscription creation, auto renewal is set to the subscription's plan
Expand Down Expand Up @@ -223,6 +226,13 @@ func (k Keeper) createNewSubscription(ctx sdk.Context, plan *planstypes.Plan, cr
}

func (k Keeper) upgradeSubscriptionPlan(ctx sdk.Context, sub *types.Subscription, newPlan *planstypes.Plan) error {
if newPlan == nil {
return utils.LavaFormatError("new plan is nil", fmt.Errorf("upgradeSubscriptionPlan: cannot upgrade subscription plan"))
}
if sub == nil {
return utils.LavaFormatError("subscription is nil", fmt.Errorf("upgradeSubscriptionPlan: cannot upgrade subscription plan"))
}

block := uint64(ctx.BlockHeight())

nextEpoch, err := k.epochstorageKeeper.GetNextEpoch(ctx, block)
Expand Down
Loading