Skip to content

Commit

Permalink
chain: Implement DB migration
Browse files Browse the repository at this point in the history
  • Loading branch information
lukechampine authored and n8maninger committed Dec 18, 2024
1 parent 43f6ab8 commit 5545b6c
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 91 deletions.
129 changes: 38 additions & 91 deletions chain/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,54 +13,28 @@ import (
)

type supplementedBlock struct {
Header *types.BlockHeader
Header types.BlockHeader
Block *types.Block
Supplement *consensus.V1BlockSupplement
}

func (sb supplementedBlock) EncodeTo(e *types.Encoder) {
e.WriteUint8(3)
types.EncodePtr(e, sb.Header)
sb.Header.EncodeTo(e)
types.EncodePtr(e, (*types.V2Block)(sb.Block))
types.EncodePtr(e, sb.Supplement)
}

func (sb *supplementedBlock) DecodeFrom(d *types.Decoder) {
switch v := d.ReadUint8(); v {
case 2:
sb.Header = nil
sb.Block = new(types.Block)
(*types.V2Block)(sb.Block).DecodeFrom(d)
types.DecodePtr(d, &sb.Supplement)
case 3:
types.DecodePtr(d, &sb.Header)
types.DecodePtrCast[types.V2Block](d, &sb.Block)
types.DecodePtr(d, &sb.Supplement)
default:
d.SetErr(fmt.Errorf("incompatible version (%d)", v))
}
}

type versionedState struct {
State consensus.State
}

func (vs versionedState) EncodeTo(e *types.Encoder) {
e.WriteUint8(2)
vs.State.EncodeTo(e)
}

func (vs *versionedState) DecodeFrom(d *types.Decoder) {
if v := d.ReadUint8(); v != 2 {
d.SetErr(fmt.Errorf("incompatible version (%d)", v))
}
vs.State.DecodeFrom(d)
sb.Header.DecodeFrom(d)
types.DecodePtrCast[types.V2Block](d, &sb.Block)
types.DecodePtr(d, &sb.Supplement)
}

// A DB is a generic key-value database.
type DB interface {
Bucket(name []byte) DBBucket
CreateBucket(name []byte) (DBBucket, error)
BucketKeys(name []byte) [][]byte
Flush() error
Cancel()
}
Expand Down Expand Up @@ -163,6 +137,18 @@ func (db *MemDB) CreateBucket(name []byte) (DBBucket, error) {
return db.Bucket(name), nil
}

// BucketKeys implements DB.
func (db *MemDB) BucketKeys(name []byte) [][]byte {
keys := make([][]byte, 0, len(db.buckets[string(name)])+len(db.puts[string(name)]))
for key := range db.buckets[string(name)] {
keys = append(keys, []byte(key))
}
for key := range db.puts[string(name)] {
keys = append(keys, []byte(key))
}
return keys
}

type memBucket struct {
name string
db *MemDB
Expand Down Expand Up @@ -282,76 +268,35 @@ func (db *DBStore) putHeight(height uint64) {
db.bucket(bMainChain).putRaw(keyHeight, db.encHeight(height))
}

func (db *DBStore) getState(id types.BlockID) (consensus.State, bool) {
var vs versionedState
ok := db.bucket(bStates).get(id[:], &vs)
vs.State.Network = db.n
return vs.State, ok
func (db *DBStore) getState(id types.BlockID) (cs consensus.State, ok bool) {
ok = db.bucket(bStates).get(id[:], &cs)
cs.Network = db.n
return
}

func (db *DBStore) putState(cs consensus.State) {
db.bucket(bStates).put(cs.Index.ID[:], versionedState{cs})
db.bucket(bStates).put(cs.Index.ID[:], cs)
}

func (db *DBStore) getBlock(id types.BlockID) (bh types.BlockHeader, b *types.Block, bs *consensus.V1BlockSupplement, _ bool) {
var sb supplementedBlock
if ok := db.bucket(bBlocks).get(id[:], &sb); !ok {
return types.BlockHeader{}, nil, nil, false
} else if sb.Header == nil {
sb.Header = new(types.BlockHeader)
*sb.Header = sb.Block.Header()
}
return *sb.Header, sb.Block, sb.Supplement, true
func (db *DBStore) getBlock(id types.BlockID) (sb supplementedBlock, ok bool) {
ok = db.bucket(bBlocks).get(id[:], &sb)
return
}

func (db *DBStore) putBlock(bh types.BlockHeader, b *types.Block, bs *consensus.V1BlockSupplement) {
id := bh.ID()
db.bucket(bBlocks).put(id[:], supplementedBlock{&bh, b, bs})
func (db *DBStore) putBlock(sb supplementedBlock) {
id := sb.Header.ID()
db.bucket(bBlocks).put(id[:], sb)
}

func (db *DBStore) getAncestorInfo(id types.BlockID) (parentID types.BlockID, timestamp time.Time, ok bool) {
ok = db.bucket(bBlocks).get(id[:], types.DecoderFunc(func(d *types.Decoder) {
v := d.ReadUint8()
if v != 2 && v != 3 {
d.SetErr(fmt.Errorf("incompatible version (%d)", v))
}
// kinda cursed; don't worry about it
if v == 3 {
if !d.ReadBool() {
d.ReadBool()
}
}
parentID.DecodeFrom(d)
_ = d.ReadUint64() // nonce
timestamp = d.ReadTime()
}))
return
}

func (db *DBStore) getBlockHeader(id types.BlockID) (bh types.BlockHeader, ok bool) {
ok = db.bucket(bBlocks).get(id[:], types.DecoderFunc(func(d *types.Decoder) {
v := d.ReadUint8()
if v != 2 && v != 3 {
d.SetErr(fmt.Errorf("incompatible version (%d)", v))
return
}
if v == 3 {
bhp := &bh
types.DecodePtr(d, &bhp)
if bhp != nil {
return
} else if !d.ReadBool() {
d.SetErr(errors.New("neither header nor block present"))
return
}
}
var b types.Block
(*types.V2Block)(&b).DecodeFrom(d)
bh = b.Header()
}))
return
}

func (db *DBStore) treeKey(row, col uint64) []byte {
// If we assume that the total number of elements is less than 2^32, we can
// pack row and col into one uint32 key. We do this by setting the top 'row'
Expand Down Expand Up @@ -677,22 +622,24 @@ func (db *DBStore) AddState(cs consensus.State) {

// Block implements Store.
func (db *DBStore) Block(id types.BlockID) (types.Block, *consensus.V1BlockSupplement, bool) {
_, b, bs, ok := db.getBlock(id)
if !ok || b == nil {
sb, ok := db.getBlock(id)
if !ok || sb.Block == nil {
return types.Block{}, nil, false
}
return *b, bs, ok
return *sb.Block, sb.Supplement, ok
}

// AddBlock implements Store.
func (db *DBStore) AddBlock(b types.Block, bs *consensus.V1BlockSupplement) {
db.putBlock(b.Header(), &b, bs)
db.putBlock(supplementedBlock{b.Header(), &b, bs})
}

// PruneBlock implements Store.
func (db *DBStore) PruneBlock(id types.BlockID) {
if bh, _, _, ok := db.getBlock(id); ok {
db.putBlock(bh, nil, nil)
if sb, ok := db.getBlock(id); ok {
sb.Block = nil
sb.Supplement = nil
db.putBlock(sb)
}
}

Expand Down Expand Up @@ -785,7 +732,7 @@ func NewDBStore(db DB, n *consensus.Network, genesisBlock types.Block) (_ *DBSto
dbs.putState(genesisState)
bs := consensus.V1BlockSupplement{Transactions: make([]consensus.V1TransactionSupplement, len(genesisBlock.Transactions))}
cs, cau := consensus.ApplyBlock(genesisState, genesisBlock, bs, time.Time{})
dbs.putBlock(genesisBlock.Header(), &genesisBlock, &bs)
dbs.putBlock(supplementedBlock{genesisBlock.Header(), &genesisBlock, &bs})
dbs.putState(cs)
dbs.ApplyBlock(cs, cau)
if err := dbs.Flush(); err != nil {
Expand Down
159 changes: 159 additions & 0 deletions chain/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package chain

import (
"errors"
"fmt"

"go.sia.tech/core/consensus"
"go.sia.tech/core/types"
)

type oldSiacoinElement types.SiacoinElement

func (oldSiacoinElement) Cast() (sce types.SiacoinElement) { return }

func (sce *oldSiacoinElement) DecodeFrom(d *types.Decoder) {
sce.ID.DecodeFrom(d)
sce.StateElement.DecodeFrom(d)
(*types.V2SiacoinOutput)(&sce.SiacoinOutput).DecodeFrom(d)
sce.MaturityHeight = d.ReadUint64()
}

type oldSiafundElement types.SiafundElement

func (oldSiafundElement) Cast() (sfe types.SiafundElement) { return }

func (sfe *oldSiafundElement) DecodeFrom(d *types.Decoder) {
sfe.ID.DecodeFrom(d)
sfe.StateElement.DecodeFrom(d)
(*types.V2SiafundOutput)(&sfe.SiafundOutput).DecodeFrom(d)
(*types.V2Currency)(&sfe.ClaimStart).DecodeFrom(d)
}

type oldFileContractElement types.FileContractElement

func (oldFileContractElement) Cast() (fce types.FileContractElement) { return }

func (fce *oldFileContractElement) DecodeFrom(d *types.Decoder) {
fce.ID.DecodeFrom(d)
fce.StateElement.DecodeFrom(d)
fce.FileContract.DecodeFrom(d)
}

type oldTransactionSupplement consensus.V1TransactionSupplement

func (oldTransactionSupplement) Cast() (ts consensus.V1TransactionSupplement) { return }

func (ts *oldTransactionSupplement) DecodeFrom(d *types.Decoder) {
types.DecodeSliceCast[oldSiacoinElement](d, &ts.SiacoinInputs)
types.DecodeSliceCast[oldSiafundElement](d, &ts.SiafundInputs)
types.DecodeSliceCast[oldFileContractElement](d, &ts.RevisedFileContracts)
types.DecodeSliceFn(d, &ts.StorageProofs, func(d *types.Decoder) (sp consensus.V1StorageProofSupplement) {
(*oldFileContractElement)(&sp.FileContract).DecodeFrom(d)
return
})
}

type oldBlockSupplement consensus.V1BlockSupplement

func (oldBlockSupplement) Cast() (bs consensus.V1BlockSupplement) { return }

func (bs *oldBlockSupplement) DecodeFrom(d *types.Decoder) {
types.DecodeSliceCast[oldTransactionSupplement](d, &bs.Transactions)
types.DecodeSliceCast[oldFileContractElement](d, &bs.ExpiringFileContracts)
}

type oldSupplementedBlock supplementedBlock

func (sb *oldSupplementedBlock) DecodeFrom(d *types.Decoder) {
if v := d.ReadUint8(); v != 2 {
d.SetErr(fmt.Errorf("incompatible version (%d)", v))
}
var b types.Block
(*types.V2Block)(&b).DecodeFrom(d)
sb.Block = &b
sb.Header = b.Header()
types.DecodePtrCast[oldBlockSupplement](d, &sb.Supplement)
}

type versionedState consensus.State

func (vs *versionedState) DecodeFrom(d *types.Decoder) {
if v := d.ReadUint8(); v != 2 {
d.SetErr(fmt.Errorf("incompatible version (%d)", v))
}
(*consensus.State)(vs).DecodeFrom(d)
}

// MigrateDB upgrades the database to the latest version.
func MigrateDB(db DB, n *consensus.Network) error {
if db.Bucket(bVersion) == nil {
return nil // nothing to migrate
}
dbs := &DBStore{
db: db,
n: n,
}
var err error
rewrite := func(bucket []byte, key []byte, from types.DecoderFrom, to types.EncoderTo) {
if err != nil {
return
}
b := dbs.bucket(bucket)
val := b.getRaw(key)
if val == nil {
return
}
d := types.NewBufDecoder(val)
from.DecodeFrom(d)
if d.Err() != nil {
err = d.Err()
return
}
b.put(key, to)
if dbs.shouldFlush() {
dbs.Flush()
}
}

version := dbs.bucket(bVersion).getRaw(bVersion)
if len(version) != 1 {
return errors.New("invalid version")
}
switch version[0] {
case 1:
var sb supplementedBlock
for _, key := range db.BucketKeys(bBlocks) {
rewrite(bBlocks, key, (*oldSupplementedBlock)(&sb), &sb)
}
var cs consensus.State
for _, key := range db.BucketKeys(bStates) {
rewrite(bStates, key, (*versionedState)(&cs), &cs)
}
var sce types.SiacoinElement
for _, key := range db.BucketKeys(bSiacoinElements) {
rewrite(bSiacoinElements, key, (*oldSiacoinElement)(&sce), &sce)
}
var sfe types.SiafundElement
for _, key := range db.BucketKeys(bSiafundElements) {
rewrite(bSiafundElements, key, (*oldSiafundElement)(&sfe), &sfe)
}
var fce types.FileContractElement
for _, key := range db.BucketKeys(bFileContractElements) {
if len(key) == 32 {
rewrite(bFileContractElements, key, (*oldFileContractElement)(&fce), &fce)
}
}
if err != nil {
return err
}
dbs.bucket(bVersion).putRaw(bVersion, []byte{2})
dbs.Flush()
fallthrough
case 2:
// up-to-date
return nil
default:
return fmt.Errorf("unrecognized version (%d)", version[0])
}
}

0 comments on commit 5545b6c

Please sign in to comment.