diff --git a/core/store/ledgerstore/block_store.go b/core/store/ledgerstore/block_store.go index 743500e553..b1b01c59b6 100644 --- a/core/store/ledgerstore/block_store.go +++ b/core/store/ledgerstore/block_store.go @@ -249,27 +249,13 @@ func (this *BlockStore) loadRawHeader(blockHash common.Uint256) (*types.RawHeade if err != nil { return nil, err } - height, err := extractHeaderHeight(source) + header := &types.RawHeader{} + err = header.Deserialization(source) if err != nil { return nil, err } - header := &types.RawHeader{ - Height: height, - Payload: value[8:], - } - return header, nil -} -func extractHeaderHeight(source *common.ZeroCopySource) (uint32, error) { - eof := source.Skip(4 + 32 + 32 + 32 + 4) - if eof { - return 0, fmt.Errorf("[exactHeaderHeight] NextUint32 error") - } - height, eof := source.NextUint32() - if eof { - return 0, fmt.Errorf("[exactHeaderHeight] NextUint32 error") - } - return height, nil + return header, nil } //GetCurrentBlock return the current block hash and current block height diff --git a/core/store/ledgerstore/block_store_test.go b/core/store/ledgerstore/block_store_test.go index 2157e8e4d6..0317eca6ec 100644 --- a/core/store/ledgerstore/block_store_test.go +++ b/core/store/ledgerstore/block_store_test.go @@ -47,9 +47,10 @@ func TestExtractHeaderHeight(t *testing.T) { sink := common.NewZeroCopySink(nil) header.Serialization(sink) source := common.NewZeroCopySource(sink.Bytes()) - height, err := extractHeaderHeight(source) + raw := types.RawHeader{} + err := raw.Deserialization(source) assert.Nil(t, err) - assert.Equal(t, uint32(99999), height) + assert.Equal(t, uint32(99999), raw.Height) } func TestVersion(t *testing.T) { diff --git a/core/types/header.go b/core/types/header.go index edca46d37f..73ba3cb426 100644 --- a/core/types/header.go +++ b/core/types/header.go @@ -31,6 +31,80 @@ type RawHeader struct { Payload []byte } +func (self *RawHeader) Serialization(sink *common.ZeroCopySink) { + sink.WriteBytes(self.Payload) +} + +// note: can only be called when source is trusted, like data from local ledger store +func (self *RawHeader) Deserialization(source *common.ZeroCopySource) error { + pstart := source.Pos() + err := self.deserializationUnsigned(source) + if err != nil { + return err + } + + n, _, irregular, eof := source.NextVarUint() + if eof { + return io.ErrUnexpectedEOF + } + if irregular { + return common.ErrIrregularData + } + + for i := 0; i < int(n); i++ { + _, _, irregular, eof := source.NextVarBytes() + if eof { + return io.ErrUnexpectedEOF + } + if irregular { + return common.ErrIrregularData + } + } + + m, _, irregular, eof := source.NextVarUint() + if eof { + return io.ErrUnexpectedEOF + } + if irregular { + return common.ErrIrregularData + } + + for i := 0; i < int(m); i++ { + _, _, irregular, eof := source.NextVarBytes() + if eof { + return io.ErrUnexpectedEOF + } + if irregular { + return common.ErrIrregularData + } + } + plen := source.Pos() - pstart + source.BackUp(plen) + self.Payload, _ = source.NextBytes(plen) + + return nil +} + +func (self *RawHeader) deserializationUnsigned(source *common.ZeroCopySource) error { + // version + preHash + tx root + block root + timestamp + source.Skip(4 + 32*3 + 4) + self.Height, _ = source.NextUint32() + //ConsensusData uint64 + source.Skip(8) + // ConsensusPayload + _, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData + } + + // next bookkeeper + eof = source.Skip(20) + if eof { + return io.ErrUnexpectedEOF + } + return nil +} + func (hd *Header) GetRawHeader() *RawHeader { sink := common.NewZeroCopySink(nil) hd.Serialization(sink) diff --git a/core/types/header_test.go b/core/types/header_test.go index 87bc4abb21..eace312f52 100644 --- a/core/types/header_test.go +++ b/core/types/header_test.go @@ -38,7 +38,14 @@ func TestHeader_Serialize(t *testing.T) { var h2 Header source := common.NewZeroCopySource(bs) err := h2.Deserialization(source) + assert.Nil(t, err) assert.Equal(t, fmt.Sprint(header), fmt.Sprint(h2)) + var h3 RawHeader + source = common.NewZeroCopySource(bs) + err = h3.Deserialization(source) assert.Nil(t, err) + assert.Equal(t, header.Height, h3.Height) + assert.Equal(t, bs, h3.Payload) + } diff --git a/p2pserver/message/types/block_header.go b/p2pserver/message/types/block_header.go index 2b081c97b8..45c7bfff9a 100644 --- a/p2pserver/message/types/block_header.go +++ b/p2pserver/message/types/block_header.go @@ -39,7 +39,7 @@ func (this *RawBlockHeader) Serialization(sink *common.ZeroCopySink) { sink.WriteUint32(uint32(len(this.BlkHdr))) for _, header := range this.BlkHdr { - sink.WriteBytes(header.Payload) + header.Serialization(sink) } } func (this *RawBlockHeader) Deserialization(source *common.ZeroCopySource) error {