From d7c9423036bc7377d8898c4f340cf5161adc5646 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Thu, 17 Oct 2024 10:52:19 -0400 Subject: [PATCH] Introduce and use `database.WithDefault` (#3478) --- database/helpers.go | 15 +++++++++++++++ database/helpers_test.go | 26 +++++++++++++++++++++++++- indexer/index.go | 15 +++++++-------- vms/example/xsvm/state/storage.go | 18 +++--------------- vms/platformvm/state/state.go | 10 +--------- 5 files changed, 51 insertions(+), 33 deletions(-) diff --git a/database/helpers.go b/database/helpers.go index d17e6669e4fa..db12878cb56c 100644 --- a/database/helpers.go +++ b/database/helpers.go @@ -137,6 +137,21 @@ func GetBool(db KeyValueReader, key []byte) (bool, error) { return b[0] == BoolTrue, nil } +// WithDefault returns the value at [key] in [db]. If the key doesn't exist, it +// returns [def]. +func WithDefault[V any]( + get func(KeyValueReader, []byte) (V, error), + db KeyValueReader, + key []byte, + def V, +) (V, error) { + v, err := get(db, key) + if err == ErrNotFound { + return def, nil + } + return v, err +} + func Count(db Iteratee) (int, error) { iterator := db.NewIterator() defer iterator.Release() diff --git a/database/helpers_test.go b/database/helpers_test.go index 1cce64c82baf..2de3f2af5728 100644 --- a/database/helpers_test.go +++ b/database/helpers_test.go @@ -1,7 +1,7 @@ // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package database +package database_test import ( "math/rand" @@ -11,7 +11,10 @@ import ( "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/utils" + + . "github.com/ava-labs/avalanchego/database" ) func TestSortednessUint64(t *testing.T) { @@ -49,3 +52,24 @@ func TestSortednessUint32(t *testing.T) { } require.True(t, utils.IsSortedBytes(intBytes)) } + +func TestOrDefault(t *testing.T) { + require := require.New(t) + + var ( + db = memdb.New() + key = utils.RandomBytes(32) + ) + + // Key doesn't exist + v, err := WithDefault(GetUInt64, db, key, 1) + require.NoError(err) + require.Equal(uint64(1), v) + + require.NoError(PutUInt64(db, key, 2)) + + // Key does exist + v, err = WithDefault(GetUInt64, db, key, 1) + require.NoError(err) + require.Equal(uint64(2), v) +} diff --git a/indexer/index.go b/indexer/index.go index 188658c16ca6..58bb215a5833 100644 --- a/indexer/index.go +++ b/indexer/index.go @@ -78,17 +78,16 @@ func newIndex( } // Get next accepted index from db - nextAcceptedIndex, err := database.GetUInt64(i.vDB, nextAcceptedIndexKey) - if err == database.ErrNotFound { - // Couldn't find it in the database. Must not have accepted any containers in previous runs. - i.log.Info("created new index", - zap.Uint64("nextAcceptedIndex", i.nextAcceptedIndex), - ) - return i, nil - } + nextAcceptedIndex, err := database.WithDefault( + database.GetUInt64, + i.vDB, + nextAcceptedIndexKey, + 0, + ) if err != nil { return nil, fmt.Errorf("couldn't get next accepted index from database: %w", err) } + i.nextAcceptedIndex = nextAcceptedIndex i.log.Info("created new index", zap.Uint64("nextAcceptedIndex", i.nextAcceptedIndex), diff --git a/vms/example/xsvm/state/storage.go b/vms/example/xsvm/state/storage.go index 00fb20e45c7b..48e96f977efd 100644 --- a/vms/example/xsvm/state/storage.go +++ b/vms/example/xsvm/state/storage.go @@ -77,11 +77,7 @@ func AddBlock(db database.KeyValueWriter, height uint64, blkID ids.ID, blk []byt func GetNonce(db database.KeyValueReader, address ids.ShortID) (uint64, error) { key := Flatten(addressPrefix, address[:]) - nonce, err := database.GetUInt64(db, key) - if errors.Is(err, database.ErrNotFound) { - return 0, nil - } - return nonce, err + return database.WithDefault(database.GetUInt64, db, key, 0) } func SetNonce(db database.KeyValueWriter, address ids.ShortID, nonce uint64) error { @@ -102,11 +98,7 @@ func IncrementNonce(db database.KeyValueReaderWriter, address ids.ShortID, nonce func GetBalance(db database.KeyValueReader, address ids.ShortID, chainID ids.ID) (uint64, error) { key := Flatten(addressPrefix, address[:], chainID[:]) - balance, err := database.GetUInt64(db, key) - if errors.Is(err, database.ErrNotFound) { - return 0, nil - } - return balance, err + return database.WithDefault(database.GetUInt64, db, key, 0) } func SetBalance(db database.KeyValueWriterDeleter, address ids.ShortID, chainID ids.ID, balance uint64) error { @@ -154,11 +146,7 @@ func AddLoanID(db database.KeyValueWriter, chainID ids.ID, loanID ids.ID) error func GetLoan(db database.KeyValueReader, chainID ids.ID) (uint64, error) { key := Flatten(chainPrefix, chainID[:]) - balance, err := database.GetUInt64(db, key) - if errors.Is(err, database.ErrNotFound) { - return 0, nil - } - return balance, err + return database.WithDefault(database.GetUInt64, db, key, 0) } func SetLoan(db database.KeyValueWriterDeleter, chainID ids.ID, balance uint64) error { diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index a624564233a6..f351bf5940b2 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -1391,7 +1391,7 @@ func (s *state) loadMetadata() error { s.persistedFeeState = feeState s.SetFeeState(feeState) - accruedFees, err := getAccruedFees(s.singletonDB) + accruedFees, err := database.WithDefault(database.GetUInt64, s.singletonDB, AccruedFeesKey, 0) if err != nil { return err } @@ -2665,11 +2665,3 @@ func getFeeState(db database.KeyValueReader) (gas.State, error) { } return feeState, nil } - -func getAccruedFees(db database.KeyValueReader) (uint64, error) { - accruedFees, err := database.GetUInt64(db, AccruedFeesKey) - if err == database.ErrNotFound { - return 0, nil - } - return accruedFees, err -}