diff --git a/portalnetwork/state/storage.go b/portalnetwork/state/storage.go index d07b2c3d0783..dae8a14d1da1 100644 --- a/portalnetwork/state/storage.go +++ b/portalnetwork/state/storage.go @@ -41,11 +41,11 @@ func (s *StateStorage) Put(contentKey []byte, contentId []byte, content []byte) keyType := contentKey[0] switch keyType { case AccountTrieNodeType: - return s.putAccountTrieNode(contentKey[1:], content) + return s.putAccountTrieNode(contentKey[1:], contentId, content) case ContractStorageTrieNodeType: - return s.putContractStorageTrieNode(contentKey[1:], content) + return s.putContractStorageTrieNode(contentKey[1:], contentId, content) case ContractByteCodeType: - return s.putContractBytecode(contentKey[1:], content) + return s.putContractBytecode(contentKey[1:], contentId, content) } return errors.New("unknown content type") } @@ -55,7 +55,7 @@ func (s *StateStorage) Radius() *uint256.Int { return s.store.Radius() } -func (s *StateStorage) putAccountTrieNode(contentKey []byte, content []byte) error { +func (s *StateStorage) putAccountTrieNode(contentKey []byte, contentId []byte, content []byte) error { accountKey := &AccountTrieNodeKey{} err := accountKey.Deserialize(codec.NewDecodingReader(bytes.NewReader(contentKey), uint64(len(contentKey)))) if err != nil { @@ -81,7 +81,6 @@ func (s *StateStorage) putAccountTrieNode(contentKey []byte, content []byte) err if err != nil { return err } - contentId := defaultContentIdFunc(contentKey) err = s.store.Put(contentId, contentId, contentValueBuf.Bytes()) if err != nil { s.log.Error("failed to save data after validate", "type", contentKey[0], "key", contentKey[1:], "value", content) @@ -89,7 +88,7 @@ func (s *StateStorage) putAccountTrieNode(contentKey []byte, content []byte) err return nil } -func (s *StateStorage) putContractStorageTrieNode(contentKey []byte, content []byte) error { +func (s *StateStorage) putContractStorageTrieNode(contentKey []byte, contentId []byte, content []byte) error { contractStorageKey := &ContractStorageTrieNodeKey{} err := contractStorageKey.Deserialize(codec.NewDecodingReader(bytes.NewReader(contentKey), uint64(len(contentKey)))) if err != nil { @@ -116,7 +115,6 @@ func (s *StateStorage) putContractStorageTrieNode(contentKey []byte, content []b if err != nil { return err } - contentId := defaultContentIdFunc(contentKey) err = s.store.Put(contentId, contentId, contentValueBuf.Bytes()) if err != nil { s.log.Error("failed to save data after validate", "type", contentKey[0], "key", contentKey[1:], "value", content) @@ -124,7 +122,7 @@ func (s *StateStorage) putContractStorageTrieNode(contentKey []byte, content []b return nil } -func (s *StateStorage) putContractBytecode(contentKey []byte, content []byte) error { +func (s *StateStorage) putContractBytecode(contentKey []byte, contentId []byte, content []byte) error { contractByteCodeKey := &ContractBytecodeKey{} err := contractByteCodeKey.Deserialize(codec.NewDecodingReader(bytes.NewReader(contentKey), uint64(len(contentKey)))) if err != nil { @@ -147,7 +145,6 @@ func (s *StateStorage) putContractBytecode(contentKey []byte, content []byte) er if err != nil { return err } - contentId := defaultContentIdFunc(contentKey) err = s.store.Put(contentId, contentId, contentValueBuf.Bytes()) if err != nil { s.log.Error("failed to save data after validate", "type", contentKey[0], "key", contentKey[1:], "value", content) diff --git a/portalnetwork/state/storage_test.go b/portalnetwork/state/storage_test.go index 0b1db78d8319..7212e14407f7 100644 --- a/portalnetwork/state/storage_test.go +++ b/portalnetwork/state/storage_test.go @@ -17,10 +17,10 @@ func TestStorage(t *testing.T) { require.NoError(t, err) for _, tt := range cases { contentKey := hexutil.MustDecode(tt.ContentKey) - contentId := defaultContentIdFunc(contentKey[1:]) + contentId := defaultContentIdFunc(contentKey) err = stateStorage.Put(contentKey, contentId, hexutil.MustDecode(tt.ContentValueOffer)) require.NoError(t, err) - res, err := stateStorage.Get(contentKey[1:], contentId) + res, err := stateStorage.Get(contentKey, contentId) require.NoError(t, err) require.Equal(t, hexutil.MustDecode(tt.ContentValueRetrieval), res) }