diff --git a/arbos/addressSet/addressSet.go b/arbos/addressSet/addressSet.go index 2d4cc34f70..1f09ff1440 100644 --- a/arbos/addressSet/addressSet.go +++ b/arbos/addressSet/addressSet.go @@ -3,6 +3,8 @@ package addressSet +// TODO lowercase this package name + import ( "errors" @@ -26,49 +28,49 @@ func Initialize(sto *storage.Storage) error { func OpenAddressSet(sto *storage.Storage) *AddressSet { return &AddressSet{ - sto, - sto.OpenStorageBackedUint64(0), - sto.OpenSubStorage([]byte{0}), + backingStorage: sto.WithoutCache(), + size: sto.OpenStorageBackedUint64(0), + byAddress: sto.OpenSubStorage([]byte{0}), } } -func (aset *AddressSet) Size() (uint64, error) { - return aset.size.Get() +func (as *AddressSet) Size() (uint64, error) { + return as.size.Get() } -func (aset *AddressSet) IsMember(addr common.Address) (bool, error) { - value, err := aset.byAddress.Get(util.AddressToHash(addr)) +func (as *AddressSet) IsMember(addr common.Address) (bool, error) { + value, err := as.byAddress.Get(util.AddressToHash(addr)) return value != (common.Hash{}), err } -func (aset *AddressSet) GetAnyMember() (*common.Address, error) { - size, err := aset.size.Get() +func (as *AddressSet) GetAnyMember() (*common.Address, error) { + size, err := as.size.Get() if err != nil || size == 0 { return nil, err } - sba := aset.backingStorage.OpenStorageBackedAddressOrNil(1) + sba := as.backingStorage.OpenStorageBackedAddressOrNil(1) addr, err := sba.Get() return addr, err } -func (aset *AddressSet) Clear() error { - size, err := aset.size.Get() +func (as *AddressSet) Clear() error { + size, err := as.size.Get() if err != nil || size == 0 { return err } for i := uint64(1); i <= size; i++ { - contents, _ := aset.backingStorage.GetByUint64(i) - _ = aset.backingStorage.ClearByUint64(i) - err = aset.byAddress.Clear(contents) + contents, _ := as.backingStorage.GetByUint64(i) + _ = as.backingStorage.ClearByUint64(i) + err = as.byAddress.Clear(contents) if err != nil { return err } } - return aset.size.Clear() + return as.size.Clear() } -func (aset *AddressSet) AllMembers(maxNumToReturn uint64) ([]common.Address, error) { - size, err := aset.size.Get() +func (as *AddressSet) AllMembers(maxNumToReturn uint64) ([]common.Address, error) { + size, err := as.size.Get() if err != nil { return nil, err } @@ -77,7 +79,7 @@ func (aset *AddressSet) AllMembers(maxNumToReturn uint64) ([]common.Address, err } ret := make([]common.Address, size) for i := range ret { - sba := aset.backingStorage.OpenStorageBackedAddress(uint64(i + 1)) + sba := as.backingStorage.OpenStorageBackedAddress(uint64(i + 1)) ret[i], err = sba.Get() if err != nil { return nil, err @@ -86,22 +88,22 @@ func (aset *AddressSet) AllMembers(maxNumToReturn uint64) ([]common.Address, err return ret, nil } -func (aset *AddressSet) ClearList() error { - size, err := aset.size.Get() +func (as *AddressSet) ClearList() error { + size, err := as.size.Get() if err != nil || size == 0 { return err } for i := uint64(1); i <= size; i++ { - err = aset.backingStorage.ClearByUint64(i) + err = as.backingStorage.ClearByUint64(i) if err != nil { return err } } - return aset.size.Clear() + return as.size.Clear() } -func (aset *AddressSet) RectifyMapping(addr common.Address) error { - isOwner, err := aset.IsMember(addr) +func (as *AddressSet) RectifyMapping(addr common.Address) error { + isOwner, err := as.IsMember(addr) if !isOwner || err != nil { return errors.New("RectifyMapping: Address is not an owner") } @@ -109,15 +111,15 @@ func (aset *AddressSet) RectifyMapping(addr common.Address) error { // If the mapping is correct, RectifyMapping shouldn't do anything // Additional safety check to avoid corruption of mapping after the initial fix addrAsHash := common.BytesToHash(addr.Bytes()) - slot, err := aset.byAddress.GetUint64(addrAsHash) + slot, err := as.byAddress.GetUint64(addrAsHash) if err != nil { return err } - atSlot, err := aset.backingStorage.GetByUint64(slot) + atSlot, err := as.backingStorage.GetByUint64(slot) if err != nil { return err } - size, err := aset.size.Get() + size, err := as.size.Get() if err != nil { return err } @@ -126,72 +128,72 @@ func (aset *AddressSet) RectifyMapping(addr common.Address) error { } // Remove the owner from map and add them as a new owner - err = aset.byAddress.Clear(addrAsHash) + err = as.byAddress.Clear(addrAsHash) if err != nil { return err } - return aset.Add(addr) + return as.Add(addr) } -func (aset *AddressSet) Add(addr common.Address) error { - present, err := aset.IsMember(addr) +func (as *AddressSet) Add(addr common.Address) error { + present, err := as.IsMember(addr) if present || err != nil { return err } - size, err := aset.size.Get() + size, err := as.size.Get() if err != nil { return err } slot := util.UintToHash(1 + size) addrAsHash := common.BytesToHash(addr.Bytes()) - err = aset.byAddress.Set(addrAsHash, slot) + err = as.byAddress.Set(addrAsHash, slot) if err != nil { return err } - sba := aset.backingStorage.OpenStorageBackedAddress(1 + size) + sba := as.backingStorage.OpenStorageBackedAddress(1 + size) err = sba.Set(addr) if err != nil { return err } - _, err = aset.size.Increment() + _, err = as.size.Increment() return err } -func (aset *AddressSet) Remove(addr common.Address, arbosVersion uint64) error { +func (as *AddressSet) Remove(addr common.Address, arbosVersion uint64) error { addrAsHash := common.BytesToHash(addr.Bytes()) - slot, err := aset.byAddress.GetUint64(addrAsHash) + slot, err := as.byAddress.GetUint64(addrAsHash) if slot == 0 || err != nil { return err } - err = aset.byAddress.Clear(addrAsHash) + err = as.byAddress.Clear(addrAsHash) if err != nil { return err } - size, err := aset.size.Get() + size, err := as.size.Get() if err != nil { return err } if slot < size { - atSize, err := aset.backingStorage.GetByUint64(size) + atSize, err := as.backingStorage.GetByUint64(size) if err != nil { return err } - err = aset.backingStorage.SetByUint64(slot, atSize) + err = as.backingStorage.SetByUint64(slot, atSize) if err != nil { return err } if arbosVersion >= 11 { - err = aset.byAddress.Set(atSize, util.UintToHash(slot)) + err = as.byAddress.Set(atSize, util.UintToHash(slot)) if err != nil { return err } } } - err = aset.backingStorage.ClearByUint64(size) + err = as.backingStorage.ClearByUint64(size) if err != nil { return err } - _, err = aset.size.Decrement() + _, err = as.size.Decrement() return err } diff --git a/arbos/addressTable/addressTable.go b/arbos/addressTable/addressTable.go index 220c2700f4..3fbb7b3782 100644 --- a/arbos/addressTable/addressTable.go +++ b/arbos/addressTable/addressTable.go @@ -3,6 +3,8 @@ package addressTable +// TODO lowercase this package name + import ( "bytes" "errors" @@ -25,7 +27,7 @@ func Initialize(sto *storage.Storage) { func Open(sto *storage.Storage) *AddressTable { numItems := sto.OpenStorageBackedUint64(0) - return &AddressTable{sto, sto.OpenSubStorage([]byte{}), numItems} + return &AddressTable{sto.WithoutCache(), sto.OpenSubStorage([]byte{}), numItems} } func (atab *AddressTable) Register(addr common.Address) (uint64, error) { diff --git a/arbos/arbosState/arbosstate.go b/arbos/arbosState/arbosstate.go index a47f4e790b..8702c62d16 100644 --- a/arbos/arbosState/arbosstate.go +++ b/arbos/arbosState/arbosstate.go @@ -73,13 +73,13 @@ func OpenArbosState(stateDB vm.StateDB, burner burn.Burner) (*ArbosState, error) backingStorage.OpenStorageBackedUint64(uint64(upgradeVersionOffset)), backingStorage.OpenStorageBackedUint64(uint64(upgradeTimestampOffset)), backingStorage.OpenStorageBackedAddress(uint64(networkFeeAccountOffset)), - l1pricing.OpenL1PricingState(backingStorage.OpenSubStorage(l1PricingSubspace)), - l2pricing.OpenL2PricingState(backingStorage.OpenSubStorage(l2PricingSubspace)), - retryables.OpenRetryableState(backingStorage.OpenSubStorage(retryablesSubspace), stateDB), - addressTable.Open(backingStorage.OpenSubStorage(addressTableSubspace)), - addressSet.OpenAddressSet(backingStorage.OpenSubStorage(chainOwnerSubspace)), - merkleAccumulator.OpenMerkleAccumulator(backingStorage.OpenSubStorage(sendMerkleSubspace)), - blockhash.OpenBlockhashes(backingStorage.OpenSubStorage(blockhashesSubspace)), + l1pricing.OpenL1PricingState(backingStorage.OpenCachedSubStorage(l1PricingSubspace)), + l2pricing.OpenL2PricingState(backingStorage.OpenCachedSubStorage(l2PricingSubspace)), + retryables.OpenRetryableState(backingStorage.OpenCachedSubStorage(retryablesSubspace), stateDB), + addressTable.Open(backingStorage.OpenCachedSubStorage(addressTableSubspace)), + addressSet.OpenAddressSet(backingStorage.OpenCachedSubStorage(chainOwnerSubspace)), + merkleAccumulator.OpenMerkleAccumulator(backingStorage.OpenCachedSubStorage(sendMerkleSubspace)), + blockhash.OpenBlockhashes(backingStorage.OpenCachedSubStorage(blockhashesSubspace)), backingStorage.OpenStorageBackedBigInt(uint64(chainIdOffset)), backingStorage.OpenStorageBackedBytes(chainConfigSubspace), backingStorage.OpenStorageBackedUint64(uint64(genesisBlockNumOffset)), @@ -225,14 +225,14 @@ func InitializeArbosState(stateDB vm.StateDB, burner burn.Burner, chainConfig *p if desiredArbosVersion >= 2 { initialRewardsRecipient = initialChainOwner } - _ = l1pricing.InitializeL1PricingState(sto.OpenSubStorage(l1PricingSubspace), initialRewardsRecipient, initMessage.InitialL1BaseFee) - _ = l2pricing.InitializeL2PricingState(sto.OpenSubStorage(l2PricingSubspace)) - _ = retryables.InitializeRetryableState(sto.OpenSubStorage(retryablesSubspace)) - addressTable.Initialize(sto.OpenSubStorage(addressTableSubspace)) - merkleAccumulator.InitializeMerkleAccumulator(sto.OpenSubStorage(sendMerkleSubspace)) - blockhash.InitializeBlockhashes(sto.OpenSubStorage(blockhashesSubspace)) - - ownersStorage := sto.OpenSubStorage(chainOwnerSubspace) + _ = l1pricing.InitializeL1PricingState(sto.OpenCachedSubStorage(l1PricingSubspace), initialRewardsRecipient, initMessage.InitialL1BaseFee) + _ = l2pricing.InitializeL2PricingState(sto.OpenCachedSubStorage(l2PricingSubspace)) + _ = retryables.InitializeRetryableState(sto.OpenCachedSubStorage(retryablesSubspace)) + addressTable.Initialize(sto.OpenCachedSubStorage(addressTableSubspace)) + merkleAccumulator.InitializeMerkleAccumulator(sto.OpenCachedSubStorage(sendMerkleSubspace)) + blockhash.InitializeBlockhashes(sto.OpenCachedSubStorage(blockhashesSubspace)) + + ownersStorage := sto.OpenCachedSubStorage(chainOwnerSubspace) _ = addressSet.Initialize(ownersStorage) _ = addressSet.OpenAddressSet(ownersStorage).Add(initialChainOwner) @@ -428,7 +428,7 @@ func (state *ArbosState) ChainOwners() *addressSet.AddressSet { func (state *ArbosState) SendMerkleAccumulator() *merkleAccumulator.MerkleAccumulator { if state.sendMerkle == nil { - state.sendMerkle = merkleAccumulator.OpenMerkleAccumulator(state.backingStorage.OpenSubStorage(sendMerkleSubspace)) + state.sendMerkle = merkleAccumulator.OpenMerkleAccumulator(state.backingStorage.OpenCachedSubStorage(sendMerkleSubspace)) } return state.sendMerkle } diff --git a/arbos/arbosState/arbosstate_test.go b/arbos/arbosState/arbosstate_test.go index c4643c9183..ef63c23386 100644 --- a/arbos/arbosState/arbosstate_test.go +++ b/arbos/arbosState/arbosstate_test.go @@ -64,7 +64,7 @@ func TestStorageBackedInt64(t *testing.T) { func TestStorageSlots(t *testing.T) { state, _ := NewArbosMemoryBackedArbOSState() - sto := state.BackingStorage().OpenSubStorage([]byte{}) + sto := state.BackingStorage().OpenCachedSubStorage([]byte{}) println("nil address", colors.Blue, storage.NilAddressRepresentation.String(), colors.Clear) diff --git a/arbos/blockhash/blockhash.go b/arbos/blockhash/blockhash.go index 2eedf7f5bb..34c907207c 100644 --- a/arbos/blockhash/blockhash.go +++ b/arbos/blockhash/blockhash.go @@ -21,7 +21,7 @@ func InitializeBlockhashes(backingStorage *storage.Storage) { } func OpenBlockhashes(backingStorage *storage.Storage) *Blockhashes { - return &Blockhashes{backingStorage, backingStorage.OpenStorageBackedUint64(0)} + return &Blockhashes{backingStorage.WithoutCache(), backingStorage.OpenStorageBackedUint64(0)} } func (bh *Blockhashes) L1BlockNumber() (uint64, error) { diff --git a/arbos/l1pricing/batchPoster.go b/arbos/l1pricing/batchPoster.go index 97b7b16234..a3428c441c 100644 --- a/arbos/l1pricing/batchPoster.go +++ b/arbos/l1pricing/batchPoster.go @@ -42,12 +42,12 @@ func InitializeBatchPostersTable(storage *storage.Storage) error { if err := totalFundsDue.SetChecked(common.Big0); err != nil { return err } - return addressSet.Initialize(storage.OpenSubStorage(PosterAddrsKey)) + return addressSet.Initialize(storage.OpenCachedSubStorage(PosterAddrsKey)) } func OpenBatchPostersTable(storage *storage.Storage) *BatchPostersTable { return &BatchPostersTable{ - posterAddrs: addressSet.OpenAddressSet(storage.OpenSubStorage(PosterAddrsKey)), + posterAddrs: addressSet.OpenAddressSet(storage.OpenCachedSubStorage(PosterAddrsKey)), posterInfo: storage.OpenSubStorage(PosterInfoKey), totalFundsDue: storage.OpenStorageBackedBigInt(totalFundsDueOffset), } diff --git a/arbos/l1pricing/l1pricing.go b/arbos/l1pricing/l1pricing.go index 142efbeafe..27ecae8b85 100644 --- a/arbos/l1pricing/l1pricing.go +++ b/arbos/l1pricing/l1pricing.go @@ -83,7 +83,7 @@ var InitialEquilibrationUnitsV0 = arbmath.UintToBig(60 * params.TxDataNonZeroGas var InitialEquilibrationUnitsV6 = arbmath.UintToBig(params.TxDataNonZeroGasEIP2028 * 10000000) func InitializeL1PricingState(sto *storage.Storage, initialRewardsRecipient common.Address, initialL1BaseFee *big.Int) error { - bptStorage := sto.OpenSubStorage(BatchPosterTableKey) + bptStorage := sto.OpenCachedSubStorage(BatchPosterTableKey) if err := InitializeBatchPostersTable(bptStorage); err != nil { return err } @@ -118,7 +118,7 @@ func InitializeL1PricingState(sto *storage.Storage, initialRewardsRecipient comm func OpenL1PricingState(sto *storage.Storage) *L1PricingState { return &L1PricingState{ sto, - OpenBatchPostersTable(sto.OpenSubStorage(BatchPosterTableKey)), + OpenBatchPostersTable(sto.OpenCachedSubStorage(BatchPosterTableKey)), sto.OpenStorageBackedAddress(payRewardsToOffset), sto.OpenStorageBackedBigUint(equilibrationUnitsOffset), sto.OpenStorageBackedUint64(inertiaOffset), diff --git a/arbos/queue_test.go b/arbos/queue_test.go index d8d491bdb0..ff993a233f 100644 --- a/arbos/queue_test.go +++ b/arbos/queue_test.go @@ -14,7 +14,7 @@ import ( func TestQueue(t *testing.T) { state, statedb := arbosState.NewArbosMemoryBackedArbOSState() - sto := state.BackingStorage().OpenSubStorage([]byte{}) + sto := state.BackingStorage().OpenCachedSubStorage([]byte{}) Require(t, storage.InitializeQueue(sto)) q := storage.OpenQueue(sto) diff --git a/arbos/retryables/retryable.go b/arbos/retryables/retryable.go index abea2ab7bd..6984e41904 100644 --- a/arbos/retryables/retryable.go +++ b/arbos/retryables/retryable.go @@ -31,13 +31,13 @@ var ( ) func InitializeRetryableState(sto *storage.Storage) error { - return storage.InitializeQueue(sto.OpenSubStorage(timeoutQueueKey)) + return storage.InitializeQueue(sto.OpenCachedSubStorage(timeoutQueueKey)) } func OpenRetryableState(sto *storage.Storage, statedb vm.StateDB) *RetryableState { return &RetryableState{ sto, - storage.OpenQueue(sto.OpenSubStorage(timeoutQueueKey)), + storage.OpenQueue(sto.OpenCachedSubStorage(timeoutQueueKey)), } } @@ -150,6 +150,7 @@ func (rs *RetryableState) DeleteRetryable(id common.Hash, evm *vm.EVM, scenario return false, err } + // we ignore returned error as we expect that if one ClearByUint64 fails, than all consecutive calls to ClearByUint64 will fail with the same error (not modifying state), and then ClearBytes will also fail with the same error (also not modifying state) - and this one we check and return _ = retStorage.ClearByUint64(numTriesOffset) _ = retStorage.ClearByUint64(fromOffset) _ = retStorage.ClearByUint64(toOffset) diff --git a/arbos/storage/queue.go b/arbos/storage/queue.go index 55231d3a90..9c02dc1ee7 100644 --- a/arbos/storage/queue.go +++ b/arbos/storage/queue.go @@ -25,7 +25,7 @@ func InitializeQueue(sto *Storage) error { func OpenQueue(sto *Storage) *Queue { return &Queue{ - sto, + sto.WithoutCache(), sto.OpenStorageBackedUint64(0), sto.OpenStorageBackedUint64(1), } diff --git a/arbos/storage/storage.go b/arbos/storage/storage.go index 478ad68f8f..63987b91f8 100644 --- a/arbos/storage/storage.go +++ b/arbos/storage/storage.go @@ -4,10 +4,13 @@ package storage import ( + "bytes" "fmt" "math/big" + "sync/atomic" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/lru" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/vm" @@ -43,12 +46,18 @@ type Storage struct { db vm.StateDB storageKey []byte burner burn.Burner + hashCache *lru.Cache[string, []byte] } const StorageReadCost = params.SloadGasEIP2200 const StorageWriteCost = params.SstoreSetGasEIP2200 const StorageWriteZeroCost = params.SstoreResetGasEIP2200 +const storageKeyCacheSize = 1024 + +var storageHashCache = lru.NewCache[string, []byte](storageKeyCacheSize) +var cacheFullLogged atomic.Bool + // NewGeth uses a Geth database to create an evm key-value store func NewGeth(statedb vm.StateDB, burner burn.Burner) *Storage { account := common.HexToAddress("0xA4B05FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF") @@ -58,6 +67,7 @@ func NewGeth(statedb vm.StateDB, burner burn.Burner) *Storage { db: statedb, storageKey: []byte{}, burner: burner, + hashCache: storageHashCache, } } @@ -81,15 +91,13 @@ func NewMemoryBackedStateDB() vm.StateDB { // a page, to preserve contiguity within a page. This will reduce cost if/when Ethereum switches to storage // representations that reward contiguity. // Because page numbers are 248 bits, this gives us 124-bit security against collision attacks, which is good enough. -func mapAddress(storageKey []byte, key common.Hash) common.Hash { +func (s *Storage) mapAddress(key common.Hash) common.Hash { keyBytes := key.Bytes() boundary := common.HashLength - 1 - return common.BytesToHash( - append( - crypto.Keccak256(storageKey, keyBytes[:boundary])[:boundary], - keyBytes[boundary], - ), - ) + mapped := make([]byte, 0, common.HashLength) + mapped = append(mapped, s.cachedKeccak(s.storageKey, keyBytes[:boundary])[:boundary]...) + mapped = append(mapped, keyBytes[boundary]) + return common.BytesToHash(mapped) } func writeCost(value common.Hash) uint64 { @@ -99,117 +107,139 @@ func writeCost(value common.Hash) uint64 { return StorageWriteCost } -func (store *Storage) Account() common.Address { - return store.account +func (s *Storage) Account() common.Address { + return s.account } -func (store *Storage) Get(key common.Hash) (common.Hash, error) { - err := store.burner.Burn(StorageReadCost) +func (s *Storage) Get(key common.Hash) (common.Hash, error) { + err := s.burner.Burn(StorageReadCost) if err != nil { return common.Hash{}, err } - if info := store.burner.TracingInfo(); info != nil { + if info := s.burner.TracingInfo(); info != nil { info.RecordStorageGet(key) } - return store.db.GetState(store.account, mapAddress(store.storageKey, key)), nil + return s.db.GetState(s.account, s.mapAddress(key)), nil } -func (store *Storage) GetStorageSlot(key common.Hash) common.Hash { - return mapAddress(store.storageKey, key) +func (s *Storage) GetStorageSlot(key common.Hash) common.Hash { + return s.mapAddress(key) } -func (store *Storage) GetUint64(key common.Hash) (uint64, error) { - value, err := store.Get(key) +func (s *Storage) GetUint64(key common.Hash) (uint64, error) { + value, err := s.Get(key) return value.Big().Uint64(), err } -func (store *Storage) GetByUint64(key uint64) (common.Hash, error) { - return store.Get(util.UintToHash(key)) +func (s *Storage) GetByUint64(key uint64) (common.Hash, error) { + return s.Get(util.UintToHash(key)) } -func (store *Storage) GetUint64ByUint64(key uint64) (uint64, error) { - return store.GetUint64(util.UintToHash(key)) +func (s *Storage) GetUint64ByUint64(key uint64) (uint64, error) { + return s.GetUint64(util.UintToHash(key)) } -func (store *Storage) Set(key common.Hash, value common.Hash) error { - if store.burner.ReadOnly() { +func (s *Storage) Set(key common.Hash, value common.Hash) error { + if s.burner.ReadOnly() { log.Error("Read-only burner attempted to mutate state", "key", key, "value", value) return vm.ErrWriteProtection } - err := store.burner.Burn(writeCost(value)) + err := s.burner.Burn(writeCost(value)) if err != nil { return err } - if info := store.burner.TracingInfo(); info != nil { + if info := s.burner.TracingInfo(); info != nil { info.RecordStorageSet(key, value) } - store.db.SetState(store.account, mapAddress(store.storageKey, key), value) + s.db.SetState(s.account, s.mapAddress(key), value) return nil } -func (store *Storage) SetByUint64(key uint64, value common.Hash) error { - return store.Set(util.UintToHash(key), value) +func (s *Storage) SetByUint64(key uint64, value common.Hash) error { + return s.Set(util.UintToHash(key), value) } -func (store *Storage) SetUint64ByUint64(key uint64, value uint64) error { - return store.Set(util.UintToHash(key), util.UintToHash(value)) +func (s *Storage) SetUint64ByUint64(key uint64, value uint64) error { + return s.Set(util.UintToHash(key), util.UintToHash(value)) } -func (store *Storage) Clear(key common.Hash) error { - return store.Set(key, common.Hash{}) +func (s *Storage) Clear(key common.Hash) error { + return s.Set(key, common.Hash{}) } -func (store *Storage) ClearByUint64(key uint64) error { - return store.Set(util.UintToHash(key), common.Hash{}) +func (s *Storage) ClearByUint64(key uint64) error { + return s.Set(util.UintToHash(key), common.Hash{}) } -func (store *Storage) Swap(key common.Hash, newValue common.Hash) (common.Hash, error) { - oldValue, err := store.Get(key) +func (s *Storage) Swap(key common.Hash, newValue common.Hash) (common.Hash, error) { + oldValue, err := s.Get(key) if err != nil { return common.Hash{}, err } - return oldValue, store.Set(key, newValue) + return oldValue, s.Set(key, newValue) +} + +func (s *Storage) OpenCachedSubStorage(id []byte) *Storage { + return &Storage{ + account: s.account, + db: s.db, + storageKey: s.cachedKeccak(s.storageKey, id), + burner: s.burner, + hashCache: storageHashCache, + } +} +func (s *Storage) OpenSubStorage(id []byte) *Storage { + return &Storage{ + account: s.account, + db: s.db, + storageKey: s.cachedKeccak(s.storageKey, id), + burner: s.burner, + hashCache: nil, + } } -func (store *Storage) OpenSubStorage(id []byte) *Storage { +// Returns shallow copy of Storage that won't use storage key hash cache. +// The storage space represented by the returned Storage is kept the same. +func (s *Storage) WithoutCache() *Storage { return &Storage{ - store.account, - store.db, - crypto.Keccak256(store.storageKey, id), - store.burner, + account: s.account, + db: s.db, + storageKey: s.storageKey, + burner: s.burner, + hashCache: nil, } } -func (store *Storage) SetBytes(b []byte) error { - err := store.ClearBytes() +func (s *Storage) SetBytes(b []byte) error { + err := s.ClearBytes() if err != nil { return err } - err = store.SetUint64ByUint64(0, uint64(len(b))) + err = s.SetUint64ByUint64(0, uint64(len(b))) if err != nil { return err } offset := uint64(1) for len(b) >= 32 { - err = store.SetByUint64(offset, common.BytesToHash(b[:32])) + err = s.SetByUint64(offset, common.BytesToHash(b[:32])) if err != nil { return err } b = b[32:] offset++ } - return store.SetByUint64(offset, common.BytesToHash(b)) + return s.SetByUint64(offset, common.BytesToHash(b)) } -func (store *Storage) GetBytes() ([]byte, error) { - bytesLeft, err := store.GetUint64ByUint64(0) +func (s *Storage) GetBytes() ([]byte, error) { + bytesLeft, err := s.GetUint64ByUint64(0) if err != nil { return nil, err } ret := []byte{} offset := uint64(1) for bytesLeft >= 32 { - next, err := store.GetByUint64(offset) + next, err := s.GetByUint64(offset) if err != nil { return nil, err } @@ -217,7 +247,7 @@ func (store *Storage) GetBytes() ([]byte, error) { bytesLeft -= 32 offset++ } - next, err := store.GetByUint64(offset) + next, err := s.GetByUint64(offset) if err != nil { return nil, err } @@ -225,18 +255,18 @@ func (store *Storage) GetBytes() ([]byte, error) { return ret, nil } -func (store *Storage) GetBytesSize() (uint64, error) { - return store.GetUint64ByUint64(0) +func (s *Storage) GetBytesSize() (uint64, error) { + return s.GetUint64ByUint64(0) } -func (store *Storage) ClearBytes() error { - bytesLeft, err := store.GetUint64ByUint64(0) +func (s *Storage) ClearBytes() error { + bytesLeft, err := s.GetUint64ByUint64(0) if err != nil { return err } offset := uint64(1) for bytesLeft > 0 { - err := store.ClearByUint64(offset) + err := s.ClearByUint64(offset) if err != nil { return err } @@ -247,30 +277,51 @@ func (store *Storage) ClearBytes() error { bytesLeft -= 32 } } - return store.ClearByUint64(0) + return s.ClearByUint64(0) } -func (store *Storage) Burner() burn.Burner { - return store.burner // not public because these should never be changed once set +func (s *Storage) Burner() burn.Burner { + return s.burner // not public because these should never be changed once set } -func (store *Storage) Keccak(data ...[]byte) ([]byte, error) { +func (s *Storage) Keccak(data ...[]byte) ([]byte, error) { byteCount := 0 for _, part := range data { byteCount += len(part) } cost := 30 + 6*arbmath.WordsForBytes(uint64(byteCount)) - if err := store.burner.Burn(cost); err != nil { + if err := s.burner.Burn(cost); err != nil { return nil, err } return crypto.Keccak256(data...), nil } -func (store *Storage) KeccakHash(data ...[]byte) (common.Hash, error) { - bytes, err := store.Keccak(data...) +func (s *Storage) KeccakHash(data ...[]byte) (common.Hash, error) { + bytes, err := s.Keccak(data...) return common.BytesToHash(bytes), err } +// Returns crypto.Keccak256 result for the given data +// If available the result is taken from hash cache +// otherwise crypto.Keccak256 is executed and its result is added to the cache and returned +// note: the method doesn't burn gas, as it's only intended for generating storage subspace keys and mapping slot addresses +// note: returned slice is not thread-safe +func (s *Storage) cachedKeccak(data ...[]byte) []byte { + if s.hashCache == nil { + return crypto.Keccak256(data...) + } + keyString := string(bytes.Join(data, []byte{})) + if hash, wasCached := s.hashCache.Get(keyString); wasCached { + return hash + } + hash := crypto.Keccak256(data...) + evicted := s.hashCache.Add(keyString, hash) + if evicted && cacheFullLogged.CompareAndSwap(false, true) { + log.Warn("Hash cache full, we didn't expect that. Some non-static storage keys may fill up the cache.") + } + return hash +} + type StorageSlot struct { account common.Address db vm.StateDB @@ -278,8 +329,8 @@ type StorageSlot struct { burner burn.Burner } -func (store *Storage) NewSlot(offset uint64) StorageSlot { - return StorageSlot{store.account, store.db, mapAddress(store.storageKey, util.UintToHash(offset)), store.burner} +func (s *Storage) NewSlot(offset uint64) StorageSlot { + return StorageSlot{s.account, s.db, s.mapAddress(util.UintToHash(offset)), s.burner} } func (ss *StorageSlot) Get() (common.Hash, error) { @@ -318,8 +369,8 @@ type StorageBackedInt64 struct { StorageSlot } -func (store *Storage) OpenStorageBackedInt64(offset uint64) StorageBackedInt64 { - return StorageBackedInt64{store.NewSlot(offset)} +func (s *Storage) OpenStorageBackedInt64(offset uint64) StorageBackedInt64 { + return StorageBackedInt64{s.NewSlot(offset)} } func (sbu *StorageBackedInt64) Get() (int64, error) { @@ -339,8 +390,8 @@ type StorageBackedBips struct { backing StorageBackedInt64 } -func (store *Storage) OpenStorageBackedBips(offset uint64) StorageBackedBips { - return StorageBackedBips{StorageBackedInt64{store.NewSlot(offset)}} +func (s *Storage) OpenStorageBackedBips(offset uint64) StorageBackedBips { + return StorageBackedBips{StorageBackedInt64{s.NewSlot(offset)}} } func (sbu *StorageBackedBips) Get() (arbmath.Bips, error) { @@ -356,8 +407,8 @@ type StorageBackedUint64 struct { StorageSlot } -func (store *Storage) OpenStorageBackedUint64(offset uint64) StorageBackedUint64 { - return StorageBackedUint64{store.NewSlot(offset)} +func (s *Storage) OpenStorageBackedUint64(offset uint64) StorageBackedUint64 { + return StorageBackedUint64{s.NewSlot(offset)} } func (sbu *StorageBackedUint64) Get() (uint64, error) { @@ -444,8 +495,8 @@ type StorageBackedBigUint struct { StorageSlot } -func (store *Storage) OpenStorageBackedBigUint(offset uint64) StorageBackedBigUint { - return StorageBackedBigUint{store.NewSlot(offset)} +func (s *Storage) OpenStorageBackedBigUint(offset uint64) StorageBackedBigUint { + return StorageBackedBigUint{s.NewSlot(offset)} } func (sbbu *StorageBackedBigUint) Get() (*big.Int, error) { @@ -483,8 +534,8 @@ type StorageBackedBigInt struct { StorageSlot } -func (store *Storage) OpenStorageBackedBigInt(offset uint64) StorageBackedBigInt { - return StorageBackedBigInt{store.NewSlot(offset)} +func (s *Storage) OpenStorageBackedBigInt(offset uint64) StorageBackedBigInt { + return StorageBackedBigInt{s.NewSlot(offset)} } func (sbbi *StorageBackedBigInt) Get() (*big.Int, error) { @@ -540,8 +591,8 @@ type StorageBackedAddress struct { StorageSlot } -func (store *Storage) OpenStorageBackedAddress(offset uint64) StorageBackedAddress { - return StorageBackedAddress{store.NewSlot(offset)} +func (s *Storage) OpenStorageBackedAddress(offset uint64) StorageBackedAddress { + return StorageBackedAddress{s.NewSlot(offset)} } func (sba *StorageBackedAddress) Get() (common.Address, error) { @@ -563,8 +614,8 @@ func init() { NilAddressRepresentation = common.BigToHash(new(big.Int).Lsh(big.NewInt(1), 255)) } -func (store *Storage) OpenStorageBackedAddressOrNil(offset uint64) StorageBackedAddressOrNil { - return StorageBackedAddressOrNil{store.NewSlot(offset)} +func (s *Storage) OpenStorageBackedAddressOrNil(offset uint64) StorageBackedAddressOrNil { + return StorageBackedAddressOrNil{s.NewSlot(offset)} } func (sba *StorageBackedAddressOrNil) Get() (*common.Address, error) { @@ -588,9 +639,9 @@ type StorageBackedBytes struct { Storage } -func (store *Storage) OpenStorageBackedBytes(id []byte) StorageBackedBytes { +func (s *Storage) OpenStorageBackedBytes(id []byte) StorageBackedBytes { return StorageBackedBytes{ - *store.OpenSubStorage(id), + *s.OpenSubStorage(id), } } diff --git a/arbos/storage/storage_test.go b/arbos/storage/storage_test.go index 35e6b7c4be..a8d424d14e 100644 --- a/arbos/storage/storage_test.go +++ b/arbos/storage/storage_test.go @@ -1,11 +1,16 @@ package storage import ( + "bytes" + "fmt" "math/big" + "math/rand" + "sync" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/crypto" "github.com/offchainlabs/nitro/arbos/burn" "github.com/offchainlabs/nitro/util/arbmath" ) @@ -91,3 +96,73 @@ func TestStorageBackedBigInt(t *testing.T) { }) } } + +func TestOpenCachedSubStorage(t *testing.T) { + s := NewMemoryBacked(burn.NewSystemBurner(nil, false)) + var subSpaceIDs [][]byte + for i := 0; i < 20; i++ { + subSpaceIDs = append(subSpaceIDs, []byte{byte(rand.Intn(0xff))}) + } + var expectedKeys [][]byte + for _, subSpaceID := range subSpaceIDs { + expectedKeys = append(expectedKeys, crypto.Keccak256(s.storageKey, subSpaceID)) + } + n := len(subSpaceIDs) * 50 + start := make(chan struct{}) + errs := make(chan error, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + j := i % len(subSpaceIDs) + subSpaceID, expectedKey := subSpaceIDs[j], expectedKeys[j] + wg.Add(1) + go func() { + defer wg.Done() + <-start + ss := s.OpenCachedSubStorage(subSpaceID) + if !bytes.Equal(ss.storageKey, expectedKey) { + errs <- fmt.Errorf("unexpected storage key, want: %v, have: %v", expectedKey, ss.storageKey) + } + }() + } + close(start) + wg.Wait() + select { + case err := <-errs: + t.Fatal(err) + default: + } +} + +func TestMapAddressCache(t *testing.T) { + s := NewMemoryBacked(burn.NewSystemBurner(nil, false)) + var keys []common.Hash + for i := 0; i < 20; i++ { + keys = append(keys, common.BytesToHash([]byte{byte(rand.Intn(0xff))})) + } + var expectedMapped []common.Hash + for _, key := range keys { + expectedMapped = append(expectedMapped, s.mapAddress(key)) + } + n := len(keys) * 50 + start := make(chan struct{}) + errs := make(chan error, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + j := i % len(keys) + key, expected := keys[j], expectedMapped[j] + wg.Add(1) + go func() { + defer wg.Done() + <-start + mapped := s.mapAddress(key) + if !bytes.Equal(mapped.Bytes(), expected.Bytes()) { + errs <- fmt.Errorf("unexpected storage key, want: %v, have: %v", expected, mapped) + } + }() + } + close(start) + wg.Wait() + if len(errs) > 0 { + t.Fatal(<-errs) + } +}