diff --git a/import.go b/import.go index 282e260ca..2379de6de 100644 --- a/import.go +++ b/import.go @@ -68,15 +68,27 @@ func (i *Importer) writeNode(node *Node) error { buf.Reset() defer bufPool.Put(buf) - if err := node.writeBytes(buf); err != nil { - return err + if i.tree.useLegacyFormat { + if err := node.writeLegacyBytes(buf); err != nil { + return err + } + } else { + if err := node.writeBytes(buf); err != nil { + return err + } } bytesCopy := make([]byte, buf.Len()) copy(bytesCopy, buf.Bytes()) - if err := i.batch.Set(i.tree.ndb.nodeKey(node.GetKey()), bytesCopy); err != nil { - return err + if i.tree.useLegacyFormat { + if err := i.batch.Set(i.tree.ndb.legacyNodeKey(node.GetKey()), bytesCopy); err != nil { + return err + } + } else { + if err := i.batch.Set(i.tree.ndb.nodeKey(node.GetKey()), bytesCopy); err != nil { + return err + } } i.batchSize++ diff --git a/mutable_tree.go b/mutable_tree.go index 596d8b22c..e3890c0f3 100644 --- a/mutable_tree.go +++ b/mutable_tree.go @@ -285,7 +285,7 @@ func (tree *MutableTree) set(key []byte, value []byte) (updated bool, err error) if !tree.skipFastStorageUpgrade { tree.addUnsavedAddition(key, fastnode.NewNode(key, value, tree.version+1)) } - tree.ImmutableTree.root = NewNode(key, value) + tree.ImmutableTree.root = NewNode(key, value, tree.useLegacyFormat) return updated, nil } @@ -344,7 +344,7 @@ func (tree *MutableTree) recursiveSetLeaf(node *Node, key []byte, value []byte) subtreeHeight: 1, size: 2, nodeKey: nil, - leftNode: NewNode(key, value), + leftNode: NewNode(key, value, tree.useLegacyFormat), rightNode: node, }, false, nil case 1: // setKey > leafKey @@ -354,10 +354,10 @@ func (tree *MutableTree) recursiveSetLeaf(node *Node, key []byte, value []byte) size: 2, nodeKey: nil, leftNode: node, - rightNode: NewNode(key, value), + rightNode: NewNode(key, value, tree.useLegacyFormat), }, false, nil default: - return NewNode(key, value), true, nil + return NewNode(key, value, tree.useLegacyFormat), true, nil } } @@ -805,17 +805,30 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { } else { if tree.root.nodeKey != nil { // it means there are no updated nodes - if err := tree.ndb.SaveRoot(version, tree.root.nodeKey); err != nil { - return nil, 0, err - } - // it means the reference node is a legacy node - if tree.root.isLegacy { - // it will update the legacy node to the new format - // which ensures the reference node is not a legacy node - tree.root.isLegacy = tree.useLegacyFormat + if tree.useLegacyFormat { + if len(tree.root.hash) == 0 { + tree.root._hash(version) + } + if err := tree.ndb.SaveLegacyRoot(version, tree.root.hash); err != nil { + return nil, 0, err + } + tree.root.isLegacy = true if err := tree.ndb.SaveNode(tree.root); err != nil { return nil, 0, fmt.Errorf("failed to save the reference legacy node: %w", err) } + } else { + if err := tree.ndb.SaveRoot(version, tree.root.nodeKey); err != nil { + return nil, 0, err + } + // it means the reference node is a legacy node + if tree.root.isLegacy { + // it will update the legacy node to the new format + // which ensures the reference node is not a legacy node + tree.root.isLegacy = false + if err := tree.ndb.SaveNode(tree.root); err != nil { + return nil, 0, fmt.Errorf("failed to save the reference legacy node: %w", err) + } + } } } else { if err := tree.saveNewNodes(version); err != nil { @@ -1080,16 +1093,14 @@ func (tree *MutableTree) saveNewNodes(version int64) error { var recursiveAssignKey func(*Node) ([]byte, error) recursiveAssignKey = func(node *Node) ([]byte, error) { node.isLegacy = tree.useLegacyFormat - if node.nodeKey != nil || (node.isLegacy && node.hash != nil) { + if (!node.isLegacy && node.nodeKey != nil) || (node.isLegacy && node.hash != nil) { return node.GetKey(), nil } - if !node.isLegacy { - nonce++ - node.nodeKey = &NodeKey{ - version: version, - nonce: nonce, - } + nonce++ + node.nodeKey = &NodeKey{ + version: version, + nonce: nonce, } var err error diff --git a/node.go b/node.go index 06ffc9e58..b7aa51e3b 100644 --- a/node.go +++ b/node.go @@ -77,12 +77,13 @@ type Node struct { var _ cache.Node = (*Node)(nil) // NewNode returns a new node from a key, value and version. -func NewNode(key []byte, value []byte) *Node { +func NewNode(key []byte, value []byte, useLegacy bool) *Node { return &Node{ key: key, value: value, subtreeHeight: 0, size: 1, + isLegacy: useLegacy, } } @@ -243,11 +244,9 @@ func MakeLegacyNode(hash, buf []byte) (*Node, error) { nodeKey: &NodeKey{version: ver}, key: key, hash: hash, - isLegacy: true, } // Read node body. - if node.isLeaf() { val, _, err := encoding.DecodeBytes(buf) if err != nil { @@ -268,6 +267,7 @@ func MakeLegacyNode(hash, buf []byte) (*Node, error) { node.leftNodeKey = leftHash node.rightNodeKey = rightHash } + //fmt.Printf("legacy node: %+v\r\n", *node) return node, nil } @@ -612,6 +612,9 @@ func (node *Node) writeBytes(w io.Writer) error { if node.leftNodeKey == nil { return ErrLeftNodeKeyEmpty } + if node.rightNodeKey == nil { + return ErrRightNodeKeyEmpty + } // check if children NodeKeys are legacy mode if len(node.leftNodeKey) == hashSize { mode += ModeLegacyLeftNode @@ -662,6 +665,54 @@ func (node *Node) writeBytes(w io.Writer) error { return nil } +func (node *Node) writeLegacyBytes(w io.Writer) error { + if node == nil { + return errors.New("cannot write nil node") + } + err := encoding.EncodeVarint(w, int64(node.subtreeHeight)) + if err != nil { + return fmt.Errorf("writing height, %w", err) + } + err = encoding.EncodeVarint(w, node.size) + if err != nil { + return fmt.Errorf("writing size, %w", err) + } + err = encoding.EncodeVarint(w, node.nodeKey.version) + if err != nil { + return fmt.Errorf("writing version, %w", err) + } + + // Unlike writeHashBytes, key is written for inner nodes. + err = encoding.EncodeBytes(w, node.key) + if err != nil { + return fmt.Errorf("writing key, %w", err) + } + + if node.isLeaf() { + err = encoding.EncodeBytes(w, node.value) + if err != nil { + return fmt.Errorf("writing value, %w", err) + } + } else { + if len(node.leftNodeKey) != hashSize { + return errors.New("node provided to writeLegacyBytes does not have a hash for leftNodeKey") + } + err = encoding.EncodeBytes(w, node.leftNodeKey) + if err != nil { + return fmt.Errorf("writing left hash, %w", err) + } + + if len(node.leftNodeKey) != 32 { + return errors.New("node provided to writeLegacyBytes does not have a hash for rightNodeKey") + } + err = encoding.EncodeBytes(w, node.rightNodeKey) + if err != nil { + return fmt.Errorf("writing right hash, %w", err) + } + } + return nil +} + func (node *Node) getLeftNode(t *ImmutableTree) (*Node, error) { if node.leftNode != nil { return node.leftNode, nil diff --git a/nodedb.go b/nodedb.go index 334c4e588..2aae2cf4d 100644 --- a/nodedb.go +++ b/nodedb.go @@ -164,9 +164,9 @@ func (ndb *nodeDB) GetNode(nk []byte) (*Node, error) { ndb.opts.Stat.IncCacheMissCnt() // Doesn't exist, load. - isLegcyNode := len(nk) == hashSize + isLegacyNode := len(nk) == hashSize var nodeKey []byte - if isLegcyNode { + if isLegacyNode { nodeKey = ndb.legacyNodeKey(nk) } else { nodeKey = ndb.nodeKey(nk) @@ -180,11 +180,12 @@ func (ndb *nodeDB) GetNode(nk []byte) (*Node, error) { } var node *Node - if isLegcyNode { + if isLegacyNode { node, err = MakeLegacyNode(nk, buf) if err != nil { return nil, fmt.Errorf("error reading Legacy Node. bytes: %x, error: %v", buf, err) } + node.isLegacy = ndb.useLegacyFormat } else { node, err = MakeNode(nk, buf) if err != nil { @@ -238,7 +239,7 @@ func (ndb *nodeDB) SaveNode(node *Node) error { ndb.mtx.Lock() defer ndb.mtx.Unlock() - if (node.nodeKey == nil && !node.isLegacy) || (node.hash == nil && node.isLegacy) { + if node.nodeKey == nil || (ndb.useLegacyFormat && node.hash == nil) { return ErrNodeMissingNodeKey } @@ -246,12 +247,21 @@ func (ndb *nodeDB) SaveNode(node *Node) error { var buf bytes.Buffer buf.Grow(node.encodedSize()) - if err := node.writeBytes(&buf); err != nil { - return err - } - - if err := ndb.batch.Set(ndb.nodeKey(node.GetKey()), buf.Bytes()); err != nil { - return err + nk := node.GetKey() + if len(nk) == hashSize { + if err := node.writeLegacyBytes(&buf); err != nil { + return err + } + if err := ndb.batch.Set(ndb.legacyNodeKey(nk), buf.Bytes()); err != nil { + return err + } + } else { + if err := node.writeBytes(&buf); err != nil { + return err + } + if err := ndb.batch.Set(ndb.nodeKey(nk), buf.Bytes()); err != nil { + return err + } } ndb.logger.Debug("BATCH SAVE", "node", node) @@ -357,6 +367,9 @@ func (ndb *nodeDB) saveFastNodeUnlocked(node *fastnode.Node, shouldAddToCache bo // Has checks if a node key exists in the database. func (ndb *nodeDB) Has(nk []byte) (bool, error) { + if len(nk) == hashSize { + return ndb.db.Has(ndb.legacyNodeKey(nk)) + } return ndb.db.Has(ndb.nodeKey(nk)) } @@ -860,6 +873,13 @@ func (ndb *nodeDB) SaveRoot(version int64, nk *NodeKey) error { return ndb.batch.Set(nodeKeyFormat.Key(GetRootKey(version)), nodeKeyFormat.Key(nk.GetKey())) } +// SaveLegacyRoot saves the root when no updates. +func (ndb *nodeDB) SaveLegacyRoot(version int64, key []byte) error { + ndb.mtx.Lock() + defer ndb.mtx.Unlock() + return ndb.batch.Set(nodeKeyFormat.Key(GetRootKey(version)), legacyNodeKeyFormat.Key(key)) +} + // Traverse fast nodes and return error if any, nil otherwise func (ndb *nodeDB) traverseFastNodes(fn func(k, v []byte) error) error { return ndb.traversePrefix(fastKeyFormat.Key(), fn) diff --git a/testutils_test.go b/testutils_test.go index b64a54984..7bbd8f24c 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -52,12 +52,12 @@ func N(l, r interface{}) *Node { if _, ok := l.(*Node); ok { left = l.(*Node) } else { - left = NewNode(i2b(l.(int)), nil) + left = NewNode(i2b(l.(int)), nil, false) } if _, ok := r.(*Node); ok { right = r.(*Node) } else { - right = NewNode(i2b(r.(int)), nil) + right = NewNode(i2b(r.(int)), nil, false) } n := &Node{