diff --git a/cmd/shisui/main.go b/cmd/shisui/main.go index e4d3065dacd3..6ad4880ffbdc 100644 --- a/cmd/shisui/main.go +++ b/cmd/shisui/main.go @@ -343,9 +343,10 @@ func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNo if err != nil { return nil, err } + stateStore := state.NewStateStorage(contentStorage) contentQueue := make(chan *discover.ContentElement, 50) - protocol, err := discover.NewPortalProtocol(config.Protocol, portalwire.State, config.PrivateKey, conn, localNode, discV5, contentStorage, contentQueue) + protocol, err := discover.NewPortalProtocol(config.Protocol, portalwire.State, config.PrivateKey, conn, localNode, discV5, stateStore, contentQueue) if err != nil { return nil, err diff --git a/p2p/discover/portal_protocol_test.go b/p2p/discover/portal_protocol_test.go index 45eda7dc9a2a..065e3e804c52 100644 --- a/p2p/discover/portal_protocol_test.go +++ b/p2p/discover/portal_protocol_test.go @@ -243,6 +243,9 @@ func TestPortalWireProtocolUdp(t *testing.T) { assert.Equal(t, largeTestContent, data) }() workGroup.Wait() + node1.Stop() + node2.Stop() + node3.Stop() } func TestPortalWireProtocol(t *testing.T) { diff --git a/portalnetwork/state/network.go b/portalnetwork/state/network.go index be3c191d68a8..ebb8fccd9879 100644 --- a/portalnetwork/state/network.go +++ b/portalnetwork/state/network.go @@ -96,7 +96,10 @@ func (h *StateNetwork) validateContents(contentKeys [][]byte, contents [][]byte) return fmt.Errorf("content validate failed with content key %x and content %x", contentKey, content) } contentId := h.portalProtocol.ToContentId(contentKey) - _ = h.portalProtocol.Put(contentKey, contentId, content) + err = h.portalProtocol.Put(contentKey, contentId, content) + if err != nil { + return err + } } return nil } diff --git a/portalnetwork/state/storage.go b/portalnetwork/state/storage.go new file mode 100644 index 000000000000..d07b2c3d0783 --- /dev/null +++ b/portalnetwork/state/storage.go @@ -0,0 +1,156 @@ +package state + +import ( + "bytes" + "crypto/sha256" + "errors" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/holiman/uint256" + "github.com/protolambda/ztyp/codec" +) + +func defaultContentIdFunc(contentKey []byte) []byte { + digest := sha256.Sum256(contentKey) + return digest[:] +} + +var _ storage.ContentStorage = &StateStorage{} + +type StateStorage struct { + store storage.ContentStorage + log log.Logger +} + +func NewStateStorage(store storage.ContentStorage) *StateStorage { + return &StateStorage{ + store: store, + log: log.New("storage", "state"), + } +} + +// Get implements storage.ContentStorage. +func (s *StateStorage) Get(contentKey []byte, contentId []byte) ([]byte, error) { + return s.store.Get(contentKey, contentId) +} + +// Put implements storage.ContentStorage. +func (s *StateStorage) Put(contentKey []byte, contentId []byte, content []byte) error { + keyType := contentKey[0] + switch keyType { + case AccountTrieNodeType: + return s.putAccountTrieNode(contentKey[1:], content) + case ContractStorageTrieNodeType: + return s.putContractStorageTrieNode(contentKey[1:], content) + case ContractByteCodeType: + return s.putContractBytecode(contentKey[1:], content) + } + return errors.New("unknown content type") +} + +// Radius implements storage.ContentStorage. +func (s *StateStorage) Radius() *uint256.Int { + return s.store.Radius() +} + +func (s *StateStorage) putAccountTrieNode(contentKey []byte, content []byte) error { + accountKey := &AccountTrieNodeKey{} + err := accountKey.Deserialize(codec.NewDecodingReader(bytes.NewReader(contentKey), uint64(len(contentKey)))) + if err != nil { + return err + } + accountData := &AccountTrieNodeWithProof{} + err = accountData.Deserialize(codec.NewDecodingReader(bytes.NewReader(content), uint64(len(content)))) + if err != nil { + return err + } + length := len(accountData.Proof) + lastProof := accountData.Proof[length-1] + + lastNodeHash := crypto.Keccak256(lastProof) + if !bytes.Equal(lastNodeHash, accountKey.NodeHash[:]) { + return errors.New("hash of the trie node doesn't match key's node_hash") + } + lastTrieNode := &TrieNode{ + Node: lastProof, + } + var contentValueBuf bytes.Buffer + err = lastTrieNode.Serialize(codec.NewEncodingWriter(&contentValueBuf)) + if err != nil { + return err + } + contentId := defaultContentIdFunc(contentKey) + err = s.store.Put(contentId, contentId, contentValueBuf.Bytes()) + if err != nil { + s.log.Error("failed to save data after validate", "type", contentKey[0], "key", contentKey[1:], "value", content) + } + return nil +} + +func (s *StateStorage) putContractStorageTrieNode(contentKey []byte, content []byte) error { + contractStorageKey := &ContractStorageTrieNodeKey{} + err := contractStorageKey.Deserialize(codec.NewDecodingReader(bytes.NewReader(contentKey), uint64(len(contentKey)))) + if err != nil { + return err + } + contractProof := &ContractStorageTrieNodeWithProof{} + err = contractProof.Deserialize(codec.NewDecodingReader(bytes.NewReader(content), uint64(len(content)))) + if err != nil { + return err + } + length := len(contractProof.StoregeProof) + lastProof := contractProof.StoregeProof[length-1] + + lastNodeHash := crypto.Keccak256(lastProof) + if !bytes.Equal(lastNodeHash, contractStorageKey.NodeHash[:]) { + return errors.New("hash of the contract storage node doesn't match key's node hash") + } + + lastTrieNode := &TrieNode{ + Node: lastProof, + } + var contentValueBuf bytes.Buffer + err = lastTrieNode.Serialize(codec.NewEncodingWriter(&contentValueBuf)) + if err != nil { + return err + } + contentId := defaultContentIdFunc(contentKey) + err = s.store.Put(contentId, contentId, contentValueBuf.Bytes()) + if err != nil { + s.log.Error("failed to save data after validate", "type", contentKey[0], "key", contentKey[1:], "value", content) + } + return nil +} + +func (s *StateStorage) putContractBytecode(contentKey []byte, content []byte) error { + contractByteCodeKey := &ContractBytecodeKey{} + err := contractByteCodeKey.Deserialize(codec.NewDecodingReader(bytes.NewReader(contentKey), uint64(len(contentKey)))) + if err != nil { + return err + } + contractBytecodeWithProof := &ContractBytecodeWithProof{} + err = contractBytecodeWithProof.Deserialize(codec.NewDecodingReader(bytes.NewReader(content), uint64(len(content)))) + if err != nil { + return err + } + codeHash := crypto.Keccak256(contractBytecodeWithProof.Code) + if !bytes.Equal(codeHash, contractByteCodeKey.CodeHash[:]) { + return errors.New("hash of the contract byte doesn't match key's code hash") + } + container := &ContractBytecodeContainer{ + Code: contractBytecodeWithProof.Code, + } + var contentValueBuf bytes.Buffer + err = container.Serialize(codec.NewEncodingWriter(&contentValueBuf)) + if err != nil { + return err + } + contentId := defaultContentIdFunc(contentKey) + err = s.store.Put(contentId, contentId, contentValueBuf.Bytes()) + if err != nil { + s.log.Error("failed to save data after validate", "type", contentKey[0], "key", contentKey[1:], "value", content) + } + return nil +} diff --git a/portalnetwork/state/storage_test.go b/portalnetwork/state/storage_test.go new file mode 100644 index 000000000000..0b1db78d8319 --- /dev/null +++ b/portalnetwork/state/storage_test.go @@ -0,0 +1,28 @@ +package state + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/stretchr/testify/require" +) + +func TestStorage(t *testing.T) { + storage := storage.NewMockStorage() + stateStorage := NewStateStorage(storage) + testfiles := []string{"account_trie_node.yaml", "contract_storage_trie_node.yaml", "contract_bytecode.yaml"} + for _, file := range testfiles { + cases, err := getTestCases(file) + require.NoError(t, err) + for _, tt := range cases { + contentKey := hexutil.MustDecode(tt.ContentKey) + contentId := defaultContentIdFunc(contentKey[1:]) + err = stateStorage.Put(contentKey, contentId, hexutil.MustDecode(tt.ContentValueOffer)) + require.NoError(t, err) + res, err := stateStorage.Get(contentKey[1:], contentId) + require.NoError(t, err) + require.Equal(t, hexutil.MustDecode(tt.ContentValueRetrieval), res) + } + } +} diff --git a/portalnetwork/storage/content_storage.go b/portalnetwork/storage/content_storage.go index 21dce49f2e64..3a01df93f6ed 100644 --- a/portalnetwork/storage/content_storage.go +++ b/portalnetwork/storage/content_storage.go @@ -43,6 +43,12 @@ type MockStorage struct { Db map[string][]byte } +func NewMockStorage() ContentStorage { + return &MockStorage{ + Db: make(map[string][]byte), + } +} + func (m *MockStorage) Get(contentKey []byte, contentId []byte) ([]byte, error) { if content, ok := m.Db[string(contentId)]; ok { return content, nil