diff --git a/chain/db.go b/chain/db.go index eb78fa9..2079e8e 100644 --- a/chain/db.go +++ b/chain/db.go @@ -13,54 +13,28 @@ import ( ) type supplementedBlock struct { - Header *types.BlockHeader + Header types.BlockHeader Block *types.Block Supplement *consensus.V1BlockSupplement } func (sb supplementedBlock) EncodeTo(e *types.Encoder) { - e.WriteUint8(3) - types.EncodePtr(e, sb.Header) + sb.Header.EncodeTo(e) types.EncodePtr(e, (*types.V2Block)(sb.Block)) types.EncodePtr(e, sb.Supplement) } func (sb *supplementedBlock) DecodeFrom(d *types.Decoder) { - switch v := d.ReadUint8(); v { - case 2: - sb.Header = nil - sb.Block = new(types.Block) - (*types.V2Block)(sb.Block).DecodeFrom(d) - types.DecodePtr(d, &sb.Supplement) - case 3: - types.DecodePtr(d, &sb.Header) - types.DecodePtrCast[types.V2Block](d, &sb.Block) - types.DecodePtr(d, &sb.Supplement) - default: - d.SetErr(fmt.Errorf("incompatible version (%d)", v)) - } -} - -type versionedState struct { - State consensus.State -} - -func (vs versionedState) EncodeTo(e *types.Encoder) { - e.WriteUint8(2) - vs.State.EncodeTo(e) -} - -func (vs *versionedState) DecodeFrom(d *types.Decoder) { - if v := d.ReadUint8(); v != 2 { - d.SetErr(fmt.Errorf("incompatible version (%d)", v)) - } - vs.State.DecodeFrom(d) + sb.Header.DecodeFrom(d) + types.DecodePtrCast[types.V2Block](d, &sb.Block) + types.DecodePtr(d, &sb.Supplement) } // A DB is a generic key-value database. type DB interface { Bucket(name []byte) DBBucket CreateBucket(name []byte) (DBBucket, error) + BucketKeys(name []byte) [][]byte Flush() error Cancel() } @@ -163,6 +137,18 @@ func (db *MemDB) CreateBucket(name []byte) (DBBucket, error) { return db.Bucket(name), nil } +// BucketKeys implements DB. +func (db *MemDB) BucketKeys(name []byte) [][]byte { + keys := make([][]byte, 0, len(db.buckets[string(name)])+len(db.puts[string(name)])) + for key := range db.buckets[string(name)] { + keys = append(keys, []byte(key)) + } + for key := range db.puts[string(name)] { + keys = append(keys, []byte(key)) + } + return keys +} + type memBucket struct { name string db *MemDB @@ -201,7 +187,7 @@ func (b *dbBucket) getRaw(key []byte) []byte { } func (b *dbBucket) get(key []byte, v types.DecoderFrom) bool { - val := b.getRaw(key) + val := b.getRaw(b.db.vkey(key)) if val == nil { return false } @@ -224,13 +210,17 @@ func (b *dbBucket) put(key []byte, v types.EncoderTo) { b.db.enc.Reset(&buf) v.EncodeTo(&b.db.enc) b.db.enc.Flush() - b.putRaw(key, buf.Bytes()) + b.putRaw(b.db.vkey(key), buf.Bytes()) } -func (b *dbBucket) delete(key []byte) { +func (b *dbBucket) deleteRaw(key []byte) { check(b.b.Delete(key)) } +func (b *dbBucket) delete(key []byte) { + b.deleteRaw(b.db.vkey(key)) +} + var ( bVersion = []byte("Version") bMainChain = []byte("MainChain") @@ -246,14 +236,21 @@ var ( // DBStore implements Store using a key-value database. type DBStore struct { - db DB - n *consensus.Network // for getState - enc types.Encoder + db DB + n *consensus.Network // for getState + version uint8 + keyBuf []byte + enc types.Encoder unflushed int lastFlush time.Time } +func (db *DBStore) vkey(key []byte) []byte { + db.keyBuf = append(db.keyBuf[:0], key...) + return append(db.keyBuf, db.version) +} + func (db *DBStore) bucket(name []byte) *dbBucket { return &dbBucket{db.db.Bucket(name), db} } @@ -272,55 +269,38 @@ func (db *DBStore) deleteBestIndex(height uint64) { } func (db *DBStore) getHeight() (height uint64) { - if val := db.bucket(bMainChain).getRaw(keyHeight); len(val) == 8 { + if val := db.bucket(bMainChain).getRaw(db.vkey(keyHeight)); len(val) == 8 { height = binary.BigEndian.Uint64(val) } return } func (db *DBStore) putHeight(height uint64) { - db.bucket(bMainChain).putRaw(keyHeight, db.encHeight(height)) + db.bucket(bMainChain).putRaw(db.vkey(keyHeight), db.encHeight(height)) } -func (db *DBStore) getState(id types.BlockID) (consensus.State, bool) { - var vs versionedState - ok := db.bucket(bStates).get(id[:], &vs) - vs.State.Network = db.n - return vs.State, ok +func (db *DBStore) getState(id types.BlockID) (cs consensus.State, ok bool) { + ok = db.bucket(bStates).get(id[:], &cs) + cs.Network = db.n + return } func (db *DBStore) putState(cs consensus.State) { - db.bucket(bStates).put(cs.Index.ID[:], versionedState{cs}) + db.bucket(bStates).put(cs.Index.ID[:], cs) } -func (db *DBStore) getBlock(id types.BlockID) (bh types.BlockHeader, b *types.Block, bs *consensus.V1BlockSupplement, _ bool) { - var sb supplementedBlock - if ok := db.bucket(bBlocks).get(id[:], &sb); !ok { - return types.BlockHeader{}, nil, nil, false - } else if sb.Header == nil { - sb.Header = new(types.BlockHeader) - *sb.Header = sb.Block.Header() - } - return *sb.Header, sb.Block, sb.Supplement, true +func (db *DBStore) getBlock(id types.BlockID) (sb supplementedBlock, ok bool) { + ok = db.bucket(bBlocks).get(id[:], &sb) + return } -func (db *DBStore) putBlock(bh types.BlockHeader, b *types.Block, bs *consensus.V1BlockSupplement) { - id := bh.ID() - db.bucket(bBlocks).put(id[:], supplementedBlock{&bh, b, bs}) +func (db *DBStore) putBlock(sb supplementedBlock) { + id := sb.Header.ID() + db.bucket(bBlocks).put(id[:], sb) } func (db *DBStore) getAncestorInfo(id types.BlockID) (parentID types.BlockID, timestamp time.Time, ok bool) { ok = db.bucket(bBlocks).get(id[:], types.DecoderFunc(func(d *types.Decoder) { - v := d.ReadUint8() - if v != 2 && v != 3 { - d.SetErr(fmt.Errorf("incompatible version (%d)", v)) - } - // kinda cursed; don't worry about it - if v == 3 { - if !d.ReadBool() { - d.ReadBool() - } - } parentID.DecodeFrom(d) _ = d.ReadUint64() // nonce timestamp = d.ReadTime() @@ -328,30 +308,6 @@ func (db *DBStore) getAncestorInfo(id types.BlockID) (parentID types.BlockID, ti return } -func (db *DBStore) getBlockHeader(id types.BlockID) (bh types.BlockHeader, ok bool) { - ok = db.bucket(bBlocks).get(id[:], types.DecoderFunc(func(d *types.Decoder) { - v := d.ReadUint8() - if v != 2 && v != 3 { - d.SetErr(fmt.Errorf("incompatible version (%d)", v)) - return - } - if v == 3 { - bhp := &bh - types.DecodePtr(d, &bhp) - if bhp != nil { - return - } else if !d.ReadBool() { - d.SetErr(errors.New("neither header nor block present")) - return - } - } - var b types.Block - (*types.V2Block)(&b).DecodeFrom(d) - bh = b.Header() - })) - return -} - func (db *DBStore) treeKey(row, col uint64) []byte { // If we assume that the total number of elements is less than 2^32, we can // pack row and col into one uint32 key. We do this by setting the top 'row' @@ -433,13 +389,13 @@ func (db *DBStore) deleteFileContractElement(id types.FileContractID) { func (db *DBStore) putFileContractExpiration(id types.FileContractID, windowEnd uint64) { b := db.bucket(bFileContractElements) - key := db.encHeight(windowEnd) + key := db.vkey(db.encHeight(windowEnd)) b.putRaw(key, append(b.getRaw(key), id[:]...)) } func (db *DBStore) deleteFileContractExpiration(id types.FileContractID, windowEnd uint64) { b := db.bucket(bFileContractElements) - key := db.encHeight(windowEnd) + key := db.vkey(db.encHeight(windowEnd)) val := append([]byte(nil), b.getRaw(key)...) for i := 0; i < len(val); i += 32 { if *(*types.FileContractID)(val[i:]) == id { @@ -625,7 +581,7 @@ func (db *DBStore) SupplementTipBlock(b types.Block) (bs consensus.V1BlockSupple for i, txn := range b.Transactions { bs.Transactions[i] = db.SupplementTipTransaction(txn) } - ids := db.bucket(bFileContractElements).getRaw(db.encHeight(db.getHeight() + 1)) + ids := db.bucket(bFileContractElements).getRaw(db.vkey(db.encHeight(db.getHeight() + 1))) for i := 0; i < len(ids); i += 32 { fce, ok := db.getFileContractElement(*(*types.FileContractID)(ids[i:]), numLeaves) if !ok { @@ -677,22 +633,24 @@ func (db *DBStore) AddState(cs consensus.State) { // Block implements Store. func (db *DBStore) Block(id types.BlockID) (types.Block, *consensus.V1BlockSupplement, bool) { - _, b, bs, ok := db.getBlock(id) - if !ok || b == nil { + sb, ok := db.getBlock(id) + if !ok || sb.Block == nil { return types.Block{}, nil, false } - return *b, bs, ok + return *sb.Block, sb.Supplement, ok } // AddBlock implements Store. func (db *DBStore) AddBlock(b types.Block, bs *consensus.V1BlockSupplement) { - db.putBlock(b.Header(), &b, bs) + db.putBlock(supplementedBlock{b.Header(), &b, bs}) } // PruneBlock implements Store. func (db *DBStore) PruneBlock(id types.BlockID) { - if bh, _, _, ok := db.getBlock(id); ok { - db.putBlock(bh, nil, nil) + if sb, ok := db.getBlock(id); ok { + sb.Block = nil + sb.Supplement = nil + db.putBlock(sb) } } @@ -757,13 +715,12 @@ func NewDBStore(db DB, n *consensus.Network, genesisBlock types.Block) (_ *DBSto } dbs := &DBStore{ - db: db, - n: n, + db: db, + n: n, + version: 2, } - - // if the db is empty, initialize it; otherwise, check that the genesis - // block is correct - if dbGenesis, ok := dbs.BestIndex(0); !ok { + if version := dbs.bucket(bVersion).getRaw(dbs.vkey(bVersion)); len(version) != 1 { + // initialize empty database for _, bucket := range [][]byte{ bVersion, bMainChain, @@ -778,29 +735,34 @@ func NewDBStore(db DB, n *consensus.Network, genesisBlock types.Block) (_ *DBSto panic(err) } } - dbs.bucket(bVersion).putRaw(bVersion, []byte{1}) + dbs.bucket(bVersion).putRaw(dbs.vkey(bVersion), []byte{dbs.version}) // store genesis state and apply genesis block to it genesisState := n.GenesisState() dbs.putState(genesisState) bs := consensus.V1BlockSupplement{Transactions: make([]consensus.V1TransactionSupplement, len(genesisBlock.Transactions))} cs, cau := consensus.ApplyBlock(genesisState, genesisBlock, bs, time.Time{}) - dbs.putBlock(genesisBlock.Header(), &genesisBlock, &bs) + dbs.putBlock(supplementedBlock{genesisBlock.Header(), &genesisBlock, &bs}) dbs.putState(cs) dbs.ApplyBlock(cs, cau) if err := dbs.Flush(); err != nil { return nil, consensus.State{}, err } - } else if dbGenesis.ID != genesisBlock.ID() { - // try to detect network so we can provide a more helpful error message - _, mainnetGenesis := Mainnet() - _, zenGenesis := TestnetZen() - if genesisBlock.ID() == mainnetGenesis.ID() && dbGenesis.ID == zenGenesis.ID() { - return nil, consensus.State{}, errors.New("cannot use Zen testnet database on mainnet") - } else if genesisBlock.ID() == zenGenesis.ID() && dbGenesis.ID == mainnetGenesis.ID() { - return nil, consensus.State{}, errors.New("cannot use mainnet database on Zen testnet") - } else { - return nil, consensus.State{}, errors.New("database previously initialized with different genesis block") + } else if version[0] != dbs.version { + return nil, consensus.State{}, errors.New("incompatible version; please migrate the database") + } else { + // verify the genesis block + if dbGenesis, ok := dbs.BestIndex(0); !ok || dbGenesis.ID != genesisBlock.ID() { + // try to detect network so we can provide a more helpful error message + _, mainnetGenesis := Mainnet() + _, zenGenesis := TestnetZen() + if genesisBlock.ID() == mainnetGenesis.ID() && dbGenesis.ID == zenGenesis.ID() { + return nil, consensus.State{}, errors.New("cannot use Zen testnet database on mainnet") + } else if genesisBlock.ID() == zenGenesis.ID() && dbGenesis.ID == mainnetGenesis.ID() { + return nil, consensus.State{}, errors.New("cannot use mainnet database on Zen testnet") + } else { + return nil, consensus.State{}, errors.New("database previously initialized with different genesis block") + } } } diff --git a/chain/migrate.go b/chain/migrate.go new file mode 100644 index 0000000..0350db6 --- /dev/null +++ b/chain/migrate.go @@ -0,0 +1,188 @@ +package chain + +import ( + "fmt" + + "go.sia.tech/core/consensus" + "go.sia.tech/core/types" +) + +type oldSiacoinElement types.SiacoinElement + +func (oldSiacoinElement) Cast() (sce types.SiacoinElement) { return } + +func (sce *oldSiacoinElement) DecodeFrom(d *types.Decoder) { + sce.ID.DecodeFrom(d) + sce.StateElement.DecodeFrom(d) + (*types.V2SiacoinOutput)(&sce.SiacoinOutput).DecodeFrom(d) + sce.MaturityHeight = d.ReadUint64() +} + +type oldSiafundElement types.SiafundElement + +func (oldSiafundElement) Cast() (sfe types.SiafundElement) { return } + +func (sfe *oldSiafundElement) DecodeFrom(d *types.Decoder) { + sfe.ID.DecodeFrom(d) + sfe.StateElement.DecodeFrom(d) + (*types.V2SiafundOutput)(&sfe.SiafundOutput).DecodeFrom(d) + (*types.V2Currency)(&sfe.ClaimStart).DecodeFrom(d) +} + +type oldFileContractElement types.FileContractElement + +func (oldFileContractElement) Cast() (fce types.FileContractElement) { return } + +func (fce *oldFileContractElement) DecodeFrom(d *types.Decoder) { + fce.ID.DecodeFrom(d) + fce.StateElement.DecodeFrom(d) + fce.FileContract.DecodeFrom(d) +} + +type oldTransactionSupplement consensus.V1TransactionSupplement + +func (oldTransactionSupplement) Cast() (ts consensus.V1TransactionSupplement) { return } + +func (ts *oldTransactionSupplement) DecodeFrom(d *types.Decoder) { + types.DecodeSliceCast[oldSiacoinElement](d, &ts.SiacoinInputs) + types.DecodeSliceCast[oldSiafundElement](d, &ts.SiafundInputs) + types.DecodeSliceCast[oldFileContractElement](d, &ts.RevisedFileContracts) + types.DecodeSliceFn(d, &ts.StorageProofs, func(d *types.Decoder) (sp consensus.V1StorageProofSupplement) { + (*oldFileContractElement)(&sp.FileContract).DecodeFrom(d) + return + }) +} + +type oldBlockSupplement consensus.V1BlockSupplement + +func (oldBlockSupplement) Cast() (bs consensus.V1BlockSupplement) { return } + +func (bs *oldBlockSupplement) DecodeFrom(d *types.Decoder) { + types.DecodeSliceCast[oldTransactionSupplement](d, &bs.Transactions) + types.DecodeSliceCast[oldFileContractElement](d, &bs.ExpiringFileContracts) +} + +type oldSupplementedBlock supplementedBlock + +func (sb *oldSupplementedBlock) DecodeFrom(d *types.Decoder) { + if v := d.ReadUint8(); v != 2 { + d.SetErr(fmt.Errorf("incompatible version (%d)", v)) + } + var b types.Block + (*types.V2Block)(&b).DecodeFrom(d) + sb.Block = &b + sb.Header = b.Header() + types.DecodePtrCast[oldBlockSupplement](d, &sb.Supplement) +} + +type versionedState consensus.State + +func (vs *versionedState) DecodeFrom(d *types.Decoder) { + if v := d.ReadUint8(); v != 2 { + d.SetErr(fmt.Errorf("incompatible version (%d)", v)) + } + (*consensus.State)(vs).DecodeFrom(d) +} + +// MigrateDB upgrades the database to the latest version. +func MigrateDB(db DB, n *consensus.Network) error { + if db.Bucket(bVersion) == nil { + return nil // nothing to migrate + } + dbs := &DBStore{ + db: db, + n: n, + version: 2, + } + + version := dbs.bucket(bVersion).getRaw(dbs.vkey(bVersion)) + if version == nil { + version = []byte{1} + } + switch version[0] { + case 1: + var err error + addVersion := func(bucket []byte, key []byte) { + if err != nil { + return + } + b := dbs.bucket(bucket) + b.putRaw(dbs.vkey(key), b.getRaw(key)) + b.deleteRaw(key) + } + rewrite := func(bucket []byte, key []byte, from types.DecoderFrom, to types.EncoderTo) { + if err != nil { + return + } + b := dbs.bucket(bucket) + val := b.getRaw(key) + d := types.NewBufDecoder(val) + from.DecodeFrom(d) + if d.Err() != nil { + err = d.Err() + return + } + b.deleteRaw(key) + b.put(key, to) + if dbs.shouldFlush() { + dbs.Flush() + } + } + + var sb supplementedBlock + for _, key := range db.BucketKeys(bBlocks) { + if len(key) == 32 { + rewrite(bBlocks, key, (*oldSupplementedBlock)(&sb), &sb) + } + } + var cs consensus.State + for _, key := range db.BucketKeys(bStates) { + if len(key) == 32 { + rewrite(bStates, key, (*versionedState)(&cs), &cs) + } + } + var sce types.SiacoinElement + for _, key := range db.BucketKeys(bSiacoinElements) { + if len(key) == 32 { + rewrite(bSiacoinElements, key, (*oldSiacoinElement)(&sce), &sce) + } + } + var sfe types.SiafundElement + for _, key := range db.BucketKeys(bSiafundElements) { + if len(key) == 32 { + rewrite(bSiafundElements, key, (*oldSiafundElement)(&sfe), &sfe) + } + } + var fce types.FileContractElement + for _, key := range db.BucketKeys(bFileContractElements) { + if len(key) == 32 { + rewrite(bFileContractElements, key, (*oldFileContractElement)(&fce), &fce) + } else if len(key) == 8 { + addVersion(bFileContractElements, key) + } + } + for _, key := range db.BucketKeys(bMainChain) { + if len(key) == 8 || len(key) == 5 { + addVersion(bMainChain, key) + } + } + for _, key := range db.BucketKeys(bTree) { + if len(key) == 4 { + addVersion(bTree, key) + } + } + dbs.bucket(bVersion).deleteRaw(bVersion) + dbs.bucket(bVersion).putRaw(dbs.vkey(bVersion), []byte{2}) + + if err != nil { + return err + } + dbs.Flush() + fallthrough + case dbs.version: + // up-to-date + return nil + default: + return fmt.Errorf("unrecognized version (%d)", version[0]) + } +} diff --git a/db.go b/db.go index 0578085..9b62374 100644 --- a/db.go +++ b/db.go @@ -42,6 +42,19 @@ func (db *BoltChainDB) CreateBucket(name []byte) (chain.DBBucket, error) { return db.tx.CreateBucket(name) } +// BucketKeys implements chain.DB. +func (db *BoltChainDB) BucketKeys(name []byte) [][]byte { + if err := db.newTx(); err != nil { + panic(err) + } + var keys [][]byte + c := db.tx.Bucket(name).Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + keys = append(keys, append([]byte(nil), k...)) + } + return keys +} + // Flush implements chain.DB. func (db *BoltChainDB) Flush() error { if db.tx == nil {