diff --git a/go.work b/go.work new file mode 100644 index 0000000..b9dc24f --- /dev/null +++ b/go.work @@ -0,0 +1,6 @@ +go 1.21.5 + +use ( + . + ./merkle +) \ No newline at end of file diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..c5b581a --- /dev/null +++ b/go.work.sum @@ -0,0 +1,3 @@ +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/tools v0.16.0/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= diff --git a/merkle/README.md b/merkle/README.md new file mode 100644 index 0000000..00d6124 --- /dev/null +++ b/merkle/README.md @@ -0,0 +1,6 @@ +# Merkle Tree + +For smaller static data structures that don't require immutable snapshots or mutability; +for instance the transactions and validation signatures of a block can be hashed using this simple merkle tree logic. + +This is a forked copy of github.com/cometbft/cometbft/crypto/merkle diff --git a/merkle/doc.go b/merkle/doc.go new file mode 100644 index 0000000..fe50b34 --- /dev/null +++ b/merkle/doc.go @@ -0,0 +1,30 @@ +/* +Package merkle computes a deterministic minimal height Merkle tree hash. +If the number of items is not a power of two, some leaves +will be at different levels. Tries to keep both sides of +the tree the same size, but the left may be one greater. + +Use this for short deterministic trees, such as the validator list. +For larger datasets, use IAVLTree. + +Be aware that the current implementation by itself does not prevent +second pre-image attacks. Hence, use this library with caution. +Otherwise you might run into similar issues as, e.g., in early Bitcoin: +https://bitcointalk.org/?topic=102395 + + * + / \ + / \ + / \ + / \ + * * + / \ / \ + / \ / \ + / \ / \ + * * * h6 + / \ / \ / \ + h0 h1 h2 h3 h4 h5 + +TODO(ismail): add 2nd pre-image protection or clarify further on how we use this and why this secure. +*/ +package merkle diff --git a/merkle/go.mod b/merkle/go.mod new file mode 100644 index 0000000..3acc43d --- /dev/null +++ b/merkle/go.mod @@ -0,0 +1,15 @@ +module github.com/celestiaorg/go-square/merkle + +go 1.21.5 + +require ( + github.com/stretchr/testify v1.8.4 + google.golang.org/protobuf v1.31.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-cmp v0.5.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/merkle/go.sum b/merkle/go.sum new file mode 100644 index 0000000..0d44549 --- /dev/null +++ b/merkle/go.sum @@ -0,0 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/merkle/hash.go b/merkle/hash.go new file mode 100644 index 0000000..973629e --- /dev/null +++ b/merkle/hash.go @@ -0,0 +1,31 @@ +package merkle + +import ( + "crypto/sha256" +) + +// TODO: make these have a large predefined capacity +var ( + leafPrefix = []byte{0} + innerPrefix = []byte{1} +) + +// returns empty sha256 hash +func emptyHash() []byte { + return hash([]byte{}) +} + +// returns sha256(0x00 || leaf) +func leafHash(leaf []byte) []byte { + return hash(append(leafPrefix, leaf...)) +} + +// returns sha256(0x01 || left || right) +func innerHash(left []byte, right []byte) []byte { + return hash(append(innerPrefix, append(left, right...)...)) +} + +func hash(bz []byte) []byte { + h := sha256.Sum256(bz) + return h[:] +} diff --git a/merkle/proof.go b/merkle/proof.go new file mode 100644 index 0000000..2491b47 --- /dev/null +++ b/merkle/proof.go @@ -0,0 +1,252 @@ +package merkle + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + + wire "github.com/celestiaorg/go-square/merkle/proto/gen/merkle/v1" +) + +const ( + // MaxAunts is the maximum number of aunts that can be included in a Proof. + // This corresponds to a tree of size 2^100, which should be sufficient for all conceivable purposes. + // This maximum helps prevent Denial-of-Service attacks by limiting the size of the proofs. + MaxAunts = 100 +) + +// Proof represents a Merkle proof. +// NOTE: The convention for proofs is to include leaf hashes but to +// exclude the root hash. +// This convention is implemented across IAVL range proofs as well. +// Keep this consistent unless there's a very good reason to change +// everything. This also affects the generalized proof system as +// well. +type Proof struct { + Total int64 `json:"total"` // Total number of items. + Index int64 `json:"index"` // Index of item to prove. + LeafHash []byte `json:"leaf_hash"` // Hash of item value. + Aunts [][]byte `json:"aunts"` // Hashes from leaf's sibling to a root's child. +} + +// ProofsFromByteSlices computes inclusion proof for given items. +// proofs[0] is the proof for items[0]. +func ProofsFromByteSlices(items [][]byte) (rootHash []byte, proofs []*Proof) { + trails, rootSPN := trailsFromByteSlices(items) + rootHash = rootSPN.Hash + proofs = make([]*Proof, len(items)) + for i, trail := range trails { + proofs[i] = &Proof{ + Total: int64(len(items)), + Index: int64(i), + LeafHash: trail.Hash, + Aunts: trail.FlattenAunts(), + } + } + return +} + +// Verify that the Proof proves the root hash. +// Check sp.Index/sp.Total manually if needed +func (sp *Proof) Verify(rootHash []byte, leaf []byte) error { + if rootHash == nil { + return fmt.Errorf("invalid root hash: cannot be nil") + } + if sp.Total < 0 { + return errors.New("proof total must be positive") + } + if sp.Index < 0 { + return errors.New("proof index cannot be negative") + } + leafHash := leafHash(leaf) + if !bytes.Equal(sp.LeafHash, leafHash) { + return fmt.Errorf("invalid leaf hash: wanted %X got %X", leafHash, sp.LeafHash) + } + computedHash, err := sp.computeRootHash() + if err != nil { + return fmt.Errorf("compute root hash: %w", err) + } + if !bytes.Equal(computedHash, rootHash) { + return fmt.Errorf("invalid root hash: wanted %X got %X", rootHash, computedHash) + } + return nil +} + +// Compute the root hash given a leaf hash. Panics in case of errors. +func (sp *Proof) ComputeRootHash() []byte { + computedHash, err := sp.computeRootHash() + if err != nil { + panic(fmt.Errorf("ComputeRootHash errored %w", err)) + } + return computedHash +} + +// Compute the root hash given a leaf hash. +func (sp *Proof) computeRootHash() ([]byte, error) { + return computeHashFromAunts( + sp.Index, + sp.Total, + sp.LeafHash, + sp.Aunts, + ) +} + +// String implements the stringer interface for Proof. +// It is a wrapper around StringIndented. +func (sp *Proof) String() string { + return sp.StringIndented("") +} + +// StringIndented generates a canonical string representation of a Proof. +func (sp *Proof) StringIndented(indent string) string { + return fmt.Sprintf(`Proof{ +%s Aunts: %X +%s}`, + indent, sp.Aunts, + indent) +} + +// ValidateBasic performs basic validation. +// NOTE: it expects the LeafHash and the elements of Aunts to be of size tmhash.Size, +// and it expects at most MaxAunts elements in Aunts. +func (sp *Proof) ValidateBasic() error { + if sp.Total < 0 { + return errors.New("negative Total") + } + if sp.Index < 0 { + return errors.New("negative Index") + } + if len(sp.LeafHash) != sha256.Size { + return fmt.Errorf("expected LeafHash size to be %d, got %d", sha256.Size, len(sp.LeafHash)) + } + if len(sp.Aunts) > MaxAunts { + return fmt.Errorf("expected no more than %d aunts, got %d", MaxAunts, len(sp.Aunts)) + } + for i, auntHash := range sp.Aunts { + if len(auntHash) != sha256.Size { + return fmt.Errorf("expected Aunts#%d size to be %d, got %d", i, sha256.Size, len(auntHash)) + } + } + return nil +} + +func (sp *Proof) ToProto() *wire.Proof { + if sp == nil { + return nil + } + pb := new(wire.Proof) + + pb.Total = sp.Total + pb.Index = sp.Index + pb.LeafHash = sp.LeafHash + pb.Aunts = sp.Aunts + + return pb +} + +func ProofFromProto(pb *wire.Proof) (*Proof, error) { + if pb == nil { + return nil, errors.New("nil proof") + } + + sp := new(Proof) + + sp.Total = pb.Total + sp.Index = pb.Index + sp.LeafHash = pb.LeafHash + sp.Aunts = pb.Aunts + + return sp, sp.ValidateBasic() +} + +// Use the leafHash and innerHashes to get the root merkle hash. +// If the length of the innerHashes slice isn't exactly correct, the result is nil. +// Recursive impl. +func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]byte) ([]byte, error) { + if index >= total || index < 0 || total <= 0 { + return nil, fmt.Errorf("invalid index %d and/or total %d", index, total) + } + switch total { + case 0: + panic("Cannot call computeHashFromAunts() with 0 total") + case 1: + if len(innerHashes) != 0 { + return nil, fmt.Errorf("unexpected inner hashes") + } + return leafHash, nil + default: + if len(innerHashes) == 0 { + return nil, fmt.Errorf("expected at least one inner hash") + } + numLeft := getSplitPoint(total) + if index < numLeft { + leftHash, err := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if err != nil { + return nil, err + } + + return innerHash(leftHash, innerHashes[len(innerHashes)-1]), nil + } + rightHash, err := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if err != nil { + return nil, err + } + return innerHash(innerHashes[len(innerHashes)-1], rightHash), nil + } +} + +// ProofNode is a helper structure to construct merkle proof. +// The node and the tree is thrown away afterwards. +// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. +// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or +// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. +type ProofNode struct { + Hash []byte + Parent *ProofNode + Left *ProofNode // Left sibling (only one of Left,Right is set) + Right *ProofNode // Right sibling (only one of Left,Right is set) +} + +// FlattenAunts will return the inner hashes for the item corresponding to the leaf, +// starting from a leaf ProofNode. +func (spn *ProofNode) FlattenAunts() [][]byte { + // Nonrecursive impl. + innerHashes := [][]byte{} + for spn != nil { + switch { + case spn.Left != nil: + innerHashes = append(innerHashes, spn.Left.Hash) + case spn.Right != nil: + innerHashes = append(innerHashes, spn.Right.Hash) + default: + break + } + spn = spn.Parent + } + return innerHashes +} + +// trails[0].Hash is the leaf hash for items[0]. +// trails[i].Parent.Parent....Parent == root for all i. +func trailsFromByteSlices(items [][]byte) (trails []*ProofNode, root *ProofNode) { + // Recursive impl. + switch len(items) { + case 0: + return []*ProofNode{}, &ProofNode{emptyHash(), nil, nil, nil} + case 1: + trail := &ProofNode{leafHash(items[0]), nil, nil, nil} + return []*ProofNode{trail}, trail + default: + k := getSplitPoint(int64(len(items))) + lefts, leftRoot := trailsFromByteSlices(items[:k]) + rights, rightRoot := trailsFromByteSlices(items[k:]) + rootHash := innerHash(leftRoot.Hash, rightRoot.Hash) + root := &ProofNode{rootHash, nil, nil, nil} + leftRoot.Parent = root + leftRoot.Right = rightRoot + rightRoot.Parent = root + rightRoot.Left = leftRoot + return append(lefts, rights...), root + } +} diff --git a/merkle/proof_key_path.go b/merkle/proof_key_path.go new file mode 100644 index 0000000..ca8b5f0 --- /dev/null +++ b/merkle/proof_key_path.go @@ -0,0 +1,110 @@ +package merkle + +import ( + "encoding/hex" + "errors" + "fmt" + "net/url" + "strings" +) + +/* + + For generalized Merkle proofs, each layer of the proof may require an + optional key. The key may be encoded either by URL-encoding or + (upper-case) hex-encoding. + TODO: In the future, more encodings may be supported, like base32 (e.g. + /32:) + + For example, for a Cosmos-SDK application where the first two proof layers + are ValueOps, and the third proof layer is an IAVLValueOp, the keys + might look like: + + 0: []byte("App") + 1: []byte("IBC") + 2: []byte{0x01, 0x02, 0x03} + + Assuming that we know that the first two layers are always ASCII texts, we + probably want to use URLEncoding for those, whereas the third layer will + require HEX encoding for efficient representation. + + kp := new(KeyPath) + kp.AppendKey([]byte("App"), KeyEncodingURL) + kp.AppendKey([]byte("IBC"), KeyEncodingURL) + kp.AppendKey([]byte{0x01, 0x02, 0x03}, KeyEncodingURL) + kp.String() // Should return "/App/IBC/x:010203" + + NOTE: Key paths must begin with a `/`. + + NOTE: All encodings *MUST* work compatibly, such that you can choose to use + whatever encoding, and the decoded keys will always be the same. In other + words, it's just as good to encode all three keys using URL encoding or HEX + encoding... it just wouldn't be optimal in terms of readability or space + efficiency. + + NOTE: Punycode will never be supported here, because not all values can be + decoded. For example, no string decodes to the string "xn--blah" in + Punycode. + +*/ + +type keyEncoding int + +const ( + KeyEncodingURL keyEncoding = iota + KeyEncodingHex + KeyEncodingMax // Number of known encodings. Used for testing +) + +type Key struct { + name []byte + enc keyEncoding +} + +type KeyPath []Key + +func (pth KeyPath) AppendKey(key []byte, enc keyEncoding) KeyPath { + return append(pth, Key{key, enc}) +} + +func (pth KeyPath) String() string { + res := "" + for _, key := range pth { + switch key.enc { + case KeyEncodingURL: + res += "/" + url.PathEscape(string(key.name)) + case KeyEncodingHex: + res += "/x:" + fmt.Sprintf("%X", key.name) + default: + panic("unexpected key encoding type") + } + } + return res +} + +// Decode a path to a list of keys. Path must begin with `/`. +// Each key must use a known encoding. +func KeyPathToKeys(path string) (keys [][]byte, err error) { + if path == "" || path[0] != '/' { + return nil, errors.New("key path string must start with a forward slash '/'") + } + parts := strings.Split(path[1:], "/") + keys = make([][]byte, len(parts)) + for i, part := range parts { + if strings.HasPrefix(part, "x:") { + hexPart := part[2:] + key, err := hex.DecodeString(hexPart) + if err != nil { + return nil, fmt.Errorf("decoding hex-encoded part #%d: /%s: %w", i, part, err) + } + keys[i] = key + } else { + key, err := url.PathUnescape(part) + if err != nil { + return nil, fmt.Errorf("decoding url-encoded part #%d: /%s: %w", i, part, err) + } + keys[i] = []byte(key) // TODO Test this with random bytes, I'm not sure that it works for arbitrary bytes... + } + } + return keys, nil +} diff --git a/merkle/proof_key_path_test.go b/merkle/proof_key_path_test.go new file mode 100644 index 0000000..25a61af --- /dev/null +++ b/merkle/proof_key_path_test.go @@ -0,0 +1,44 @@ +package merkle + +import ( + // it is ok to use math/rand here: we do not need a cryptographically secure random + // number generator here and we can run the tests a bit faster + crand "crypto/rand" + "math/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKeyPath(t *testing.T) { + var path KeyPath + keys := make([][]byte, 10) + alphanum := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + for d := 0; d < 1e4; d++ { + path = nil + + for i := range keys { + enc := keyEncoding(rand.Intn(int(KeyEncodingMax))) + keys[i] = make([]byte, rand.Uint32()%20) + switch enc { + case KeyEncodingURL: + for j := range keys[i] { + keys[i][j] = alphanum[rand.Intn(len(alphanum))] + } + case KeyEncodingHex: + _, _ = crand.Read(keys[i]) + default: + panic("Unexpected encoding") + } + path = path.AppendKey(keys[i], enc) + } + + res, err := KeyPathToKeys(path.String()) + require.Nil(t, err) + + for i, key := range keys { + require.Equal(t, key, res[i]) + } + } +} diff --git a/merkle/proof_op.go b/merkle/proof_op.go new file mode 100644 index 0000000..122b600 --- /dev/null +++ b/merkle/proof_op.go @@ -0,0 +1,187 @@ +package merkle + +import ( + "bytes" + "errors" + "fmt" + + wire "github.com/celestiaorg/go-square/merkle/proto/gen/merkle/v1" +) + +//---------------------------------------- +// ProofOp gets converted to an instance of ProofOperator: + +// ProofOperator is a layer for calculating intermediate Merkle roots +// when a series of Merkle trees are chained together. +// Run() takes leaf values from a tree and returns the Merkle +// root for the corresponding tree. It takes and returns a list of bytes +// to allow multiple leaves to be part of a single proof, for instance in a range proof. +// ProofOp() encodes the ProofOperator in a generic way so it can later be +// decoded with OpDecoder. +type ProofOperator interface { + Run([][]byte) ([][]byte, error) + GetKey() []byte + ProofOp() wire.ProofOp +} + +//---------------------------------------- +// Operations on a list of ProofOperators + +// ProofOperators is a slice of ProofOperator(s). +// Each operator will be applied to the input value sequentially +// and the last Merkle root will be verified with already known data +type ProofOperators []ProofOperator + +func (poz ProofOperators) VerifyValue(root []byte, keypath string, value []byte) (err error) { + return poz.Verify(root, keypath, [][]byte{value}) +} + +func (poz ProofOperators) Verify(root []byte, keypath string, args [][]byte) (err error) { + keys, err := KeyPathToKeys(keypath) + if err != nil { + return + } + + for i, op := range poz { + key := op.GetKey() + if len(key) != 0 { + if len(keys) == 0 { + return fmt.Errorf("key path has insufficient # of parts: expected no more keys but got %+v", string(key)) + } + lastKey := keys[len(keys)-1] + if !bytes.Equal(lastKey, key) { + return fmt.Errorf("key mismatch on operation #%d: expected %+v but got %+v", i, string(lastKey), string(key)) + } + keys = keys[:len(keys)-1] + } + args, err = op.Run(args) + if err != nil { + return + } + } + if !bytes.Equal(root, args[0]) { + return fmt.Errorf("calculated root hash is invalid: expected %X but got %X", root, args[0]) + } + if len(keys) != 0 { + return errors.New("keypath not consumed all") + } + return nil +} + +// VerifyFromKeys performs the same verification logic as the normal Verify +// method, except it does not perform any processing on the keypath. This is +// useful when using keys that have split or escape points as a part of the key. +func (poz ProofOperators) VerifyFromKeys(root []byte, keys [][]byte, args [][]byte) (err error) { + for i, op := range poz { + key := op.GetKey() + if len(key) != 0 { + if len(keys) == 0 { + return fmt.Errorf("key path has insufficient # of parts: expected no more keys but got %+v", string(key)) + } + lastKey := keys[len(keys)-1] + if !bytes.Equal(lastKey, key) { + return fmt.Errorf("key mismatch on operation #%d: expected %+v but got %+v", i, string(lastKey), string(key)) + } + keys = keys[:len(keys)-1] + } + args, err = op.Run(args) + if err != nil { + return + } + } + if !bytes.Equal(root, args[0]) { + return fmt.Errorf("calculated root hash is invalid: expected %X but got %X", root, args[0]) + } + if len(keys) != 0 { + return fmt.Errorf("keypath not consumed all: %s", string(bytes.Join(keys, []byte("/")))) + } + return nil +} + +//---------------------------------------- +// ProofRuntime - main entrypoint + +type OpDecoder func(*wire.ProofOp) (ProofOperator, error) + +type ProofRuntime struct { + decoders map[string]OpDecoder +} + +func NewProofRuntime() *ProofRuntime { + return &ProofRuntime{ + decoders: make(map[string]OpDecoder), + } +} + +func (prt *ProofRuntime) RegisterOpDecoder(typ string, dec OpDecoder) { + _, ok := prt.decoders[typ] + if ok { + panic("already registered for type " + typ) + } + prt.decoders[typ] = dec +} + +func (prt *ProofRuntime) Decode(pop *wire.ProofOp) (ProofOperator, error) { + decoder := prt.decoders[pop.Type] + if decoder == nil { + return nil, fmt.Errorf("unrecognized proof type %v", pop.Type) + } + if pop == nil { + return nil, errors.New("nil ProofOp") + } + return decoder(pop) +} + +func (prt *ProofRuntime) DecodeProof(proof *wire.ProofOps) (ProofOperators, error) { + poz := make(ProofOperators, 0, len(proof.Ops)) + for _, pop := range proof.Ops { + operator, err := prt.Decode(pop) + if err != nil { + return nil, fmt.Errorf("decoding a proof operator: %w", err) + } + poz = append(poz, operator) + } + return poz, nil +} + +func (prt *ProofRuntime) VerifyValue(proof *wire.ProofOps, root []byte, keypath string, value []byte) (err error) { + return prt.Verify(proof, root, keypath, [][]byte{value}) +} + +func (prt *ProofRuntime) VerifyValueFromKeys(proof *wire.ProofOps, root []byte, keys [][]byte, value []byte) (err error) { + return prt.VerifyFromKeys(proof, root, keys, [][]byte{value}) +} + +// TODO In the long run we'll need a method of classification of ops, +// whether existence or absence or perhaps a third? +func (prt *ProofRuntime) VerifyAbsence(proof *wire.ProofOps, root []byte, keypath string) (err error) { + return prt.Verify(proof, root, keypath, nil) +} + +func (prt *ProofRuntime) Verify(proof *wire.ProofOps, root []byte, keypath string, args [][]byte) (err error) { + poz, err := prt.DecodeProof(proof) + if err != nil { + return fmt.Errorf("decoding proof: %w", err) + } + return poz.Verify(root, keypath, args) +} + +// VerifyFromKeys performs the same verification logic as the normal Verify +// method, except it does not perform any processing on the keypath. This is +// useful when using keys that have split or escape points as a part of the key. +func (prt *ProofRuntime) VerifyFromKeys(proof *wire.ProofOps, root []byte, keys [][]byte, args [][]byte) (err error) { + poz, err := prt.DecodeProof(proof) + if err != nil { + return fmt.Errorf("decoding proof: %w", err) + } + return poz.VerifyFromKeys(root, keys, args) +} + +// DefaultProofRuntime only knows about value proofs. +// To use e.g. IAVL proofs, register op-decoders as +// defined in the IAVL package. +func DefaultProofRuntime() (prt *ProofRuntime) { + prt = NewProofRuntime() + prt.RegisterOpDecoder(ProofOpValue, ValueOpDecoder) + return +} diff --git a/merkle/proof_test.go b/merkle/proof_test.go new file mode 100644 index 0000000..abaea7d --- /dev/null +++ b/merkle/proof_test.go @@ -0,0 +1,288 @@ +package merkle + +import ( + "bytes" + "errors" + "fmt" + "testing" + + wire "github.com/celestiaorg/go-square/merkle/proto/gen/merkle/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +const ProofOpDomino = "test:domino" + +// Expects given input, produces given output. +// Like the game dominos. +type DominoOp struct { + key string // unexported, may be empty + Input string + Output string +} + +func NewDominoOp(key, input, output string) DominoOp { + return DominoOp{ + key: key, + Input: input, + Output: output, + } +} + +func (dop DominoOp) ProofOp() wire.ProofOp { + dopb := &wire.DominoOp{ + Key: dop.key, + Input: dop.Input, + Output: dop.Output, + } + bz, err := proto.Marshal(dopb) + if err != nil { + panic(err) + } + + return wire.ProofOp{ + Type: ProofOpDomino, + Key: []byte(dop.key), + Data: bz, + } +} + +func (dop DominoOp) Run(input [][]byte) (output [][]byte, err error) { + if len(input) != 1 { + return nil, errors.New("expected input of length 1") + } + if string(input[0]) != dop.Input { + return nil, fmt.Errorf("expected input %v, got %v", + dop.Input, string(input[0])) + } + return [][]byte{[]byte(dop.Output)}, nil +} + +func (dop DominoOp) GetKey() []byte { + return []byte(dop.key) +} + +//---------------------------------------- + +func TestProofOperators(t *testing.T) { + var err error + + // ProofRuntime setup + // TODO test this somehow. + + // ProofOperators setup + op1 := NewDominoOp("KEY1", "INPUT1", "INPUT2") + op2 := NewDominoOp("KEY2", "INPUT2", "INPUT3") + op3 := NewDominoOp("", "INPUT3", "INPUT4") + op4 := NewDominoOp("KEY4", "INPUT4", "OUTPUT4") + + // Good + popz := ProofOperators([]ProofOperator{op1, op2, op3, op4}) + err = popz.Verify(bz("OUTPUT4"), "/KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.Nil(t, err) + err = popz.VerifyValue(bz("OUTPUT4"), "/KEY4/KEY2/KEY1", bz("INPUT1")) + assert.Nil(t, err) + + // BAD INPUT + err = popz.Verify(bz("OUTPUT4"), "/KEY4/KEY2/KEY1", [][]byte{bz("INPUT1_WRONG")}) + assert.NotNil(t, err) + err = popz.VerifyValue(bz("OUTPUT4"), "/KEY4/KEY2/KEY1", bz("INPUT1_WRONG")) + assert.NotNil(t, err) + + // BAD KEY 1 + err = popz.Verify(bz("OUTPUT4"), "/KEY3/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD KEY 2 + err = popz.Verify(bz("OUTPUT4"), "KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD KEY 3 + err = popz.Verify(bz("OUTPUT4"), "/KEY4/KEY2/KEY1/", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD KEY 4 + err = popz.Verify(bz("OUTPUT4"), "//KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD KEY 5 + err = popz.Verify(bz("OUTPUT4"), "/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD OUTPUT 1 + err = popz.Verify(bz("OUTPUT4_WRONG"), "/KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD OUTPUT 2 + err = popz.Verify(bz(""), "/KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD POPZ 1 + popz = []ProofOperator{op1, op2, op4} + err = popz.Verify(bz("OUTPUT4"), "/KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD POPZ 2 + popz = []ProofOperator{op4, op3, op2, op1} + err = popz.Verify(bz("OUTPUT4"), "/KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) + + // BAD POPZ 3 + popz = []ProofOperator{} + err = popz.Verify(bz("OUTPUT4"), "/KEY4/KEY2/KEY1", [][]byte{bz("INPUT1")}) + assert.NotNil(t, err) +} + +func TestProofOperatorsFromKeys(t *testing.T) { + var err error + + // ProofRuntime setup + // TODO test this somehow. + + // ProofOperators setup + op1 := NewDominoOp("KEY1", "INPUT1", "INPUT2") + op2 := NewDominoOp("KEY%2", "INPUT2", "INPUT3") + op3 := NewDominoOp("", "INPUT3", "INPUT4") + op4 := NewDominoOp("KEY/4", "INPUT4", "OUTPUT4") + + // add characters to the keys that would otherwise result in bad keypath if + // processed + keys1 := [][]byte{bz("KEY/4"), bz("KEY%2"), bz("KEY1")} + badkeys1 := [][]byte{bz("WrongKey"), bz("KEY%2"), bz("KEY1")} + keys2 := [][]byte{bz("KEY3"), bz("KEY%2"), bz("KEY1")} + keys3 := [][]byte{bz("KEY2"), bz("KEY1")} + + // Good + popz := ProofOperators([]ProofOperator{op1, op2, op3, op4}) + err = popz.VerifyFromKeys(bz("OUTPUT4"), keys1, [][]byte{bz("INPUT1")}) + assert.NoError(t, err) + + // BAD INPUT + err = popz.VerifyFromKeys(bz("OUTPUT4"), keys1, [][]byte{bz("INPUT1_WRONG")}) + assert.Error(t, err) + + // BAD KEY 1 + err = popz.VerifyFromKeys(bz("OUTPUT4"), keys2, [][]byte{bz("INPUT1")}) + assert.Error(t, err) + + // BAD KEY 2 + err = popz.VerifyFromKeys(bz("OUTPUT4"), badkeys1, [][]byte{bz("INPUT1")}) + assert.Error(t, err) + + // BAD KEY 5 + err = popz.VerifyFromKeys(bz("OUTPUT4"), keys3, [][]byte{bz("INPUT1")}) + assert.Error(t, err) + + // BAD OUTPUT 1 + err = popz.VerifyFromKeys(bz("OUTPUT4_WRONG"), keys1, [][]byte{bz("INPUT1")}) + assert.Error(t, err) + + // BAD OUTPUT 2 + err = popz.VerifyFromKeys(bz(""), keys1, [][]byte{bz("INPUT1")}) + assert.Error(t, err) + + // BAD POPZ 1 + popz = []ProofOperator{op1, op2, op4} + err = popz.VerifyFromKeys(bz("OUTPUT4"), keys1, [][]byte{bz("INPUT1")}) + assert.Error(t, err) + + // BAD POPZ 2 + popz = []ProofOperator{op4, op3, op2, op1} + err = popz.VerifyFromKeys(bz("OUTPUT4"), keys1, [][]byte{bz("INPUT1")}) + assert.Error(t, err) + + // BAD POPZ 3 + popz = []ProofOperator{} + err = popz.VerifyFromKeys(bz("OUTPUT4"), keys1, [][]byte{bz("INPUT1")}) + assert.Error(t, err) +} + +func bz(s string) []byte { + return []byte(s) +} + +func TestProofValidateBasic(t *testing.T) { + testCases := []struct { + testName string + malleateProof func(*Proof) + errStr string + }{ + {"Good", func(sp *Proof) {}, ""}, + {"Negative Total", func(sp *Proof) { sp.Total = -1 }, "negative Total"}, + {"Negative Index", func(sp *Proof) { sp.Index = -1 }, "negative Index"}, + {"Invalid LeafHash", func(sp *Proof) { sp.LeafHash = make([]byte, 10) }, + "expected LeafHash size to be 32, got 10"}, + {"Too many Aunts", func(sp *Proof) { sp.Aunts = make([][]byte, MaxAunts+1) }, + "expected no more than 100 aunts, got 101"}, + {"Invalid Aunt", func(sp *Proof) { sp.Aunts[0] = make([]byte, 10) }, + "expected Aunts#0 size to be 32, got 10"}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + _, proofs := ProofsFromByteSlices([][]byte{ + []byte("apple"), + []byte("watermelon"), + []byte("kiwi"), + }) + tc.malleateProof(proofs[0]) + err := proofs[0].ValidateBasic() + if tc.errStr != "" { + assert.Contains(t, err.Error(), tc.errStr) + } + }) + } +} +func TestVoteProtobuf(t *testing.T) { + + _, proofs := ProofsFromByteSlices([][]byte{ + []byte("apple"), + []byte("watermelon"), + []byte("kiwi"), + }) + testCases := []struct { + testName string + v1 *Proof + expPass bool + }{ + {"empty proof", &Proof{}, false}, + {"failure nil", nil, false}, + {"success", proofs[0], true}, + } + for _, tc := range testCases { + pb := tc.v1.ToProto() + + v, err := ProofFromProto(pb) + if tc.expPass { + require.NoError(t, err) + require.Equal(t, tc.v1, v, tc.testName) + } else { + require.Error(t, err) + } + } +} + +// TestVsa2022_100 verifies https://blog.verichains.io/p/vsa-2022-100-tendermint-forging-membership-proof +func TestVsa2022_100(t *testing.T) { + // a fake key-value pair and its hash + key := []byte{0x13} + value := []byte{0x37} + vhash := hash(value) + bz := new(bytes.Buffer) + _ = encodeByteSlice(bz, key) + _ = encodeByteSlice(bz, vhash) + kvhash := hash(append([]byte{0}, bz.Bytes()...)) + + // the malicious `op` + op := NewValueOp( + key, + &Proof{LeafHash: kvhash}, + ) + + // the nil root + var root []byte + + assert.NotNil(t, ProofOperators{op}.Verify(root, "/"+string(key), [][]byte{value})) +} diff --git a/merkle/proof_value.go b/merkle/proof_value.go new file mode 100644 index 0000000..5ef8933 --- /dev/null +++ b/merkle/proof_value.go @@ -0,0 +1,108 @@ +package merkle + +import ( + "bytes" + "crypto/sha256" + "fmt" + + wire "github.com/celestiaorg/go-square/merkle/proto/gen/merkle/v1" + "google.golang.org/protobuf/proto" +) + +const ProofOpValue = "simple:v" + +// ValueOp takes a key and a single value as argument and +// produces the root hash. The corresponding tree structure is +// the SimpleMap tree. SimpleMap takes a Hasher, and currently +// CometBFT uses tmhash. SimpleValueOp should support +// the hash function as used in tmhash. TODO support +// additional hash functions here as options/args to this +// operator. +// +// If the produced root hash matches the expected hash, the +// proof is good. +type ValueOp struct { + // Encoded in ProofOp.Key. + key []byte + + // To encode in ProofOp.Data + Proof *Proof `json:"proof"` +} + +var _ ProofOperator = ValueOp{} + +func NewValueOp(key []byte, proof *Proof) ValueOp { + return ValueOp{ + key: key, + Proof: proof, + } +} + +func ValueOpDecoder(pop *wire.ProofOp) (ProofOperator, error) { + if pop.Type != ProofOpValue { + return nil, fmt.Errorf("unexpected ProofOp.Type; got %v, want %v", pop.Type, ProofOpValue) + } + var pbop wire.ValueOp // a bit strange as we'll discard this, but it works. + err := proto.Unmarshal(pop.Data, &pbop) + if err != nil { + return nil, fmt.Errorf("decoding ProofOp.Data into ValueOp: %w", err) + } + + sp, err := ProofFromProto(pbop.Proof) + if err != nil { + return nil, err + } + return NewValueOp(pop.Key, sp), nil +} + +func (op ValueOp) ProofOp() wire.ProofOp { + pbval := &wire.ValueOp{ + Key: op.key, + Proof: op.Proof.ToProto(), + } + bz, err := proto.Marshal(pbval) + if err != nil { + panic(err) + } + return wire.ProofOp{ + Type: ProofOpValue, + Key: op.key, + Data: bz, + } +} + +func (op ValueOp) String() string { + return fmt.Sprintf("ValueOp{%v}", op.GetKey()) +} + +func (op ValueOp) Run(args [][]byte) ([][]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("expected 1 arg, got %v", len(args)) + } + value := args[0] + hasher := sha256.New() + hasher.Write(value) + vhash := hasher.Sum(nil) + + bz := new(bytes.Buffer) + // Wrap to hash the KVPair. + encodeByteSlice(bz, op.key) //nolint: errcheck // does not error + encodeByteSlice(bz, vhash) //nolint: errcheck // does not error + kvhash := leafHash(bz.Bytes()) + + if !bytes.Equal(kvhash, op.Proof.LeafHash) { + return nil, fmt.Errorf("leaf hash mismatch: want %X got %X", op.Proof.LeafHash, kvhash) + } + + rootHash, err := op.Proof.computeRootHash() + if err != nil { + return nil, err + } + return [][]byte{ + rootHash, + }, nil +} + +func (op ValueOp) GetKey() []byte { + return op.key +} diff --git a/merkle/proto/buf.gen.yaml b/merkle/proto/buf.gen.yaml new file mode 100644 index 0000000..d887b47 --- /dev/null +++ b/merkle/proto/buf.gen.yaml @@ -0,0 +1,6 @@ +version: v1 +plugins: + - plugin: buf.build/protocolbuffers/go + out: gen + opt: + - paths=source_relative \ No newline at end of file diff --git a/merkle/proto/gen/merkle/v1/proof.pb.go b/merkle/proto/gen/merkle/v1/proof.pb.go new file mode 100644 index 0000000..f26aa80 --- /dev/null +++ b/merkle/proto/gen/merkle/v1/proof.pb.go @@ -0,0 +1,478 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.31.0 +// protoc (unknown) +// source: merkle/v1/proof.proto + +package merkle + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Proof struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Total int64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"` + Index int64 `protobuf:"varint,2,opt,name=index,proto3" json:"index,omitempty"` + LeafHash []byte `protobuf:"bytes,3,opt,name=leaf_hash,json=leafHash,proto3" json:"leaf_hash,omitempty"` + Aunts [][]byte `protobuf:"bytes,4,rep,name=aunts,proto3" json:"aunts,omitempty"` +} + +func (x *Proof) Reset() { + *x = Proof{} + if protoimpl.UnsafeEnabled { + mi := &file_merkle_v1_proof_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Proof) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Proof) ProtoMessage() {} + +func (x *Proof) ProtoReflect() protoreflect.Message { + mi := &file_merkle_v1_proof_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Proof.ProtoReflect.Descriptor instead. +func (*Proof) Descriptor() ([]byte, []int) { + return file_merkle_v1_proof_proto_rawDescGZIP(), []int{0} +} + +func (x *Proof) GetTotal() int64 { + if x != nil { + return x.Total + } + return 0 +} + +func (x *Proof) GetIndex() int64 { + if x != nil { + return x.Index + } + return 0 +} + +func (x *Proof) GetLeafHash() []byte { + if x != nil { + return x.LeafHash + } + return nil +} + +func (x *Proof) GetAunts() [][]byte { + if x != nil { + return x.Aunts + } + return nil +} + +type ValueOp struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Encoded in ProofOp.Key. + Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + // To encode in ProofOp.Data + Proof *Proof `protobuf:"bytes,2,opt,name=proof,proto3" json:"proof,omitempty"` +} + +func (x *ValueOp) Reset() { + *x = ValueOp{} + if protoimpl.UnsafeEnabled { + mi := &file_merkle_v1_proof_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ValueOp) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ValueOp) ProtoMessage() {} + +func (x *ValueOp) ProtoReflect() protoreflect.Message { + mi := &file_merkle_v1_proof_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ValueOp.ProtoReflect.Descriptor instead. +func (*ValueOp) Descriptor() ([]byte, []int) { + return file_merkle_v1_proof_proto_rawDescGZIP(), []int{1} +} + +func (x *ValueOp) GetKey() []byte { + if x != nil { + return x.Key + } + return nil +} + +func (x *ValueOp) GetProof() *Proof { + if x != nil { + return x.Proof + } + return nil +} + +type DominoOp struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Input string `protobuf:"bytes,2,opt,name=input,proto3" json:"input,omitempty"` + Output string `protobuf:"bytes,3,opt,name=output,proto3" json:"output,omitempty"` +} + +func (x *DominoOp) Reset() { + *x = DominoOp{} + if protoimpl.UnsafeEnabled { + mi := &file_merkle_v1_proof_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DominoOp) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DominoOp) ProtoMessage() {} + +func (x *DominoOp) ProtoReflect() protoreflect.Message { + mi := &file_merkle_v1_proof_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DominoOp.ProtoReflect.Descriptor instead. +func (*DominoOp) Descriptor() ([]byte, []int) { + return file_merkle_v1_proof_proto_rawDescGZIP(), []int{2} +} + +func (x *DominoOp) GetKey() string { + if x != nil { + return x.Key + } + return "" +} + +func (x *DominoOp) GetInput() string { + if x != nil { + return x.Input + } + return "" +} + +func (x *DominoOp) GetOutput() string { + if x != nil { + return x.Output + } + return "" +} + +// ProofOp defines an operation used for calculating Merkle root +// The data could be arbitrary format, providing necessary data +// for example neighbouring node hash +type ProofOp struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` + Key []byte `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` + Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *ProofOp) Reset() { + *x = ProofOp{} + if protoimpl.UnsafeEnabled { + mi := &file_merkle_v1_proof_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProofOp) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProofOp) ProtoMessage() {} + +func (x *ProofOp) ProtoReflect() protoreflect.Message { + mi := &file_merkle_v1_proof_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProofOp.ProtoReflect.Descriptor instead. +func (*ProofOp) Descriptor() ([]byte, []int) { + return file_merkle_v1_proof_proto_rawDescGZIP(), []int{3} +} + +func (x *ProofOp) GetType() string { + if x != nil { + return x.Type + } + return "" +} + +func (x *ProofOp) GetKey() []byte { + if x != nil { + return x.Key + } + return nil +} + +func (x *ProofOp) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +// ProofOps is Merkle proof defined by the list of ProofOps +type ProofOps struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Ops []*ProofOp `protobuf:"bytes,1,rep,name=ops,proto3" json:"ops,omitempty"` +} + +func (x *ProofOps) Reset() { + *x = ProofOps{} + if protoimpl.UnsafeEnabled { + mi := &file_merkle_v1_proof_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProofOps) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProofOps) ProtoMessage() {} + +func (x *ProofOps) ProtoReflect() protoreflect.Message { + mi := &file_merkle_v1_proof_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProofOps.ProtoReflect.Descriptor instead. +func (*ProofOps) Descriptor() ([]byte, []int) { + return file_merkle_v1_proof_proto_rawDescGZIP(), []int{4} +} + +func (x *ProofOps) GetOps() []*ProofOp { + if x != nil { + return x.Ops + } + return nil +} + +var File_merkle_v1_proof_proto protoreflect.FileDescriptor + +var file_merkle_v1_proof_proto_rawDesc = []byte{ + 0x0a, 0x15, 0x6d, 0x65, 0x72, 0x6b, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x72, 0x6f, 0x6f, + 0x66, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, 0x73, 0x71, 0x75, 0x61, 0x72, 0x65, 0x2e, + 0x6d, 0x65, 0x72, 0x6b, 0x6c, 0x65, 0x22, 0x66, 0x0a, 0x05, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, + 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, + 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x12, 0x1b, 0x0a, 0x09, 0x6c, + 0x65, 0x61, 0x66, 0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, + 0x6c, 0x65, 0x61, 0x66, 0x48, 0x61, 0x73, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x61, 0x75, 0x6e, 0x74, + 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x05, 0x61, 0x75, 0x6e, 0x74, 0x73, 0x22, 0x47, + 0x0a, 0x07, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x4f, 0x70, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2a, 0x0a, 0x05, 0x70, + 0x72, 0x6f, 0x6f, 0x66, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x71, 0x75, + 0x61, 0x72, 0x65, 0x2e, 0x6d, 0x65, 0x72, 0x6b, 0x6c, 0x65, 0x2e, 0x50, 0x72, 0x6f, 0x6f, 0x66, + 0x52, 0x05, 0x70, 0x72, 0x6f, 0x6f, 0x66, 0x22, 0x4a, 0x0a, 0x08, 0x44, 0x6f, 0x6d, 0x69, 0x6e, + 0x6f, 0x4f, 0x70, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6f, + 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f, 0x75, 0x74, + 0x70, 0x75, 0x74, 0x22, 0x43, 0x0a, 0x07, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x4f, 0x70, 0x12, 0x12, + 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79, + 0x70, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x03, 0x6b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x6f, + 0x66, 0x4f, 0x70, 0x73, 0x12, 0x28, 0x0a, 0x03, 0x6f, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x73, 0x71, 0x75, 0x61, 0x72, 0x65, 0x2e, 0x6d, 0x65, 0x72, 0x6b, 0x6c, + 0x65, 0x2e, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x4f, 0x70, 0x52, 0x03, 0x6f, 0x70, 0x73, 0x42, 0x29, + 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x65, 0x6c, + 0x65, 0x73, 0x74, 0x69, 0x61, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x71, 0x75, 0x61, + 0x72, 0x65, 0x2f, 0x6d, 0x65, 0x72, 0x6b, 0x6c, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, +} + +var ( + file_merkle_v1_proof_proto_rawDescOnce sync.Once + file_merkle_v1_proof_proto_rawDescData = file_merkle_v1_proof_proto_rawDesc +) + +func file_merkle_v1_proof_proto_rawDescGZIP() []byte { + file_merkle_v1_proof_proto_rawDescOnce.Do(func() { + file_merkle_v1_proof_proto_rawDescData = protoimpl.X.CompressGZIP(file_merkle_v1_proof_proto_rawDescData) + }) + return file_merkle_v1_proof_proto_rawDescData +} + +var file_merkle_v1_proof_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_merkle_v1_proof_proto_goTypes = []interface{}{ + (*Proof)(nil), // 0: square.merkle.Proof + (*ValueOp)(nil), // 1: square.merkle.ValueOp + (*DominoOp)(nil), // 2: square.merkle.DominoOp + (*ProofOp)(nil), // 3: square.merkle.ProofOp + (*ProofOps)(nil), // 4: square.merkle.ProofOps +} +var file_merkle_v1_proof_proto_depIdxs = []int32{ + 0, // 0: square.merkle.ValueOp.proof:type_name -> square.merkle.Proof + 3, // 1: square.merkle.ProofOps.ops:type_name -> square.merkle.ProofOp + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_merkle_v1_proof_proto_init() } +func file_merkle_v1_proof_proto_init() { + if File_merkle_v1_proof_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_merkle_v1_proof_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Proof); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_merkle_v1_proof_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValueOp); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_merkle_v1_proof_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DominoOp); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_merkle_v1_proof_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ProofOp); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_merkle_v1_proof_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ProofOps); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_merkle_v1_proof_proto_rawDesc, + NumEnums: 0, + NumMessages: 5, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_merkle_v1_proof_proto_goTypes, + DependencyIndexes: file_merkle_v1_proof_proto_depIdxs, + MessageInfos: file_merkle_v1_proof_proto_msgTypes, + }.Build() + File_merkle_v1_proof_proto = out.File + file_merkle_v1_proof_proto_rawDesc = nil + file_merkle_v1_proof_proto_goTypes = nil + file_merkle_v1_proof_proto_depIdxs = nil +} diff --git a/merkle/proto/merkle/v1/proof.proto b/merkle/proto/merkle/v1/proof.proto new file mode 100644 index 0000000..47fe221 --- /dev/null +++ b/merkle/proto/merkle/v1/proof.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; +package square.merkle; + +option go_package = "github.com/celestiaorg/go-square/merkle"; + +message Proof { + int64 total = 1; + int64 index = 2; + bytes leaf_hash = 3; + repeated bytes aunts = 4; +} + +message ValueOp { + // Encoded in ProofOp.Key. + bytes key = 1; + + // To encode in ProofOp.Data + Proof proof = 2; +} + +message DominoOp { + string key = 1; + string input = 2; + string output = 3; +} + +// ProofOp defines an operation used for calculating Merkle root +// The data could be arbitrary format, providing necessary data +// for example neighbouring node hash +message ProofOp { + string type = 1; + bytes key = 2; + bytes data = 3; +} + +// ProofOps is Merkle proof defined by the list of ProofOps +message ProofOps { + repeated ProofOp ops = 1; +} diff --git a/merkle/rfc6962_test.go b/merkle/rfc6962_test.go new file mode 100644 index 0000000..ae4b54a --- /dev/null +++ b/merkle/rfc6962_test.go @@ -0,0 +1,104 @@ +package merkle + +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// These tests were taken from https://github.com/google/trillian/blob/master/merkle/rfc6962/rfc6962_test.go, +// and consequently fall under the above license. +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "testing" +) + +func TestRFC6962Hasher(t *testing.T) { + _, leafHashTrail := trailsFromByteSlices([][]byte{[]byte("L123456")}) + leafHash := leafHashTrail.Hash + _, leafHashTrail = trailsFromByteSlices([][]byte{{}}) + emptyLeafHash := leafHashTrail.Hash + _, emptyHashTrail := trailsFromByteSlices([][]byte{}) + emptyTreeHash := emptyHashTrail.Hash + for _, tc := range []struct { + desc string + got []byte + want string + }{ + // Check that empty trees return the hash of an empty string. + // echo -n '' | sha256sum + { + desc: "RFC6962 Empty Tree", + want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"[:sha256.Size*2], + got: emptyTreeHash, + }, + + // Check that the empty hash is not the same as the hash of an empty leaf. + // echo -n 00 | xxd -r -p | sha256sum + { + desc: "RFC6962 Empty Leaf", + want: "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d"[:sha256.Size*2], + got: emptyLeafHash, + }, + // echo -n 004C313233343536 | xxd -r -p | sha256sum + { + desc: "RFC6962 Leaf", + want: "395aa064aa4c29f7010acfe3f25db9485bbd4b91897b6ad7ad547639252b4d56"[:sha256.Size*2], + got: leafHash, + }, + // echo -n 014E3132334E343536 | xxd -r -p | sha256sum + { + desc: "RFC6962 Node", + want: "aa217fe888e47007fa15edab33c2b492a722cb106c64667fc2b044444de66bbb"[:sha256.Size*2], + got: innerHash([]byte("N123"), []byte("N456")), + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + wantBytes, err := hex.DecodeString(tc.want) + if err != nil { + t.Fatalf("hex.DecodeString(%x): %v", tc.want, err) + } + if got, want := tc.got, wantBytes; !bytes.Equal(got, want) { + t.Errorf("got %x, want %x", got, want) + } + }) + } +} + +func TestRFC6962HasherCollisions(t *testing.T) { + // Check that different leaves have different hashes. + leaf1, leaf2 := []byte("Hello"), []byte("World") + _, leafHashTrail := trailsFromByteSlices([][]byte{leaf1}) + hash1 := leafHashTrail.Hash + _, leafHashTrail = trailsFromByteSlices([][]byte{leaf2}) + hash2 := leafHashTrail.Hash + if bytes.Equal(hash1, hash2) { + t.Errorf("leaf hashes should differ, but both are %x", hash1) + } + // Compute an intermediate subtree hash. + _, subHash1Trail := trailsFromByteSlices([][]byte{hash1, hash2}) + subHash1 := subHash1Trail.Hash + // Check that this is not the same as a leaf hash of their concatenation. + preimage := append(hash1, hash2...) + _, forgedHashTrail := trailsFromByteSlices([][]byte{preimage}) + forgedHash := forgedHashTrail.Hash + if bytes.Equal(subHash1, forgedHash) { + t.Errorf("hasher is not second-preimage resistant") + } + // Swap the order of nodes and check that the hash is different. + _, subHash2Trail := trailsFromByteSlices([][]byte{hash2, hash1}) + subHash2 := subHash2Trail.Hash + if bytes.Equal(subHash1, subHash2) { + t.Errorf("subtree hash does not depend on the order of leaves") + } +} diff --git a/merkle/tree.go b/merkle/tree.go new file mode 100644 index 0000000..089c2f8 --- /dev/null +++ b/merkle/tree.go @@ -0,0 +1,106 @@ +package merkle + +import ( + "math/bits" +) + +// HashFromByteSlices computes a Merkle tree where the leaves are the byte slice, +// in the provided order. It follows RFC-6962. +func HashFromByteSlices(items [][]byte) []byte { + switch len(items) { + case 0: + return emptyHash() + case 1: + return leafHash(items[0]) + default: + k := getSplitPoint(int64(len(items))) + left := HashFromByteSlices(items[:k]) + right := HashFromByteSlices(items[k:]) + return innerHash(left, right) + } +} + +// HashFromByteSliceIterative is an iterative alternative to +// HashFromByteSlice motivated by potential performance improvements. +// (#2611) had suggested that an iterative version of +// HashFromByteSlice would be faster, presumably because +// we can envision some overhead accumulating from stack +// frames and function calls. Additionally, a recursive algorithm risks +// hitting the stack limit and causing a stack overflow should the tree +// be too large. +// +// Provided here is an iterative alternative, a test to assert +// correctness and a benchmark. On the performance side, there appears to +// be no overall difference: +// +// BenchmarkHashAlternatives/recursive-4 20000 77677 ns/op +// BenchmarkHashAlternatives/iterative-4 20000 76802 ns/op +// +// On the surface it might seem that the additional overhead is due to +// the different allocation patterns of the implementations. The recursive +// version uses a single [][]byte slices which it then re-slices at each level of the tree. +// The iterative version reproduces [][]byte once within the function and +// then rewrites sub-slices of that array at each level of the tree. +// +// Experimenting by modifying the code to simply calculate the +// hash and not store the result show little to no difference in performance. +// +// These preliminary results suggest: +// +// 1. The performance of the HashFromByteSlice is pretty good +// 2. Go has low overhead for recursive functions +// 3. The performance of the HashFromByteSlice routine is dominated +// by the actual hashing of data +// +// Although this work is in no way exhaustive, point #3 suggests that +// optimization of this routine would need to take an alternative +// approach to make significant improvements on the current performance. +// +// Finally, considering that the recursive implementation is easier to +// read, it might not be worthwhile to switch to a less intuitive +// implementation for so little benefit. +func HashFromByteSlicesIterative(input [][]byte) []byte { + items := make([][]byte, len(input)) + + for i, leaf := range input { + items[i] = leafHash(leaf) + } + + size := len(items) + for { + switch size { + case 0: + return emptyHash() + case 1: + return items[0] + default: + rp := 0 // read position + wp := 0 // write position + for rp < size { + if rp+1 < size { + items[wp] = innerHash(items[rp], items[rp+1]) + rp += 2 + } else { + items[wp] = items[rp] + rp++ + } + wp++ + } + size = wp + } + } +} + +// getSplitPoint returns the largest power of 2 less than length +func getSplitPoint(length int64) int64 { + if length < 1 { + panic("Trying to split a tree with size < 1") + } + uLength := uint(length) + bitlen := bits.Len(uLength) + k := int64(1 << uint(bitlen-1)) + if k == length { + k >>= 1 + } + return k +} diff --git a/merkle/tree_test.go b/merkle/tree_test.go new file mode 100644 index 0000000..78c1783 --- /dev/null +++ b/merkle/tree_test.go @@ -0,0 +1,188 @@ +package merkle + +import ( + crand "crypto/rand" + "crypto/sha256" + "encoding/hex" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testItem []byte + +func (tI testItem) Hash() []byte { + return []byte(tI) +} + +func TestHashFromByteSlices(t *testing.T) { + testcases := map[string]struct { + slices [][]byte + expectHash string // in hex format + }{ + "nil": {nil, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"}, + "empty": {[][]byte{}, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"}, + "single": {[][]byte{{1, 2, 3}}, "054edec1d0211f624fed0cbca9d4f9400b0e491c43742af2c5b0abebf0c990d8"}, + "single blank": {[][]byte{{}}, "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d"}, + "two": {[][]byte{{1, 2, 3}, {4, 5, 6}}, "82e6cfce00453804379b53962939eaa7906b39904be0813fcadd31b100773c4b"}, + "many": { + [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}}, + "f326493eceab4f2d9ffbc78c59432a0a005d6ea98392045c74df5d14a113be18", + }, + } + for name, tc := range testcases { + tc := tc + t.Run(name, func(t *testing.T) { + hash := HashFromByteSlices(tc.slices) + assert.Equal(t, tc.expectHash, hex.EncodeToString(hash)) + }) + } +} + +func TestProof(t *testing.T) { + + // Try an empty proof first + rootHash, proofs := ProofsFromByteSlices([][]byte{}) + require.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hex.EncodeToString(rootHash)) + require.Empty(t, proofs) + + total := 100 + + items := make([][]byte, total) + for i := 0; i < total; i++ { + items[i] = testItem(randBytes(sha256.Size)) + } + + rootHash = HashFromByteSlices(items) + + rootHash2, proofs := ProofsFromByteSlices(items) + + require.Equal(t, rootHash, rootHash2, "Unmatched root hashes: %X vs %X", rootHash, rootHash2) + + // For each item, check the trail. + for i, item := range items { + proof := proofs[i] + + // Check total/index + require.EqualValues(t, proof.Index, i, "Unmatched indices: %d vs %d", proof.Index, i) + + require.EqualValues(t, proof.Total, total, "Unmatched totals: %d vs %d", proof.Total, total) + + // Verify success + err := proof.Verify(rootHash, item) + require.NoError(t, err, "Verification failed: %v.", err) + + // Trail too long should make it fail + origAunts := proof.Aunts + proof.Aunts = append(proof.Aunts, randBytes(32)) + err = proof.Verify(rootHash, item) + require.Error(t, err, "Expected verification to fail for wrong trail length") + + proof.Aunts = origAunts + + // Trail too short should make it fail + proof.Aunts = proof.Aunts[0 : len(proof.Aunts)-1] + err = proof.Verify(rootHash, item) + require.Error(t, err, "Expected verification to fail for wrong trail length") + + proof.Aunts = origAunts + + // Mutating the itemHash should make it fail. + err = proof.Verify(rootHash, mutateByteSlice(item)) + require.Error(t, err, "Expected verification to fail for mutated leaf hash") + + // Mutating the rootHash should make it fail. + err = proof.Verify(mutateByteSlice(rootHash), item) + require.Error(t, err, "Expected verification to fail for mutated root hash") + } +} + +func TestHashAlternatives(t *testing.T) { + + total := 100 + + items := make([][]byte, total) + for i := 0; i < total; i++ { + items[i] = testItem(randBytes(sha256.Size)) + } + + rootHash1 := HashFromByteSlicesIterative(items) + rootHash2 := HashFromByteSlices(items) + require.Equal(t, rootHash1, rootHash2, "Unmatched root hashes: %X vs %X", rootHash1, rootHash2) +} + +func BenchmarkHashAlternatives(b *testing.B) { + total := 100 + + items := make([][]byte, total) + for i := 0; i < total; i++ { + items[i] = testItem(randBytes(sha256.Size)) + } + + b.ResetTimer() + b.Run("recursive", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = HashFromByteSlices(items) + } + }) + + b.Run("iterative", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = HashFromByteSlicesIterative(items) + } + }) +} + +func Test_getSplitPoint(t *testing.T) { + tests := []struct { + length int64 + want int64 + }{ + {1, 0}, + {2, 1}, + {3, 2}, + {4, 2}, + {5, 4}, + {10, 8}, + {20, 16}, + {100, 64}, + {255, 128}, + {256, 128}, + {257, 256}, + } + for _, tt := range tests { + got := getSplitPoint(tt.length) + require.EqualValues(t, tt.want, got, "getSplitPoint(%d) = %v, want %v", tt.length, got, tt.want) + } +} + +func randBytes(size int) []byte { + b := make([]byte, size) + _, _ = crand.Read(b) + return b +} + +// Contract: !bytes.Equal(input, output) && len(input) >= len(output) +func mutateByteSlice(bytez []byte) []byte { + // If bytez is empty, panic + if len(bytez) == 0 { + panic("Cannot mutate an empty bytez") + } + + // Copy bytez + mBytez := make([]byte, len(bytez)) + copy(mBytez, bytez) + bytez = mBytez + + // Try a random mutation + switch rand.Int() % 2 { + case 0: // Mutate a single byte + bytez[rand.Int()%len(bytez)] += byte(rand.Int()%255 + 1) + case 1: // Remove an arbitrary byte + pos := rand.Int() % len(bytez) + bytez = append(bytez[:pos], bytez[pos+1:]...) + } + return bytez +} diff --git a/merkle/types.go b/merkle/types.go new file mode 100644 index 0000000..6a5c7e6 --- /dev/null +++ b/merkle/types.go @@ -0,0 +1,39 @@ +package merkle + +import ( + "encoding/binary" + "io" +) + +// Tree is a Merkle tree interface. +type Tree interface { + Size() (size int) + Height() (height int8) + Has(key []byte) (has bool) + Proof(key []byte) (value []byte, proof []byte, exists bool) // TODO make it return an index + Get(key []byte) (index int, value []byte, exists bool) + GetByIndex(index int) (key []byte, value []byte) + Set(key []byte, value []byte) (updated bool) + Remove(key []byte) (value []byte, removed bool) + HashWithCount() (hash []byte, count int) + Hash() (hash []byte) + Save() (hash []byte) + Load(hash []byte) + Copy() Tree + Iterate(func(key []byte, value []byte) (stop bool)) (stopped bool) + IterateRange(start []byte, end []byte, ascending bool, fx func(key []byte, value []byte) (stop bool)) (stopped bool) +} + +//----------------------------------------------------------------------- + +// Uvarint length prefixed byteslice +func encodeByteSlice(w io.Writer, bz []byte) (err error) { + var buf [binary.MaxVarintLen64]byte + n := binary.PutUvarint(buf[:], uint64(len(bz))) + _, err = w.Write(buf[0:n]) + if err != nil { + return + } + _, err = w.Write(bz) + return +}