From 6c4baae8e6460fa8d0df2d1fa4b8fec7c23f9f1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20Irmak?= Date: Thu, 31 Oct 2024 16:29:13 +0300 Subject: [PATCH] feat: bring ZkTrie in (#1076) Co-authored-by: colinlyguo --- cmd/evm/runner.go | 2 - cmd/utils/flags.go | 14 +- core/blockchain.go | 14 +- core/blockchain_repair_test.go | 3 +- core/blockchain_sethead_test.go | 2 +- core/blockchain_snapshot_test.go | 3 + core/blockchain_test.go | 91 +- core/chain_makers.go | 12 +- core/chain_makers_test.go | 1 + core/forkid/forkid_test.go | 3 + core/genesis.go | 20 +- core/genesis_test.go | 10 +- core/state/database.go | 16 - core/state/iterator_test.go | 3 + core/state/snapshot/difflayer_test.go | 3 + core/state/snapshot/disklayer_test.go | 3 + core/state/snapshot/generate_test.go | 3 + core/state/snapshot/holdable_iterator_test.go | 3 + core/state/snapshot/iterator_test.go | 3 + core/state/snapshot/snapshot_test.go | 3 + core/state/state_object.go | 21 +- core/state/state_prove.go | 85 -- core/state/state_test.go | 2 + core/state/statedb.go | 12 +- core/state/statedb_fuzz_test.go | 1 + core/state/statedb_test.go | 8 +- core/state/sync_test.go | 3 + core/state_processor.go | 2 +- core/state_processor_test.go | 8 +- core/txpool/blobpool/blobpool_test.go | 7 +- core/txpool/legacypool/legacypool_test.go | 1 + core/types/hashes.go | 14 +- core/types/hashing_test.go | 1 + core/types/state_account_marshalling.go | 6 +- core/types/state_account_marshalling_test.go | 46 +- core/vm/gas_table_test.go | 81 +- core/vm/logger.go | 2 - core/vm/runtime/runtime_test.go | 17 +- eth/filters/filter_system_test.go | 2 +- eth/filters/filter_test.go | 2 +- eth/tracers/js/goja.go | 4 - eth/tracers/logger/access_list_tracer.go | 4 - eth/tracers/logger/logger.go | 4 - eth/tracers/logger/logger_json.go | 4 - eth/tracers/native/4byte.go | 4 - eth/tracers/native/call.go | 4 - eth/tracers/native/call_flat.go | 4 - eth/tracers/native/mux.go | 4 - eth/tracers/native/noop.go | 2 - go.mod | 1 - go.sum | 2 - internal/ethapi/api_test.go | 7 +- rollup/ccc/async_checker_test.go | 2 +- rollup/ccc/logger.go | 4 - .../tracing/proof.go | 65 +- rollup/tracing/proof_test.go | 111 ++ rollup/tracing/tracing.go | 49 +- tests/fuzzers/trie/trie-fuzzer.go | 2 +- trie/byte32.go | 42 + trie/byte32_test.go | 44 + trie/database.go | 54 +- trie/database_supplement.go | 32 - trie/hash.go | 149 ++ trie/hash_test.go | 83 + trie/iterator.go | 4 +- trie/iterator_test.go | 28 +- trie/node_test.go | 3 + trie/proof.go | 7 +- trie/proof_test.go | 3 + trie/secure_trie.go | 50 +- trie/secure_trie_test.go | 9 +- trie/stacktrie.go | 7 +- trie/stacktrie_test.go | 9 + trie/sync.go | 2 +- trie/sync_test.go | 15 +- trie/tracer_test.go | 11 +- trie/trie.go | 10 +- trie/trie_reader.go | 13 +- trie/trie_test.go | 29 +- trie/triedb/hashdb/database.go | 8 +- trie/triedb/hashdb/database_supplement.go | 15 - trie/util.go | 117 ++ trie/util_test.go | 96 ++ trie/zk_trie.go | 1346 +++++++++++++++-- trie/zk_trie_database.go | 172 --- trie/zk_trie_database_test.go | 63 - trie/zk_trie_impl_test.go | 289 ---- trie/zk_trie_node.go | 405 +++++ trie/zk_trie_node_test.go | 240 +++ trie/zk_trie_proof_test.go | 200 +-- trie/zk_trie_test.go | 919 ++++++++--- 91 files changed, 3653 insertions(+), 1631 deletions(-) delete mode 100644 core/state/state_prove.go rename trie/zktrie_deletionproof.go => rollup/tracing/proof.go (71%) create mode 100644 rollup/tracing/proof_test.go create mode 100644 trie/byte32.go create mode 100644 trie/byte32_test.go delete mode 100644 trie/database_supplement.go create mode 100644 trie/hash.go create mode 100644 trie/hash_test.go delete mode 100644 trie/triedb/hashdb/database_supplement.go create mode 100644 trie/util.go create mode 100644 trie/util_test.go delete mode 100644 trie/zk_trie_database.go delete mode 100644 trie/zk_trie_database_test.go delete mode 100644 trie/zk_trie_impl_test.go create mode 100644 trie/zk_trie_node.go create mode 100644 trie/zk_trie_node_test.go diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index ea539490edde..81d2c6d62256 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -151,8 +151,6 @@ func runCmd(ctx *cli.Context) error { triedb := trie.NewDatabase(db, &trie.Config{ Preimages: preimages, HashDB: hashdb.Defaults, - // scroll related - IsUsingZktrie: genesisConfig.Config.Scroll.ZktrieEnabled(), }) defer triedb.Close() genesis := genesisConfig.MustCommit(db, triedb) diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 25a61c20838e..f934dd4801ed 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -283,7 +283,7 @@ var ( GCModeFlag = &cli.StringFlag{ Name: "gcmode", Usage: `Blockchain garbage collection mode, only relevant in state.scheme=hash ("full", "archive")`, - Value: GCModeArchive, + Value: GCModeFull, Category: flags.StateCategory, } StateSchemeFlag = &cli.StringFlag{ @@ -2056,12 +2056,6 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) { stack.Config().L1Confirmations = rpc.FinalizedBlockNumber log.Info("Setting flag", "--l1.sync.startblock", "4038000") stack.Config().L1DeploymentBlock = 4038000 - // disable pruning - if ctx.String(GCModeFlag.Name) != GCModeArchive { - log.Crit("Must use --gcmode=archive") - } - log.Info("Pruning disabled") - cfg.NoPruning = true case ctx.Bool(ScrollFlag.Name): if !ctx.IsSet(NetworkIdFlag.Name) { cfg.NetworkId = 534352 @@ -2072,12 +2066,6 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) { stack.Config().L1Confirmations = rpc.FinalizedBlockNumber log.Info("Setting flag", "--l1.sync.startblock", "18306000") stack.Config().L1DeploymentBlock = 18306000 - // disable pruning - if ctx.String(GCModeFlag.Name) != GCModeArchive { - log.Crit("Must use --gcmode=archive") - } - log.Info("Pruning disabled") - cfg.NoPruning = true case ctx.Bool(DeveloperFlag.Name): if !ctx.IsSet(NetworkIdFlag.Name) { cfg.NetworkId = 1337 diff --git a/core/blockchain.go b/core/blockchain.go index 0306f1681b3c..db608897e297 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -153,8 +153,8 @@ type CacheConfig struct { } // triedbConfig derives the configures for trie database. -func (c *CacheConfig) triedbConfig(isUsingZktrie bool) *trie.Config { - config := &trie.Config{Preimages: c.Preimages, IsUsingZktrie: isUsingZktrie} +func (c *CacheConfig) triedbConfig() *trie.Config { + config := &trie.Config{Preimages: c.Preimages} if c.StateScheme == rawdb.HashScheme { config.HashDB = &hashdb.Config{ CleanCacheSize: c.TrieCleanLimit * 1024 * 1024, @@ -176,8 +176,8 @@ var defaultCacheConfig = &CacheConfig{ TrieCleanLimit: 256, TrieDirtyLimit: 256, TrieTimeLimit: 5 * time.Minute, - SnapshotLimit: 256, - SnapshotWait: true, + SnapshotLimit: 0, // Snapshots don't support zkTrie yet + SnapshotWait: false, StateScheme: rawdb.HashScheme, } @@ -272,11 +272,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis cacheConfig = defaultCacheConfig } // Open trie database with provided config - triedbConfig := cacheConfig.triedbConfig(false) - if genesis != nil && genesis.Config != nil && genesis.Config.Scroll.ZktrieEnabled() { - cacheConfig.triedbConfig(true) - } - triedb := trie.NewDatabase(db, triedbConfig) + triedb := trie.NewDatabase(db, cacheConfig.triedbConfig()) // Setup the genesis block, commit the provided genesis specification // to database if the genesis block is not present yet, or load the diff --git a/core/blockchain_repair_test.go b/core/blockchain_repair_test.go index 987816d95cb5..9edb9b8c5cbe 100644 --- a/core/blockchain_repair_test.go +++ b/core/blockchain_repair_test.go @@ -1750,7 +1750,7 @@ func testLongReorgedSnapSyncingDeepRepair(t *testing.T, snapshots bool) { } func testRepair(t *testing.T, tt *rewindTest, snapshots bool) { - for _, scheme := range []string{rawdb.HashScheme, rawdb.PathScheme} { + for _, scheme := range []string{rawdb.HashScheme /*, rawdb.PathScheme*/} { testRepairWithScheme(t, tt, snapshots, scheme) } } @@ -1898,6 +1898,7 @@ func testRepairWithScheme(t *testing.T, tt *rewindTest, snapshots bool, scheme s // In this case the snapshot layer of B3 is not created because of existent // state. func TestIssue23496(t *testing.T) { + t.Skip("snapshot doesn't support zktrie") testIssue23496(t, rawdb.HashScheme) testIssue23496(t, rawdb.PathScheme) } diff --git a/core/blockchain_sethead_test.go b/core/blockchain_sethead_test.go index e16fe57eec50..3e98a76f54b7 100644 --- a/core/blockchain_sethead_test.go +++ b/core/blockchain_sethead_test.go @@ -1954,7 +1954,7 @@ func testLongReorgedSnapSyncingDeepSetHead(t *testing.T, snapshots bool) { } func testSetHead(t *testing.T, tt *rewindTest, snapshots bool) { - for _, scheme := range []string{rawdb.HashScheme, rawdb.PathScheme} { + for _, scheme := range []string{rawdb.HashScheme /*, rawdb.PathScheme*/} { testSetHeadWithScheme(t, tt, snapshots, scheme) } } diff --git a/core/blockchain_snapshot_test.go b/core/blockchain_snapshot_test.go index 4a0900657dd3..89cf7bd22159 100644 --- a/core/blockchain_snapshot_test.go +++ b/core/blockchain_snapshot_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2020 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/blockchain_test.go b/core/blockchain_test.go index bc4772f46d74..548f0868f736 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -279,7 +279,7 @@ func TestExtendCanonicalHeaders(t *testing.T) { } func TestExtendCanonicalBlocks(t *testing.T) { testExtendCanonical(t, true, rawdb.HashScheme) - testExtendCanonical(t, true, rawdb.PathScheme) + // testExtendCanonical(t, true, rawdb.PathScheme) } func testExtendCanonical(t *testing.T, full bool, scheme string) { @@ -313,7 +313,7 @@ func TestExtendCanonicalHeadersAfterMerge(t *testing.T) { } func TestExtendCanonicalBlocksAfterMerge(t *testing.T) { testExtendCanonicalAfterMerge(t, true, rawdb.HashScheme) - testExtendCanonicalAfterMerge(t, true, rawdb.PathScheme) + // testExtendCanonicalAfterMerge(t, true, rawdb.PathScheme) } func testExtendCanonicalAfterMerge(t *testing.T, full bool, scheme string) { @@ -338,7 +338,7 @@ func TestShorterForkHeaders(t *testing.T) { } func TestShorterForkBlocks(t *testing.T) { testShorterFork(t, true, rawdb.HashScheme) - testShorterFork(t, true, rawdb.PathScheme) + // testShorterFork(t, true, rawdb.PathScheme) } func testShorterFork(t *testing.T, full bool, scheme string) { @@ -374,7 +374,7 @@ func TestShorterForkHeadersAfterMerge(t *testing.T) { } func TestShorterForkBlocksAfterMerge(t *testing.T) { testShorterForkAfterMerge(t, true, rawdb.HashScheme) - testShorterForkAfterMerge(t, true, rawdb.PathScheme) + // testShorterForkAfterMerge(t, true, rawdb.PathScheme) } func testShorterForkAfterMerge(t *testing.T, full bool, scheme string) { @@ -403,7 +403,7 @@ func TestLongerForkHeaders(t *testing.T) { } func TestLongerForkBlocks(t *testing.T) { testLongerFork(t, true, rawdb.HashScheme) - testLongerFork(t, true, rawdb.PathScheme) + // testLongerFork(t, true, rawdb.PathScheme) } func testLongerFork(t *testing.T, full bool, scheme string) { @@ -432,7 +432,7 @@ func TestLongerForkHeadersAfterMerge(t *testing.T) { } func TestLongerForkBlocksAfterMerge(t *testing.T) { testLongerForkAfterMerge(t, true, rawdb.HashScheme) - testLongerForkAfterMerge(t, true, rawdb.PathScheme) + // testLongerForkAfterMerge(t, true, rawdb.PathScheme) } func testLongerForkAfterMerge(t *testing.T, full bool, scheme string) { @@ -461,7 +461,7 @@ func TestEqualForkHeaders(t *testing.T) { } func TestEqualForkBlocks(t *testing.T) { testEqualFork(t, true, rawdb.HashScheme) - testEqualFork(t, true, rawdb.PathScheme) + // testEqualFork(t, true, rawdb.PathScheme) } func testEqualFork(t *testing.T, full bool, scheme string) { @@ -497,7 +497,7 @@ func TestEqualForkHeadersAfterMerge(t *testing.T) { } func TestEqualForkBlocksAfterMerge(t *testing.T) { testEqualForkAfterMerge(t, true, rawdb.HashScheme) - testEqualForkAfterMerge(t, true, rawdb.PathScheme) + // testEqualForkAfterMerge(t, true, rawdb.PathScheme) } func testEqualForkAfterMerge(t *testing.T, full bool, scheme string) { @@ -525,7 +525,7 @@ func TestBrokenHeaderChain(t *testing.T) { } func TestBrokenBlockChain(t *testing.T) { testBrokenChain(t, true, rawdb.HashScheme) - testBrokenChain(t, true, rawdb.PathScheme) + // testBrokenChain(t, true, rawdb.PathScheme) } func testBrokenChain(t *testing.T, full bool, scheme string) { @@ -558,7 +558,7 @@ func TestReorgLongHeaders(t *testing.T) { } func TestReorgLongBlocks(t *testing.T) { testReorgLong(t, true, rawdb.HashScheme) - testReorgLong(t, true, rawdb.PathScheme) + // testReorgLong(t, true, rawdb.PathScheme) } func testReorgLong(t *testing.T, full bool, scheme string) { @@ -573,7 +573,7 @@ func TestReorgShortHeaders(t *testing.T) { } func TestReorgShortBlocks(t *testing.T) { testReorgShort(t, true, rawdb.HashScheme) - testReorgShort(t, true, rawdb.PathScheme) + // testReorgShort(t, true, rawdb.PathScheme) } func testReorgShort(t *testing.T, full bool, scheme string) { @@ -667,7 +667,7 @@ func TestBadHeaderHashes(t *testing.T) { } func TestBadBlockHashes(t *testing.T) { testBadHashes(t, true, rawdb.HashScheme) - testBadHashes(t, true, rawdb.PathScheme) + // testBadHashes(t, true, rawdb.PathScheme) } func testBadHashes(t *testing.T, full bool, scheme string) { @@ -707,7 +707,7 @@ func TestReorgBadHeaderHashes(t *testing.T) { } func TestReorgBadBlockHashes(t *testing.T) { testReorgBadHashes(t, true, rawdb.HashScheme) - testReorgBadHashes(t, true, rawdb.PathScheme) + // testReorgBadHashes(t, true, rawdb.PathScheme) } func testReorgBadHashes(t *testing.T, full bool, scheme string) { @@ -768,7 +768,7 @@ func TestHeadersInsertNonceError(t *testing.T) { } func TestBlocksInsertNonceError(t *testing.T) { testInsertNonceError(t, true, rawdb.HashScheme) - testInsertNonceError(t, true, rawdb.PathScheme) + // testInsertNonceError(t, true, rawdb.PathScheme) } func testInsertNonceError(t *testing.T, full bool, scheme string) { @@ -830,7 +830,7 @@ func testInsertNonceError(t *testing.T, full bool, scheme string) { // classical full block processing. func TestFastVsFullChains(t *testing.T) { testFastVsFullChains(t, rawdb.HashScheme) - testFastVsFullChains(t, rawdb.PathScheme) + // testFastVsFullChains(t, rawdb.PathScheme) } func testFastVsFullChains(t *testing.T, scheme string) { @@ -963,7 +963,7 @@ func testFastVsFullChains(t *testing.T, scheme string) { // positions. func TestLightVsFastVsFullChainHeads(t *testing.T) { testLightVsFastVsFullChainHeads(t, rawdb.HashScheme) - testLightVsFastVsFullChainHeads(t, rawdb.PathScheme) + // testLightVsFastVsFullChainHeads(t, rawdb.PathScheme) } func testLightVsFastVsFullChainHeads(t *testing.T, scheme string) { @@ -1080,7 +1080,7 @@ func testLightVsFastVsFullChainHeads(t *testing.T, scheme string) { // Tests that chain reorganisations handle transaction removals and reinsertions. func TestChainTxReorgs(t *testing.T) { testChainTxReorgs(t, rawdb.HashScheme) - testChainTxReorgs(t, rawdb.PathScheme) + // testChainTxReorgs(t, rawdb.PathScheme) } func testChainTxReorgs(t *testing.T, scheme string) { @@ -1199,7 +1199,7 @@ func testChainTxReorgs(t *testing.T, scheme string) { func TestLogReorgs(t *testing.T) { testLogReorgs(t, rawdb.HashScheme) - testLogReorgs(t, rawdb.PathScheme) + // testLogReorgs(t, rawdb.PathScheme) } func testLogReorgs(t *testing.T, scheme string) { @@ -1259,7 +1259,7 @@ var logCode = common.Hex2Bytes("60606040525b7f24ec1d3ff24c2f6ff210738839dbc339cd // when the chain reorganizes. func TestLogRebirth(t *testing.T) { testLogRebirth(t, rawdb.HashScheme) - testLogRebirth(t, rawdb.PathScheme) + // testLogRebirth(t, rawdb.PathScheme) } func testLogRebirth(t *testing.T, scheme string) { @@ -1341,7 +1341,7 @@ func testLogRebirth(t *testing.T, scheme string) { // when a side chain containing log events overtakes the canonical chain. func TestSideLogRebirth(t *testing.T) { testSideLogRebirth(t, rawdb.HashScheme) - testSideLogRebirth(t, rawdb.PathScheme) + // testSideLogRebirth(t, rawdb.PathScheme) } func testSideLogRebirth(t *testing.T, scheme string) { @@ -1436,7 +1436,7 @@ func checkLogEvents(t *testing.T, logsCh <-chan []*types.Log, rmLogsCh <-chan Re func TestReorgSideEvent(t *testing.T) { testReorgSideEvent(t, rawdb.HashScheme) - testReorgSideEvent(t, rawdb.PathScheme) + // testReorgSideEvent(t, rawdb.PathScheme) } func testReorgSideEvent(t *testing.T, scheme string) { @@ -1521,7 +1521,7 @@ done: // Tests if the canonical block can be fetched from the database during chain insertion. func TestCanonicalBlockRetrieval(t *testing.T) { testCanonicalBlockRetrieval(t, rawdb.HashScheme) - testCanonicalBlockRetrieval(t, rawdb.PathScheme) + // testCanonicalBlockRetrieval(t, rawdb.PathScheme) } func testCanonicalBlockRetrieval(t *testing.T, scheme string) { @@ -1571,7 +1571,7 @@ func testCanonicalBlockRetrieval(t *testing.T, scheme string) { } func TestEIP155Transition(t *testing.T) { testEIP155Transition(t, rawdb.HashScheme) - testEIP155Transition(t, rawdb.PathScheme) + // testEIP155Transition(t, rawdb.PathScheme) } func testEIP155Transition(t *testing.T, scheme string) { @@ -1685,7 +1685,7 @@ func testEIP155Transition(t *testing.T, scheme string) { } func TestEIP161AccountRemoval(t *testing.T) { testEIP161AccountRemoval(t, rawdb.HashScheme) - testEIP161AccountRemoval(t, rawdb.PathScheme) + // testEIP161AccountRemoval(t, rawdb.PathScheme) } func testEIP161AccountRemoval(t *testing.T, scheme string) { @@ -1760,7 +1760,7 @@ func testEIP161AccountRemoval(t *testing.T, scheme string) { // https://github.com/ethereum/go-ethereum/pull/15941 func TestBlockchainHeaderchainReorgConsistency(t *testing.T) { testBlockchainHeaderchainReorgConsistency(t, rawdb.HashScheme) - testBlockchainHeaderchainReorgConsistency(t, rawdb.PathScheme) + // testBlockchainHeaderchainReorgConsistency(t, rawdb.PathScheme) } func testBlockchainHeaderchainReorgConsistency(t *testing.T, scheme string) { @@ -1856,7 +1856,7 @@ func TestTrieForkGC(t *testing.T) { // forking point is not available any more. func TestLargeReorgTrieGC(t *testing.T) { testLargeReorgTrieGC(t, rawdb.HashScheme) - testLargeReorgTrieGC(t, rawdb.PathScheme) + // testLargeReorgTrieGC(t, rawdb.PathScheme) } func testLargeReorgTrieGC(t *testing.T, scheme string) { @@ -1865,6 +1865,10 @@ func testLargeReorgTrieGC(t *testing.T, scheme string) { genesis := &Genesis{ Config: params.TestChainConfig, BaseFee: big.NewInt(params.InitialBaseFee), + Alloc: GenesisAlloc{ + common.Address{2}: {Balance: big.NewInt(1)}, + common.Address{3}: {Balance: big.NewInt(1)}, + }, } genDb, shared, _ := GenerateChainWithGenesis(genesis, engine, 64, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) }) original, _ := GenerateChain(genesis.Config, shared[len(shared)-1], engine, genDb, 2*TriesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) }) @@ -1925,7 +1929,7 @@ func testLargeReorgTrieGC(t *testing.T, scheme string) { func TestBlockchainRecovery(t *testing.T) { testBlockchainRecovery(t, rawdb.HashScheme) - testBlockchainRecovery(t, rawdb.PathScheme) + // testBlockchainRecovery(t, rawdb.PathScheme) } func testBlockchainRecovery(t *testing.T, scheme string) { @@ -1981,7 +1985,7 @@ func testBlockchainRecovery(t *testing.T, scheme string) { // This test checks that InsertReceiptChain will roll back correctly when attempting to insert a side chain. func TestInsertReceiptChainRollback(t *testing.T) { testInsertReceiptChainRollback(t, rawdb.HashScheme) - testInsertReceiptChainRollback(t, rawdb.PathScheme) + // testInsertReceiptChainRollback(t, rawdb.PathScheme) } func testInsertReceiptChainRollback(t *testing.T, scheme string) { @@ -2063,7 +2067,7 @@ func testInsertReceiptChainRollback(t *testing.T, scheme string) { // - https://github.com/ethereum/go-ethereum/pull/18988 func TestLowDiffLongChain(t *testing.T) { testLowDiffLongChain(t, rawdb.HashScheme) - testLowDiffLongChain(t, rawdb.PathScheme) + // testLowDiffLongChain(t, rawdb.PathScheme) } func testLowDiffLongChain(t *testing.T, scheme string) { @@ -2270,15 +2274,15 @@ func TestPrunedImportSideWithMerging(t *testing.T) { func TestInsertKnownHeaders(t *testing.T) { testInsertKnownChainData(t, "headers", rawdb.HashScheme) - testInsertKnownChainData(t, "headers", rawdb.PathScheme) + // testInsertKnownChainData(t, "headers", rawdb.PathScheme) } func TestInsertKnownReceiptChain(t *testing.T) { testInsertKnownChainData(t, "receipts", rawdb.HashScheme) - testInsertKnownChainData(t, "receipts", rawdb.PathScheme) + // testInsertKnownChainData(t, "receipts", rawdb.PathScheme) } func TestInsertKnownBlocks(t *testing.T) { testInsertKnownChainData(t, "blocks", rawdb.HashScheme) - testInsertKnownChainData(t, "blocks", rawdb.PathScheme) + // testInsertKnownChainData(t, "blocks", rawdb.PathScheme) } func testInsertKnownChainData(t *testing.T, typ string, scheme string) { @@ -2636,7 +2640,7 @@ func getLongAndShortChains(scheme string) (*BlockChain, []*types.Block, []*types // 4. The forked block should still be retrievable by hash func TestReorgToShorterRemovesCanonMapping(t *testing.T) { testReorgToShorterRemovesCanonMapping(t, rawdb.HashScheme) - testReorgToShorterRemovesCanonMapping(t, rawdb.PathScheme) + // testReorgToShorterRemovesCanonMapping(t, rawdb.PathScheme) } func testReorgToShorterRemovesCanonMapping(t *testing.T, scheme string) { @@ -2679,7 +2683,7 @@ func testReorgToShorterRemovesCanonMapping(t *testing.T, scheme string) { // imports -- that is, for fast sync func TestReorgToShorterRemovesCanonMappingHeaderChain(t *testing.T) { testReorgToShorterRemovesCanonMappingHeaderChain(t, rawdb.HashScheme) - testReorgToShorterRemovesCanonMappingHeaderChain(t, rawdb.PathScheme) + // testReorgToShorterRemovesCanonMappingHeaderChain(t, rawdb.PathScheme) } func testReorgToShorterRemovesCanonMappingHeaderChain(t *testing.T, scheme string) { @@ -2827,7 +2831,7 @@ func TestTransactionIndices(t *testing.T) { func TestSkipStaleTxIndicesInSnapSync(t *testing.T) { testSkipStaleTxIndicesInSnapSync(t, rawdb.HashScheme) - testSkipStaleTxIndicesInSnapSync(t, rawdb.PathScheme) + // testSkipStaleTxIndicesInSnapSync(t, rawdb.PathScheme) } func testSkipStaleTxIndicesInSnapSync(t *testing.T, scheme string) { @@ -3024,7 +3028,7 @@ func BenchmarkBlockChain_1x1000Executions(b *testing.B) { // 3. The blocks fetched are all known and canonical blocks func TestSideImportPrunedBlocks(t *testing.T) { testSideImportPrunedBlocks(t, rawdb.HashScheme) - testSideImportPrunedBlocks(t, rawdb.PathScheme) + // testSideImportPrunedBlocks(t, rawdb.PathScheme) } func testSideImportPrunedBlocks(t *testing.T, scheme string) { @@ -3082,7 +3086,7 @@ func testSideImportPrunedBlocks(t *testing.T, scheme string) { // first, but the journal wiped the entire state object on create-revert. func TestDeleteCreateRevert(t *testing.T) { testDeleteCreateRevert(t, rawdb.HashScheme) - testDeleteCreateRevert(t, rawdb.PathScheme) + // testDeleteCreateRevert(t, rawdb.PathScheme) } func testDeleteCreateRevert(t *testing.T, scheme string) { @@ -3156,6 +3160,7 @@ func testDeleteCreateRevert(t *testing.T, scheme string) { // Expected outcome is that _all_ slots are cleared from A, due to the selfdestruct, // and then the new slots exist func TestDeleteRecreateSlots(t *testing.T) { + t.Skip("Scroll doesn't support SELFDESTRUCT") testDeleteRecreateSlots(t, rawdb.HashScheme) testDeleteRecreateSlots(t, rawdb.PathScheme) } @@ -3284,6 +3289,7 @@ func testDeleteRecreateSlots(t *testing.T, scheme string) { // regular value-transfer // Expected outcome is that _all_ slots are cleared from A func TestDeleteRecreateAccount(t *testing.T) { + t.Skip("Scroll doesn't support SELFDESTRUCT") testDeleteRecreateAccount(t, rawdb.HashScheme) testDeleteRecreateAccount(t, rawdb.PathScheme) } @@ -3362,6 +3368,7 @@ func testDeleteRecreateAccount(t *testing.T, scheme string) { // Expected outcome is that _all_ slots are cleared from A, due to the selfdestruct, // and then the new slots exist func TestDeleteRecreateSlotsAcrossManyBlocks(t *testing.T) { + t.Skip("Scroll doesn't support SELFDESTRUCT") testDeleteRecreateSlotsAcrossManyBlocks(t, rawdb.HashScheme) testDeleteRecreateSlotsAcrossManyBlocks(t, rawdb.PathScheme) } @@ -3569,7 +3576,7 @@ func testDeleteRecreateSlotsAcrossManyBlocks(t *testing.T, scheme string) { func TestInitThenFailCreateContract(t *testing.T) { testInitThenFailCreateContract(t, rawdb.HashScheme) - testInitThenFailCreateContract(t, rawdb.PathScheme) + // testInitThenFailCreateContract(t, rawdb.PathScheme) } func testInitThenFailCreateContract(t *testing.T, scheme string) { @@ -3684,7 +3691,7 @@ func testInitThenFailCreateContract(t *testing.T, scheme string) { // correctly. func TestEIP2718Transition(t *testing.T) { testEIP2718Transition(t, rawdb.HashScheme) - testEIP2718Transition(t, rawdb.PathScheme) + // testEIP2718Transition(t, rawdb.PathScheme) } func testEIP2718Transition(t *testing.T, scheme string) { @@ -3766,7 +3773,7 @@ func testEIP2718Transition(t *testing.T, scheme string) { // 6. Legacy transaction behave as expected (e.g. gasPrice = gasFeeCap = gasTipCap). func TestEIP1559Transition(t *testing.T) { testEIP1559Transition(t, rawdb.HashScheme) - testEIP1559Transition(t, rawdb.PathScheme) + // testEIP1559Transition(t, rawdb.PathScheme) } func testEIP1559Transition(t *testing.T, scheme string) { @@ -3912,7 +3919,7 @@ func testEIP1559Transition(t *testing.T, scheme string) { // It expects the state is recovered and all relevant chain markers are set correctly. func TestSetCanonical(t *testing.T) { testSetCanonical(t, rawdb.HashScheme) - testSetCanonical(t, rawdb.PathScheme) + // testSetCanonical(t, rawdb.PathScheme) } func testSetCanonical(t *testing.T, scheme string) { @@ -3999,7 +4006,7 @@ func testSetCanonical(t *testing.T, scheme string) { // correctly in case reorg is called. func TestCanonicalHashMarker(t *testing.T) { testCanonicalHashMarker(t, rawdb.HashScheme) - testCanonicalHashMarker(t, rawdb.PathScheme) + // testCanonicalHashMarker(t, rawdb.PathScheme) } func testCanonicalHashMarker(t *testing.T, scheme string) { diff --git a/core/chain_makers.go b/core/chain_makers.go index 47979b030d1c..31f70a321d6d 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -354,11 +354,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse return nil, nil } // Forcibly use hash-based state scheme for retaining all nodes in disk. - trieConfig := trie.HashDefaults - if config.Scroll.ZktrieEnabled() { - trieConfig = trie.HashDefaultsWithZktrie - } - triedb := trie.NewDatabase(db, trieConfig) + triedb := trie.NewDatabase(db, trie.HashDefaults) defer triedb.Close() for i := 0; i < n; i++ { @@ -379,11 +375,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse // then generate chain on top. func GenerateChainWithGenesis(genesis *Genesis, engine consensus.Engine, n int, gen func(int, *BlockGen)) (ethdb.Database, []*types.Block, []types.Receipts) { db := rawdb.NewMemoryDatabase() - trieConfig := trie.HashDefaults - if genesis.Config != nil && genesis.Config.Scroll.ZktrieEnabled() { - trieConfig = trie.HashDefaultsWithZktrie - } - triedb := trie.NewDatabase(db, trieConfig) + triedb := trie.NewDatabase(db, trie.HashDefaults) defer triedb.Close() _, err := genesis.Commit(db, triedb) if err != nil { diff --git a/core/chain_makers_test.go b/core/chain_makers_test.go index dd2164a33e23..b1f6ba9be68e 100644 --- a/core/chain_makers_test.go +++ b/core/chain_makers_test.go @@ -33,6 +33,7 @@ import ( ) func TestGeneratePOSChain(t *testing.T) { + t.Skip("POS is out of scope") var ( keyHex = "9c647b8b7c4e7c3490668fb6c11473619db80c93704c70893d3813af4090c39c" key, _ = crypto.HexToECDSA(keyHex) diff --git a/core/forkid/forkid_test.go b/core/forkid/forkid_test.go index 54d7bff8ba5d..9dd1f3015c87 100644 --- a/core/forkid/forkid_test.go +++ b/core/forkid/forkid_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2019 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/genesis.go b/core/genesis.go index a939e0174e34..210b48cf29eb 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -125,11 +125,7 @@ func (ga *GenesisAlloc) UnmarshalJSON(data []byte) error { func (ga *GenesisAlloc) hash(isUsingZktrie bool) (common.Hash, error) { // Create an ephemeral in-memory database for computing hash, // all the derived states will be discarded to not pollute disk. - trieConfig := trie.HashDefaults - if isUsingZktrie { - trieConfig = trie.HashDefaultsWithZktrie - } - db := state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), trieConfig) + db := state.NewDatabase(rawdb.NewMemoryDatabase()) statedb, err := state.New(types.EmptyRootHash, db, nil) if err != nil { return common.Hash{}, err @@ -292,10 +288,6 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, triedb *trie.Database, gen } else { log.Info("Writing custom genesis block") } - if genesis.Config.Scroll.ZktrieEnabled() { // genesis.Config must be not nil atm - // overwrite triedb IsUsingZktrie config to be safe - triedb.SetIsUsingZktrie(genesis.Config.Scroll.ZktrieEnabled()) - } block, err := genesis.Commit(db, triedb) if err != nil { return genesis.Config, common.Hash{}, err @@ -309,13 +301,6 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, triedb *trie.Database, gen // in this case. header := rawdb.ReadHeader(db, stored, 0) storedcfg := rawdb.ReadChainConfig(db, stored) - if genesis != nil { // genesis.Config must be not nil atm - // overwrite triedb IsUsingZktrie config to be safe - triedb.SetIsUsingZktrie(genesis.Config.Scroll.ZktrieEnabled()) - } else if storedcfg != nil && storedcfg.Scroll.ZktrieEnabled() { - // overwrite triedb IsUsingZktrie config to be safe - triedb.SetIsUsingZktrie(storedcfg.Scroll.ZktrieEnabled()) - } // if header.Root != types.EmptyRootHash && !triedb.Initialized(header.Hash()) { if _, err := state.New(header.Root, state.NewDatabaseWithNodeDB(db, triedb), nil); err != nil { if genesis == nil { @@ -496,9 +481,6 @@ func (g *Genesis) Commit(db ethdb.Database, triedb *trie.Database) (*types.Block if config == nil { config = params.AllEthashProtocolChanges } - if config.Scroll.ZktrieEnabled() != triedb.IsUsingZktrie() { - return nil, fmt.Errorf("ZktrieEnabled mismatch. genesis: %v, triedb: %v", g.Config.Scroll.ZktrieEnabled(), triedb.IsUsingZktrie()) - } block := g.ToBlock() if block.Number().Sign() != 0 { diff --git a/core/genesis_test.go b/core/genesis_test.go index 8784bace214e..7c5d6cf3431e 100644 --- a/core/genesis_test.go +++ b/core/genesis_test.go @@ -44,12 +44,12 @@ func TestInvalidCliqueConfig(t *testing.T) { func TestSetupGenesis(t *testing.T) { testSetupGenesis(t, rawdb.HashScheme) - testSetupGenesis(t, rawdb.PathScheme) + // testSetupGenesis(t, rawdb.PathScheme) } func testSetupGenesis(t *testing.T, scheme string) { var ( - customghash = common.HexToHash("0x700380ab70d789c462c4e8f0db082842095321f390d0a3f25f400f0746db32bc") + customghash = common.HexToHash("0xc96ed5df64e683d5af1b14ec67126e31b914ca828021c330efa00572a61ede8f") customg = Genesis{ Config: ¶ms.ChainConfig{HomesteadBlock: big.NewInt(3)}, Alloc: GenesisAlloc{ @@ -189,11 +189,7 @@ func TestGenesisHashes(t *testing.T) { } { // Test via MustCommit db := rawdb.NewMemoryDatabase() - trieConfig := trie.HashDefaults - if c.genesis.Config.Scroll.ZktrieEnabled() { - trieConfig = trie.HashDefaultsWithZktrie - } - if have := c.genesis.MustCommit(db, trie.NewDatabase(db, trieConfig)).Hash(); have != c.want { + if have := c.genesis.MustCommit(db, trie.NewDatabase(db, trie.HashDefaults)).Hash(); have != c.want { t.Errorf("case: %d a), want: %s, got: %s", i, c.want.Hex(), have.Hex()) } // Test via ToBlock diff --git a/core/state/database.go b/core/state/database.go index 1b7cc0e23006..7dd410264eb8 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -170,13 +170,6 @@ type cachingDB struct { // OpenTrie opens the main account trie at a specific root hash. func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { - if db.triedb.IsUsingZktrie() { - tr, err := trie.NewZkTrie(root, trie.NewZktrieDatabaseFromTriedb(db.triedb)) - if err != nil { - return nil, err - } - return tr, nil - } tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb) if err != nil { return nil, err @@ -186,13 +179,6 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { // OpenStorageTrie opens the storage trie of an account. func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (Trie, error) { - if db.triedb.IsUsingZktrie() { - tr, err := trie.NewZkTrie(root, trie.NewZktrieDatabaseFromTriedb(db.triedb)) - if err != nil { - return nil, err - } - return tr, nil - } tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, crypto.Keccak256Hash(address.Bytes()), root), db.triedb) if err != nil { return nil, err @@ -205,8 +191,6 @@ func (db *cachingDB) CopyTrie(t Trie) Trie { switch t := t.(type) { case *trie.StateTrie: return t.Copy() - case *trie.ZkTrie: - return t.Copy() default: panic(fmt.Errorf("unknown trie type %T", t)) } diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go index 29981d3c3939..b1c39386d015 100644 --- a/core/state/iterator_test.go +++ b/core/state/iterator_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2016 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state/snapshot/difflayer_test.go b/core/state/snapshot/difflayer_test.go index 393de837e859..a29e95f02a01 100644 --- a/core/state/snapshot/difflayer_test.go +++ b/core/state/snapshot/difflayer_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2019 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state/snapshot/disklayer_test.go b/core/state/snapshot/disklayer_test.go index ee8182f95137..fd9d50bd28b5 100644 --- a/core/state/snapshot/disklayer_test.go +++ b/core/state/snapshot/disklayer_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2019 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go index 9bc87c180d0d..9c963955f339 100644 --- a/core/state/snapshot/generate_test.go +++ b/core/state/snapshot/generate_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2019 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state/snapshot/holdable_iterator_test.go b/core/state/snapshot/holdable_iterator_test.go index d699744ca9bd..31609912e05b 100644 --- a/core/state/snapshot/holdable_iterator_test.go +++ b/core/state/snapshot/holdable_iterator_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2022 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state/snapshot/iterator_test.go b/core/state/snapshot/iterator_test.go index 18490392bb75..b708c7637b94 100644 --- a/core/state/snapshot/iterator_test.go +++ b/core/state/snapshot/iterator_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2019 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go index a068d85a6fa7..fc313cfdea47 100644 --- a/core/state/snapshot/snapshot_test.go +++ b/core/state/snapshot/snapshot_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2017 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state/state_object.go b/core/state/state_object.go index 6fa3fac156bd..d4f6fba8dcf0 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -143,7 +143,7 @@ func (s *stateObject) touch() { func (s *stateObject) getTrie() (Trie, error) { if s.trie == nil { // Try fetching from prefetcher first - if s.data.Root != s.db.db.TrieDB().EmptyRoot() && s.db.prefetcher != nil { + if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil { // When the miner is creating the pending state, there is no prefetcher s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) } @@ -199,15 +199,7 @@ func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { if metrics.EnabledExpensive { s.db.SnapshotStorageReads += time.Since(start) } - if s.db.db.TrieDB().IsUsingZktrie() { - value = common.BytesToHash(enc) - } else if len(enc) > 0 { - _, content, _, err := rlp.Split(enc) - if err != nil { - s.db.setError(err) - } - value.SetBytes(content) - } + value = common.BytesToHash(enc) } // If the snapshot is unavailable or reading from it fails, load from the database. if s.db.snap == nil || err != nil { @@ -261,7 +253,7 @@ func (s *stateObject) finalise(prefetch bool) { slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure } } - if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != s.db.db.TrieDB().EmptyRoot() { + if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash { s.db.prefetcher.prefetch(s.addrHash, s.data.Root, s.address, slotsToPrefetch) } if len(s.dirtyStorage) > 0 { @@ -316,12 +308,7 @@ func (s *stateObject) updateTrie() (Trie, error) { s.db.StorageDeleted += 1 } else { trimmed := common.TrimLeftZeroes(value[:]) - if s.db.db.TrieDB().IsUsingZktrie() { - encoded = common.CopyBytes(value[:]) - } else { - // Encoding []byte cannot fail, ok to ignore the error. - encoded, _ = rlp.EncodeToBytes(trimmed) - } + encoded = common.CopyBytes(value[:]) if err := tr.UpdateStorage(s.address, key[:], trimmed); err != nil { s.db.setError(err) return nil, err diff --git a/core/state/state_prove.go b/core/state/state_prove.go deleted file mode 100644 index 5fc176023f5e..000000000000 --- a/core/state/state_prove.go +++ /dev/null @@ -1,85 +0,0 @@ -package state - -import ( - "fmt" - - zkt "github.com/scroll-tech/zktrie/types" - - "github.com/scroll-tech/go-ethereum/common" - "github.com/scroll-tech/go-ethereum/crypto" - "github.com/scroll-tech/go-ethereum/ethdb" - zktrie "github.com/scroll-tech/go-ethereum/trie" - "github.com/scroll-tech/go-ethereum/trie/zkproof" -) - -type TrieProve interface { - Prove(key []byte, proofDb ethdb.KeyValueWriter) error -} - -type ZktrieProofTracer struct { - *zktrie.ProofTracer -} - -// MarkDeletion overwrite the underlayer method with secure key -func (t ZktrieProofTracer) MarkDeletion(key common.Hash) { - key_s, _ := zkt.ToSecureKeyBytes(key.Bytes()) - t.ProofTracer.MarkDeletion(key_s.Bytes()) -} - -// Merge overwrite underlayer method with proper argument -func (t ZktrieProofTracer) Merge(another ZktrieProofTracer) { - t.ProofTracer.Merge(another.ProofTracer) -} - -func (t ZktrieProofTracer) Available() bool { - return t.ProofTracer != nil -} - -// NewProofTracer is not in Db interface and used explictily for reading proof in storage trie (not updated by the dirty value) -func (s *StateDB) NewProofTracer(trieS Trie) ZktrieProofTracer { - if s.IsUsingZktrie() { - zkTrie := trieS.(*zktrie.ZkTrie) - if zkTrie == nil { - panic("unexpected trie type for zktrie") - } - return ZktrieProofTracer{zkTrie.NewProofTracer()} - } - return ZktrieProofTracer{} -} - -// GetStorageTrieForProof is not in Db interface and used explictily for reading proof in storage trie (not updated by the dirty value) -func (s *StateDB) GetStorageTrieForProof(addr common.Address) (Trie, error) { - // try the trie in stateObject first, else we would create one - stateObject := s.getStateObject(addr) - if stateObject == nil { - // still return a empty trie - dummy_trie, _ := s.db.OpenStorageTrie(s.originalRoot, addr, common.Hash{}) - return dummy_trie, nil - } - - trie := stateObject.trie - var err error - if trie == nil { - // use a new, temporary trie - trie, err = s.db.OpenStorageTrie(s.originalRoot, stateObject.address, stateObject.data.Root) - if err != nil { - return nil, fmt.Errorf("can't create storage trie on root %s: %v ", stateObject.data.Root, err) - } - } - - return trie, nil -} - -// GetSecureTrieProof handle any interface with Prove (should be a Trie in most case) and -// deliver the proof in bytes -func (s *StateDB) GetSecureTrieProof(trieProve TrieProve, key common.Hash) ([][]byte, error) { - var proof zkproof.ProofList - var err error - if s.IsUsingZktrie() { - key_s, _ := zkt.ToSecureKeyBytes(key.Bytes()) - err = trieProve.Prove(key_s.Bytes(), &proof) - } else { - err = trieProve.Prove(crypto.Keccak256(key.Bytes()), &proof) - } - return proof, err -} diff --git a/core/state/state_test.go b/core/state/state_test.go index 4c19e206d62a..d0ded82b076f 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -41,6 +41,7 @@ func newStateEnv() *stateEnv { } func TestDump(t *testing.T) { + t.Skip("Due to ZkTrie not supporting iterators") db := rawdb.NewMemoryDatabase() tdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true}) sdb, _ := New(types.EmptyRootHash, tdb, nil) @@ -101,6 +102,7 @@ func TestDump(t *testing.T) { } func TestIterativeDump(t *testing.T) { + t.Skip("Due to ZkTrie not supporting iterators") db := rawdb.NewMemoryDatabase() tdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true}) sdb, _ := New(types.EmptyRootHash, tdb, nil) diff --git a/core/state/statedb.go b/core/state/statedb.go index 2e74c638b57e..374b47769b16 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -23,8 +23,6 @@ import ( "sort" "time" - zkt "github.com/scroll-tech/zktrie/types" - "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/core/rawdb" "github.com/scroll-tech/go-ethereum/core/state/snapshot" @@ -209,10 +207,6 @@ func (s *StateDB) Error() error { return s.dbErr } -func (s *StateDB) IsUsingZktrie() bool { - return s.db.TrieDB().IsUsingZktrie() -} - func (s *StateDB) AddLog(log *types.Log) { s.journal.append(addLogChange{txhash: s.thash}) @@ -363,11 +357,7 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { // GetProof returns the Merkle proof for a given account. func (s *StateDB) GetProof(addr common.Address) ([][]byte, error) { - if s.IsUsingZktrie() { - addr_s, _ := zkt.ToSecureKeyBytes(addr.Bytes()) - return s.GetProofByHash(common.BytesToHash(addr_s.Bytes())) - } - return s.GetProofByHash(crypto.Keccak256Hash(addr.Bytes())) + return s.GetProofByHash(common.BytesToHash(addr.Bytes())) } // GetProofByHash returns the Merkle proof for a given account. diff --git a/core/state/statedb_fuzz_test.go b/core/state/statedb_fuzz_test.go index 802b14587268..99072d8e61e7 100644 --- a/core/state/statedb_fuzz_test.go +++ b/core/state/statedb_fuzz_test.go @@ -381,6 +381,7 @@ func (test *stateTest) verify(root common.Hash, next common.Hash, db *trie.Datab } func TestStateChanges(t *testing.T) { + t.Skip("This test doesn't support ZkTrie yet") config := &quick.Config{MaxCount: 1000} err := quick.Check((*stateTest).run, config) if cerr, ok := err.(*quick.CheckError); ok { diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index b81de73f1918..42d9a21c6f3e 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -30,6 +30,7 @@ import ( "testing" "testing/quick" + "github.com/holiman/uint256" "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/core/rawdb" "github.com/scroll-tech/go-ethereum/core/state/snapshot" @@ -40,7 +41,6 @@ import ( "github.com/scroll-tech/go-ethereum/trie/triedb/hashdb" "github.com/scroll-tech/go-ethereum/trie/triedb/pathdb" "github.com/scroll-tech/go-ethereum/trie/trienode" - "github.com/holiman/uint256" ) // Tests that updating a state trie does not leak any database writes prior to @@ -796,7 +796,7 @@ func TestDeleteCreateRevert(t *testing.T) { // If we are missing trie nodes, we should not continue writing to the trie func TestMissingTrieNodes(t *testing.T) { testMissingTrieNodes(t, rawdb.HashScheme) - testMissingTrieNodes(t, rawdb.PathScheme) + // testMissingTrieNodes(t, rawdb.PathScheme) } func testMissingTrieNodes(t *testing.T, scheme string) { @@ -1051,7 +1051,7 @@ func TestFlushOrderDataLoss(t *testing.T) { t.Fatalf("failed to commit state trie: %v", err) } triedb.Reference(root, common.Hash{}) - if err := triedb.Cap(1024); err != nil { + if err := triedb.Cap(128); err != nil { t.Fatalf("failed to cap trie dirty cache: %v", err) } if err := triedb.Commit(root, false); err != nil { @@ -1106,6 +1106,7 @@ func TestStateDBTransientStorage(t *testing.T) { } func TestResetObject(t *testing.T) { + t.Skip("Snapshot doesn't support ZkTrie") var ( disk = rawdb.NewMemoryDatabase() tdb = trie.NewDatabase(disk, nil) @@ -1140,6 +1141,7 @@ func TestResetObject(t *testing.T) { } func TestDeleteStorage(t *testing.T) { + t.Skip("Snapshot doesn't support ZkTrie") var ( disk = rawdb.NewMemoryDatabase() tdb = trie.NewDatabase(disk, nil) diff --git a/core/state/sync_test.go b/core/state/sync_test.go index c842292191b6..f485c122aa7a 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2015 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/core/state_processor.go b/core/state_processor.go index 47c57f8c53de..f26cfcf3c3e6 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -144,7 +144,7 @@ func applyTransaction(msg *Message, config *params.ChainConfig, gp *GasPool, sta // Apply the transaction to the current state (included in the env). applyMessageStartTime := time.Now() result, err := ApplyMessage(evm, msg, gp, l1DataFee) - if evm.Config.Tracer != nil && evm.Config.Tracer.IsDebug() { + if evm.Config.Tracer != nil { if erroringTracer, ok := evm.Config.Tracer.(interface{ Error() error }); ok { err = errors.Join(err, erroringTracer.Error()) } diff --git a/core/state_processor_test.go b/core/state_processor_test.go index e8300127c9b8..1f37a077995f 100644 --- a/core/state_processor_test.go +++ b/core/state_processor_test.go @@ -200,7 +200,7 @@ func TestStateProcessorErrors(t *testing.T) { txs: []*types.Transaction{ mkDynamicTx(0, common.Address{}, params.TxGas, big.NewInt(0), big.NewInt(0)), }, - want: "could not apply tx 0 [0xc4ab868fef0c82ae0387b742aee87907f2d0fc528fc6ea0a021459fb0fc4a4a8]: max fee per gas less than block base fee: address 0x71562b71999873DB5b286dF957af199Ec94617F7, maxFeePerGas: 0 baseFee: 38100000", + want: "could not apply tx 0 [0xc4ab868fef0c82ae0387b742aee87907f2d0fc528fc6ea0a021459fb0fc4a4a8]: max fee per gas less than block base fee: address 0x71562b71999873DB5b286dF957af199Ec94617F7, maxFeePerGas: 0 baseFee: 39370000", }, { // ErrTipVeryHigh txs: []*types.Transaction{ @@ -241,19 +241,19 @@ func TestStateProcessorErrors(t *testing.T) { txs: []*types.Transaction{ mkDynamicCreationTx(0, 500000, common.Big0, big.NewInt(params.InitialBaseFee), tooBigInitCode[:]), }, - want: "could not apply tx 0 [0xa31de6e26bd5ffba0ca91a2bc29fc2eaad6a6cfc5ad9ab6ffb69cac121e0125c]: max initcode size exceeded: code size 49153 limit 49152", + want: "could not apply tx 0 [0xd491405f06c92d118dd3208376fcee18a57c54bc52063ee4a26b1cf296857c25]: max initcode size exceeded: code size 49153 limit 49152", }, { // ErrIntrinsicGas: Not enough gas to cover init code txs: []*types.Transaction{ mkDynamicCreationTx(0, 54299, common.Big0, big.NewInt(params.InitialBaseFee), make([]byte, 320)), }, - want: "could not apply tx 0 [0xf36b7d68cf239f956f7c36be26688a97aaa317ea5f5230d109bb30dbc8598ccb]: intrinsic gas too low: have 54299, want 54300", + want: "could not apply tx 0 [0xfd49536a9b323769d8472fcb3ebb3689b707a349379baee3e2ee3fe7baae06a1]: intrinsic gas too low: have 54299, want 54300", }, { // ErrBlobFeeCapTooLow txs: []*types.Transaction{ mkBlobTx(0, common.Address{}, params.TxGas, big.NewInt(1), big.NewInt(1), []common.Hash{(common.Hash{1})}), }, - want: "could not apply tx 0 [0x6c11015985ce82db691d7b2d017acda296db88b811c3c60dc71449c76256c716]: max fee per gas less than block base fee: address 0x71562b71999873DB5b286dF957af199Ec94617F7, maxFeePerGas: 1 baseFee: 38100000", + want: "could not apply tx 0 [0x6c11015985ce82db691d7b2d017acda296db88b811c3c60dc71449c76256c716]: max fee per gas less than block base fee: address 0x71562b71999873DB5b286dF957af199Ec94617F7, maxFeePerGas: 1 baseFee: 39370000", }, } { block := GenerateBadBlock(gspec.ToBlock(), beacon.New(ethash.NewFaker()), tt.txs, gspec.Config) diff --git a/core/txpool/blobpool/blobpool_test.go b/core/txpool/blobpool/blobpool_test.go index 0c5a03e91f19..03c1acacb13e 100644 --- a/core/txpool/blobpool/blobpool_test.go +++ b/core/txpool/blobpool/blobpool_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2023 The go-ethereum Authors // This file is part of the go-ethereum library. // @@ -29,6 +32,8 @@ import ( "testing" "time" + "github.com/holiman/billy" + "github.com/holiman/uint256" "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/consensus/misc/eip1559" "github.com/scroll-tech/go-ethereum/consensus/misc/eip4844" @@ -43,8 +48,6 @@ import ( "github.com/scroll-tech/go-ethereum/log" "github.com/scroll-tech/go-ethereum/params" "github.com/scroll-tech/go-ethereum/rlp" - "github.com/holiman/billy" - "github.com/holiman/uint256" ) var ( diff --git a/core/txpool/legacypool/legacypool_test.go b/core/txpool/legacypool/legacypool_test.go index c4b6633db95a..4d05077c7512 100644 --- a/core/txpool/legacypool/legacypool_test.go +++ b/core/txpool/legacypool/legacypool_test.go @@ -1505,6 +1505,7 @@ func TestRepricing(t *testing.T) { // // Note, local transactions are never allowed to be dropped. func TestRepricingDynamicFee(t *testing.T) { + t.Skip("broken by https://github.com/scroll-tech/go-ethereum/pull/964/files") t.Parallel() // Create the pool to test the pricing enforcement with diff --git a/core/types/hashes.go b/core/types/hashes.go index aa37e6d1e2f5..1f33d8853092 100644 --- a/core/types/hashes.go +++ b/core/types/hashes.go @@ -23,14 +23,8 @@ import ( ) var ( - // EmptyZkTrieRootHash is the known root hash of an empty zktrie. - EmptyZkTrieRootHash = common.Hash{} - - // EmptyLegacyTrieRootHash is the known root hash of an empty legacy trie. - EmptyLegacyTrieRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - // EmptyRootHash is the known root hash of an empty trie. - EmptyRootHash = EmptyZkTrieRootHash + EmptyRootHash = common.Hash{} // EmptyUncleHash is the known hash of the empty uncle set. EmptyUncleHash = rlpHash([]*Header(nil)) // 1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347 @@ -45,13 +39,13 @@ var ( EmptyPoseidonCodeHash = codehash.EmptyPoseidonCodeHash // EmptyTxsHash is the known hash of the empty transaction set. - EmptyTxsHash = EmptyLegacyTrieRootHash + EmptyTxsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") // EmptyReceiptsHash is the known hash of the empty receipt set. - EmptyReceiptsHash = EmptyLegacyTrieRootHash + EmptyReceiptsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") // EmptyWithdrawalsHash is the known hash of the empty withdrawal set. - EmptyWithdrawalsHash = EmptyLegacyTrieRootHash + EmptyWithdrawalsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") ) // TrieRootHash returns the hash itself if it's non-empty or the predefined diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go index 5ced9b410604..97948b2a1a8a 100644 --- a/core/types/hashing_test.go +++ b/core/types/hashing_test.go @@ -34,6 +34,7 @@ import ( ) func TestDeriveSha(t *testing.T) { + t.Skip("due to legacy trie being deprecated") txs, err := genTxs(0) if err != nil { t.Fatal(err) diff --git a/core/types/state_account_marshalling.go b/core/types/state_account_marshalling.go index db8fbed345c7..72da2656dfaa 100644 --- a/core/types/state_account_marshalling.go +++ b/core/types/state_account_marshalling.go @@ -23,8 +23,6 @@ import ( "github.com/iden3/go-iden3-crypto/utils" - zkt "github.com/scroll-tech/zktrie/types" - "github.com/scroll-tech/go-ethereum/common" ) @@ -44,8 +42,8 @@ var ( // [96:128] KeccakCodeHash // [128:160] PoseidonCodehash // (total 160 bytes) -func (s *StateAccount) MarshalFields() ([]zkt.Byte32, uint32) { - fields := make([]zkt.Byte32, 5) +func (s *StateAccount) MarshalFields() ([][32]byte, uint32) { + fields := make([][32]byte, 5) if s.Balance == nil { panic("StateAccount balance nil") diff --git a/core/types/state_account_marshalling_test.go b/core/types/state_account_marshalling_test.go index b822329a509d..91b1b3871962 100644 --- a/core/types/state_account_marshalling_test.go +++ b/core/types/state_account_marshalling_test.go @@ -38,18 +38,18 @@ func TestMarshalUnmarshalEmptyAccount(t *testing.T) { assert.Equal(t, 5, len(bytes)) assert.Equal(t, uint32(8), flag) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[0].Bytes()) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[1].Bytes()) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[2].Bytes()) - assert.Equal(t, common.Hex2Bytes("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"), bytes[3].Bytes()) - assert.Equal(t, common.Hex2Bytes("2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864"), bytes[4].Bytes()) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[0][:]) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[1][:]) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[2][:]) + assert.Equal(t, common.Hex2Bytes("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"), bytes[3][:]) + assert.Equal(t, common.Hex2Bytes("2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864"), bytes[4][:]) // unmarshal account flatBytes := []byte("") for _, item := range bytes { - flatBytes = append(flatBytes, item.Bytes()...) + flatBytes = append(flatBytes, item[:]...) } acc2, err := UnmarshalStateAccount(flatBytes) @@ -75,11 +75,11 @@ func TestMarshalUnmarshalZeroAccount(t *testing.T) { assert.Equal(t, 5, len(bytes)) assert.Equal(t, uint32(8), flag) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[0].Bytes()) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[1].Bytes()) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[2].Bytes()) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[3].Bytes()) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[4].Bytes()) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[0][:]) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[1][:]) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[2][:]) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[3][:]) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000"), bytes[4][:]) } func TestMarshalUnmarshalNonEmptyAccount(t *testing.T) { @@ -99,18 +99,18 @@ func TestMarshalUnmarshalNonEmptyAccount(t *testing.T) { assert.Equal(t, 5, len(bytes)) assert.Equal(t, uint32(8), flag) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000222222220000000011111111"), bytes[0].Bytes()) - assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000033333333"), bytes[1].Bytes()) - assert.Equal(t, common.Hex2Bytes("123456789abcdef123456789abcdef123456789abcdef123456789abcdef1234"), bytes[2].Bytes()) - assert.Equal(t, common.Hex2Bytes("1111111111111111111111111111111111111111111111111111111111111111"), bytes[3].Bytes()) - assert.Equal(t, common.Hex2Bytes("2222222222222222222222222222222222222222222222222222222222222222"), bytes[4].Bytes()) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000222222220000000011111111"), bytes[0][:]) + assert.Equal(t, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000033333333"), bytes[1][:]) + assert.Equal(t, common.Hex2Bytes("123456789abcdef123456789abcdef123456789abcdef123456789abcdef1234"), bytes[2][:]) + assert.Equal(t, common.Hex2Bytes("1111111111111111111111111111111111111111111111111111111111111111"), bytes[3][:]) + assert.Equal(t, common.Hex2Bytes("2222222222222222222222222222222222222222222222222222222222222222"), bytes[4][:]) // unmarshal account flatBytes := []byte("") for _, item := range bytes { - flatBytes = append(flatBytes, item.Bytes()...) + flatBytes = append(flatBytes, item[:]...) } acc2, err := UnmarshalStateAccount(flatBytes) @@ -138,18 +138,18 @@ func TestMarshalUnmarshalAccountWithMaxFields(t *testing.T) { assert.Equal(t, 5, len(bytes)) assert.Equal(t, uint32(8), flag) - assert.Equal(t, common.Hex2Bytes("00000000000000000000000000000000ffffffffffffffffffffffffffffffff"), bytes[0].Bytes()) - assert.Equal(t, common.Hex2Bytes("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000"), bytes[1].Bytes()) - assert.Equal(t, common.Hex2Bytes("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000"), bytes[2].Bytes()) - assert.Equal(t, common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), bytes[3].Bytes()) - assert.Equal(t, common.Hex2Bytes("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000"), bytes[4].Bytes()) + assert.Equal(t, common.Hex2Bytes("00000000000000000000000000000000ffffffffffffffffffffffffffffffff"), bytes[0][:]) + assert.Equal(t, common.Hex2Bytes("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000"), bytes[1][:]) + assert.Equal(t, common.Hex2Bytes("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000"), bytes[2][:]) + assert.Equal(t, common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), bytes[3][:]) + assert.Equal(t, common.Hex2Bytes("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000"), bytes[4][:]) // unmarshal account flatBytes := []byte("") for _, item := range bytes { - flatBytes = append(flatBytes, item.Bytes()...) + flatBytes = append(flatBytes, item[:]...) } acc2, err := UnmarshalStateAccount(flatBytes) diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go index bc1ed368400b..b90bd558a5e6 100644 --- a/core/vm/gas_table_test.go +++ b/core/vm/gas_table_test.go @@ -133,48 +133,53 @@ var createGasTests = []struct { func TestCreateGas(t *testing.T) { for i, tt := range createGasTests { - var gasUsed = uint64(0) - doCheck := func(testGas int) bool { - address := common.BytesToAddress([]byte("contract")) - statedb, _ := state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) - statedb.CreateAccount(address) - statedb.SetCode(address, hexutil.MustDecode(tt.code)) - statedb.Finalise(true) - vmctx := BlockContext{ - CanTransfer: func(StateDB, common.Address, *big.Int) bool { return true }, - Transfer: func(StateDB, common.Address, common.Address, *big.Int) {}, - BlockNumber: big.NewInt(0), - } - config := Config{} - if tt.eip3860 { - config.ExtraEips = []int{3860} + t.Run("createGasTests", func(t *testing.T) { + if tt.eip3860 == false { + t.Skip("EIP-3860 is enabled by default on Scroll") } + var gasUsed = uint64(0) + doCheck := func(testGas int) bool { + address := common.BytesToAddress([]byte("contract")) + statedb, _ := state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + statedb.CreateAccount(address) + statedb.SetCode(address, hexutil.MustDecode(tt.code)) + statedb.Finalise(true) + vmctx := BlockContext{ + CanTransfer: func(StateDB, common.Address, *big.Int) bool { return true }, + Transfer: func(StateDB, common.Address, common.Address, *big.Int) {}, + BlockNumber: big.NewInt(0), + } + config := Config{} + if tt.eip3860 { + config.ExtraEips = []int{3860} + } - vmenv := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, config) - var startGas = uint64(testGas) - ret, gas, err := vmenv.Call(AccountRef(common.Address{}), address, nil, startGas, new(big.Int)) - if err != nil { - return false + vmenv := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, config) + var startGas = uint64(testGas) + ret, gas, err := vmenv.Call(AccountRef(common.Address{}), address, nil, startGas, new(big.Int)) + if err != nil { + return false + } + gasUsed = startGas - gas + if len(ret) != 32 { + t.Fatalf("test %d: expected 32 bytes returned, have %d", i, len(ret)) + } + if bytes.Equal(ret, make([]byte, 32)) { + // Failure + return false + } + return true } - gasUsed = startGas - gas - if len(ret) != 32 { - t.Fatalf("test %d: expected 32 bytes returned, have %d", i, len(ret)) + minGas := sort.Search(100_000, doCheck) + if uint64(minGas) != tt.minimumGas { + t.Fatalf("test %d: min gas error, want %d, have %d", i, tt.minimumGas, minGas) } - if bytes.Equal(ret, make([]byte, 32)) { - // Failure - return false + // If the deployment succeeded, we also check the gas used + if minGas < 100_000 { + if gasUsed != tt.gasUsed { + t.Errorf("test %d: gas used mismatch: have %v, want %v", i, gasUsed, tt.gasUsed) + } } - return true - } - minGas := sort.Search(100_000, doCheck) - if uint64(minGas) != tt.minimumGas { - t.Fatalf("test %d: min gas error, want %d, have %d", i, tt.minimumGas, minGas) - } - // If the deployment succeeded, we also check the gas used - if minGas < 100_000 { - if gasUsed != tt.gasUsed { - t.Errorf("test %d: gas used mismatch: have %v, want %v", i, gasUsed, tt.gasUsed) - } - } + }) } } diff --git a/core/vm/logger.go b/core/vm/logger.go index 5b588263fd8f..eb56f1f6b1be 100644 --- a/core/vm/logger.go +++ b/core/vm/logger.go @@ -41,6 +41,4 @@ type EVMLogger interface { CaptureState(pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) CaptureStateAfter(pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) CaptureFault(pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) - // Helper function - IsDebug() bool } diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 0ffdead1c34e..adedd4536e24 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -270,6 +270,7 @@ func (d *dummyChain) GetHeader(h common.Hash, n uint64) *types.Header { // TestBlockhash tests the blockhash operation. It's a bit special, since it internally // requires access to a chain reader. func TestBlockhash(t *testing.T) { + t.Skip("Scroll has a different implementation of blockhash") // Current head n := uint64(1000) parentHash := common.Hash{} @@ -672,13 +673,14 @@ func TestColdAccountAccessCost(t *testing.T) { step: 6, want: 2855, }, - { // SELFDESTRUCT(0xff) - code: []byte{ - byte(vm.PUSH1), 0xff, byte(vm.SELFDESTRUCT), - }, - step: 1, - want: 7600, - }, + // disabled due to SELFDESTRUCT not being supported in Scroll + // { // SELFDESTRUCT(0xff) + // code: []byte{ + // byte(vm.PUSH1), 0xff, byte(vm.SELFDESTRUCT), + // }, + // step: 1, + // want: 7600, + // }, } { tracer := logger.NewStructLogger(nil) Execute(tc.code, nil, &Config{ @@ -697,6 +699,7 @@ func TestColdAccountAccessCost(t *testing.T) { } func TestRuntimeJSTracer(t *testing.T) { + t.Skip("disabled due to SELFDESTRUCT not being supported in Scroll") jsTracers := []string{ `{enters: 0, exits: 0, enterGas: 0, gasUsed: 0, steps:0, step: function() { this.steps++}, diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index 32326acb9afa..7c3944785950 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -468,7 +468,7 @@ func TestGetLogsRange(t *testing.T) { gspec := &core.Genesis{ Config: params.TestChainConfig, } - _, err := gspec.Commit(db, trie.NewDatabase(db, &trie.Config{IsUsingZktrie: true})) + _, err := gspec.Commit(db, trie.NewDatabase(db, nil)) if err != nil { t.Fatal(err) } diff --git a/eth/filters/filter_test.go b/eth/filters/filter_test.go index 8da131c1a69d..c26af885cfb3 100644 --- a/eth/filters/filter_test.go +++ b/eth/filters/filter_test.go @@ -180,7 +180,7 @@ func TestFilters(t *testing.T) { // Hack: GenerateChainWithGenesis creates a new db. // Commit the genesis manually and use GenerateChain. - _, err = gspec.Commit(db, trie.NewDatabase(db, &trie.Config{IsUsingZktrie: true})) + _, err = gspec.Commit(db, trie.NewDatabase(db, nil)) if err != nil { t.Fatal(err) } diff --git a/eth/tracers/js/goja.go b/eth/tracers/js/goja.go index 5f584e272473..9659c15d5eda 100644 --- a/eth/tracers/js/goja.go +++ b/eth/tracers/js/goja.go @@ -327,10 +327,6 @@ func (t *jsTracer) CaptureEnter(typ vm.OpCode, from common.Address, to common.Ad } } -func (t *jsTracer) IsDebug() bool { - return false -} - // CaptureExit is called when EVM exits a scope, even if the scope didn't // execute any code. func (t *jsTracer) CaptureExit(output []byte, gasUsed uint64, err error) { diff --git a/eth/tracers/logger/access_list_tracer.go b/eth/tracers/logger/access_list_tracer.go index 87e37af5dbda..e83cb8c0491f 100644 --- a/eth/tracers/logger/access_list_tracer.go +++ b/eth/tracers/logger/access_list_tracer.go @@ -185,7 +185,3 @@ func (a *AccessListTracer) AccessList() types.AccessList { func (a *AccessListTracer) Equal(other *AccessListTracer) bool { return a.list.equal(other.list) } - -func (a *AccessListTracer) IsDebug() bool { - return false -} diff --git a/eth/tracers/logger/logger.go b/eth/tracers/logger/logger.go index 0868dd07fd44..5ef1bb003c55 100644 --- a/eth/tracers/logger/logger.go +++ b/eth/tracers/logger/logger.go @@ -420,8 +420,6 @@ func (l *StructLogger) UpdatedStorages() map[common.Address]Storage { // CreatedAccount return the account data in case it is a create tx func (l *StructLogger) CreatedAccount() *types.AccountWrapper { return l.createdAccount } -func (l *StructLogger) IsDebug() bool { return l.cfg.Debug } - // WriteTrace writes a formatted trace to the given writer func WriteTrace(writer io.Writer, logs []StructLog) { for _, log := range logs { @@ -545,8 +543,6 @@ func (*mdLogger) CaptureTxStart(gasLimit uint64) {} func (*mdLogger) CaptureTxEnd(restGas uint64) {} -func (t *mdLogger) IsDebug() bool { return t.cfg.Debug } - // FormatLogs formats EVM returned structured logs for json output func FormatLogs(logs []StructLog) []types.StructLogRes { formatted := make([]types.StructLogRes, len(logs)) diff --git a/eth/tracers/logger/logger_json.go b/eth/tracers/logger/logger_json.go index dbdfa3b0b3a6..8e6e14ef94b8 100644 --- a/eth/tracers/logger/logger_json.go +++ b/eth/tracers/logger/logger_json.go @@ -104,7 +104,3 @@ func (l *JSONLogger) CaptureExit(output []byte, gasUsed uint64, err error) {} func (l *JSONLogger) CaptureTxStart(gasLimit uint64) {} func (l *JSONLogger) CaptureTxEnd(restGas uint64) {} - -func (l *JSONLogger) IsDebug() bool { - return l.cfg.Debug -} diff --git a/eth/tracers/native/4byte.go b/eth/tracers/native/4byte.go index dea374825c1c..7701d2d9b88b 100644 --- a/eth/tracers/native/4byte.go +++ b/eth/tracers/native/4byte.go @@ -115,10 +115,6 @@ func (t *fourByteTracer) CaptureEnter(op vm.OpCode, from common.Address, to comm t.store(input[0:4], len(input)-4) } -func (t *fourByteTracer) IsDebug() bool { - return false -} - // GetResult returns the json-encoded nested list of call traces, and any // error arising from the encoding or forceful termination (via `Stop`). func (t *fourByteTracer) GetResult() (json.RawMessage, error) { diff --git a/eth/tracers/native/call.go b/eth/tracers/native/call.go index 15571a840bb6..96d8c37eee70 100644 --- a/eth/tracers/native/call.go +++ b/eth/tracers/native/call.go @@ -259,10 +259,6 @@ func (t *CallTracer) CaptureTxEnd(restGas uint64) { } } -func (t *CallTracer) IsDebug() bool { - return false -} - // GetResult returns the json-encoded nested list of call traces, and any // error arising from the encoding or forceful termination (via `Stop`). func (t *CallTracer) GetResult() (json.RawMessage, error) { diff --git a/eth/tracers/native/call_flat.go b/eth/tracers/native/call_flat.go index dbfb5809aaba..e504a56c06a5 100644 --- a/eth/tracers/native/call_flat.go +++ b/eth/tracers/native/call_flat.go @@ -213,10 +213,6 @@ func (t *flatCallTracer) CaptureTxEnd(restGas uint64) { t.tracer.CaptureTxEnd(restGas) } -func (t *flatCallTracer) IsDebug() bool { - return false -} - // GetResult returns an empty json object. func (t *flatCallTracer) GetResult() (json.RawMessage, error) { if len(t.tracer.callstack) < 1 { diff --git a/eth/tracers/native/mux.go b/eth/tracers/native/mux.go index 5c99fefe970b..84a08309edb0 100644 --- a/eth/tracers/native/mux.go +++ b/eth/tracers/native/mux.go @@ -122,10 +122,6 @@ func (t *MuxTracer) CaptureTxEnd(restGas uint64) { } } -func (t *MuxTracer) IsDebug() bool { - return false -} - // GetResult returns an empty json object. func (t *MuxTracer) GetResult() (json.RawMessage, error) { resObject := make(map[string]json.RawMessage) diff --git a/eth/tracers/native/noop.go b/eth/tracers/native/noop.go index 3f3851c0e9f1..181d4a60635c 100644 --- a/eth/tracers/native/noop.go +++ b/eth/tracers/native/noop.go @@ -71,8 +71,6 @@ func (*noopTracer) CaptureTxStart(gasLimit uint64) {} func (*noopTracer) CaptureTxEnd(restGas uint64) {} -func (*noopTracer) IsDebug() bool { return false } - // GetResult returns an empty json object. func (t *noopTracer) GetResult() (json.RawMessage, error) { return json.RawMessage(`{}`), nil diff --git a/go.mod b/go.mod index eb5cbb594568..db6f03c51b84 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,6 @@ require ( github.com/protolambda/bls12-381-util v0.0.0-20220416220906-d8552aa452c7 github.com/rs/cors v1.7.0 github.com/scroll-tech/da-codec v0.1.2 - github.com/scroll-tech/zktrie v0.8.4 github.com/shirou/gopsutil v3.21.11+incompatible github.com/sourcegraph/conc v0.3.0 github.com/status-im/keycard-go v0.2.0 diff --git a/go.sum b/go.sum index fa1a27460115..438ae7a5d44d 100644 --- a/go.sum +++ b/go.sum @@ -476,8 +476,6 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/scroll-tech/da-codec v0.1.2 h1:QyJ+dQ4zWVVJwuqxNt4MiKyrymVc6rHe4YPtURkjiRc= github.com/scroll-tech/da-codec v0.1.2/go.mod h1:odz1ck3umvYccCG03osaQBISAYGinZktZYbpk94fYRE= -github.com/scroll-tech/zktrie v0.8.4 h1:UagmnZ4Z3ITCk+aUq9NQZJNAwnWl4gSxsLb2Nl7IgRE= -github.com/scroll-tech/zktrie v0.8.4/go.mod h1:XvNo7vAk8yxNyTjBDj5WIiFzYW4bx/gJ78+NK6Zn6Uk= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= diff --git a/internal/ethapi/api_test.go b/internal/ethapi/api_test.go index 9df3a6d787a0..be6db98c4af4 100644 --- a/internal/ethapi/api_test.go +++ b/internal/ethapi/api_test.go @@ -114,6 +114,7 @@ type txData struct { } func allTransactionTypes(addr common.Address, config *params.ChainConfig) []txData { + emptyRootHash := common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") return []txData{ { Tx: &types.LegacyTx{ @@ -188,7 +189,7 @@ func allTransactionTypes(addr common.Address, config *params.ChainConfig) []txDa AccessList: types.AccessList{ types.AccessTuple{ Address: common.Address{0x2}, - StorageKeys: []common.Hash{types.EmptyLegacyTrieRootHash}, + StorageKeys: []common.Hash{emptyRootHash}, }, }, V: big.NewInt(32), @@ -234,7 +235,7 @@ func allTransactionTypes(addr common.Address, config *params.ChainConfig) []txDa AccessList: types.AccessList{ types.AccessTuple{ Address: common.Address{0x2}, - StorageKeys: []common.Hash{types.EmptyLegacyTrieRootHash}, + StorageKeys: []common.Hash{emptyRootHash}, }, }, V: big.NewInt(32), @@ -281,7 +282,7 @@ func allTransactionTypes(addr common.Address, config *params.ChainConfig) []txDa AccessList: types.AccessList{ types.AccessTuple{ Address: common.Address{0x2}, - StorageKeys: []common.Hash{types.EmptyLegacyTrieRootHash}, + StorageKeys: []common.Hash{emptyRootHash}, }, }, V: big.NewInt(32), diff --git a/rollup/ccc/async_checker_test.go b/rollup/ccc/async_checker_test.go index e4c7e934e00a..06da50d177aa 100644 --- a/rollup/ccc/async_checker_test.go +++ b/rollup/ccc/async_checker_test.go @@ -30,7 +30,7 @@ func TestAsyncChecker(t *testing.T) { Config: params.TestChainConfig, Alloc: core.GenesisAlloc{testAddr: {Balance: new(big.Int).Mul(big.NewInt(1000), big.NewInt(params.Ether))}}, } - gspec.MustCommit(db, trie.NewDatabase(db, trie.HashDefaultsWithZktrie)) + gspec.MustCommit(db, trie.NewDatabase(db, trie.HashDefaults)) chain, _ := core.NewBlockChain(db, nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil) asyncChecker := NewAsyncChecker(chain, 1, false) diff --git a/rollup/ccc/logger.go b/rollup/ccc/logger.go index 2c1f6678f497..c69095c6da93 100644 --- a/rollup/ccc/logger.go +++ b/rollup/ccc/logger.go @@ -260,10 +260,6 @@ func (l *Logger) CaptureTxStart(gasLimit uint64) { func (l *Logger) CaptureTxEnd(restGas uint64) { } -func (l *Logger) IsDebug() bool { - return true -} - // Error returns an error if executed txns triggered an overflow // Caller should revert some transactions and close the block func (l *Logger) Error() error { diff --git a/trie/zktrie_deletionproof.go b/rollup/tracing/proof.go similarity index 71% rename from trie/zktrie_deletionproof.go rename to rollup/tracing/proof.go index 9fbf2ae36abf..b02f3540b413 100644 --- a/trie/zktrie_deletionproof.go +++ b/rollup/tracing/proof.go @@ -1,30 +1,28 @@ -package trie +package tracing import ( "bytes" "fmt" - zktrie "github.com/scroll-tech/zktrie/trie" - zkt "github.com/scroll-tech/zktrie/types" - "github.com/scroll-tech/go-ethereum/ethdb" + "github.com/scroll-tech/go-ethereum/trie" ) type ProofTracer struct { - *ZkTrie - deletionTracer map[zkt.Hash]struct{} - rawPaths map[string][]*zktrie.Node - emptyTermPaths map[string][]*zktrie.Node + trie *trie.ZkTrie + deletionTracer map[trie.Hash]struct{} + rawPaths map[string][]*trie.Node + emptyTermPaths map[string][]*trie.Node } // NewProofTracer create a proof tracer object -func (t *ZkTrie) NewProofTracer() *ProofTracer { +func NewProofTracer(t *trie.ZkTrie) *ProofTracer { return &ProofTracer{ - ZkTrie: t, + trie: t, // always consider 0 is "deleted" - deletionTracer: map[zkt.Hash]struct{}{zkt.HashZero: {}}, - rawPaths: make(map[string][]*zktrie.Node), - emptyTermPaths: make(map[string][]*zktrie.Node), + deletionTracer: map[trie.Hash]struct{}{trie.HashZero: {}}, + rawPaths: make(map[string][]*trie.Node), + emptyTermPaths: make(map[string][]*trie.Node), } } @@ -32,7 +30,7 @@ func (t *ZkTrie) NewProofTracer() *ProofTracer { func (t *ProofTracer) Merge(another *ProofTracer) *ProofTracer { // sanity checking - if !bytes.Equal(t.Hash().Bytes(), another.Hash().Bytes()) { + if !bytes.Equal(t.trie.Hash().Bytes(), another.trie.Hash().Bytes()) { panic("can not merge two proof tracer base on different trie") } @@ -59,7 +57,7 @@ func (t *ProofTracer) Merge(another *ProofTracer) *ProofTracer { // always decode the node for its purpose func (t *ProofTracer) GetDeletionProofs() ([][]byte, error) { - retMap := map[zkt.Hash][]byte{} + retMap := map[trie.Hash][]byte{} // check each path: reversively, skip the final leaf node for _, path := range t.rawPaths { @@ -73,18 +71,18 @@ func (t *ProofTracer) GetDeletionProofs() ([][]byte, error) { nodeHash, _ := n.NodeHash() t.deletionTracer[*nodeHash] = struct{}{} } else { - var siblingHash *zkt.Hash + var siblingHash *trie.Hash if deletedL { siblingHash = n.ChildR } else if deletedR { siblingHash = n.ChildL } if siblingHash != nil { - sibling, err := t.ZkTrie.Tree().GetNode(siblingHash) + sibling, err := t.trie.GetNodeByHash(siblingHash) if err != nil { return nil, err } - if sibling.Type != zktrie.NodeTypeEmpty_New { + if sibling.Type != trie.NodeTypeEmpty_New { retMap[*siblingHash] = sibling.Value() } } @@ -103,7 +101,7 @@ func (t *ProofTracer) GetDeletionProofs() ([][]byte, error) { } // MarkDeletion mark a key has been involved into deletion -func (t *ProofTracer) MarkDeletion(key []byte) { +func (t *ProofTracer) MarkDeletion(key []byte) error { if path, existed := t.emptyTermPaths[string(key)]; existed { // copy empty node terminated path for final scanning t.rawPaths[string(key)] = path @@ -111,38 +109,39 @@ func (t *ProofTracer) MarkDeletion(key []byte) { // sanity check leafNode := path[len(path)-1] - if leafNode.Type != zktrie.NodeTypeLeaf_New { + if leafNode.Type != trie.NodeTypeLeaf_New { panic("all path recorded in proofTrace should be ended with leafNode") } nodeHash, _ := leafNode.NodeHash() t.deletionTracer[*nodeHash] = struct{}{} } + return nil } // Prove act the same as zktrie.Prove, while also collect the raw path // for collecting deletion proofs in a post-work func (t *ProofTracer) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { fromLevel := uint(0) - var mptPath []*zktrie.Node - err := t.ZkTrie.ProveWithDeletion(key, fromLevel, - func(n *zktrie.Node) error { + var mptPath []*trie.Node + return t.trie.ProveWithDeletion(key, fromLevel, + func(n *trie.Node) error { nodeHash, err := n.NodeHash() if err != nil { return err } switch n.Type { - case zktrie.NodeTypeLeaf_New: - preImage := t.GetKey(n.NodeKey.Bytes()) + case trie.NodeTypeLeaf_New: + preImage := t.trie.GetKey(n.NodeKey.Bytes()) if len(preImage) > 0 { - n.KeyPreimage = &zkt.Byte32{} + n.KeyPreimage = &trie.Byte32{} copy(n.KeyPreimage[:], preImage) } - case zktrie.NodeTypeBranch_0, zktrie.NodeTypeBranch_1, - zktrie.NodeTypeBranch_2, zktrie.NodeTypeBranch_3: + case trie.NodeTypeBranch_0, trie.NodeTypeBranch_1, + trie.NodeTypeBranch_2, trie.NodeTypeBranch_3: mptPath = append(mptPath, n) - case zktrie.NodeTypeEmpty_New: + case trie.NodeTypeEmpty_New: // empty node is considered as "unhit" but it should be also being added // into a temporary slot for possibly being marked as deletion later mptPath = append(mptPath, n) @@ -153,17 +152,11 @@ func (t *ProofTracer) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { return proofDb.Put(nodeHash[:], n.Value()) }, - func(n *zktrie.Node, _ *zktrie.Node) { + func(n *trie.Node, _ *trie.Node) { // only "hit" path (i.e. the leaf node corresponding the input key can be found) // would be add into tracer mptPath = append(mptPath, n) t.rawPaths[string(key)] = mptPath }, ) - if err != nil { - return err - } - // we put this special kv pair in db so we can distinguish the type and - // make suitable Proof - return proofDb.Put(magicHash, zktrie.ProofMagicBytes()) } diff --git a/rollup/tracing/proof_test.go b/rollup/tracing/proof_test.go new file mode 100644 index 000000000000..0d2c457db921 --- /dev/null +++ b/rollup/tracing/proof_test.go @@ -0,0 +1,111 @@ +package tracing + +import ( + "bytes" + "testing" + + "github.com/scroll-tech/go-ethereum/common" + "github.com/scroll-tech/go-ethereum/core/rawdb" + "github.com/scroll-tech/go-ethereum/ethdb/memorydb" + "github.com/scroll-tech/go-ethereum/trie" + "github.com/stretchr/testify/assert" +) + +func newTestingMerkle(t *testing.T) (*trie.ZkTrie, *trie.Database) { + db := trie.NewDatabase(rawdb.NewMemoryDatabase(), &trie.Config{}) + return newTestingMerkleWithDb(t, common.Hash{}, db) +} + +func newTestingMerkleWithDb(t *testing.T, root common.Hash, db *trie.Database) (*trie.ZkTrie, *trie.Database) { + maxLevels := trie.NodeKeyValidBytes * 8 + mt, err := trie.NewZkTrie(trie.TrieID(root), db) + if err != nil { + t.Fatal(err) + return nil, nil + } + assert.Equal(t, maxLevels, mt.MaxLevels()) + return mt, db +} + +// Tests that new "proof trace" feature +func TestProofWithDeletion(t *testing.T) { + mt, _ := newTestingMerkle(t) + key1 := bytes.Repeat([]byte("b"), 32) + key2 := bytes.Repeat([]byte("c"), 32) + err := mt.TryUpdate( + key1, + 1, + []trie.Byte32{*trie.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32))}, + ) + assert.NoError(t, err) + err = mt.TryUpdate( + key2, + 1, + []trie.Byte32{*trie.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("n"), 32))}, + ) + assert.NoError(t, err) + + proof := memorydb.New() + proofTracer := NewProofTracer(mt) + + err = proofTracer.Prove(key1, proof) + assert.NoError(t, err) + nd, err := mt.TryGet(key2) + assert.NoError(t, err) + + key4 := bytes.Repeat([]byte("x"), 32) + err = proofTracer.Prove(key4, proof) + assert.NoError(t, err) + //assert.Equal(t, len(sibling1), len(delTracer.GetProofs())) + + siblings, err := proofTracer.GetDeletionProofs() + assert.NoError(t, err) + assert.Equal(t, 0, len(siblings)) + + proofTracer.MarkDeletion(key1) + siblings, err = proofTracer.GetDeletionProofs() + assert.NoError(t, err) + assert.Equal(t, 1, len(siblings)) + l := len(siblings[0]) + // a hacking to grep the value part directly from the encoded leaf node, + // notice the sibling of key `k*32`` is just the leaf of key `m*32` + assert.Equal(t, siblings[0][l-33:l-1], nd) + + // Marking a key that is currently not hit (but terminated by an empty node) + // also causes it to be added to the deletion proof + proofTracer.MarkDeletion(key4) + siblings, err = proofTracer.GetDeletionProofs() + assert.NoError(t, err) + assert.Equal(t, 2, len(siblings)) + + key3 := bytes.Repeat([]byte("x"), 32) + err = mt.TryUpdate( + key3, + 1, + []trie.Byte32{*trie.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("z"), 32))}, + ) + assert.NoError(t, err) + + proofTracer = NewProofTracer(mt) + err = proofTracer.Prove(key1, proof) + assert.NoError(t, err) + err = proofTracer.Prove(key4, proof) + assert.NoError(t, err) + + proofTracer.MarkDeletion(key1) + siblings, err = proofTracer.GetDeletionProofs() + assert.NoError(t, err) + assert.Equal(t, 1, len(siblings)) + + proofTracer.MarkDeletion(key4) + siblings, err = proofTracer.GetDeletionProofs() + assert.NoError(t, err) + assert.Equal(t, 2, len(siblings)) + + // one of the siblings is just leaf for key2, while + // another one must be a middle node + match1 := bytes.Equal(siblings[0][l-33:l-1], nd) + match2 := bytes.Equal(siblings[1][l-33:l-1], nd) + assert.True(t, match1 || match2) + assert.False(t, match1 && match2) +} diff --git a/rollup/tracing/tracing.go b/rollup/tracing/tracing.go index 6e4c16a9619a..a84604c5244e 100644 --- a/rollup/tracing/tracing.go +++ b/rollup/tracing/tracing.go @@ -27,6 +27,8 @@ import ( "github.com/scroll-tech/go-ethereum/rollup/fees" "github.com/scroll-tech/go-ethereum/rollup/rcfg" "github.com/scroll-tech/go-ethereum/rollup/withdrawtrie" + "github.com/scroll-tech/go-ethereum/trie" + "github.com/scroll-tech/go-ethereum/trie/zkproof" ) var ( @@ -78,7 +80,7 @@ type TraceEnv struct { TxStorageTraces []*types.StorageTrace Codes map[common.Hash]logger.CodeInfo // zktrie tracer is used for zktrie storage to build additional deletion proof - ZkTrieTracer map[string]state.ZktrieProofTracer + ZkTrieTracer map[string]*ProofTracer // StartL1QueueIndex is the next L1 message queue index that this block can process. // Example: If the parent block included QueueIndex=9, then StartL1QueueIndex will @@ -117,7 +119,7 @@ func CreateTraceEnvHelper(chainConfig *params.ChainConfig, logConfig *logger.Con }, TxStorageTraces: make([]*types.StorageTrace, block.Transactions().Len()), Codes: make(map[common.Hash]logger.CodeInfo), - ZkTrieTracer: make(map[string]state.ZktrieProofTracer), + ZkTrieTracer: make(map[string]*ProofTracer), StartL1QueueIndex: startL1QueueIndex, } } @@ -435,14 +437,15 @@ func (env *TraceEnv) getTxResult(state *state.StateDB, index int, block *types.B } env.sMu.Lock() - trie, err := state.GetStorageTrieForProof(addr) - if err != nil { + storageTrie, err := state.Database().OpenStorageTrie(state.GetRootHash(), addr, state.GetOrNewStateObject(addr).Root()) + zkStorageTrie, isZk := storageTrie.(*trie.ZkTrie) + if err != nil || !isZk { // but we still continue to next address log.Error("Storage trie not available", "error", err, "address", addr) env.sMu.Unlock() continue } - zktrieTracer := state.NewProofTracer(trie) + zktrieTracer := NewProofTracer(zkStorageTrie) env.sMu.Unlock() for key := range keys { @@ -458,29 +461,23 @@ func (env *TraceEnv) getTxResult(state *state.StateDB, index int, block *types.B m = make(map[string][]hexutil.Bytes) env.StorageProofs[addrStr] = m } - if zktrieTracer.Available() && !env.ZkTrieTracer[addrStr].Available() { - env.ZkTrieTracer[addrStr] = state.NewProofTracer(trie) + if _, exists := env.ZkTrieTracer[addrStr]; !exists { + env.ZkTrieTracer[addrStr] = zktrieTracer } if proof, existed := m[keyStr]; existed { txm[keyStr] = proof // still need to touch tracer for deletion - if isDelete && zktrieTracer.Available() { - env.ZkTrieTracer[addrStr].MarkDeletion(key) + if isDelete { + env.ZkTrieTracer[addrStr].MarkDeletion(key.Bytes()) } env.sMu.Unlock() continue } env.sMu.Unlock() - var proof [][]byte - var err error - if zktrieTracer.Available() { - proof, err = state.GetSecureTrieProof(zktrieTracer, key) - } else { - proof, err = state.GetSecureTrieProof(trie, key) - } - if err != nil { + var proof zkproof.ProofList + if err = zkStorageTrie.Prove(key.Bytes(), &proof); err != nil { log.Error("Storage proof not available", "error", err, "address", addrStr, "key", keyStr) // but we still mark the proofs map with nil array } @@ -488,12 +485,10 @@ func (env *TraceEnv) getTxResult(state *state.StateDB, index int, block *types.B env.sMu.Lock() txm[keyStr] = wrappedProof m[keyStr] = wrappedProof - if zktrieTracer.Available() { - if isDelete { - zktrieTracer.MarkDeletion(key) - } - env.ZkTrieTracer[addrStr].Merge(zktrieTracer) + if isDelete { + zktrieTracer.MarkDeletion(key.Bytes()) } + env.ZkTrieTracer[addrStr].Merge(zktrieTracer) env.sMu.Unlock() } } @@ -564,9 +559,13 @@ func (env *TraceEnv) fillBlockTrace(block *types.Block) (*types.BlockTrace, erro for _, slot := range storages { if _, existed := env.StorageProofs[addr.String()][slot.String()]; !existed { - if trie, err := statedb.GetStorageTrieForProof(addr); err != nil { - log.Error("Storage proof for intrinstic address not available", "error", err, "address", addr) - } else if proof, err := statedb.GetSecureTrieProof(trie, slot); err != nil { + var proof zkproof.ProofList + storageTrie, err := statedb.Database().OpenStorageTrie(statedb.GetRootHash(), addr, statedb.GetOrNewStateObject(addr).Root()) + zkStorageTrie, isZk := storageTrie.(*trie.ZkTrie) + if err != nil || !isZk { + // but we still continue to next address + log.Error("Storage trie not available", "error", err, "address", addr) + } else if err := zkStorageTrie.Prove(slot.Bytes(), &proof); err != nil { log.Error("Get storage proof for intrinstic address failed", "error", err, "address", addr, "slot", slot) } else { env.StorageProofs[addr.String()][slot.String()] = types.WrapProof(proof) diff --git a/tests/fuzzers/trie/trie-fuzzer.go b/tests/fuzzers/trie/trie-fuzzer.go index e8c700d67ee1..b1fef12fb13d 100644 --- a/tests/fuzzers/trie/trie-fuzzer.go +++ b/tests/fuzzers/trie/trie-fuzzer.go @@ -145,7 +145,7 @@ func runRandTest(rt randTest) error { var ( triedb = trie.NewDatabase(rawdb.NewMemoryDatabase(), nil) tr = trie.NewEmpty(triedb) - origin = types.EmptyLegacyTrieRootHash + origin = types.EmptyRootHash values = make(map[string]string) // tracks content of the trie ) for i, step := range rt { diff --git a/trie/byte32.go b/trie/byte32.go new file mode 100644 index 000000000000..313b4062b725 --- /dev/null +++ b/trie/byte32.go @@ -0,0 +1,42 @@ +package trie + +import ( + "math/big" + + "github.com/scroll-tech/go-ethereum/crypto/poseidon" +) + +type Byte32 [32]byte + +func (b *Byte32) Hash() (*big.Int, error) { + first16 := new(big.Int).SetBytes(b[0:16]) + last16 := new(big.Int).SetBytes(b[16:32]) + hash, err := poseidon.HashFixedWithDomain([]*big.Int{first16, last16}, big.NewInt(HASH_DOMAIN_BYTE32)) + if err != nil { + return nil, err + } + return hash, nil +} + +func (b *Byte32) Bytes() []byte { return b[:] } + +// same action as common.Hash (truncate bytes longer than 32 bytes FROM beginning, +// and padding 0 at the beginning for shorter bytes) +func NewByte32FromBytes(b []byte) *Byte32 { + + byte32 := new(Byte32) + + if len(b) > 32 { + b = b[len(b)-32:] + } + + copy(byte32[32-len(b):], b) + return byte32 +} + +// create bytes32 with zeropadding to shorter bytes, or truncate it +func NewByte32FromBytesPaddingZero(b []byte) *Byte32 { + byte32 := new(Byte32) + copy(byte32[:], b) + return byte32 +} diff --git a/trie/byte32_test.go b/trie/byte32_test.go new file mode 100644 index 000000000000..d261c97de2a4 --- /dev/null +++ b/trie/byte32_test.go @@ -0,0 +1,44 @@ +package trie + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewByte32(t *testing.T) { + var tests = []struct { + input []byte + expected []byte + expectedPaddingZero []byte + expectedHash string + expectedHashPadding string + }{ + {bytes.Repeat([]byte{1}, 4), + []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}, + []byte{1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + "1120169262217660912395665138727312015286293827539936259020934722663991619468", + "11815021958450380571374861379539732018094133931187815125213818828376493710327", + }, + {bytes.Repeat([]byte{1}, 34), + []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + "2219239698457798269997113163039475489501011181643161136091371987815450431154", + "2219239698457798269997113163039475489501011181643161136091371987815450431154", + }, + } + + for _, tt := range tests { + byte32Result := NewByte32FromBytes(tt.input) + byte32PaddingResult := NewByte32FromBytesPaddingZero(tt.input) + assert.Equal(t, tt.expected, byte32Result.Bytes()) + assert.Equal(t, tt.expectedPaddingZero, byte32PaddingResult.Bytes()) + hashResult, err := byte32Result.Hash() + assert.NoError(t, err) + hashPaddingResult, err := byte32PaddingResult.Hash() + assert.NoError(t, err) + assert.Equal(t, tt.expectedHash, hashResult.String()) + assert.Equal(t, tt.expectedHashPadding, hashPaddingResult.String()) + } +} diff --git a/trie/database.go b/trie/database.go index f0268d096317..da243da4c1c0 100644 --- a/trie/database.go +++ b/trie/database.go @@ -18,7 +18,6 @@ package trie import ( "errors" - "sync" "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/ethdb" @@ -34,9 +33,6 @@ type Config struct { Preimages bool // Flag whether the preimage of node key is recorded HashDB *hashdb.Config // Configs for hash-based scheme PathDB *pathdb.Config // Configs for experimental path-based scheme - - // zktrie related stuff - IsUsingZktrie bool } // HashDefaults represents a config for using hash-based scheme with @@ -46,13 +42,6 @@ var HashDefaults = &Config{ HashDB: hashdb.Defaults, } -// HashDefaultsWithZktrie represents a config based on HashDefaults but with zktrie enabled. -var HashDefaultsWithZktrie = &Config{ - Preimages: false, - HashDB: hashdb.Defaults, - IsUsingZktrie: true, -} - // HashDefaultsWithPreimages represents a config based on HashDefaults but with Preimages enabled. var HashDefaultsWithPreimages = &Config{ Preimages: true, @@ -90,9 +79,6 @@ type backend interface { // Close closes the trie database backend and releases all held resources. Close() error - - // database supplementary methods, to get the underlying fields - GetLock() *sync.RWMutex } // Database is the wrapper of the underlying backend which is shared by different @@ -103,10 +89,6 @@ type Database struct { diskdb ethdb.Database // Persistent database to store the snapshot preimages *preimageStore // The store for caching preimages backend backend // The backend for managing trie nodes - - // zktrie related stuff - // TODO: It's a quick&dirty implementation. FIXME later. - rawDirties KvMap } // NewDatabase initializes the trie database with default settings, note @@ -124,8 +106,6 @@ func NewDatabase(diskdb ethdb.Database, config *Config) *Database { config: config, diskdb: diskdb, preimages: preimages, - // scroll-related - rawDirties: make(KvMap), } if config.HashDB != nil && config.PathDB != nil { log.Crit("Both 'hash' and 'path' mode are configured") @@ -133,24 +113,11 @@ func NewDatabase(diskdb ethdb.Database, config *Config) *Database { if config.PathDB != nil { db.backend = pathdb.New(diskdb, config.PathDB) } else { - db.backend = hashdb.New(diskdb, config.HashDB, mptResolver{}) + db.backend = hashdb.New(diskdb, config.HashDB, ZkChildResolver{}) } return db } -func (db *Database) IsUsingZktrie() bool { - // compatible logic for light mode - if db == nil || db.config == nil { - return false - } - return db.config.IsUsingZktrie -} - -func (db *Database) SetIsUsingZktrie(isUsingZktrie bool) { - // config must not be nil - db.config.IsUsingZktrie = isUsingZktrie -} - // Reader returns a reader for accessing all trie nodes with provided state root. // An error will be returned if the requested state is not available. func (db *Database) Reader(blockRoot common.Hash) (Reader, error) { @@ -181,25 +148,6 @@ func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, n // to disk. As a side effect, all pre-images accumulated up to this point are // also written. func (db *Database) Commit(root common.Hash, report bool) error { - batch := db.diskdb.NewBatch() - - db.GetLock().Lock() - for _, v := range db.rawDirties { - batch.Put(v.K, v.V) - } - for k := range db.rawDirties { - delete(db.rawDirties, k) - } - db.GetLock().Unlock() - if err := batch.Write(); err != nil { - return err - } - batch.Reset() - - if (root == common.Hash{}) { - return nil - } - if db.preimages != nil { db.preimages.commit(true) } diff --git a/trie/database_supplement.go b/trie/database_supplement.go deleted file mode 100644 index fa04d4dbb2ff..000000000000 --- a/trie/database_supplement.go +++ /dev/null @@ -1,32 +0,0 @@ -package trie - -import ( - "sync" - - "github.com/VictoriaMetrics/fastcache" - - "github.com/scroll-tech/go-ethereum/common" - "github.com/scroll-tech/go-ethereum/core/types" - "github.com/scroll-tech/go-ethereum/trie/triedb/hashdb" -) - -func (db *Database) GetLock() *sync.RWMutex { - return db.backend.GetLock() -} - -func (db *Database) GetCleans() *fastcache.Cache { - hdb, ok := db.backend.(*hashdb.Database) - if !ok { - panic("only hashdb supported") - } - return hdb.GetCleans() -} - -// EmptyRoot indicate what root is for an empty trie, it depends on its underlying implement (zktrie or common trie) -func (db *Database) EmptyRoot() common.Hash { - if db.IsUsingZktrie() { - return types.EmptyZkTrieRootHash - } else { - return types.EmptyLegacyTrieRootHash - } -} diff --git a/trie/hash.go b/trie/hash.go new file mode 100644 index 000000000000..e97013a9f4b0 --- /dev/null +++ b/trie/hash.go @@ -0,0 +1,149 @@ +package trie + +import ( + "encoding/hex" + "fmt" + "math/big" + "slices" +) + +var Q *big.Int + +const ( + HASH_DOMAIN_ELEMS_BASE = 256 + HASH_DOMAIN_BYTE32 = 2 * HASH_DOMAIN_ELEMS_BASE +) + +func init() { + qString := "21888242871839275222246405745257275088548364400416034343698204186575808495617" + var ok bool + Q, ok = new(big.Int).SetString(qString, 10) //nolint:gomnd + if !ok { + panic(fmt.Sprintf("Bad base 10 string %s", qString)) + } +} + +// CheckBigIntInField checks if given *big.Int fits in a Field Q element +func CheckBigIntInField(a *big.Int) bool { + return a.Cmp(Q) == -1 +} + +const numCharPrint = 8 + +// HashByteLen is the length of the Hash byte array +const HashByteLen = 32 + +var HashZero = Hash{} + +// Hash is the generic type to store the hash in the MerkleTree, encoded in little endian +type Hash [HashByteLen]byte + +// MarshalText implements the marshaler for the Hash type +func (h Hash) MarshalText() ([]byte, error) { + return []byte(h.BigInt().String()), nil +} + +// UnmarshalText implements the unmarshaler for the Hash type +func (h *Hash) UnmarshalText(b []byte) error { + ha, err := NewHashFromString(string(b)) + copy(h[:], ha[:]) + return err +} + +// String returns decimal representation in string format of the Hash +func (h Hash) String() string { + s := h.BigInt().String() + if len(s) < numCharPrint { + return s + } + return s[0:numCharPrint] + "..." +} + +// Hex returns the hexadecimal representation of the Hash +func (h Hash) Hex() string { + return hex.EncodeToString(h.Bytes()) +} + +// BigInt returns the *big.Int representation of the *Hash +func (h *Hash) BigInt() *big.Int { + return big.NewInt(0).SetBytes(h.Bytes()) +} + +// SetBytes sets the value of the hash from the given big endian byte array +func (h *Hash) SetBytes(b []byte) { + *h = HashZero + _ = h[len(b)-1] // eliminate range checks + for i := 0; i < len(b); i++ { + h[len(b)-i-1] = b[i] + } +} + +// Bytes returns the byte representation of the *Hash in big-endian encoding. +// The function converts the byte order from little endian to big endian. +func (h *Hash) Bytes() []byte { + b := [HashByteLen]byte{} + copy(b[:], h[:]) + slices.Reverse(b[:]) + return b[:] +} + +// Set copies the given hash in to this +func (h *Hash) Set(other *Hash) { + *h = *other +} + +// Copy copies the given hash in to this +func (h *Hash) Clone() *Hash { + var clone Hash + clone.Set(h) + return &clone +} + +// NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the +// endianness in the process. This is the intended method to get a *big.Int +// from a byte array that previously has ben generated by the Hash.Bytes() +// method. +func NewBigIntFromHashBytes(b []byte) (*big.Int, error) { + if len(b) != HashByteLen { + return nil, fmt.Errorf("expected %d bytes, but got %d bytes", HashByteLen, len(b)) + } + bi := new(big.Int).SetBytes(b) + if !CheckBigIntInField(bi) { + return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field") + } + return bi, nil +} + +// NewHashFromBigInt returns a *Hash representation of the given *big.Int +func NewHashFromBigInt(b *big.Int) *Hash { + var bytes [HashByteLen]byte + return NewHashFromBytes(b.FillBytes(bytes[:])) +} + +// NewHashFromBytes returns a *Hash from a byte array considered to be +// a represent of big-endian integer, it swapping the endianness +// in the process. +func NewHashFromBytes(b []byte) *Hash { + var h Hash + h.SetBytes(b) + return &h +} + +// NewHashFromCheckedBytes is the intended method to get a *Hash from a byte array +// that previously has ben generated by the Hash.Bytes() method. so it check the +// size of bytes to be expected length +func NewHashFromCheckedBytes(b []byte) (*Hash, error) { + if len(b) != HashByteLen { + return nil, fmt.Errorf("expected %d bytes, but got %d bytes", HashByteLen, len(b)) + } + return NewHashFromBytes(b), nil +} + +// NewHashFromString returns a *Hash representation of the given decimal string +func NewHashFromString(s string) (*Hash, error) { + bi, ok := new(big.Int).SetString(s, 10) + if !ok { + return nil, fmt.Errorf("cannot parse the string to Hash") + } + return NewHashFromBigInt(bi), nil +} diff --git a/trie/hash_test.go b/trie/hash_test.go new file mode 100644 index 000000000000..8f25e731cc9d --- /dev/null +++ b/trie/hash_test.go @@ -0,0 +1,83 @@ +package trie + +import ( + "bytes" + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCheckBigIntInField(t *testing.T) { + bi := big.NewInt(0) + assert.True(t, CheckBigIntInField(bi)) + + bi = new(big.Int).Sub(Q, big.NewInt(1)) + assert.True(t, CheckBigIntInField(bi)) + + bi = new(big.Int).Set(Q) + assert.False(t, CheckBigIntInField(bi)) +} + +func TestNewHashAndBigIntFromBytes(t *testing.T) { + b := bytes.Repeat([]byte{1, 2}, 16) + h := NewHashFromBytes(b) + assert.Equal(t, "0102010201020102010201020102010201020102010201020102010201020102", h.Hex()) + assert.Equal(t, "45585349...", h.String()) + + h, err := NewHashFromCheckedBytes(b) + assert.NoError(t, err) + assert.Equal(t, "0102010201020102010201020102010201020102010201020102010201020102", h.Hex()) + + bi, err := NewBigIntFromHashBytes(b) + assert.NoError(t, err) + assert.Equal(t, "455853498485199945361735166433836579326217380693297711485161465995904286978", bi.String()) + + h1 := NewHashFromBytes(b) + text, err := h1.MarshalText() + assert.NoError(t, err) + assert.Equal(t, "455853498485199945361735166433836579326217380693297711485161465995904286978", h1.BigInt().String()) + h2 := &Hash{} + err = h2.UnmarshalText(text) + assert.NoError(t, err) + assert.Equal(t, h1, h2) + + short := []byte{1, 2, 3, 4, 5} + _, err = NewHashFromCheckedBytes(short) + assert.Error(t, err) + assert.Equal(t, fmt.Sprintf("expected %d bytes, but got %d bytes", HashByteLen, len(short)), err.Error()) + + short = []byte{1, 2, 3, 4, 5} + _, err = NewBigIntFromHashBytes(short) + assert.Error(t, err) + assert.Equal(t, fmt.Sprintf("expected %d bytes, but got %d bytes", HashByteLen, len(short)), err.Error()) + + outOfField := bytes.Repeat([]byte{255}, 32) + _, err = NewBigIntFromHashBytes(outOfField) + assert.Error(t, err) + assert.Equal(t, "NewBigIntFromHashBytes: Value not inside the Finite Field", err.Error()) +} + +func TestNewHashFromBigIntAndString(t *testing.T) { + bi := big.NewInt(12345) + h := NewHashFromBigInt(bi) + assert.Equal(t, "0000000000000000000000000000000000000000000000000000000000003039", h.Hex()) + assert.Equal(t, "12345", h.String()) + + s := "454086624460063511464984254936031011189294057512315937409637584344757371137" + h, err := NewHashFromString(s) + assert.NoError(t, err) + assert.Equal(t, "0101010101010101010101010101010101010101010101010101010101010101", h.Hex()) + assert.Equal(t, "45408662...", h.String()) +} + +func TestNewHashFromBytes(t *testing.T) { + h := HashZero + read, err := rand.Read(h[:]) + require.NoError(t, err) + require.Equal(t, HashByteLen, read) + require.Equal(t, h, *NewHashFromBytes(h.Bytes())) +} diff --git a/trie/iterator.go b/trie/iterator.go index d4539c1a397a..6a06be88b44b 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -161,7 +161,7 @@ func (e seekError) Error() string { } func newNodeIterator(trie *Trie, start []byte) NodeIterator { - if trie.Hash() == types.EmptyLegacyTrieRootHash { + if trie.Hash() == types.EmptyRootHash { return &nodeIterator{ trie: trie, err: errIteratorEnd, @@ -303,7 +303,7 @@ func (it *nodeIterator) seek(prefix []byte) error { func (it *nodeIterator) init() (*nodeIteratorState, error) { root := it.trie.Hash() state := &nodeIteratorState{node: it.trie.root, index: -1} - if root != types.EmptyLegacyTrieRootHash { + if root != types.EmptyRootHash { state.hash = root } return state, state.resolve(it, nil) diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 007be947bfe9..77569aecc72c 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2014 The go-ethereum Authors // This file is part of the go-ethereum library. // @@ -60,7 +63,7 @@ func TestIterator(t *testing.T) { trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) trie, _ = New(TrieID(root), db) found := make(map[string]string) @@ -76,15 +79,6 @@ func TestIterator(t *testing.T) { } } -type kv struct { - k, v []byte - t bool -} - -func (k *kv) cmp(other *kv) int { - return bytes.Compare(k.k, other.k) -} - func TestIteratorLargeData(t *testing.T) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), nil)) vals := make(map[string]*kv) @@ -252,7 +246,7 @@ func TestDifferenceIterator(t *testing.T) { triea.MustUpdate([]byte(val.k), []byte(val.v)) } rootA, nodesA, _ := triea.Commit(false) - dba.Update(rootA, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodesA), nil) + dba.Update(rootA, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodesA), nil) triea, _ = New(TrieID(rootA), dba) dbb := NewDatabase(rawdb.NewMemoryDatabase(), nil) @@ -261,7 +255,7 @@ func TestDifferenceIterator(t *testing.T) { trieb.MustUpdate([]byte(val.k), []byte(val.v)) } rootB, nodesB, _ := trieb.Commit(false) - dbb.Update(rootB, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodesB), nil) + dbb.Update(rootB, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodesB), nil) trieb, _ = New(TrieID(rootB), dbb) found := make(map[string]string) @@ -294,7 +288,7 @@ func TestUnionIterator(t *testing.T) { triea.MustUpdate([]byte(val.k), []byte(val.v)) } rootA, nodesA, _ := triea.Commit(false) - dba.Update(rootA, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodesA), nil) + dba.Update(rootA, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodesA), nil) triea, _ = New(TrieID(rootA), dba) dbb := NewDatabase(rawdb.NewMemoryDatabase(), nil) @@ -303,7 +297,7 @@ func TestUnionIterator(t *testing.T) { trieb.MustUpdate([]byte(val.k), []byte(val.v)) } rootB, nodesB, _ := trieb.Commit(false) - dbb.Update(rootB, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodesB), nil) + dbb.Update(rootB, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodesB), nil) trieb, _ = New(TrieID(rootB), dbb) di, _ := NewUnionIterator([]NodeIterator{triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil)}) @@ -365,7 +359,7 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool, scheme string) { tr.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes, _ := tr.Commit(false) - tdb.Update(root, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + tdb.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) if !memonly { tdb.Commit(root, false) } @@ -481,7 +475,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme strin break } } - triedb.Update(root, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + triedb.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) if !memonly { triedb.Commit(root, false) } @@ -555,7 +549,7 @@ func testIteratorNodeBlob(t *testing.T, scheme string) { trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes, _ := trie.Commit(false) - triedb.Update(root, types.EmptyZkTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + triedb.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) triedb.Commit(root, false) var found = make(map[common.Hash][]byte) diff --git a/trie/node_test.go b/trie/node_test.go index 70a924f86268..3c0359ca5ae5 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2016 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/trie/proof.go b/trie/proof.go index 4e98708c945e..15418758909c 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -107,7 +107,7 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { // If the trie does not contain a value for key, the returned proof contains all // nodes of the longest existing prefix of the key (at least the root node), ending // with the node that proves the absence of the key. -func (t *StateTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { +func (t *stateTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { return t.trie.Prove(key, proofDb) } @@ -115,11 +115,6 @@ func (t *StateTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { // key in a trie with the given root hash. VerifyProof returns an error if the // proof contains invalid trie nodes or the wrong value. func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { - // test the type of proof (for trie or SMT) - if buf, _ := proofDb.Get(magicHash); buf != nil { - return VerifyProofSMT(rootHash, key, proofDb) - } - key = keybytesToHex(key) wantHash := rootHash for i := 0; ; i++ { diff --git a/trie/proof_test.go b/trie/proof_test.go index 249c4a021ae6..07057fb20ebb 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2015 The go-ethereum Authors // This file is part of the go-ethereum library. // diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 0b21d7a34580..de1364941034 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -48,12 +48,12 @@ func NewSecure(stateRoot common.Hash, owner common.Hash, root common.Hash, db *D // the preimage of each key if preimage recording is enabled. // // StateTrie is not safe for concurrent use. -type StateTrie struct { +type stateTrie struct { trie Trie preimages *preimageStore hashKeyBuf [common.HashLength]byte secKeyCache map[string][]byte - secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch + secKeyCacheOwner *stateTrie // Pointer to self, replace the key cache on mismatch } // NewStateTrie creates a trie with an existing root node from a backing database. @@ -61,7 +61,7 @@ type StateTrie struct { // If root is the zero hash or the sha3 hash of an empty string, the // trie is initially empty. Otherwise, New will panic if db is nil // and returns MissingNodeError if the root node cannot be found. -func NewStateTrie(id *ID, db *Database) (*StateTrie, error) { +func newStateTrie(id *ID, db *Database) (*stateTrie, error) { if db == nil { panic("trie.NewStateTrie called without a database") } @@ -69,7 +69,7 @@ func NewStateTrie(id *ID, db *Database) (*StateTrie, error) { if err != nil { return nil, err } - return &StateTrie{trie: *trie, preimages: db.preimages}, nil + return &stateTrie{trie: *trie, preimages: db.preimages}, nil } // MustGet returns the value for key stored in the trie. @@ -77,7 +77,7 @@ func NewStateTrie(id *ID, db *Database) (*StateTrie, error) { // // This function will omit any encountered error but just // print out an error message. -func (t *StateTrie) MustGet(key []byte) []byte { +func (t *stateTrie) MustGet(key []byte) []byte { return t.trie.MustGet(t.hashKey(key)) } @@ -85,7 +85,7 @@ func (t *StateTrie) MustGet(key []byte) []byte { // and slot key. The value bytes must not be modified by the caller. // If the specified storage slot is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { +func (t *stateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { enc, err := t.trie.Get(t.hashKey(key)) if err != nil || len(enc) == 0 { return nil, err @@ -97,7 +97,7 @@ func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { // GetAccount attempts to retrieve an account with provided account address. // If the specified account is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) { +func (t *stateTrie) GetAccount(address common.Address) (*types.StateAccount, error) { res, err := t.trie.Get(t.hashKey(address.Bytes())) if res == nil || err != nil { return nil, err @@ -110,7 +110,7 @@ func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, err // GetAccountByHash does the same thing as GetAccount, however it expects an // account hash that is the hash of address. This constitutes an abstraction // leak, since the client code needs to know the key format. -func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { +func (t *stateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { res, err := t.trie.Get(addrHash.Bytes()) if res == nil || err != nil { return nil, err @@ -124,7 +124,7 @@ func (t *StateTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, // possible to use keybyte-encoding as the path might contain odd nibbles. // If the specified trie node is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) { +func (t *stateTrie) GetNode(path []byte) ([]byte, int, error) { return t.trie.GetNode(path) } @@ -137,7 +137,7 @@ func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) { // // This function will omit any encountered error but just print out an // error message. -func (t *StateTrie) MustUpdate(key, value []byte) { +func (t *stateTrie) MustUpdate(key, value []byte) { hk := t.hashKey(key) t.trie.MustUpdate(hk, value) t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) @@ -151,7 +151,7 @@ func (t *StateTrie) MustUpdate(key, value []byte) { // stored in the trie. // // If a node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error { +func (t *stateTrie) UpdateStorage(_ common.Address, key, value []byte) error { hk := t.hashKey(key) v, _ := rlp.EncodeToBytes(value) err := t.trie.Update(hk, v) @@ -163,7 +163,7 @@ func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error { } // UpdateAccount will abstract the write of an account to the secure trie. -func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { +func (t *stateTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { hk := t.hashKey(address.Bytes()) data, err := rlp.EncodeToBytes(acc) if err != nil { @@ -176,13 +176,13 @@ func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccoun return nil } -func (t *StateTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { +func (t *stateTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { return nil } // MustDelete removes any existing value for key from the trie. This function // will omit any encountered error but just print out an error message. -func (t *StateTrie) MustDelete(key []byte) { +func (t *stateTrie) MustDelete(key []byte) { hk := t.hashKey(key) delete(t.getSecKeyCache(), string(hk)) t.trie.MustDelete(hk) @@ -191,14 +191,14 @@ func (t *StateTrie) MustDelete(key []byte) { // DeleteStorage removes any existing storage slot from the trie. // If the specified trie node is not in the trie, nothing will be changed. // If a node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error { +func (t *stateTrie) DeleteStorage(_ common.Address, key []byte) error { hk := t.hashKey(key) delete(t.getSecKeyCache(), string(hk)) return t.trie.Delete(hk) } // DeleteAccount abstracts an account deletion from the trie. -func (t *StateTrie) DeleteAccount(address common.Address) error { +func (t *stateTrie) DeleteAccount(address common.Address) error { hk := t.hashKey(address.Bytes()) delete(t.getSecKeyCache(), string(hk)) return t.trie.Delete(hk) @@ -206,7 +206,7 @@ func (t *StateTrie) DeleteAccount(address common.Address) error { // GetKey returns the sha3 preimage of a hashed key that was // previously used to store a value. -func (t *StateTrie) GetKey(shaKey []byte) []byte { +func (t *stateTrie) GetKey(shaKey []byte) []byte { if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { return key } @@ -223,7 +223,7 @@ func (t *StateTrie) GetKey(shaKey []byte) []byte { // All cached preimages will be also flushed if preimages recording is enabled. // Once the trie is committed, it's not usable anymore. A new trie must // be created with new root and updated trie database for following usage -func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) { +func (t *stateTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) { // Write all the pre-images to the actual disk database if len(t.getSecKeyCache()) > 0 { if t.preimages != nil { @@ -241,13 +241,13 @@ func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, er // Hash returns the root hash of StateTrie. It does not write to the // database and can be used even if the trie doesn't have one. -func (t *StateTrie) Hash() common.Hash { +func (t *stateTrie) Hash() common.Hash { return t.trie.Hash() } // Copy returns a copy of StateTrie. -func (t *StateTrie) Copy() *StateTrie { - return &StateTrie{ +func (t *stateTrie) Copy() *stateTrie { + return &stateTrie{ trie: *t.trie.Copy(), preimages: t.preimages, secKeyCache: t.secKeyCache, @@ -256,20 +256,20 @@ func (t *StateTrie) Copy() *StateTrie { // NodeIterator returns an iterator that returns nodes of the underlying trie. // Iteration starts at the key after the given start key. -func (t *StateTrie) NodeIterator(start []byte) (NodeIterator, error) { +func (t *stateTrie) NodeIterator(start []byte) (NodeIterator, error) { return t.trie.NodeIterator(start) } // MustNodeIterator is a wrapper of NodeIterator and will omit any encountered // error but just print out an error message. -func (t *StateTrie) MustNodeIterator(start []byte) NodeIterator { +func (t *stateTrie) MustNodeIterator(start []byte) NodeIterator { return t.trie.MustNodeIterator(start) } // hashKey returns the hash of key as an ephemeral buffer. // The caller must not hold onto the return value because it will become // invalid on the next call to hashKey or secKey. -func (t *StateTrie) hashKey(key []byte) []byte { +func (t *stateTrie) hashKey(key []byte) []byte { h := newHasher(false) h.sha.Reset() h.sha.Write(key) @@ -281,7 +281,7 @@ func (t *StateTrie) hashKey(key []byte) []byte { // getSecKeyCache returns the current secure key cache, creating a new one if // ownership changed (i.e. the current secure trie is a copy of another owning // the actual cache). -func (t *StateTrie) getSecKeyCache() map[string][]byte { +func (t *stateTrie) getSecKeyCache() map[string][]byte { if t != t.secKeyCacheOwner { t.secKeyCacheOwner = t t.secKeyCache = make(map[string][]byte) diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index 69003e3245b4..0b7f4d3fb8cd 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2015 The go-ethereum Authors // This file is part of the go-ethereum library. // @@ -31,7 +34,7 @@ import ( ) func newEmptySecure() *StateTrie { - trie, _ := NewStateTrie(TrieID(types.EmptyZkTrieRootHash), NewDatabase(rawdb.NewMemoryDatabase(), nil)) + trie, _ := NewStateTrie(TrieID(types.EmptyStateRootHash), NewDatabase(rawdb.NewMemoryDatabase(), nil)) return trie } @@ -39,7 +42,7 @@ func newEmptySecure() *StateTrie { func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { // Create an empty trie triedb := NewDatabase(rawdb.NewMemoryDatabase(), nil) - trie, _ := NewStateTrie(TrieID(types.EmptyZkTrieRootHash), triedb) + trie, _ := NewStateTrie(TrieID(types.EmptyStateRootHash), triedb) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -61,7 +64,7 @@ func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { } } root, nodes, _ := trie.Commit(false) - if err := triedb.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil); err != nil { + if err := triedb.Update(root, types.EmptyStateRootHash, 0, trienode.NewWithNodeSet(nodes), nil); err != nil { panic(fmt.Errorf("failed to commit db %v", err)) } // Re-create the trie based on the new state diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 13143911862b..3a64e8319dd9 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -344,15 +344,16 @@ func (t *StackTrie) insert(st *stNode, key, value []byte, path []byte) { // This method also sets 'st.type' to hashedNode, and clears 'st.key'. func (t *StackTrie) hash(st *stNode, path []byte) { var ( - blob []byte // RLP-encoded node blob - internal [][]byte // List of node paths covered by the extension node + emptyHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + blob []byte // RLP-encoded node blob + internal [][]byte // List of node paths covered by the extension node ) switch st.typ { case hashedNode: return case emptyNode: - st.val = types.EmptyLegacyTrieRootHash.Bytes() + st.val = emptyHash.Bytes() st.key = st.key[:0] st.typ = hashedNode return diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 453402f4a56c..b9cda9d8ebea 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -421,6 +421,15 @@ func buildPartialTree(entries []*kv, t *testing.T) map[string]common.Hash { return nodes } +type kv struct { + k, v []byte + t bool +} + +func (k *kv) cmp(other *kv) int { + return bytes.Compare(k.k, other.k) +} + func TestPartialStackTrie(t *testing.T) { for round := 0; round < 100; round++ { var ( diff --git a/trie/sync.go b/trie/sync.go index 44c028c04474..07509b1121b2 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -211,7 +211,7 @@ func NewSync(root common.Hash, database ethdb.KeyValueReader, callback LeafCallb // hex format and contain all the parent path if it's layered trie node. func (s *Sync) AddSubTrie(root common.Hash, path []byte, parent common.Hash, parentPath []byte, callback LeafCallback) { // Short circuit if the trie is empty or already known - if root == types.EmptyLegacyTrieRootHash { + if root == types.EmptyRootHash { return } if s.membatch.hasNode(path) { diff --git a/trie/sync_test.go b/trie/sync_test.go index 3cb1c01b92ee..a20c5fb09030 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2015 The go-ethereum Authors // This file is part of the go-ethereum library. // @@ -35,7 +38,7 @@ func makeTestTrie(scheme string) (ethdb.Database, *Database, *StateTrie, map[str // Create an empty trie db := rawdb.NewMemoryDatabase() triedb := newTestDatabase(db, scheme) - trie, _ := NewStateTrie(TrieID(types.EmptyZkTrieRootHash), triedb) + trie, _ := NewStateTrie(TrieID(types.EmptyStateRootHash), triedb) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -57,7 +60,7 @@ func makeTestTrie(scheme string) (ethdb.Database, *Database, *StateTrie, map[str } } root, nodes, _ := trie.Commit(false) - if err := triedb.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil); err != nil { + if err := triedb.Update(root, types.EmptyStateRootHash, 0, trienode.NewWithNodeSet(nodes), nil); err != nil { panic(fmt.Errorf("failed to commit db %v", err)) } if err := triedb.Commit(root, false); err != nil { @@ -137,9 +140,9 @@ func TestEmptySync(t *testing.T) { // dbD := newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.PathScheme) emptyA := NewEmpty(dbA) - emptyB, _ := New(TrieID(types.EmptyZkTrieRootHash), dbB) + emptyB, _ := New(TrieID(types.EmptyRootHash), dbB) // emptyC := NewEmpty(dbC) - // emptyD, _ := New(TrieID(types.EmptyLegacyTrieRootHash), dbD) + // emptyD, _ := New(TrieID(types.EmptyRootHash), dbD) // for i, trie := range []*Trie{emptyA, emptyB, emptyC, emptyD} { // sync := NewSync(trie.Hash(), memorydb.New(), nil, []*Database{dbA, dbB, dbC, dbD}[i].Scheme()) @@ -811,7 +814,7 @@ func testPivotMove(t *testing.T, scheme string, tiny bool) { var ( srcDisk = rawdb.NewMemoryDatabase() srcTrieDB = newTestDatabase(srcDisk, scheme) - srcTrie, _ = New(TrieID(types.EmptyZkTrieRootHash), srcTrieDB) + srcTrie, _ = New(TrieID(types.EmptyRootHash), srcTrieDB) deleteFn = func(key []byte, tr *Trie, states map[string][]byte) { tr.Delete(key) @@ -845,7 +848,7 @@ func testPivotMove(t *testing.T, scheme string, tiny bool) { writeFn([]byte{0x13, 0x44}, nil, srcTrie, stateA) rootA, nodesA, _ := srcTrie.Commit(false) - if err := srcTrieDB.Update(rootA, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodesA), nil); err != nil { + if err := srcTrieDB.Update(rootA, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodesA), nil); err != nil { panic(err) } if err := srcTrieDB.Commit(rootA, false); err != nil { diff --git a/trie/tracer_test.go b/trie/tracer_test.go index 36b57efcf4a2..47775020276c 100644 --- a/trie/tracer_test.go +++ b/trie/tracer_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2022 The go-ethereum Authors // This file is part of the go-ethereum library. // @@ -71,7 +74,7 @@ func testTrieTracer(t *testing.T, vals []struct{ k, v string }) { insertSet := copySet(trie.tracer.inserts) // copy before commit deleteSet := copySet(trie.tracer.deletes) // copy before commit root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) seen := setKeys(iterNodes(db, root)) if !compareSet(insertSet, seen) { @@ -137,7 +140,7 @@ func testAccessList(t *testing.T, vals []struct{ k, v string }) { trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) trie, _ = New(TrieID(root), db) if err := verifyAccessList(orig, trie, nodes); err != nil { @@ -219,7 +222,7 @@ func TestAccessListLeak(t *testing.T) { trie.MustUpdate([]byte(val.k), []byte(val.v)) } root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) var cases = []struct { op func(tr *Trie) @@ -269,7 +272,7 @@ func TestTinyTree(t *testing.T) { trie.MustUpdate([]byte(val.k), randBytes(32)) } root, set, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(set), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(set), nil) parent := root trie, _ = New(TrieID(root), db) diff --git a/trie/trie.go b/trie/trie.go index 836443ce4b44..5839e8d6cc4e 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -89,7 +89,7 @@ func New(id *ID, db *Database) (*Trie, error) { reader: reader, tracer: newTracer(), } - if id.Root != (common.Hash{}) && id.Root != types.EmptyLegacyTrieRootHash { + if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash { rootnode, err := trie.resolveAndTrack(id.Root[:], nil) if err != nil { return nil, err @@ -101,7 +101,7 @@ func New(id *ID, db *Database) (*Trie, error) { // NewEmpty is a shortcut to create empty tree. It's mostly used in tests. func NewEmpty(db *Database) *Trie { - tr, _ := New(TrieID(types.EmptyZkTrieRootHash), db) + tr, _ := New(TrieID(types.EmptyRootHash), db) return tr } @@ -619,13 +619,13 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) if t.root == nil { paths := t.tracer.deletedNodes() if len(paths) == 0 { - return types.EmptyLegacyTrieRootHash, nil, nil // case (a) + return types.EmptyRootHash, nil, nil // case (a) } nodes := trienode.NewNodeSet(t.owner) for _, path := range paths { nodes.AddNode([]byte(path), trienode.NewDeleted()) } - return types.EmptyLegacyTrieRootHash, nodes, nil // case (b) + return types.EmptyRootHash, nodes, nil // case (b) } // Derive the hash for all dirty nodes first. We hold the assumption // in the following procedure that all nodes are hashed. @@ -650,7 +650,7 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) // hashRoot calculates the root hash of the given trie func (t *Trie) hashRoot() (node, node) { if t.root == nil { - return hashNode(types.EmptyLegacyTrieRootHash.Bytes()), nil + return hashNode(types.EmptyRootHash.Bytes()), nil } // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) diff --git a/trie/trie_reader.go b/trie/trie_reader.go index 73dd7d4d515d..72c14a7301fa 100644 --- a/trie/trie_reader.go +++ b/trie/trie_reader.go @@ -19,6 +19,7 @@ package trie import ( "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/core/types" + "github.com/scroll-tech/go-ethereum/log" "github.com/scroll-tech/go-ethereum/trie/triestate" ) @@ -45,14 +46,10 @@ type trieReader struct { // newTrieReader initializes the trie reader with the given node reader. func newTrieReader(stateRoot, owner common.Hash, db *Database) (*trieReader, error) { - // if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash { - // if stateRoot == (common.Hash{}) { - // log.Error("Zero state root hash!") - // } - // return &trieReader{owner: owner}, nil - // } - if stateRoot == types.EmptyZkTrieRootHash { - // log.Error("Zero state root hash!") + if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash { + if stateRoot == (common.Hash{}) { + log.Error("Zero state root hash!") + } return &trieReader{owner: owner}, nil } reader, err := db.Reader(stateRoot) diff --git a/trie/trie_test.go b/trie/trie_test.go index 9de9125070bc..cf09b0fe9ffd 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1,3 +1,6 @@ +//go:build all_tests +// +build all_tests + // Copyright 2014 The go-ethereum Authors // This file is part of the go-ethereum library. // @@ -48,7 +51,7 @@ func init() { func TestEmptyTrie(t *testing.T) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), nil)) res := trie.Hash() - exp := types.EmptyLegacyTrieRootHash + exp := types.EmptyRootHash if res != exp { t.Errorf("expected %x got %x", exp, res) } @@ -95,7 +98,7 @@ func testMissingNode(t *testing.T, memonly bool, scheme string) { updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer") updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") root, nodes, _ := trie.Commit(false) - triedb.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + triedb.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) if !memonly { triedb.Commit(root, false) @@ -209,7 +212,7 @@ func TestGet(t *testing.T) { return } root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) trie, _ = New(TrieID(root), db) } } @@ -281,7 +284,7 @@ func TestReplication(t *testing.T) { updateString(trie, val.k, val.v) } root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) // create a new trie on top of the database and check that lookups work. trie2, err := New(TrieID(root), db) @@ -300,7 +303,7 @@ func TestReplication(t *testing.T) { // recreate the trie after commit if nodes != nil { - db.Update(hash, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(hash, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) } trie2, err = New(TrieID(hash), db) if err != nil { @@ -467,7 +470,7 @@ func runRandTest(rt randTest) bool { scheme = rawdb.PathScheme } var ( - origin = types.EmptyLegacyTrieRootHash + origin = types.EmptyRootHash triedb = newTestDatabase(rawdb.NewMemoryDatabase(), scheme) tr = NewEmpty(triedb) values = make(map[string]string) // tracks content of the trie @@ -492,7 +495,7 @@ func runRandTest(rt randTest) bool { } case opProve: hash := tr.Hash() - if hash == types.EmptyLegacyTrieRootHash { + if hash == types.EmptyRootHash { continue } proofDb := rawdb.NewMemoryDatabase() @@ -764,7 +767,7 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { for i := 0; i < len(accounts); i++ { var ( nonce = uint64(random.Int63()) - root = types.EmptyLegacyTrieRootHash + root = types.EmptyRootHash codekeccak = codehash.EmptyKeccakCodeHash codeposeidon = codehash.EmptyPoseidonCodeHash ) @@ -862,7 +865,7 @@ func TestCommitSequence(t *testing.T) { } // Flush trie -> database root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) // Flush memdb -> disk (sponge) db.Commit(root, false) if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) { @@ -903,7 +906,7 @@ func TestCommitSequenceRandomBlobs(t *testing.T) { } // Flush trie -> database root, nodes, _ := trie.Commit(false) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) // Flush memdb -> disk (sponge) db.Commit(root, false) if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) { @@ -946,7 +949,7 @@ func TestCommitSequenceStackTrie(t *testing.T) { // Flush trie -> database root, nodes, _ := trie.Commit(false) // Flush memdb -> disk (sponge) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) db.Commit(root, false) // And flush stacktrie -> disk stRoot := stTrie.Commit() @@ -994,7 +997,7 @@ func TestCommitSequenceSmallRoot(t *testing.T) { // Flush trie -> database root, nodes, _ := trie.Commit(false) // Flush memdb -> disk (sponge) - db.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) db.Commit(root, false) // And flush stacktrie -> disk stRoot := stTrie.Commit() @@ -1163,7 +1166,7 @@ func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts [] } h := trie.Hash() root, nodes, _ := trie.Commit(false) - triedb.Update(root, types.EmptyLegacyTrieRootHash, 0, trienode.NewWithNodeSet(nodes), nil) + triedb.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) b.StartTimer() triedb.Dereference(h) b.StopTimer() diff --git a/trie/triedb/hashdb/database.go b/trie/triedb/hashdb/database.go index 504566b7c483..ca6d79731a7b 100644 --- a/trie/triedb/hashdb/database.go +++ b/trie/triedb/hashdb/database.go @@ -30,7 +30,6 @@ import ( "github.com/scroll-tech/go-ethereum/ethdb" "github.com/scroll-tech/go-ethereum/log" "github.com/scroll-tech/go-ethereum/metrics" - "github.com/scroll-tech/go-ethereum/rlp" "github.com/scroll-tech/go-ethereum/trie/trienode" "github.com/scroll-tech/go-ethereum/trie/triestate" ) @@ -575,8 +574,7 @@ func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, n // Ensure the parent state is present and signal a warning if not. if parent != types.EmptyRootHash { if blob, _ := db.Node(parent); len(blob) == 0 { - // Silence the warning because it is not applicable to zktrie. - // log.Error("parent state is not present") + log.Error("parent state is not present") } } db.lock.Lock() @@ -611,8 +609,8 @@ func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, n // to an account trie leaf. if set, present := nodes.Sets[common.Hash{}]; present { for _, n := range set.Leaves { - var account types.StateAccount - if err := rlp.DecodeBytes(n.Blob, &account); err != nil { + account, err := types.UnmarshalStateAccount(n.Blob) + if err != nil { return err } if account.Root != types.EmptyRootHash { diff --git a/trie/triedb/hashdb/database_supplement.go b/trie/triedb/hashdb/database_supplement.go deleted file mode 100644 index 841bcc371163..000000000000 --- a/trie/triedb/hashdb/database_supplement.go +++ /dev/null @@ -1,15 +0,0 @@ -package hashdb - -import ( - "sync" - - "github.com/VictoriaMetrics/fastcache" -) - -func (db *Database) GetLock() *sync.RWMutex { - return &db.lock -} - -func (db *Database) GetCleans() *fastcache.Cache { - return db.cleans -} diff --git a/trie/util.go b/trie/util.go new file mode 100644 index 000000000000..79f87e4aacb4 --- /dev/null +++ b/trie/util.go @@ -0,0 +1,117 @@ +package trie + +import ( + "math/big" + + "github.com/scroll-tech/go-ethereum/crypto/poseidon" +) + +// HashElemsWithDomain performs a recursive poseidon hash over the array of ElemBytes, each hash +// reduce 2 fieds into one, with a specified domain field which would be used in +// every recursiving call +func HashElemsWithDomain(domain, fst, snd *big.Int, elems ...*big.Int) (*Hash, error) { + + l := len(elems) + baseH, err := poseidon.HashFixedWithDomain([]*big.Int{fst, snd}, domain) + if err != nil { + return nil, err + } + if l == 0 { + return NewHashFromBigInt(baseH), nil + } else if l == 1 { + return HashElemsWithDomain(domain, baseH, elems[0]) + } + + tmp := make([]*big.Int, (l+1)/2) + for i := range tmp { + if (i+1)*2 > l { + tmp[i] = elems[i*2] + } else { + h, err := poseidon.HashFixedWithDomain(elems[i*2:(i+1)*2], domain) + if err != nil { + return nil, err + } + tmp[i] = h + } + } + + return HashElemsWithDomain(domain, baseH, tmp[0], tmp[1:]...) +} + +// HashElems call HashElemsWithDomain with a domain of HASH_DOMAIN_ELEMS_BASE(256)* +func HashElems(fst, snd *big.Int, elems ...*big.Int) (*Hash, error) { + + return HashElemsWithDomain(big.NewInt(int64(len(elems)*HASH_DOMAIN_ELEMS_BASE)+HASH_DOMAIN_BYTE32), + fst, snd, elems...) +} + +// HandlingElemsAndByte32 hash an arry mixed with field and byte32 elements, turn each byte32 into +// field elements first then calculate the hash with HashElems +func HandlingElemsAndByte32(flagArray uint32, elems []Byte32) (*Hash, error) { + + ret := make([]*big.Int, len(elems)) + var err error + + for i, elem := range elems { + if flagArray&(1<. - package trie import ( + "bytes" + "errors" "fmt" - - zktrie "github.com/scroll-tech/zktrie/trie" - zkt "github.com/scroll-tech/zktrie/types" + "io" + "maps" + "math/big" + "sync" "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/core/types" - "github.com/scroll-tech/go-ethereum/crypto/poseidon" "github.com/scroll-tech/go-ethereum/ethdb" - "github.com/scroll-tech/go-ethereum/log" "github.com/scroll-tech/go-ethereum/trie/trienode" ) -var magicHash []byte = []byte("THIS IS THE MAGIC INDEX FOR ZKTRIE") +const ( + // NodeKeyValidBytes is the number of least significant bytes in the node key + // that are considered valid to addressing the leaf node, and thus limits the + // maximum trie depth to NodeKeyValidBytes * 8. + // We need to truncate the node key because the key is the output of Poseidon + // hash and the key space doesn't fully occupy the range of power of two. It can + // lead to an ambiguous bit representation of the key in the finite field + // causing a soundness issue in the zk circuit. + NodeKeyValidBytes = 31 + + // proofFlagsLen is the byte length of the flags in the proof header + // (first 32 bytes). + proofFlagsLen = 2 +) + +var ( + magicHash = []byte("THIS IS THE MAGIC INDEX FOR ZKTRIE") + magicSMTBytes = []byte("THIS IS SOME MAGIC BYTES FOR SMT m1rRXgP2xpDI") + + // ErrNodeKeyAlreadyExists is used when a node key already exists. + ErrInvalidField = errors.New("Key not inside the Finite Field") + // ErrNodeKeyAlreadyExists is used when a node key already exists. + ErrNodeKeyAlreadyExists = errors.New("key already exists") + // ErrKeyNotFound is used when a key is not found in the ZkTrie. + ErrKeyNotFound = errors.New("key not found in ZkTrie") + // ErrNodeBytesBadSize is used when the data of a node has an incorrect + // size and can't be parsed. + ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB") + // ErrReachedMaxLevel is used when a traversal of the MT reaches the + // maximum level. + ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree") + // ErrInvalidNodeFound is used when an invalid node is found and can't + // be parsed. + ErrInvalidNodeFound = errors.New("found an invalid node in the DB") + // ErrInvalidProofBytes is used when a serialized proof is invalid. + ErrInvalidProofBytes = errors.New("the serialized proof is invalid") + // ErrEntryIndexAlreadyExists is used when the entry index already + // exists in the tree. + ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree") + // ErrNotWritable is used when the ZkTrie is not writable and a + // write function is called + ErrNotWritable = errors.New("merkle Tree not writable") +) + +// StateTrie is just an alias for ZkTrie now +type StateTrie = ZkTrie -// wrap zktrie for trie interface +// ZkTrie is the struct with the main elements of the ZkTrie type ZkTrie struct { - *zktrie.ZkTrie - db *ZktrieDatabase + lock sync.RWMutex + owner common.Hash + reader *trieReader + rootKey *Hash + maxLevels int + + // Preimage store + preimages *preimageStore + secKeyCache map[string][]byte + + // Flag whether the commit operation is already performed. If so the + // trie is not usable(latest states is invisible). + committed bool + dirtyIndex *big.Int + dirtyStorage map[Hash]*Node +} + +// NewStateTrie is just an alias for NewZkTrie now +var NewStateTrie = NewZkTrie + +// NewZkTrie loads a new ZkTrie. If in the storage already exists one +// will open that one, if not, will create a new one. +func NewZkTrie(id *ID, db *Database) (*ZkTrie, error) { + reader, err := newTrieReader(id.StateRoot, id.Owner, db) + if err != nil { + return nil, err + } + + mt := ZkTrie{ + owner: id.Owner, + reader: reader, + maxLevels: NodeKeyValidBytes * 8, + dirtyIndex: big.NewInt(0), + dirtyStorage: make(map[Hash]*Node), + preimages: db.preimages, + secKeyCache: make(map[string][]byte), + } + mt.rootKey = NewHashFromBytes(id.Root.Bytes()) + if *mt.rootKey != HashZero { + _, err := mt.GetNodeByHash(mt.rootKey) + if err != nil { + return nil, err + } + } + return &mt, nil +} + +// Root returns the MerkleRoot +func (mt *ZkTrie) Root() (*Hash, error) { + mt.lock.Lock() + defer mt.lock.Unlock() + return mt.root() +} + +func (mt *ZkTrie) root() (*Hash, error) { + // short circuit if there are no nodes to hash + if mt.dirtyIndex.Cmp(big.NewInt(0)) == 0 { + return mt.rootKey, nil + } + + hashedDirtyStorage := make(map[Hash]*Node) + rootKey, err := mt.calcCommitment(mt.rootKey, hashedDirtyStorage, new(sync.Mutex)) + if err != nil { + return nil, err + } + + mt.rootKey = rootKey + mt.dirtyIndex = big.NewInt(0) + mt.dirtyStorage = hashedDirtyStorage + return mt.rootKey, nil +} + +// Hash returns the root hash of SecureBinaryTrie. It does not write to the +// database and can be used even if the trie doesn't have one. +func (mt *ZkTrie) Hash() common.Hash { + root, err := mt.Root() + if err != nil { + panic("root failed in trie.Hash") + } + return common.BytesToHash(root.Bytes()) +} + +// MaxLevels returns the MT maximum level +func (mt *ZkTrie) MaxLevels() int { + return mt.maxLevels +} + +// TryUpdate updates a nodeKey & value into the ZkTrie. Where the `k` determines the +// path from the Root to the Leaf. This also return the updated leaf node +func (mt *ZkTrie) TryUpdate(key []byte, vFlag uint32, vPreimage []Byte32) error { + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return ErrCommitted + } + + secureKey, err := ToSecureKey(key) + if err != nil { + return err + } + nodeKey := NewHashFromBigInt(secureKey) + + // verify that k are valid and fit inside the Finite Field. + if !CheckBigIntInField(nodeKey.BigInt()) { + return ErrInvalidField + } + + newLeafNode := NewLeafNode(nodeKey, vFlag, vPreimage) + path := getPath(mt.maxLevels, nodeKey[:]) + + mt.lock.Lock() + defer mt.lock.Unlock() + + mt.secKeyCache[string(nodeKey.Bytes())] = key + + newRootKey, _, err := mt.addLeaf(newLeafNode, mt.rootKey, 0, path) + // sanity check + if err == ErrEntryIndexAlreadyExists { + panic("Encounter unexpected errortype: ErrEntryIndexAlreadyExists") + } else if err != nil { + return err + } + if newRootKey != nil { + mt.rootKey = newRootKey + } + return nil +} + +// UpdateStorage updates the storage with the given key and value +func (mt *ZkTrie) UpdateStorage(_ common.Address, key, value []byte) error { + return mt.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes(value)}) +} + +// UpdateAccount updates the account with the given address and account +func (mt *ZkTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { + value, flag := acc.MarshalFields() + accValue := make([]Byte32, 0, len(value)) + for _, v := range value { + accValue = append(accValue, *NewByte32FromBytes(v[:])) + } + return mt.TryUpdate(address.Bytes(), flag, accValue) +} + +// UpdateContractCode updates the contract code with the given address and code +func (mt *ZkTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { + return nil +} + +// pushLeaf recursively pushes an existing oldLeaf down until its path diverges +// from newLeaf, at which point both leafs are stored, all while updating the +// path. pushLeaf returns the node hash of the parent of the oldLeaf and newLeaf +func (mt *ZkTrie) pushLeaf(newLeaf *Node, oldLeaf *Node, lvl int, + pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) { + if lvl > mt.maxLevels-2 { + return nil, ErrReachedMaxLevel + } + var newParentNode *Node + if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper! + // notice the node corresponding to return hash is always branch + nextNodeHash, err := mt.pushLeaf(newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf) + if err != nil { + return nil, err + } + if pathNewLeaf[lvl] { // go right + newParentNode = NewParentNode(NodeTypeBranch_1, &HashZero, nextNodeHash) + } else { // go left + newParentNode = NewParentNode(NodeTypeBranch_2, nextNodeHash, &HashZero) + } + + newParentNodeKey := mt.newDirtyNodeKey() + mt.dirtyStorage[*newParentNodeKey] = newParentNode + return newParentNodeKey, nil + } + oldLeafHash, err := oldLeaf.NodeHash() + if err != nil { + return nil, err + } + newLeafHash, err := newLeaf.NodeHash() + if err != nil { + return nil, err + } + + if pathNewLeaf[lvl] { + newParentNode = NewParentNode(NodeTypeBranch_0, oldLeafHash, newLeafHash) + } else { + newParentNode = NewParentNode(NodeTypeBranch_0, newLeafHash, oldLeafHash) + } + // We can add newLeaf now. We don't need to add oldLeaf because it's + // already in the tree. + mt.dirtyStorage[*newLeafHash] = newLeaf + newParentNodeKey := mt.newDirtyNodeKey() + mt.dirtyStorage[*newParentNodeKey] = newParentNode + return newParentNodeKey, nil +} + +// Commit calculates the root for the entire trie and persist all the dirty nodes +func (mt *ZkTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) { + mt.lock.Lock() + defer mt.lock.Unlock() + + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return common.Hash{}, nil, ErrCommitted + } + + root, err := mt.root() + if err != nil { + return common.Hash{}, nil, err + } + + nodeSet := trienode.NewNodeSet(mt.owner) + if err := mt.commit(root, nil, nodeSet, collectLeaf); err != nil { + return common.Hash{}, nil, err + } + + mt.dirtyStorage = make(map[Hash]*Node) + mt.committed = true + + return common.BytesToHash(root.Bytes()), nodeSet, nil } -func init() { - zkt.InitHashScheme(poseidon.HashFixedWithDomain) +// Commit calculates the root for the entire trie and persist all the dirty nodes +func (mt *ZkTrie) commit(nodeHash *Hash, path []byte, nodeSet *trienode.NodeSet, collectLeaf bool) error { + node := mt.dirtyStorage[*nodeHash] + if node == nil { + return nil + } + + if node.Type == NodeTypeLeaf_New { + if mt.preimages != nil { + mt.preimages.insertPreimage(map[common.Hash][]byte{ + common.BytesToHash(nodeHash.Bytes()): node.NodeKey.Bytes(), + }) + } + if collectLeaf { + nodeSet.AddLeaf(common.BytesToHash(nodeHash.Bytes()), node.Data()) + } + } + if node.ChildL != nil { + if err := mt.commit(node.ChildL, append(path, byte(0)), nodeSet, collectLeaf); err != nil { + return err + } + } + if node.ChildR != nil { + if err := mt.commit(node.ChildR, append(path, byte(1)), nodeSet, collectLeaf); err != nil { + return err + } + } + nodeSet.AddNode(path, trienode.New(common.BytesToHash(nodeHash.Bytes()), node.CanonicalValue())) + return nil } -func sanityCheckByte32Key(b []byte) { - if len(b) != 32 && len(b) != 20 { - panic(fmt.Errorf("do not support length except for 120bit and 256bit now. data: %v len: %v", b, len(b))) +// addLeaf recursively adds a newLeaf in the MT while updating the path, and returns the key +// of the new added leaf. +func (mt *ZkTrie) addLeaf(newLeaf *Node, currNodeKey *Hash, + lvl int, path []bool) (*Hash, bool, error) { + var err error + if lvl > mt.maxLevels-1 { + return nil, false, ErrReachedMaxLevel + } + n, err := mt.getNode(currNodeKey) + if err != nil { + return nil, false, err + } + switch n.Type { + case NodeTypeEmpty_New: + newLeafHash, err := newLeaf.NodeHash() + if err != nil { + return nil, false, err + } + + mt.dirtyStorage[*newLeafHash] = newLeaf + return newLeafHash, true, nil + case NodeTypeLeaf_New: + newLeafHash, err := newLeaf.NodeHash() + if err != nil { + return nil, false, err + } + + if bytes.Equal(currNodeKey[:], newLeafHash[:]) { + // do nothing, duplicate entry + return nil, true, nil + } else if bytes.Equal(newLeaf.NodeKey.Bytes(), n.NodeKey.Bytes()) { + // update the existing leaf + mt.dirtyStorage[*newLeafHash] = newLeaf + return newLeafHash, true, nil + } + newSubTrieRootHash, err := mt.pushLeaf(newLeaf, n, lvl, path, getPath(mt.maxLevels, n.NodeKey[:])) + return newSubTrieRootHash, false, err + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + // We need to go deeper, continue traversing the tree, left or + // right depending on path + branchRight := path[lvl] + childSubTrieRoot := n.ChildL + if branchRight { + childSubTrieRoot = n.ChildR + } + newChildSubTrieRoot, isTerminal, err := mt.addLeaf(newLeaf, childSubTrieRoot, lvl+1, path) + if err != nil { + return nil, false, err + } + + // do nothing, if child subtrie was not modified + if newChildSubTrieRoot == nil { + return nil, false, nil + } + + newNodetype := n.Type + if !isTerminal { + newNodetype = newNodetype.DeduceUpgradeType(branchRight) + } + + var newNode *Node + if branchRight { + newNode = NewParentNode(newNodetype, n.ChildL, newChildSubTrieRoot) + } else { + newNode = NewParentNode(newNodetype, newChildSubTrieRoot, n.ChildR) + } + + // if current node is already dirty, modify in-place + // else create a new dirty sub-trie + newCurTrieRootKey := mt.newDirtyNodeKey() + mt.dirtyStorage[*newCurTrieRootKey] = newNode + return newCurTrieRootKey, false, err + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter unsupported deprecated node type") + default: + return nil, false, ErrInvalidNodeFound } } -// NewZkTrie creates a trie -// NewZkTrie bypasses all the buffer mechanism in *Database, it directly uses the -// underlying diskdb -func NewZkTrie(root common.Hash, db *ZktrieDatabase) (*ZkTrie, error) { - tr, err := zktrie.NewZkTrie(*zkt.NewByte32FromBytes(root.Bytes()), db) +// newDirtyNodeKey increments the dirtyIndex and creates a new dirty node key +func (mt *ZkTrie) newDirtyNodeKey() *Hash { + mt.dirtyIndex.Add(mt.dirtyIndex, BigOne) + return NewHashFromBigInt(mt.dirtyIndex) +} + +// isDirtyNode returns if the node with the given key is dirty or not +func (mt *ZkTrie) isDirtyNode(nodeKey *Hash) bool { + _, found := mt.dirtyStorage[*nodeKey] + return found +} + +// calcCommitment calculates the commitment for the given sub trie +func (mt *ZkTrie) calcCommitment(rootKey *Hash, hashedDirtyNodes map[Hash]*Node, commitLock *sync.Mutex) (*Hash, error) { + if !mt.isDirtyNode(rootKey) { + return rootKey, nil + } + + root, err := mt.getNode(rootKey) + if err != nil { + return nil, err + } + + switch root.Type { + case NodeTypeEmpty: + return &HashZero, nil + case NodeTypeLeaf_New: + // leaves are already hashed, we just need to persist it + break + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + leftDone := make(chan struct{}) + var leftErr error + go func() { + root.ChildL, leftErr = mt.calcCommitment(root.ChildL, hashedDirtyNodes, commitLock) + close(leftDone) + }() + root.ChildR, err = mt.calcCommitment(root.ChildR, hashedDirtyNodes, commitLock) + if err != nil { + return nil, err + } + <-leftDone + if leftErr != nil { + return nil, leftErr + } + default: + return nil, errors.New(fmt.Sprint("unexpected node type", root.Type)) + } + + rootHash, err := root.NodeHash() if err != nil { return nil, err } - return &ZkTrie{tr, db}, nil + + commitLock.Lock() + defer commitLock.Unlock() + hashedDirtyNodes[*rootHash] = root + return rootHash, nil +} + +func (mt *ZkTrie) tryGet(nodeKey *Hash) (*Node, error) { + + path := getPath(mt.maxLevels, nodeKey[:]) + var nextKey Hash + nextKey.Set(mt.rootKey) + n := new(Node) + //sanity check + lastNodeType := NodeTypeBranch_3 + for i := 0; i < mt.maxLevels; i++ { + err := mt.getNodeTo(&nextKey, n) + if err != nil { + return nil, err + } + //sanity check + if i > 0 && n.IsTerminal() { + if lastNodeType == NodeTypeBranch_3 { + panic("parent node has invalid type: children are not terminal") + } else if path[i-1] && lastNodeType == NodeTypeBranch_1 { + panic("parent node has invalid type: right child is not terminal") + } else if !path[i-1] && lastNodeType == NodeTypeBranch_2 { + panic("parent node has invalid type: left child is not terminal") + } + } + + lastNodeType = n.Type + switch n.Type { + case NodeTypeEmpty_New: + return NewEmptyNode(), ErrKeyNotFound + case NodeTypeLeaf_New: + if bytes.Equal(nodeKey[:], n.NodeKey[:]) { + return n, nil + } + return n, ErrKeyNotFound + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + if path[i] { + nextKey.Set(n.ChildR) + } else { + nextKey.Set(n.ChildL) + } + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter deprecated node types") + default: + return nil, ErrInvalidNodeFound + } + } + + return nil, ErrReachedMaxLevel } -// Get returns the value for key stored in the trie. +// TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -func (t *ZkTrie) Get(key []byte) []byte { - sanityCheckByte32Key(key) - res, err := t.TryGet(key) +// If a node was not found in the database, a MissingNodeError is returned. +func (mt *ZkTrie) TryGet(key []byte) ([]byte, error) { + mt.lock.RLock() + defer mt.lock.RUnlock() + + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return nil, ErrCommitted + } + + secureK, err := ToSecureKey(key) if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + return nil, err + } + + node, err := mt.tryGet(NewHashFromBigInt(secureK)) + if err == ErrKeyNotFound { + // according to https://github.com/ethereum/go-ethereum/blob/37f9d25ba027356457953eab5f181c98b46e9988/trie/trie.go#L135 + return nil, nil + } else if err != nil { + return nil, err } - return res + return node.Data(), nil +} + +// GetStorage returns the value for key stored in the trie. +func (mt *ZkTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { + return mt.TryGet(key) } -func (t *ZkTrie) GetAccount(address common.Address) (*types.StateAccount, error) { +// GetAccount returns the account for the given address. +func (mt *ZkTrie) GetAccount(address common.Address) (*types.StateAccount, error) { key := address.Bytes() - sanityCheckByte32Key(key) - res, err := t.TryGet(key) + res, err := mt.TryGet(key) if res == nil || err != nil { return nil, err } return types.UnmarshalStateAccount(res) } -func (t *ZkTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { - sanityCheckByte32Key(key) - return t.TryGet(key) +// GetKey returns the key for the given hash. +func (mt *ZkTrie) GetKey(hashKey []byte) []byte { + mt.lock.RLock() + defer mt.lock.RUnlock() + return mt.getKey(hashKey) } -func (t *ZkTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { - return t.TryUpdateAccount(address.Bytes(), acc) +// GetKey returns the key for the given hash. +func (mt *ZkTrie) getKey(hashKey []byte) []byte { + if key, ok := mt.secKeyCache[string(hashKey)]; ok { + return key + } + if mt.preimages == nil { + return nil + } + return mt.preimages.preimage(common.BytesToHash(hashKey)) } -// TryUpdateAccount will abstract the write of an account to the -// secure trie. -func (t *ZkTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { - sanityCheckByte32Key(key) - value, flag := acc.MarshalFields() - return t.ZkTrie.TryUpdate(key, flag, value) +// Delete removes the specified Key from the ZkTrie and updates the path +// from the deleted key to the Root with the new values. This method removes +// the key from the ZkTrie, but does not remove the old nodes from the +// key-value database; this means that if the tree is accessed by an old Root +// where the key was not deleted yet, the key will still exist. If is desired +// to remove the key-values from the database that are not under the current +// Root, an option could be to dump all the leafs (using mt.DumpLeafs) and +// import them in a new ZkTrie in a new database (using +// mt.ImportDumpedLeafs), but this will lose all the Root history of the +// ZkTrie +func (mt *ZkTrie) TryDelete(key []byte) error { + secureKey, err := ToSecureKey(key) + if err != nil { + return err + } + + nodeKey := NewHashFromBigInt(secureKey) + + // verify that k is valid and fit inside the Finite Field. + if !CheckBigIntInField(nodeKey.BigInt()) { + return ErrInvalidField + } + + mt.lock.Lock() + defer mt.lock.Unlock() + + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return ErrCommitted + } + + //mitigate the create-delete issue: do not delete unexisted key + if r, _ := mt.tryGet(nodeKey); r == nil { + return nil + } + + newRootKey, _, err := mt.tryDelete(mt.rootKey, nodeKey, getPath(mt.maxLevels, nodeKey[:])) + if err != nil && !errors.Is(err, ErrKeyNotFound) { + return err + } + if newRootKey != nil { + mt.rootKey = newRootKey + } + return nil } -// Update associates key with value in the trie. Subsequent calls to -// Get will return value. If value has length zero, any existing value -// is deleted from the trie and calls to Get will return nil. -// -// The value bytes must not be modified by the caller while they are -// stored in the trie. -func (t *ZkTrie) Update(key, value []byte) { - if err := t.TryUpdate(key, value); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +func (mt *ZkTrie) tryDelete(rootKey *Hash, nodeKey *Hash, path []bool) (*Hash, bool, error) { + root, err := mt.getNode(rootKey) + if err != nil { + return nil, false, err + } + + switch root.Type { + case NodeTypeEmpty_New: + return nil, false, ErrKeyNotFound + case NodeTypeLeaf_New: + if bytes.Equal(nodeKey[:], root.NodeKey[:]) { + return &HashZero, true, nil + } + return nil, false, ErrKeyNotFound + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + branchRight := path[0] + childKey, siblingKey := root.ChildL, root.ChildR + if branchRight { + childKey, siblingKey = root.ChildR, root.ChildL + } + + newChildKey, newChildIsTerminal, err := mt.tryDelete(childKey, nodeKey, path[1:]) + if err != nil { + return nil, false, err + } + + siblingIsTerminal := root.Type == NodeTypeBranch_0 || + (branchRight && root.Type == NodeTypeBranch_1) || + (!branchRight && root.Type == NodeTypeBranch_2) + + leftChild, rightChild := newChildKey, siblingKey + leftIsTerminal, rightIsTerminal := newChildIsTerminal, siblingIsTerminal + if branchRight { + leftChild, rightChild = siblingKey, newChildKey + leftIsTerminal, rightIsTerminal = siblingIsTerminal, newChildIsTerminal + } + + var newNodeType NodeType + if leftIsTerminal && rightIsTerminal { + leftIsEmpty := bytes.Equal(HashZero[:], (*leftChild)[:]) + rightIsEmpty := bytes.Equal(HashZero[:], (*rightChild)[:]) + + // if both children are terminal and one of them is empty, prune the root node + // and send return the non-empty child + if leftIsEmpty || rightIsEmpty { + if leftIsEmpty { + return rightChild, true, nil + } + return leftChild, true, nil + } else { + newNodeType = NodeTypeBranch_0 + } + } else if leftIsTerminal { + newNodeType = NodeTypeBranch_1 + } else if rightIsTerminal { + newNodeType = NodeTypeBranch_2 + } else { + newNodeType = NodeTypeBranch_3 + } + + newRootKey := mt.newDirtyNodeKey() + mt.dirtyStorage[*newRootKey] = NewParentNode(newNodeType, leftChild, rightChild) + return newRootKey, false, nil + default: + panic("encounter unsupported deprecated node type") } } -// NOTE: value is restricted to length of bytes32. -// we override the underlying zktrie's TryUpdate method -func (t *ZkTrie) TryUpdate(key, value []byte) error { - sanityCheckByte32Key(key) - return t.ZkTrie.TryUpdate(key, 1, []zkt.Byte32{*zkt.NewByte32FromBytes(value)}) +// DeleteAccount removes the account with the given address from the trie. +func (mt *ZkTrie) DeleteAccount(address common.Address) error { + return mt.TryDelete(address.Bytes()) } -func (t *ZkTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { - return nil +// DeleteStorage removes the key from the trie. +func (mt *ZkTrie) DeleteStorage(_ common.Address, key []byte) error { + return mt.TryDelete(key) } -func (t *ZkTrie) UpdateStorage(_ common.Address, key, value []byte) error { - return t.TryUpdate(key, value) +// GetLeafNode is more underlying method than TryGet, which obtain an leaf node +// or nil if not exist +func (mt *ZkTrie) GetLeafNode(key []byte) (*Node, error) { + mt.lock.RLock() + defer mt.lock.RUnlock() + + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return nil, ErrCommitted + } + + secureKey, err := ToSecureKey(key) + if err != nil { + return nil, err + } + + nodeKey := NewHashFromBigInt(secureKey) + + n, err := mt.tryGet(nodeKey) + return n, err } -// Delete removes any existing value for key from the trie. -func (t *ZkTrie) Delete(key []byte) { - sanityCheckByte32Key(key) - if err := t.TryDelete(key); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// GetNodeByHash gets a node by node hash from the MT. Empty nodes are not stored in the +// tree; they are all the same and assumed to always exist. +// for non exist key, return (NewEmptyNode(), nil) +func (mt *ZkTrie) GetNodeByHash(nodeHash *Hash) (*Node, error) { + mt.lock.RLock() + defer mt.lock.RUnlock() + + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return nil, ErrCommitted } + return mt.getNode(nodeHash) } -func (t *ZkTrie) DeleteAccount(address common.Address) error { - key := address.Bytes() - sanityCheckByte32Key(key) - return t.TryDelete(key) +func (mt *ZkTrie) getNodeTo(nodeHash *Hash, node *Node) error { + if bytes.Equal(nodeHash[:], HashZero[:]) { + *node = *NewEmptyNode() + return nil + } + if dirtyNode, found := mt.dirtyStorage[*nodeHash]; found { + *node = *dirtyNode.Copy() + return nil + } + + var hash common.Hash + hash.SetBytes(nodeHash.Bytes()) + nBytes, err := mt.reader.node(nil, hash) + if err != nil { + return err + } + return node.SetBytes(nBytes) +} + +func (mt *ZkTrie) getNode(nodeHash *Hash) (*Node, error) { + var n Node + if err := mt.getNodeTo(nodeHash, &n); err != nil { + return nil, err + } + return &n, nil +} + +// getPath returns the binary path, from the root to the leaf. +func getPath(numLevels int, k []byte) []bool { + path := make([]bool, numLevels) + for n := 0; n < numLevels; n++ { + path[n] = TestBit(k[:], uint(n)) + } + return path +} + +// NodeAux contains the auxiliary node used in a non-existence proof. +type NodeAux struct { + Key *Hash // Key is the node key + Value *Hash // Value is the value hash in the node +} + +// Proof defines the required elements for a MT proof of existence or +// non-existence. +type Proof struct { + // existence indicates wether this is a proof of existence or + // non-existence. + Existence bool + // depth indicates how deep in the tree the proof goes. + depth uint + // notempties is a bitmap of non-empty Siblings found in Siblings. + notempties [HashByteLen - proofFlagsLen]byte + // Siblings is a list of non-empty sibling node hashes. + Siblings []*Hash + // NodeInfos is a list of nod types along mpt path + NodeInfos []NodeType + // NodeKey record the key of node and path + NodeKey *Hash + // NodeAux contains the auxiliary information of the lowest common ancestor + // node in a non-existence proof. + NodeAux *NodeAux } -func (t *ZkTrie) DeleteStorage(_ common.Address, key []byte) error { - sanityCheckByte32Key(key) - return t.TryDelete(key) +// BuildZkTrieProof prove uniformed way to turn some data collections into Proof struct +func BuildZkTrieProof(rootHash *Hash, k *big.Int, lvl int, getNode func(key *Hash) (*Node, error)) (*Proof, + *Node, error) { + + p := &Proof{} + var siblingHash *Hash + + p.NodeKey = NewHashFromBigInt(k) + kHash := p.NodeKey + path := getPath(lvl, kHash[:]) + + nextHash := rootHash + for p.depth = 0; p.depth < uint(lvl); p.depth++ { + n, err := getNode(nextHash) + if err != nil { + return nil, nil, err + } + p.NodeInfos = append(p.NodeInfos, n.Type) + switch n.Type { + case NodeTypeEmpty_New: + return p, n, nil + case NodeTypeLeaf_New: + if bytes.Equal(kHash[:], n.NodeKey[:]) { + p.Existence = true + return p, n, nil + } + vHash, err := n.ValueHash() + // We found a leaf whose entry didn't match hIndex + p.NodeAux = &NodeAux{Key: n.NodeKey, Value: vHash} + return p, n, err + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + if path[p.depth] { + nextHash = n.ChildR + siblingHash = n.ChildL + } else { + nextHash = n.ChildL + siblingHash = n.ChildR + } + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter deprecated node types") + default: + return nil, nil, ErrInvalidNodeFound + } + if !bytes.Equal(siblingHash[:], HashZero[:]) { + SetBitBigEndian(p.notempties[:], p.depth) + p.Siblings = append(p.Siblings, siblingHash) + } + } + return nil, nil, ErrKeyNotFound + } -// GetKey returns the preimage of a hashed key that was -// previously used to store a value. -func (t *ZkTrie) GetKey(kHashBytes []byte) []byte { - // TODO: use a kv cache in memory - k, err := zkt.NewBigIntFromHashBytes(kHashBytes) +// VerifyProof verifies the Merkle Proof for the entry and root. +// nodeHash can be nil when try to verify a nonexistent proof +func VerifyProofZkTrie(rootHash *Hash, proof *Proof, node *Node) bool { + var nodeHash *Hash + var err error + if node == nil { + if proof.NodeAux != nil { + nodeHash, err = LeafHash(proof.NodeAux.Key, proof.NodeAux.Value) + } else { + nodeHash = &HashZero + } + } else { + nodeHash, err = node.NodeHash() + } + if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + return false } - if t.db.db.preimages != nil { - return t.db.db.preimages.preimage(common.BytesToHash(k.Bytes())) + + rootFromProof, err := proof.rootFromProof(nodeHash, proof.NodeKey) + if err != nil { + return false } - return nil + return bytes.Equal(rootHash[:], rootFromProof[:]) } -// Commit writes all nodes and the secure hash pre-images to the trie's database. -// Nodes are stored with their sha3 hash as the key. -// -// Committing flushes nodes from memory. Subsequent Get calls will load nodes -// from the database. -func (t *ZkTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) { - if err := t.ZkTrie.Commit(); err != nil { - return common.Hash{}, nil, err +// Verify the proof and calculate the root, nodeHash can be nil when try to verify +// a nonexistent proof +func (proof *Proof) Verify(nodeHash *Hash) (*Hash, error) { + if proof.Existence { + if nodeHash == nil { + return nil, ErrKeyNotFound + } + return proof.rootFromProof(nodeHash, proof.NodeKey) + } else { + if proof.NodeAux == nil { + return proof.rootFromProof(&HashZero, proof.NodeKey) + } else { + if bytes.Equal(proof.NodeKey[:], proof.NodeAux.Key[:]) { + return nil, fmt.Errorf("non-existence proof being checked against hIndex equal to nodeAux") + } + midHash, err := LeafHash(proof.NodeAux.Key, proof.NodeAux.Value) + if err != nil { + return nil, err + } + return proof.rootFromProof(midHash, proof.NodeKey) + } } - return t.Hash(), nil, nil + } -// Hash returns the root hash of SecureBinaryTrie. It does not write to the -// database and can be used even if the trie doesn't have one. -func (t *ZkTrie) Hash() common.Hash { - var hash common.Hash - hash.SetBytes(t.ZkTrie.Hash()) - return hash +func (proof *Proof) rootFromProof(nodeHash, nodeKey *Hash) (*Hash, error) { + var err error + + sibIdx := len(proof.Siblings) - 1 + path := getPath(int(proof.depth), nodeKey[:]) + for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- { + var siblingHash *Hash + if TestBitBigEndian(proof.notempties[:], uint(lvl)) { + siblingHash = proof.Siblings[sibIdx] + sibIdx-- + } else { + siblingHash = &HashZero + } + curType := proof.NodeInfos[lvl] + if path[lvl] { + nodeHash, err = NewParentNode(curType, siblingHash, nodeHash).NodeHash() + if err != nil { + return nil, err + } + } else { + nodeHash, err = NewParentNode(curType, nodeHash, siblingHash).NodeHash() + if err != nil { + return nil, err + } + } + } + return nodeHash, nil +} + +// walk is a helper recursive function to iterate over all tree branches +func (mt *ZkTrie) walk(nodeHash *Hash, f func(*Node)) error { + n, err := mt.getNode(nodeHash) + if err != nil { + return err + } + if n.IsTerminal() { + f(n) + } else { + f(n) + if err := mt.walk(n.ChildL, f); err != nil { + return err + } + if err := mt.walk(n.ChildR, f); err != nil { + return err + } + } + return nil } -// Copy returns a copy of SecureBinaryTrie. -func (t *ZkTrie) Copy() *ZkTrie { - return &ZkTrie{t.ZkTrie.Copy(), t.db} +// Walk iterates over all the branches of a ZkTrie with the given rootHash +// if rootHash is nil, it will get the current RootHash of the current state of +// the ZkTrie. For each node, it calls the f function given in the +// parameters. See some examples of the Walk function usage in the +// ZkTrie.go and merkletree_test.go +func (mt *ZkTrie) Walk(rootHash *Hash, f func(*Node)) error { + var err error + if rootHash == nil { + rootHash, err = mt.Root() + if err != nil { + return err + } + } + mt.lock.RLock() + defer mt.lock.RUnlock() + + err = mt.walk(rootHash, f) + return err } -// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration -// starts at the key after the given start key. -func (t *ZkTrie) NodeIterator(start []byte) (NodeIterator, error) { - /// FIXME - panic("not implemented") +// GraphViz uses Walk function to generate a string GraphViz representation of +// the tree and writes it to w +func (mt *ZkTrie) GraphViz(w io.Writer, rootHash *Hash) error { + if rootHash == nil { + var err error + rootHash, err = mt.Root() + if err != nil { + return err + } + } + + mt.lock.RLock() + defer mt.lock.RUnlock() + + fmt.Fprintf(w, + "--------\nGraphViz of the ZkTrie with RootHash "+rootHash.BigInt().String()+"\n") + + fmt.Fprintf(w, `digraph hierarchy { +node [fontname=Monospace,fontsize=10,shape=box] +`) + cnt := 0 + var errIn error + err := mt.walk(rootHash, func(n *Node) { + hash, err := n.NodeHash() + if err != nil { + errIn = err + } + switch n.Type { + case NodeTypeEmpty_New: + case NodeTypeLeaf_New: + fmt.Fprintf(w, "\"%v\" [style=filled];\n", hash.String()) + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + lr := [2]string{n.ChildL.String(), n.ChildR.String()} + emptyNodes := "" + for i := range lr { + if lr[i] == "0" { + lr[i] = fmt.Sprintf("empty%v", cnt) + emptyNodes += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lr[i]) + cnt++ + } + } + fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hash.String(), lr[0], lr[1]) + fmt.Fprint(w, emptyNodes) + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter unsupported deprecated node type") + default: + } + }) + fmt.Fprintf(w, "}\n") + + fmt.Fprintf(w, + "End of GraphViz of the ZkTrie with RootHash "+rootHash.BigInt().String()+"\n--------\n") + + if errIn != nil { + return errIn + } + return err } -// hashKey returns the hash of key as an ephemeral buffer. -// The caller must not hold onto the return value because it will become -// invalid on the next call to hashKey or secKey. -/*func (t *ZkTrie) hashKey(key []byte) []byte { - if len(key) != 32 { - panic("non byte32 input to hashKey") - } - low16 := new(big.Int).SetBytes(key[:16]) - high16 := new(big.Int).SetBytes(key[16:]) - hash, err := poseidon.Hash([]*big.Int{low16, high16}) - if err != nil { - panic(err) - } - return hash.Bytes() +// Copy creates a new independent zkTrie from the given trie +func (mt *ZkTrie) Copy() *ZkTrie { + mt.lock.RLock() + defer mt.lock.RUnlock() + + // Deep copy in-memory dirty nodes + newDirtyStorage := make(map[Hash]*Node, len(mt.dirtyStorage)) + for key, dirtyNode := range mt.dirtyStorage { + newDirtyStorage[key] = dirtyNode.Copy() + } + + newRootKey := *mt.rootKey + return &ZkTrie{ + reader: mt.reader, + maxLevels: mt.maxLevels, + dirtyIndex: new(big.Int).Set(mt.dirtyIndex), + dirtyStorage: newDirtyStorage, + rootKey: &newRootKey, + committed: mt.committed, + preimages: mt.preimages, + secKeyCache: maps.Clone(mt.secKeyCache), + } } -*/ // Prove constructs a merkle proof for key. The result contains all encoded nodes // on the path to the value at key. The value itself is also included in the last @@ -214,49 +1048,220 @@ func (t *ZkTrie) NodeIterator(start []byte) (NodeIterator, error) { // nodes of the longest existing prefix of the key (at least the root node), ending // with the node that proves the absence of the key. // func (t *ZkTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { -func (t *ZkTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { +func (mt *ZkTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { fromLevel := uint(0) - err := t.ZkTrie.Prove(key, fromLevel, func(n *zktrie.Node) error { + err := mt.ProveWithDeletion(key, fromLevel, func(n *Node) error { nodeHash, err := n.NodeHash() if err != nil { return err } - if n.Type == zktrie.NodeTypeLeaf_New { - preImage := t.GetKey(n.NodeKey.Bytes()) + if n.Type == NodeTypeLeaf_New { + preImage := mt.getKey(n.NodeKey.Bytes()) if len(preImage) > 0 { - n.KeyPreimage = &zkt.Byte32{} + n.KeyPreimage = &Byte32{} copy(n.KeyPreimage[:], preImage) - //return fmt.Errorf("key preimage not found for [%x] ref %x", n.NodeKey.Bytes(), k.Bytes()) } } return proofDb.Put(nodeHash[:], n.Value()) - }) + }, nil) if err != nil { return err } // we put this special kv pair in db so we can distinguish the type and // make suitable Proof - return proofDb.Put(magicHash, zktrie.ProofMagicBytes()) + return proofDb.Put(magicHash, magicSMTBytes) +} + +// DecodeProof try to decode a node bytes, return can be nil for any non-node data (magic code) +func DecodeSMTProof(data []byte) (*Node, error) { + + if bytes.Equal(magicSMTBytes, data) { + //skip magic bytes node + return nil, nil + } + + return NewNodeFromBytes(data) +} + +// ProveWithDeletion constructs a merkle proof for key. The result contains all encoded nodes +// on the path to the value at key. The value itself is also included in the last +// node and can be retrieved by verifying the proof. +// +// If the trie does not contain a value for key, the returned proof contains all +// nodes of the longest existing prefix of the key (at least the root node), ending +// with the node that proves the absence of the key. +// +// If the trie contain value for key, the onHit is called BEFORE writeNode being called, +// both the hitted leaf node and its sibling node is provided as arguments so caller +// would receive enough information for launch a deletion and calculate the new root +// base on the proof data +// Also notice the sibling can be nil if the trie has only one leaf +func (mt *ZkTrie) ProveWithDeletion(key []byte, fromLevel uint, writeNode func(*Node) error, onHit func(*Node, *Node)) error { + secureKey, err := ToSecureKey(key) + if err != nil { + return err + } + + nodeKey := NewHashFromBigInt(secureKey) + var prev *Node + return mt.prove(nodeKey, fromLevel, func(n *Node) (err error) { + defer func() { + if err == nil { + err = writeNode(n) + } + prev = n + }() + + if prev != nil { + switch prev.Type { + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + default: + // sanity check: we should stop after obtain leaf/empty + panic("unexpected behavior in prove") + } + } + + if onHit == nil { + return + } + + // check and call onhit + if n.Type == NodeTypeLeaf_New && bytes.Equal(n.NodeKey.Bytes(), nodeKey.Bytes()) { + if prev == nil { + // for sole element trie + onHit(n, nil) + } else { + var sibling, nHash *Hash + nHash, err = n.NodeHash() + if err != nil { + return + } + + if bytes.Equal(nHash.Bytes(), prev.ChildL.Bytes()) { + sibling = prev.ChildR + } else { + sibling = prev.ChildL + } + + if siblingNode, err := mt.getNode(sibling); err == nil { + onHit(n, siblingNode) + } else { + onHit(n, nil) + } + } + + } + return + }) +} + +// Prove constructs a merkle proof for SMT, it respect the protocol used by the ethereum-trie +// but save the node data with a compact form +func (mt *ZkTrie) prove(kHash *Hash, fromLevel uint, writeNode func(*Node) error) error { + // force root hash calculation if needed + if _, err := mt.Root(); err != nil { + return err + } + + mt.lock.RLock() + defer mt.lock.RUnlock() + + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return ErrCommitted + } + + path := getPath(mt.maxLevels, kHash[:]) + var nodes []*Node + var lastN *Node + tn := mt.rootKey + for i := 0; i < mt.maxLevels; i++ { + n, err := mt.getNode(tn) + if err != nil { + fmt.Println("get node fail", err, tn.Hex(), + lastN.ChildL.Hex(), + lastN.ChildR.Hex(), + path, + i, + ) + return err + } + nodeHash := tn + lastN = n + + finished := true + switch n.Type { + case NodeTypeEmpty_New: + case NodeTypeLeaf_New: + // notice even we found a leaf whose entry didn't match the expected k, + // we still include it as the proof of absence + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + finished = false + if path[i] { + tn = n.ChildR + } else { + tn = n.ChildL + } + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter deprecated node types") + default: + return ErrInvalidNodeFound + } + + nCopy := n.Copy() + nCopy.nodeHash = nodeHash + nodes = append(nodes, nCopy) + if finished { + break + } + } + + for _, n := range nodes { + if fromLevel > 0 { + fromLevel-- + continue + } + + // TODO: notice here we may have broken some implicit on the proofDb: + // the key is not kecca(value) and it even can not be derived from + // the value by any means without a actually decoding + if err := writeNode(n); err != nil { + return err + } + } + + return nil +} + +// NodeIterator returns an iterator that returns nodes of the trie. Iteration +// starts at the key after the given start key. And error will be returned +// if fails to create node iterator. +func (mt *ZkTrie) NodeIterator(start []byte) (NodeIterator, error) { + // Short circuit if the trie is already committed and not usable. + if mt.committed { + return nil, ErrCommitted + } + return nil, errors.New("not implemented") } // VerifyProof checks merkle proofs. The given proof must contain the value for // key in a trie with the given root hash. VerifyProof returns an error if the // proof contains invalid trie nodes or the wrong value. func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { - h := zkt.NewHashFromBytes(rootHash.Bytes()) - k, err := zkt.ToSecureKey(key) + h := NewHashFromBytes(rootHash.Bytes()) + k, err := ToSecureKey(key) if err != nil { return nil, err } - proof, n, err := zktrie.BuildZkTrieProof(h, k, len(key)*8, func(key *zkt.Hash) (*zktrie.Node, error) { + proof, n, err := BuildZkTrieProof(h, k, len(key)*8, func(key *Hash) (*Node, error) { buf, _ := proofDb.Get(key[:]) if buf == nil { - return nil, zktrie.ErrKeyNotFound + return nil, ErrKeyNotFound } - n, err := zktrie.NewNodeFromBytes(buf) + n, err := NewNodeFromBytes(buf) return n, err }) @@ -267,9 +1272,56 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead return nil, nil } - if zktrie.VerifyProofZkTrie(h, proof, n) { + if VerifyProofZkTrie(h, proof, n) { return n.Data(), nil } else { return nil, fmt.Errorf("bad proof node %v", proof) } } + +// MustDelete deletes the key from the trie and panics if it fails. +func (mt *ZkTrie) MustDelete(key []byte) { + if err := mt.TryDelete(key); err != nil { + panic(err) + } +} + +// MustUpdate updates the key with the given value and panics if it fails. +func (mt *ZkTrie) MustUpdate(key, value []byte) { + if err := mt.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes(value)}); err != nil { + panic(err) + } +} + +// MustGet returns the value for key stored in the trie and panics if it fails. +func (mt *ZkTrie) MustGet(key []byte) []byte { + v, err := mt.TryGet(key) + if err != nil { + panic(err) + } + return v +} + +// MustNodeIterator returns an iterator that returns nodes of the trie and panics if it fails. +func (mt *ZkTrie) MustNodeIterator(start []byte) NodeIterator { + itr, err := mt.NodeIterator(start) + if err != nil { + panic(err) + } + return itr +} + +// GetAccountByHash does the same thing as GetAccount, however it expects an +// account hash that is the hash of address. This constitutes an abstraction +// leak, since the client code needs to know the key format. +func (mt *ZkTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { + return nil, errors.New("not implemented") +} + +// GetNode attempts to retrieve a trie node by compact-encoded path. It is not +// possible to use keybyte-encoding as the path might contain odd nibbles. +// If the specified trie node is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (mt *ZkTrie) GetNode(path []byte) ([]byte, int, error) { + return nil, 0, errors.New("not implemented") +} diff --git a/trie/zk_trie_database.go b/trie/zk_trie_database.go deleted file mode 100644 index b4c3fbbc268e..000000000000 --- a/trie/zk_trie_database.go +++ /dev/null @@ -1,172 +0,0 @@ -package trie - -import ( - "math/big" - - "github.com/syndtr/goleveldb/leveldb" - - zktrie "github.com/scroll-tech/zktrie/trie" - - "github.com/scroll-tech/go-ethereum/common" - "github.com/scroll-tech/go-ethereum/ethdb" - "github.com/scroll-tech/go-ethereum/trie/triedb/hashdb" -) - -// ZktrieDatabase Database adaptor implements zktrie.ZktrieDatbase -// It also reverses the bit order of the key being persisted. -// This ensures that the adjacent leaf in zktrie maintains minimal -// distance when persisted with dictionary order in LevelDB. -// Consequently, this optimizes the snapshot operation, allowing it -// to iterate through adjacent leaves at a reduced cost. - -type ZktrieDatabase struct { - db *Database - prefix []byte -} - -func NewZktrieDatabase(diskdb ethdb.Database) *ZktrieDatabase { - db := NewDatabase(diskdb, nil) - db.config.IsUsingZktrie = true - return &ZktrieDatabase{db: db, prefix: []byte{}} -} - -// adhoc wrapper... -func NewZktrieDatabaseFromTriedb(db *Database) *ZktrieDatabase { - db.config.IsUsingZktrie = true - return &ZktrieDatabase{db: db, prefix: []byte{}} -} - -// Put saves a key:value into the Storage -func (l *ZktrieDatabase) Put(k, v []byte) error { - k = bitReverse(k) - l.db.GetLock().Lock() - l.db.rawDirties.Put(Concat(l.prefix, k[:]), v) - l.db.GetLock().Unlock() - return nil -} - -// Get retrieves a value from a key in the Storage -func (l *ZktrieDatabase) Get(key []byte) ([]byte, error) { - key = bitReverse(key) - concatKey := Concat(l.prefix, key[:]) - l.db.GetLock().RLock() - value, ok := l.db.rawDirties.Get(concatKey) - l.db.GetLock().RUnlock() - if ok { - return value, nil - } - - if l.db.GetCleans() != nil { - if enc := l.db.GetCleans().Get(nil, concatKey); enc != nil { - hashdb.MemcacheCleanHitMeter.Mark(1) - hashdb.MemcacheCleanReadMeter.Mark(int64(len(enc))) - return enc, nil - } - } - - v, err := l.db.diskdb.Get(concatKey) - if err == leveldb.ErrNotFound { - return nil, zktrie.ErrKeyNotFound - } - if l.db.GetCleans() != nil { - l.db.GetCleans().Set(concatKey[:], v) - hashdb.MemcacheCleanMissMeter.Mark(1) - hashdb.MemcacheCleanWriteMeter.Mark(int64(len(v))) - } - return v, err -} - -func (l *ZktrieDatabase) UpdatePreimage(preimage []byte, hashField *big.Int) { - db := l.db - if db.preimages != nil { // Ugly direct check but avoids the below write lock - // we must copy the input key - db.preimages.insertPreimage(map[common.Hash][]byte{common.BytesToHash(hashField.Bytes()): common.CopyBytes(preimage)}) - } -} - -// Iterate implements the method Iterate of the interface Storage -func (l *ZktrieDatabase) Iterate(f func([]byte, []byte) (bool, error)) error { - iter := l.db.diskdb.NewIterator(l.prefix, nil) - defer iter.Release() - for iter.Next() { - localKey := bitReverse(iter.Key()[len(l.prefix):]) - if cont, err := f(localKey, iter.Value()); err != nil { - return err - } else if !cont { - break - } - } - iter.Release() - return iter.Error() -} - -// Close implements the method Close of the interface Storage -func (l *ZktrieDatabase) Close() { - // FIXME: is this correct? - if err := l.db.diskdb.Close(); err != nil { - panic(err) - } -} - -// List implements the method List of the interface Storage -func (l *ZktrieDatabase) List(limit int) ([]KV, error) { - ret := []KV{} - err := l.Iterate(func(key []byte, value []byte) (bool, error) { - ret = append(ret, KV{K: Clone(key), V: Clone(value)}) - if len(ret) == limit { - return false, nil - } - return true, nil - }) - return ret, err -} - -func bitReverseForNibble(b byte) byte { - switch b { - case 0: - return 0 - case 1: - return 8 - case 2: - return 4 - case 3: - return 12 - case 4: - return 2 - case 5: - return 10 - case 6: - return 6 - case 7: - return 14 - case 8: - return 1 - case 9: - return 9 - case 10: - return 5 - case 11: - return 13 - case 12: - return 3 - case 13: - return 11 - case 14: - return 7 - case 15: - return 15 - default: - panic("unexpected input") - } -} - -func bitReverse(inp []byte) (out []byte) { - l := len(inp) - out = make([]byte, l) - - for i, b := range inp { - out[l-i-1] = bitReverseForNibble(b&15)<<4 + bitReverseForNibble(b>>4) - } - - return -} diff --git a/trie/zk_trie_database_test.go b/trie/zk_trie_database_test.go deleted file mode 100644 index 6d8c15e6fa27..000000000000 --- a/trie/zk_trie_database_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package trie - -import ( - "bytes" - "testing" - - "github.com/scroll-tech/go-ethereum/common" -) - -// grep from `feat/snap` -func reverseBitInPlace(b []byte) { - var v [8]uint8 - for i := 0; i < len(b); i++ { - for j := 0; j < 8; j++ { - v[j] = (b[i] >> j) & 1 - } - var tmp uint8 = 0 - for j := 0; j < 8; j++ { - tmp |= v[8-j-1] << j - } - b[i] = tmp - } -} - -func reverseBytesInPlace(b []byte) { - for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { - b[i], b[j] = b[j], b[i] - } -} - -func TestBitReverse(t *testing.T) { - for _, testBytes := range [][]byte{ - common.FromHex("7b908cce3bc16abb3eac5dff6c136856526f15225f74ce860a2bec47912a5492"), - common.FromHex("fac65cd2ad5e301083d0310dd701b5faaff1364cbe01cdbfaf4ec3609bb4149e"), - common.FromHex("55791f6ec2f83fee512a2d3d4b505784fdefaea89974e10440d01d62a18a298a"), - common.FromHex("5ab775b64d86a8058bb71c3c765d0f2158c14bbeb9cb32a65eda793a7e95e30f"), - common.FromHex("ccb464abf67804538908c62431b3a6788e8dc6dee62aff9bfe6b10136acfceac"), - common.FromHex("b908adff17a5aa9d6787324c39014a74b04cef7fba6a92aeb730f48da1ca665d"), - } { - b1 := bitReverse(testBytes) - reverseBitInPlace(testBytes) - reverseBytesInPlace(testBytes) - if !bytes.Equal(b1, testBytes) { - t.Errorf("unexpected bit reversed %x vs %x", b1, testBytes) - } - } -} - -func TestBitDoubleReverse(t *testing.T) { - for _, testBytes := range [][]byte{ - common.FromHex("7b908cce3bc16abb3eac5dff6c136856526f15225f74ce860a2bec47912a5492"), - common.FromHex("fac65cd2ad5e301083d0310dd701b5faaff1364cbe01cdbfaf4ec3609bb4149e"), - common.FromHex("55791f6ec2f83fee512a2d3d4b505784fdefaea89974e10440d01d62a18a298a"), - common.FromHex("5ab775b64d86a8058bb71c3c765d0f2158c14bbeb9cb32a65eda793a7e95e30f"), - common.FromHex("ccb464abf67804538908c62431b3a6788e8dc6dee62aff9bfe6b10136acfceac"), - common.FromHex("b908adff17a5aa9d6787324c39014a74b04cef7fba6a92aeb730f48da1ca665d"), - } { - b := bitReverse(bitReverse(testBytes)) - if !bytes.Equal(b, testBytes) { - t.Errorf("unexpected double bit reversed %x vs %x", b, testBytes) - } - } -} diff --git a/trie/zk_trie_impl_test.go b/trie/zk_trie_impl_test.go deleted file mode 100644 index 3da9fe91932c..000000000000 --- a/trie/zk_trie_impl_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package trie - -import ( - "math/big" - "testing" - - "github.com/iden3/go-iden3-crypto/constants" - cryptoUtils "github.com/iden3/go-iden3-crypto/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - zktrie "github.com/scroll-tech/zktrie/trie" - zkt "github.com/scroll-tech/zktrie/types" - - "github.com/scroll-tech/go-ethereum/common" - "github.com/scroll-tech/go-ethereum/core/rawdb" - "github.com/scroll-tech/go-ethereum/core/types" -) - -// we do not need zktrie impl anymore, only made a wrapper for adapting testing -type zkTrieImplTestWrapper struct { - *zktrie.ZkTrieImpl -} - -func newZkTrieImpl(storage *ZktrieDatabase, maxLevels int) (*zkTrieImplTestWrapper, error) { - return newZkTrieImplWithRoot(storage, &zkt.HashZero, maxLevels) -} - -// NewZkTrieImplWithRoot loads a new ZkTrieImpl. If in the storage already exists one -// will open that one, if not, will create a new one. -func newZkTrieImplWithRoot(storage *ZktrieDatabase, root *zkt.Hash, maxLevels int) (*zkTrieImplTestWrapper, error) { - impl, err := zktrie.NewZkTrieImplWithRoot(storage, root, maxLevels) - if err != nil { - return nil, err - } - - return &zkTrieImplTestWrapper{impl}, nil -} - -// AddWord -// Deprecated: Add a Bytes32 kv to ZkTrieImpl, only for testing -func (mt *zkTrieImplTestWrapper) AddWord(kPreimage, vPreimage *zkt.Byte32) error { - k, err := kPreimage.Hash() - if err != nil { - return err - } - - if v, _ := mt.TryGet(k.Bytes()); v != nil { - return zktrie.ErrEntryIndexAlreadyExists - } - - return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBigInt(k), 1, []zkt.Byte32{*vPreimage}) -} - -// GetLeafNodeByWord -// Deprecated: Get a Bytes32 kv to ZkTrieImpl, only for testing -func (mt *zkTrieImplTestWrapper) GetLeafNodeByWord(kPreimage *zkt.Byte32) (*zktrie.Node, error) { - k, err := kPreimage.Hash() - if err != nil { - return nil, err - } - return mt.ZkTrieImpl.GetLeafNode(zkt.NewHashFromBigInt(k)) -} - -// Deprecated: only for testing -func (mt *zkTrieImplTestWrapper) UpdateWord(kPreimage, vPreimage *zkt.Byte32) error { - k, err := kPreimage.Hash() - if err != nil { - return err - } - - return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBigInt(k), 1, []zkt.Byte32{*vPreimage}) -} - -// Deprecated: only for testing -func (mt *zkTrieImplTestWrapper) DeleteWord(kPreimage *zkt.Byte32) error { - k, err := kPreimage.Hash() - if err != nil { - return err - } - return mt.ZkTrieImpl.TryDelete(zkt.NewHashFromBigInt(k)) -} - -func (mt *zkTrieImplTestWrapper) TryGet(key []byte) ([]byte, error) { - return mt.ZkTrieImpl.TryGet(zkt.NewHashFromBytes(key)) -} - -func (mt *zkTrieImplTestWrapper) TryDelete(key []byte) error { - return mt.ZkTrieImpl.TryDelete(zkt.NewHashFromBytes(key)) -} - -// TryUpdateAccount will abstract the write of an account to the trie -func (mt *zkTrieImplTestWrapper) TryUpdateAccount(key []byte, acc *types.StateAccount) error { - value, flag := acc.MarshalFields() - return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBytes(key), flag, value) -} - -// NewHashFromHex returns a *Hash representation of the given hex string -func NewHashFromHex(h string) (*zkt.Hash, error) { - return zkt.NewHashFromCheckedBytes(common.FromHex(h)) -} - -type Fatalable interface { - Fatal(args ...interface{}) -} - -func newTestingMerkle(f Fatalable, numLevels int) *zkTrieImplTestWrapper { - mt, err := newZkTrieImpl(NewZktrieDatabase(rawdb.NewMemoryDatabase()), numLevels) - if err != nil { - f.Fatal(err) - return nil - } - return mt -} - -func TestHashParsers(t *testing.T) { - h0 := zkt.NewHashFromBigInt(big.NewInt(0)) - assert.Equal(t, "0", h0.String()) - h1 := zkt.NewHashFromBigInt(big.NewInt(1)) - assert.Equal(t, "1", h1.String()) - h10 := zkt.NewHashFromBigInt(big.NewInt(10)) - assert.Equal(t, "10", h10.String()) - - h7l := zkt.NewHashFromBigInt(big.NewInt(1234567)) - assert.Equal(t, "1234567", h7l.String()) - h8l := zkt.NewHashFromBigInt(big.NewInt(12345678)) - assert.Equal(t, "12345678...", h8l.String()) - - b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll - assert.True(t, ok) - h := zkt.NewHashFromBigInt(b) - assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll - assert.Equal(t, "49322979...", h.String()) - assert.Equal(t, "0ae794eb9c3d8bbb9002e993fc2ed301dcbd2af5508ed072c375e861f1aa5b26", h.Hex()) - - b1, err := zkt.NewBigIntFromHashBytes(b.Bytes()) - assert.Nil(t, err) - assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String()) - - b2, err := zkt.NewHashFromCheckedBytes(b.Bytes()) - assert.Nil(t, err) - assert.Equal(t, b.String(), b2.BigInt().String()) - - h2, err := NewHashFromHex(h.Hex()) - assert.Nil(t, err) - assert.Equal(t, h, h2) - _, err = NewHashFromHex("0x12") - assert.NotNil(t, err) - - // check limits - a := new(big.Int).Sub(constants.Q, big.NewInt(1)) - testHashParsers(t, a) - a = big.NewInt(int64(1)) - testHashParsers(t, a) -} - -func testHashParsers(t *testing.T, a *big.Int) { - require.True(t, cryptoUtils.CheckBigIntInField(a)) - h := zkt.NewHashFromBigInt(a) - assert.Equal(t, a, h.BigInt()) - hFromBytes, err := zkt.NewHashFromCheckedBytes(h.Bytes()) - assert.Nil(t, err) - assert.Equal(t, h, hFromBytes) - assert.Equal(t, a, hFromBytes.BigInt()) - assert.Equal(t, a.String(), hFromBytes.BigInt().String()) - hFromHex, err := NewHashFromHex(h.Hex()) - assert.Nil(t, err) - assert.Equal(t, h, hFromHex) - - aBIFromHBytes, err := zkt.NewBigIntFromHashBytes(h.Bytes()) - assert.Nil(t, err) - assert.Equal(t, a, aBIFromHBytes) - assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String()) -} - -func TestMerkleTree_AddUpdateGetWord(t *testing.T) { - mt := newTestingMerkle(t, 10) - err := mt.AddWord(&zkt.Byte32{1}, &zkt.Byte32{2}) - assert.Nil(t, err) - err = mt.AddWord(&zkt.Byte32{3}, &zkt.Byte32{4}) - assert.Nil(t, err) - err = mt.AddWord(&zkt.Byte32{5}, &zkt.Byte32{6}) - assert.Nil(t, err) - err = mt.AddWord(&zkt.Byte32{5}, &zkt.Byte32{7}) - assert.Equal(t, zktrie.ErrEntryIndexAlreadyExists, err) - - node, err := mt.GetLeafNodeByWord(&zkt.Byte32{1}) - assert.Nil(t, err) - assert.Equal(t, len(node.ValuePreimage), 1) - assert.Equal(t, (&zkt.Byte32{2})[:], node.ValuePreimage[0][:]) - node, err = mt.GetLeafNodeByWord(&zkt.Byte32{3}) - assert.Nil(t, err) - assert.Equal(t, len(node.ValuePreimage), 1) - assert.Equal(t, (&zkt.Byte32{4})[:], node.ValuePreimage[0][:]) - node, err = mt.GetLeafNodeByWord(&zkt.Byte32{5}) - assert.Nil(t, err) - assert.Equal(t, len(node.ValuePreimage), 1) - assert.Equal(t, (&zkt.Byte32{6})[:], node.ValuePreimage[0][:]) - - err = mt.UpdateWord(&zkt.Byte32{1}, &zkt.Byte32{7}) - assert.Nil(t, err) - err = mt.UpdateWord(&zkt.Byte32{3}, &zkt.Byte32{8}) - assert.Nil(t, err) - err = mt.UpdateWord(&zkt.Byte32{5}, &zkt.Byte32{9}) - assert.Nil(t, err) - - node, err = mt.GetLeafNodeByWord(&zkt.Byte32{1}) - assert.Nil(t, err) - assert.Equal(t, len(node.ValuePreimage), 1) - assert.Equal(t, (&zkt.Byte32{7})[:], node.ValuePreimage[0][:]) - node, err = mt.GetLeafNodeByWord(&zkt.Byte32{3}) - assert.Nil(t, err) - assert.Equal(t, len(node.ValuePreimage), 1) - assert.Equal(t, (&zkt.Byte32{8})[:], node.ValuePreimage[0][:]) - node, err = mt.GetLeafNodeByWord(&zkt.Byte32{5}) - assert.Nil(t, err) - assert.Equal(t, len(node.ValuePreimage), 1) - assert.Equal(t, (&zkt.Byte32{9})[:], node.ValuePreimage[0][:]) - _, err = mt.GetLeafNodeByWord(&zkt.Byte32{100}) - assert.Equal(t, zktrie.ErrKeyNotFound, err) -} - -func TestMerkleTree_UpdateAccount(t *testing.T) { - mt := newTestingMerkle(t, 10) - - acc1 := &types.StateAccount{ - Nonce: 1, - Balance: big.NewInt(10000000), - Root: common.HexToHash("22fb59aa5410ed465267023713ab42554c250f394901455a3366e223d5f7d147"), - KeccakCodeHash: common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), - PoseidonCodeHash: common.HexToHash("0c0a77f6e063b4b62eb7d9ed6f427cf687d8d0071d751850cfe5d136bc60d3ab").Bytes(), - CodeSize: 0, - } - err := mt.TryUpdateAccount(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes(), acc1) - assert.Nil(t, err) - - acc2 := &types.StateAccount{ - Nonce: 5, - Balance: big.NewInt(50000000), - Root: common.HexToHash("0"), - KeccakCodeHash: common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), - PoseidonCodeHash: common.HexToHash("05d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), - CodeSize: 5, - } - err = mt.TryUpdateAccount(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes(), acc2) - assert.Nil(t, err) - - bt, err := mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) - assert.Nil(t, err) - - acc, err := types.UnmarshalStateAccount(bt) - assert.Nil(t, err) - assert.Equal(t, acc1.Nonce, acc.Nonce) - assert.Equal(t, acc1.Balance.Uint64(), acc.Balance.Uint64()) - assert.Equal(t, acc1.Root.Bytes(), acc.Root.Bytes()) - assert.Equal(t, acc1.KeccakCodeHash, acc.KeccakCodeHash) - assert.Equal(t, acc1.PoseidonCodeHash, acc.PoseidonCodeHash) - assert.Equal(t, acc1.CodeSize, acc.CodeSize) - - bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) - assert.Nil(t, err) - - acc, err = types.UnmarshalStateAccount(bt) - assert.Nil(t, err) - assert.Equal(t, acc2.Nonce, acc.Nonce) - assert.Equal(t, acc2.Balance.Uint64(), acc.Balance.Uint64()) - assert.Equal(t, acc2.Root.Bytes(), acc.Root.Bytes()) - assert.Equal(t, acc2.KeccakCodeHash, acc.KeccakCodeHash) - assert.Equal(t, acc2.PoseidonCodeHash, acc.PoseidonCodeHash) - assert.Equal(t, acc2.CodeSize, acc.CodeSize) - - bt, err = mt.TryGet(common.HexToAddress("0x8dE13967F19410A7991D63c2c0179feBFDA0c261").Bytes()) - assert.Nil(t, err) - assert.Nil(t, bt) - - err = mt.TryDelete(common.HexToHash("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) - assert.Nil(t, err) - - bt, err = mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) - assert.Nil(t, err) - assert.Nil(t, bt) - - err = mt.TryDelete(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) - assert.Nil(t, err) - - bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) - assert.Nil(t, err) - assert.Nil(t, bt) -} diff --git a/trie/zk_trie_node.go b/trie/zk_trie_node.go new file mode 100644 index 000000000000..950a77e69ee9 --- /dev/null +++ b/trie/zk_trie_node.go @@ -0,0 +1,405 @@ +package trie + +import ( + "encoding/binary" + "fmt" + "math/big" + "reflect" + "slices" + "unsafe" + + "github.com/scroll-tech/go-ethereum/common" +) + +// NodeType defines the type of node in the MT. +type NodeType byte + +const ( + // NodeTypeParent indicates the type of parent Node that has children. + NodeTypeParent NodeType = 0 + // NodeTypeLeaf indicates the type of a leaf Node that contains a key & + // value. + NodeTypeLeaf NodeType = 1 + // NodeTypeEmpty indicates the type of an empty Node. + NodeTypeEmpty NodeType = 2 + + // DBEntryTypeRoot indicates the type of a DB entry that indicates the + // current Root of a MerkleTree + DBEntryTypeRoot NodeType = 3 + + NodeTypeLeaf_New NodeType = 4 + NodeTypeEmpty_New NodeType = 5 + // branch node for both child are terminal nodes + NodeTypeBranch_0 NodeType = 6 + // branch node for left child is terminal node and right child is branch + NodeTypeBranch_1 NodeType = 7 + // branch node for left child is branch node and right child is terminal + NodeTypeBranch_2 NodeType = 8 + // branch node for both child are branch nodes + NodeTypeBranch_3 NodeType = 9 +) + +// DeduceUploadType deduce a new branch type from current branch when one of its child become non-terminal +func (n NodeType) DeduceUpgradeType(goRight bool) NodeType { + if goRight { + switch n { + case NodeTypeBranch_0: + return NodeTypeBranch_1 + case NodeTypeBranch_1: + return n + case NodeTypeBranch_2, NodeTypeBranch_3: + return NodeTypeBranch_3 + } + } else { + switch n { + case NodeTypeBranch_0: + return NodeTypeBranch_2 + case NodeTypeBranch_1, NodeTypeBranch_3: + return NodeTypeBranch_3 + case NodeTypeBranch_2: + return n + } + } + + panic(fmt.Errorf("invalid NodeType: %d", n)) +} + +// DeduceDowngradeType deduce a new branch type from current branch when one of its child become terminal +func (n NodeType) DeduceDowngradeType(atRight bool) NodeType { + if atRight { + switch n { + case NodeTypeBranch_1: + return NodeTypeBranch_0 + case NodeTypeBranch_3: + return NodeTypeBranch_2 + case NodeTypeBranch_0, NodeTypeBranch_2: + panic(fmt.Errorf("can not downgrade a node with terminal child (%d)", n)) + } + } else { + switch n { + case NodeTypeBranch_3: + return NodeTypeBranch_1 + case NodeTypeBranch_2: + return NodeTypeBranch_0 + case NodeTypeBranch_0, NodeTypeBranch_1: + panic(fmt.Errorf("can not downgrade a node with terminal child (%d)", n)) + } + } + panic(fmt.Errorf("invalid NodeType: %d", n)) +} + +// Node is the struct that represents a node in the MT. The node should not be +// modified after creation because the cached key won't be updated. +type Node struct { + // Type is the type of node in the tree. + Type NodeType + // ChildL is the node hash of the left child of a parent node. + ChildL *Hash + // ChildR is the node hash of the right child of a parent node. + ChildR *Hash + // NodeKey is the node's key stored in a leaf node. + NodeKey *Hash + // ValuePreimage can store at most 256 byte32 as fields (represnted by BIG-ENDIAN integer) + // and the first 24 can be compressed (each bytes32 consider as 2 fields), in hashing the compressed + // elemments would be calculated first + ValuePreimage []Byte32 + // CompressedFlags use each bit for indicating the compressed flag for the first 24 fields + CompressedFlags uint32 + // nodeHash is the cache of the hash of the node to avoid recalculating + nodeHash *Hash + // valueHash is the cache of the hash of valuePreimage to avoid recalculating, only valid for leaf node + valueHash *Hash + // KeyPreimage is the original key value that derives the NodeKey, kept here only for proof + KeyPreimage *Byte32 +} + +// NewLeafNode creates a new leaf node. +func NewLeafNode(k *Hash, valueFlags uint32, valuePreimage []Byte32) *Node { + return &Node{Type: NodeTypeLeaf_New, NodeKey: k, CompressedFlags: valueFlags, ValuePreimage: valuePreimage} +} + +// NewParentNode creates a new parent node. +func NewParentNode(ntype NodeType, childL *Hash, childR *Hash) *Node { + return &Node{Type: ntype, ChildL: childL, ChildR: childR} +} + +// NewEmptyNode creates a new empty node. +func NewEmptyNode() *Node { + return &Node{Type: NodeTypeEmpty_New} +} + +// NewNodeFromBytes creates a new node by parsing the input []byte. +func NewNodeFromBytes(b []byte) (*Node, error) { + var n Node + if err := n.SetBytes(b); err != nil { + return nil, err + } + return &n, nil +} + +// LeafHash computes the key of a leaf node given the hIndex and hValue of the +// entry of the leaf. +func LeafHash(k, v *Hash) (*Hash, error) { + return HashElemsWithDomain(big.NewInt(int64(NodeTypeLeaf_New)), k.BigInt(), v.BigInt()) +} + +func (n *Node) SetBytes(b []byte) error { + if len(b) < 1 { + return ErrNodeBytesBadSize + } + nType := NodeType(b[0]) + b = b[1:] + switch nType { + case NodeTypeParent, NodeTypeBranch_0, + NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + if len(b) != 2*HashByteLen { + return ErrNodeBytesBadSize + } + + childL := n.ChildL + childR := n.ChildR + + if childL == nil { + childL = NewHashFromBytes(b[:HashByteLen]) + } else { + childL.SetBytes(b[:HashByteLen]) + } + + if childR == nil { + childR = NewHashFromBytes(b[HashByteLen : HashByteLen*2]) + } else { + childR.SetBytes(b[HashByteLen : HashByteLen*2]) + } + + *n = Node{ + Type: nType, + ChildL: childL, + ChildR: childR, + } + case NodeTypeLeaf, NodeTypeLeaf_New: + if len(b) < HashByteLen+4 { + return ErrNodeBytesBadSize + } + nodeKey := NewHashFromBytes(b[0:HashByteLen]) + mark := binary.LittleEndian.Uint32(b[HashByteLen : HashByteLen+4]) + preimageLen := int(mark & 255) + compressedFlags := mark >> 8 + valuePreimage := slices.Grow(n.ValuePreimage[0:], preimageLen) + curPos := HashByteLen + 4 + if len(b) < curPos+preimageLen*32+1 { + return ErrNodeBytesBadSize + } + for i := 0; i < preimageLen; i++ { + var byte32 Byte32 + copy(byte32[:], b[i*32+curPos:(i+1)*32+curPos]) + valuePreimage = append(valuePreimage, byte32) + } + curPos += preimageLen * 32 + preImageSize := int(b[curPos]) + curPos += 1 + + var keyPreimage *Byte32 + if preImageSize != 0 { + if len(b) < curPos+preImageSize { + return ErrNodeBytesBadSize + } + + keyPreimage = n.KeyPreimage + if keyPreimage == nil { + keyPreimage = new(Byte32) + } + copy(keyPreimage[:], b[curPos:curPos+preImageSize]) + } + + *n = Node{ + Type: nType, + NodeKey: nodeKey, + CompressedFlags: compressedFlags, + ValuePreimage: valuePreimage, + KeyPreimage: keyPreimage, + } + case NodeTypeEmpty, NodeTypeEmpty_New: + *n = Node{Type: nType} + default: + return ErrInvalidNodeFound + } + return nil +} + +// IsTerminal returns if the node is 'terminated', i.e. empty or leaf node +func (n *Node) IsTerminal() bool { + switch n.Type { + case NodeTypeEmpty_New, NodeTypeLeaf_New: + return true + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + return false + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter deprecated node types") + default: + panic(fmt.Errorf("encounter unknown node types %d", n.Type)) + } + +} + +// NodeHash computes the hash digest of the node by hashing the content in a +// specific way for each type of node. This key is used as the hash of the +// Merkle tree for each node. +func (n *Node) NodeHash() (*Hash, error) { + if n.nodeHash == nil { // Cache the key to avoid repeated hash computations. + // NOTE: We are not using the type to calculate the hash! + switch n.Type { + case NodeTypeBranch_0, + NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: // H(ChildL || ChildR) + var err error + n.nodeHash, err = HashElemsWithDomain(big.NewInt(int64(n.Type)), + n.ChildL.BigInt(), n.ChildR.BigInt()) + if err != nil { + return nil, err + } + case NodeTypeLeaf_New: + var err error + n.valueHash, err = HandlingElemsAndByte32(n.CompressedFlags, n.ValuePreimage) + if err != nil { + return nil, err + } + + n.nodeHash, err = LeafHash(n.NodeKey, n.valueHash) + if err != nil { + return nil, err + } + + case NodeTypeEmpty_New: // Zero + n.nodeHash = &HashZero + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter deprecated node types") + default: + n.nodeHash = &HashZero + } + } + return n.nodeHash, nil +} + +// ValueHash computes the hash digest of the value stored in the leaf node. For +// other node types, it returns the zero hash. +func (n *Node) ValueHash() (*Hash, error) { + if n.Type != NodeTypeLeaf_New { + return &HashZero, nil + } + if _, err := n.NodeHash(); err != nil { + return nil, err + } + return n.valueHash, nil +} + +// Data returns the wrapped data inside LeafNode and cast them into bytes +// for other node type it just return nil +func (n *Node) Data() []byte { + switch n.Type { + case NodeTypeLeaf_New: + var data []byte + hdata := (*reflect.SliceHeader)(unsafe.Pointer(&data)) + //TODO: uintptr(reflect.ValueOf(n.ValuePreimage).UnsafePointer()) should be more elegant but only available until go 1.18 + hdata.Data = uintptr(unsafe.Pointer(&n.ValuePreimage[0])) + hdata.Len = 32 * len(n.ValuePreimage) + hdata.Cap = hdata.Len + return data + default: + return nil + } +} + +// CanonicalValue returns the byte form of a node required to be persisted, and strip unnecessary fields +// from the encoding (current only KeyPreimage for Leaf node) to keep a minimum size for content being +// stored in backend storage +func (n *Node) CanonicalValue() []byte { + switch n.Type { + case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: // {Type || ChildL || ChildR} + bytes := []byte{byte(n.Type)} + bytes = append(bytes, n.ChildL.Bytes()...) + bytes = append(bytes, n.ChildR.Bytes()...) + return bytes + case NodeTypeLeaf_New: // {Type || Data...} + bytes := []byte{byte(n.Type)} + bytes = append(bytes, n.NodeKey.Bytes()...) + tmp := make([]byte, 4) + compressedFlag := (n.CompressedFlags << 8) + uint32(len(n.ValuePreimage)) + binary.LittleEndian.PutUint32(tmp, compressedFlag) + bytes = append(bytes, tmp...) + for _, elm := range n.ValuePreimage { + bytes = append(bytes, elm[:]...) + } + bytes = append(bytes, 0) + return bytes + case NodeTypeEmpty_New: // { Type } + return []byte{byte(n.Type)} + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + panic("encounter deprecated node types") + default: + return []byte{} + } +} + +// Value returns the encoded bytes of a node, include all information of it +func (n *Node) Value() []byte { + outBytes := n.CanonicalValue() + switch n.Type { + case NodeTypeLeaf_New: // {Type || Data...} + if n.KeyPreimage != nil { + outBytes[len(outBytes)-1] = byte(len(n.KeyPreimage)) + outBytes = append(outBytes, n.KeyPreimage[:]...) + } + } + + return outBytes +} + +// String outputs a string representation of a node (different for each type). +func (n *Node) String() string { + switch n.Type { + // {Type || ChildL || ChildR} + case NodeTypeBranch_0: + return fmt.Sprintf("Parent L(t):%s R(t):%s", n.ChildL, n.ChildR) + case NodeTypeBranch_1: + return fmt.Sprintf("Parent L(t):%s R:%s", n.ChildL, n.ChildR) + case NodeTypeBranch_2: + return fmt.Sprintf("Parent L:%s R(t):%s", n.ChildL, n.ChildR) + case NodeTypeBranch_3: + return fmt.Sprintf("Parent L:%s R:%s", n.ChildL, n.ChildR) + case NodeTypeLeaf_New: // {Type || Data...} + return fmt.Sprintf("Leaf I:%v Items: %d, First:%v", n.NodeKey, len(n.ValuePreimage), n.ValuePreimage[0]) + case NodeTypeEmpty_New: // {} + return "Empty" + case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: + return "deprecated Node" + default: + return "Invalid Node" + } +} + +// Copy creates a new Node instance from the given node +func (n *Node) Copy() *Node { + newNode, err := NewNodeFromBytes(n.Value()) + if err != nil { + panic("failed to copy trie node") + } + return newNode +} + +type ZkChildResolver struct{} + +// ForEach iterates over the children of a node and calls the given function +// note: original implementation from geth works recursively, but our Node definition +// doesn't allow that. So we only iterate over the children of the current node, which +// should be fine. +func (r ZkChildResolver) ForEach(node []byte, onChild func(common.Hash)) { + switch NodeType(node[0]) { + case NodeTypeParent, NodeTypeBranch_0, + NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: + + var childHash common.Hash + childHash.SetBytes(node[1 : HashByteLen+1]) + onChild(childHash) + childHash.SetBytes(node[HashByteLen+1 : HashByteLen*2+1]) + onChild(childHash) + } +} diff --git a/trie/zk_trie_node_test.go b/trie/zk_trie_node_test.go new file mode 100644 index 000000000000..1cd5daa385ca --- /dev/null +++ b/trie/zk_trie_node_test.go @@ -0,0 +1,240 @@ +package trie + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewNode(t *testing.T) { + t.Run("Test NewEmptyNode", func(t *testing.T) { + node := NewEmptyNode() + assert.Equal(t, NodeTypeEmpty_New, node.Type) + + hash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, hash) + + hash, err = node.ValueHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, hash) + }) + + t.Run("Test NewLeafNode", func(t *testing.T) { + k := NewHashFromBytes(bytes.Repeat([]byte("0"), 32)) + vp := []Byte32{*NewByte32FromBytes(bytes.Repeat([]byte("b"), 32))} + node := NewLeafNode(k, 1, vp) + assert.Equal(t, NodeTypeLeaf_New, node.Type) + assert.Equal(t, uint32(1), node.CompressedFlags) + assert.Equal(t, vp, node.ValuePreimage) + + hash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, "2536e274d373c4ca79bc85c6aa140fe911eb7fe04939e1311004bbaf3c13c32a", hash.Hex()) + + hash, err = node.ValueHash() + assert.NoError(t, err) + hashFromVp, err := vp[0].Hash() + assert.NoError(t, err) + assert.Equal(t, hashFromVp.Text(16), hash.Hex()) + }) + + t.Run("Test NewParentNode", func(t *testing.T) { + k := NewHashFromBytes(bytes.Repeat([]byte("0"), 32)) + node := NewParentNode(NodeTypeBranch_3, k, k) + assert.Equal(t, NodeTypeBranch_3, node.Type) + assert.Equal(t, k, node.ChildL) + assert.Equal(t, k, node.ChildR) + + hash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, "242d3e8a6a7683f9858a08cdf1db2a4448638c168e32168ef4e5e9e2e8794629", hash.Hex()) + + hash, err = node.ValueHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, hash) + }) + + t.Run("Test NewParentNodeWithEmptyChild", func(t *testing.T) { + k := NewHashFromBytes(bytes.Repeat([]byte("0"), 32)) + r, err := NewEmptyNode().NodeHash() + assert.NoError(t, err) + node := NewParentNode(NodeTypeBranch_2, k, r) + + assert.Equal(t, NodeTypeBranch_2, node.Type) + assert.Equal(t, k, node.ChildL) + assert.Equal(t, r, node.ChildR) + + hash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, "005bc4e8f3b3f2ff0b980d4f3c32973de6a01f89ddacb08b0e7903d1f1f0c50f", hash.Hex()) + + hash, err = node.ValueHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, hash) + }) + + t.Run("Test Invalid Node", func(t *testing.T) { + node := &Node{Type: 99} + + invalidNodeHash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, invalidNodeHash) + }) +} + +func TestNewNodeFromBytes(t *testing.T) { + t.Run("ParentNode", func(t *testing.T) { + k1 := NewHashFromBytes(bytes.Repeat([]byte("0"), 32)) + k2 := NewHashFromBytes(bytes.Repeat([]byte("0"), 32)) + node := NewParentNode(NodeTypeBranch_0, k1, k2) + b := node.Value() + + node, err := NewNodeFromBytes(b) + assert.NoError(t, err) + + assert.Equal(t, NodeTypeBranch_0, node.Type) + assert.Equal(t, k1, node.ChildL) + assert.Equal(t, k2, node.ChildR) + + hash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, "12b90fefb7b19131d25980a38ca92edb66bb91828d305836e4ab7e961165c83f", hash.Hex()) + + hash, err = node.ValueHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, hash) + }) + + t.Run("LeafNode", func(t *testing.T) { + k := NewHashFromBytes(bytes.Repeat([]byte("0"), 32)) + vp := make([]Byte32, 1) + node := NewLeafNode(k, 1, vp) + + node.KeyPreimage = NewByte32FromBytes(bytes.Repeat([]byte("b"), 32)) + + nodeBytes := node.Value() + newNode, err := NewNodeFromBytes(nodeBytes) + assert.NoError(t, err) + + assert.Equal(t, node.Type, newNode.Type) + assert.Equal(t, node.NodeKey, newNode.NodeKey) + assert.Equal(t, node.ValuePreimage, newNode.ValuePreimage) + assert.Equal(t, node.KeyPreimage, newNode.KeyPreimage) + + hash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, "2f7094f04ed1592909311471ba67d84d7d11e2438c055f4d5d43189390c5cf5a", hash.Hex()) + + hash, err = node.ValueHash() + assert.NoError(t, err) + hashFromVp, err := vp[0].Hash() + + assert.Equal(t, NewHashFromBigInt(hashFromVp), hash) + }) + + t.Run("EmptyNode", func(t *testing.T) { + node := NewEmptyNode() + b := node.Value() + + node, err := NewNodeFromBytes(b) + assert.NoError(t, err) + + assert.Equal(t, NodeTypeEmpty_New, node.Type) + + hash, err := node.NodeHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, hash) + + hash, err = node.ValueHash() + assert.NoError(t, err) + assert.Equal(t, &HashZero, hash) + }) + + t.Run("BadSize", func(t *testing.T) { + testCases := [][]byte{ + {}, + {0, 1, 2}, + func() []byte { + b := make([]byte, HashByteLen+3) + b[0] = byte(NodeTypeLeaf) + return b + }(), + func() []byte { + k := NewHashFromBytes([]byte{1, 2, 3, 4, 5}) + vp := make([]Byte32, 1) + node := NewLeafNode(k, 1, vp) + b := node.Value() + return b[:len(b)-32] + }(), + func() []byte { + k := NewHashFromBytes([]byte{1, 2, 3, 4, 5}) + vp := make([]Byte32, 1) + node := NewLeafNode(k, 1, vp) + node.KeyPreimage = NewByte32FromBytes([]byte{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37}) + + b := node.Value() + return b[:len(b)-1] + }(), + } + + for _, b := range testCases { + node, err := NewNodeFromBytes(b) + assert.ErrorIs(t, err, ErrNodeBytesBadSize) + assert.Nil(t, node) + } + }) + + t.Run("InvalidType", func(t *testing.T) { + b := []byte{255} + + node, err := NewNodeFromBytes(b) + assert.ErrorIs(t, err, ErrInvalidNodeFound) + assert.Nil(t, node) + }) +} + +func TestNodeValueAndData(t *testing.T) { + k := NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) + vp := []Byte32{*NewByte32FromBytes(bytes.Repeat([]byte("b"), 32))} + + node := NewLeafNode(k, 1, vp) + canonicalValue := node.CanonicalValue() + assert.Equal(t, []byte{0x4, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x1, 0x1, 0x0, 0x0, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x0}, canonicalValue) + assert.Equal(t, []byte{0x4, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x1, 0x1, 0x0, 0x0, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x0}, node.Value()) + node.KeyPreimage = NewByte32FromBytes(bytes.Repeat([]byte("c"), 32)) + assert.Equal(t, []byte{0x4, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x1, 0x1, 0x0, 0x0, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x20, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63}, node.Value()) + assert.Equal(t, []byte{0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62}, node.Data()) + + parentNode := NewParentNode(NodeTypeBranch_3, k, k) + canonicalValue = parentNode.CanonicalValue() + assert.Equal(t, []byte{0x9, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61}, canonicalValue) + assert.Nil(t, parentNode.Data()) + + emptyNode := &Node{Type: NodeTypeEmpty_New} + assert.Equal(t, []byte{byte(emptyNode.Type)}, emptyNode.CanonicalValue()) + assert.Nil(t, emptyNode.Data()) + + invalidNode := &Node{Type: 99} + assert.Equal(t, []byte{}, invalidNode.CanonicalValue()) + assert.Nil(t, invalidNode.Data()) +} + +func TestNodeString(t *testing.T) { + k := NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) + vp := []Byte32{*NewByte32FromBytes(bytes.Repeat([]byte("b"), 32))} + + leafNode := NewLeafNode(k, 1, vp) + assert.Equal(t, fmt.Sprintf("Leaf I:%v Items: %d, First:%v", leafNode.NodeKey, len(leafNode.ValuePreimage), leafNode.ValuePreimage[0]), leafNode.String()) + + parentNode := NewParentNode(NodeTypeBranch_3, k, k) + assert.Equal(t, fmt.Sprintf("Parent L:%s R:%s", parentNode.ChildL, parentNode.ChildR), parentNode.String()) + + emptyNode := NewEmptyNode() + assert.Equal(t, "Empty", emptyNode.String()) + + invalidNode := &Node{Type: 99} + assert.Equal(t, "Invalid Node", invalidNode.String()) +} diff --git a/trie/zk_trie_proof_test.go b/trie/zk_trie_proof_test.go index bab8950ec217..4701a0229cb9 100644 --- a/trie/zk_trie_proof_test.go +++ b/trie/zk_trie_proof_test.go @@ -18,17 +18,14 @@ package trie import ( "bytes" + "crypto/rand" mrand "math/rand" "testing" "time" "github.com/stretchr/testify/assert" - zkt "github.com/scroll-tech/zktrie/types" - "github.com/scroll-tech/go-ethereum/common" - "github.com/scroll-tech/go-ethereum/core/rawdb" - "github.com/scroll-tech/go-ethereum/crypto" "github.com/scroll-tech/go-ethereum/ethdb/memorydb" ) @@ -43,13 +40,8 @@ func makeSMTProvers(mt *ZkTrie) []func(key []byte) *memorydb.Database { // Create a direct trie based Merkle prover provers = append(provers, func(key []byte) *memorydb.Database { - word := zkt.NewByte32FromBytesPaddingZero(key) - k, err := word.Hash() - if err != nil { - panic(err) - } proofDB := memorydb.New() - err = mt.Prove(common.BytesToHash(k.Bytes()).Bytes(), proofDB) + err := mt.Prove(key, proofDB) if err != nil { panic(err) } @@ -64,14 +56,14 @@ func verifyValue(proveVal []byte, vPreimage []byte) bool { } func TestSMTOneElementProof(t *testing.T) { - tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase(rawdb.NewMemoryDatabase())) - mt := &zkTrieImplTestWrapper{tr.Tree()} - err := mt.UpdateWord( - zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)), - zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)), + mt, _ := newTestingMerkle(t) + err := mt.TryUpdate( + NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)).Bytes(), + 1, + []Byte32{*NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32))}, ) assert.Nil(t, err) - for i, prover := range makeSMTProvers(tr) { + for i, prover := range makeSMTProvers(mt) { keyBytes := bytes.Repeat([]byte("k"), 32) proof := prover(keyBytes) if proof == nil { @@ -84,7 +76,7 @@ func TestSMTOneElementProof(t *testing.T) { root, err := mt.Root() assert.NoError(t, err) - val, err := VerifyProof(common.BytesToHash(root.Bytes()), keyBytes, proof) + val, err := VerifyProofSMT(common.BytesToHash(root.Bytes()), keyBytes, proof) if err != nil { t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) } @@ -96,21 +88,22 @@ func TestSMTOneElementProof(t *testing.T) { func TestSMTProof(t *testing.T) { mt, vals := randomZktrie(t, 500) - root, err := mt.Tree().Root() + root, err := mt.Root() assert.NoError(t, err) for i, prover := range makeSMTProvers(mt) { - for _, kv := range vals { - proof := prover(kv.k) + for kStr, v := range vals { + k := []byte(kStr) + proof := prover(k) if proof == nil { - t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) + t.Fatalf("prover %d: missing key %x while constructing proof", i, k) } - val, err := VerifyProof(common.BytesToHash(root.Bytes()), kv.k, proof) + val, err := VerifyProofSMT(common.BytesToHash(root.Bytes()), k, proof) if err != nil { - t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x\n", i, kv.k, err, proof) + t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x\n", i, k, err, proof) } - if !verifyValue(val, zkt.NewByte32FromBytesPaddingZero(kv.v)[:]) { - t.Fatalf("prover %d: verified value mismatch for key %x, want %x, get %x", i, kv.k, kv.v, val) + if !verifyValue(val, NewByte32FromBytesPaddingZero(v)[:]) { + t.Fatalf("prover %d: verified value mismatch for key %x, want %x, get %x", i, k, v, val) } } } @@ -118,29 +111,30 @@ func TestSMTProof(t *testing.T) { func TestSMTBadProof(t *testing.T) { mt, vals := randomZktrie(t, 500) - root, err := mt.Tree().Root() + root, err := mt.Root() assert.NoError(t, err) for i, prover := range makeSMTProvers(mt) { - for _, kv := range vals { - proof := prover(kv.k) + for kStr, _ := range vals { + k := []byte(kStr) + proof := prover(k) if proof == nil { t.Fatalf("prover %d: nil proof", i) } it := proof.NewIterator(nil, nil) - for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ { + for i, d := 0, mrand.Intn(proof.Len()-1); i <= d; i++ { + it.Next() + } + if bytes.Equal(it.Key(), magicHash) { it.Next() } + key := it.Key() - val, _ := proof.Get(key) proof.Delete(key) it.Release() - mutateByte(val) - proof.Put(crypto.Keccak256(val), val) - - if _, err := VerifyProof(common.BytesToHash(root.Bytes()), kv.k, proof); err == nil { - t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k) + if value, err := VerifyProof(common.BytesToHash(root.Bytes()), k, proof); err == nil && value != nil { + t.Fatalf("prover %d: expected proof to fail for key %x", i, k) } } } @@ -149,15 +143,15 @@ func TestSMTBadProof(t *testing.T) { // Tests that missing keys can also be proven. The test explicitly uses a single // entry trie and checks for missing keys both before and after the single entry. func TestSMTMissingKeyProof(t *testing.T) { - tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase(rawdb.NewMemoryDatabase())) - mt := &zkTrieImplTestWrapper{tr.Tree()} - err := mt.UpdateWord( - zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)), - zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)), + mt, _ := newTestingMerkle(t) + err := mt.TryUpdate( + NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)).Bytes(), + 1, + []Byte32{*NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32))}, ) assert.Nil(t, err) - prover := makeSMTProvers(tr)[0] + prover := makeSMTProvers(mt)[0] for i, key := range []string{"a", "j", "l", "z"} { keyBytes := bytes.Repeat([]byte(key), 32) @@ -170,7 +164,7 @@ func TestSMTMissingKeyProof(t *testing.T) { root, err := mt.Root() assert.NoError(t, err) - val, err := VerifyProof(common.BytesToHash(root.Bytes()), keyBytes, proof) + val, err := VerifyProofSMT(common.BytesToHash(root.Bytes()), keyBytes, proof) if err != nil { t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) } @@ -180,117 +174,35 @@ func TestSMTMissingKeyProof(t *testing.T) { } } -func randomZktrie(t *testing.T, n int) (*ZkTrie, map[string]*kv) { - tr, err := NewZkTrie(common.Hash{}, NewZktrieDatabase(rawdb.NewMemoryDatabase())) - if err != nil { - panic(err) +func randomZktrie(t *testing.T, n int) (*ZkTrie, map[string][]byte) { + randBytes := func(len int) []byte { + buf := make([]byte, len) + if n, err := rand.Read(buf); n != len || err != nil { + panic(err) + } + return buf } - mt := &zkTrieImplTestWrapper{tr.Tree()} - vals := make(map[string]*kv) + + mt, _ := newTestingMerkle(t) + vals := make(map[string][]byte) for i := byte(0); i < 100; i++ { - value := &kv{common.LeftPadBytes([]byte{i}, 32), bytes.Repeat([]byte{i}, 32), false} - value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), bytes.Repeat([]byte{i}, 32), false} + key, value := common.LeftPadBytes([]byte{i}, 32), NewByte32FromBytes(bytes.Repeat([]byte{i}, 32)) + key2, value2 := common.LeftPadBytes([]byte{i + 10}, 32), NewByte32FromBytes(bytes.Repeat([]byte{i}, 32)) - err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value.k), zkt.NewByte32FromBytesPaddingZero(value.v)) + err := mt.TryUpdate(key, 1, []Byte32{*value}) assert.Nil(t, err) - err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value2.k), zkt.NewByte32FromBytesPaddingZero(value2.v)) + err = mt.TryUpdate(key2, 1, []Byte32{*value2}) assert.Nil(t, err) - vals[string(value.k)] = value - vals[string(value2.k)] = value2 + vals[string(key)] = value.Bytes() + vals[string(key2)] = value2.Bytes() } for i := 0; i < n; i++ { - value := &kv{randBytes(32), randBytes(20), false} - err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value.k), zkt.NewByte32FromBytesPaddingZero(value.v)) + key, value := randBytes(32), NewByte32FromBytes(randBytes(20)) + err := mt.TryUpdate(key, 1, []Byte32{*value}) assert.Nil(t, err) - vals[string(value.k)] = value + vals[string(key)] = value.Bytes() } - return tr, vals -} - -// Tests that new "proof trace" feature -func TestProofWithDeletion(t *testing.T) { - tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase(rawdb.NewMemoryDatabase())) - mt := &zkTrieImplTestWrapper{tr.Tree()} - key1 := bytes.Repeat([]byte("l"), 32) - key2 := bytes.Repeat([]byte("m"), 32) - err := mt.UpdateWord( - zkt.NewByte32FromBytesPaddingZero(key1), - zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)), - ) - assert.NoError(t, err) - err = mt.UpdateWord( - zkt.NewByte32FromBytesPaddingZero(key2), - zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("n"), 32)), - ) - assert.NoError(t, err) - - proof := memorydb.New() - s_key1, err := zkt.ToSecureKeyBytes(key1) - assert.NoError(t, err) - - proofTracer := tr.NewProofTracer() - - err = proofTracer.Prove(s_key1.Bytes(), proof) - assert.NoError(t, err) - nd, err := tr.TryGet(key2) - assert.NoError(t, err) - - s_key2, err := zkt.ToSecureKeyBytes(bytes.Repeat([]byte("x"), 32)) - assert.NoError(t, err) - - err = proofTracer.Prove(s_key2.Bytes(), proof) - assert.NoError(t, err) - //assert.Equal(t, len(sibling1), len(delTracer.GetProofs())) - - siblings, err := proofTracer.GetDeletionProofs() - assert.NoError(t, err) - assert.Equal(t, 0, len(siblings)) - - proofTracer.MarkDeletion(s_key1.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() - assert.NoError(t, err) - assert.Equal(t, 1, len(siblings)) - l := len(siblings[0]) - // a hacking to grep the value part directly from the encoded leaf node, - // notice the sibling of key `k*32`` is just the leaf of key `m*32` - assert.Equal(t, siblings[0][l-33:l-1], nd) - - // Marking a key that is currently not hit (but terminated by an empty node) - // also causes it to be added to the deletion proof - proofTracer.MarkDeletion(s_key2.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() - assert.NoError(t, err) - assert.Equal(t, 2, len(siblings)) - - key3 := bytes.Repeat([]byte("x"), 32) - err = mt.UpdateWord( - zkt.NewByte32FromBytesPaddingZero(key3), - zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("z"), 32)), - ) - assert.NoError(t, err) - - proofTracer = tr.NewProofTracer() - err = proofTracer.Prove(s_key1.Bytes(), proof) - assert.NoError(t, err) - err = proofTracer.Prove(s_key2.Bytes(), proof) - assert.NoError(t, err) - - proofTracer.MarkDeletion(s_key1.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() - assert.NoError(t, err) - assert.Equal(t, 1, len(siblings)) - - proofTracer.MarkDeletion(s_key2.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() - assert.NoError(t, err) - assert.Equal(t, 2, len(siblings)) - - // one of the siblings is just leaf for key2, while - // another one must be a middle node - match1 := bytes.Equal(siblings[0][l-33:l-1], nd) - match2 := bytes.Equal(siblings[1][l-33:l-1], nd) - assert.True(t, match1 || match2) - assert.False(t, match1 && match2) + return mt, vals } diff --git a/trie/zk_trie_test.go b/trie/zk_trie_test.go index 61e5d33427f2..db03180b96c0 100644 --- a/trie/zk_trie_test.go +++ b/trie/zk_trie_test.go @@ -1,122 +1,763 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - package trie import ( "bytes" - "encoding/binary" - "io/ioutil" - "os" + "math/big" "runtime" "sync" "testing" - "github.com/stretchr/testify/assert" - - zkt "github.com/scroll-tech/zktrie/types" - + "github.com/iden3/go-iden3-crypto/constants" "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/core/rawdb" - "github.com/scroll-tech/go-ethereum/trie/triedb/hashdb" + "github.com/scroll-tech/go-ethereum/core/types" + "github.com/scroll-tech/go-ethereum/trie/trienode" + "github.com/stretchr/testify/assert" ) -func newEmptyZkTrie() *ZkTrie { - trie, _ := NewZkTrie( - common.Hash{}, - &ZktrieDatabase{ - db: NewDatabase(rawdb.NewMemoryDatabase(), - &Config{Preimages: true}), - prefix: []byte{}, - }, - ) - return trie +func newTestingMerkle(t *testing.T) (*ZkTrie, *Database) { + db := NewDatabase(rawdb.NewMemoryDatabase(), HashDefaults) + return newTestingMerkleWithDb(t, common.Hash{}, db) +} + +func newTestingMerkleWithDb(t *testing.T, root common.Hash, db *Database) (*ZkTrie, *Database) { + maxLevels := NodeKeyValidBytes * 8 + mt, err := NewZkTrie(TrieID(root), db) + if err != nil { + t.Fatal(err) + return nil, nil + } + assert.Equal(t, maxLevels, mt.MaxLevels()) + return mt, db +} + +func TestMerkleTree_Init(t *testing.T) { + maxLevels := 248 + t.Run("Test NewZkTrieImpl", func(t *testing.T) { + mt, _ := newTestingMerkle(t) + mtRoot, err := mt.Root() + assert.NoError(t, err) + assert.Equal(t, HashZero.Bytes(), mtRoot.Bytes()) + }) + + t.Run("Test NewZkTrieImplWithRoot with zero hash root", func(t *testing.T) { + mt, _ := newTestingMerkle(t) + mtRoot, err := mt.Root() + assert.NoError(t, err) + assert.Equal(t, HashZero.Bytes(), mtRoot.Bytes()) + }) + + t.Run("Test NewZkTrieImplWithRoot with non-zero hash root and node exists", func(t *testing.T) { + mt1, db := newTestingMerkle(t) + mt1Root, err := mt1.Root() + assert.NoError(t, err) + assert.Equal(t, HashZero.Bytes(), mt1Root.Bytes()) + err = mt1.TryUpdate([]byte{1}, 1, []Byte32{{byte(1)}}) + assert.NoError(t, err) + mt1Root, err = mt1.Root() + assert.NoError(t, err) + assert.Equal(t, "1525946038598ec48c663db06fa4f0b68ba40b80d7b1ddce3206d4857ac4a47c", mt1Root.Hex()) + rootHash, nodeSet, err := mt1.Commit(false) + assert.NoError(t, err) + assert.NoError(t, db.Update(rootHash, common.Hash{}, 0, trienode.NewWithNodeSet(nodeSet), nil)) + assert.NoError(t, db.Commit(rootHash, false)) + + mt2, _ := newTestingMerkleWithDb(t, rootHash, db) + assert.Equal(t, maxLevels, mt2.maxLevels) + mt2Root, err := mt2.Root() + assert.NoError(t, err) + assert.Equal(t, "1525946038598ec48c663db06fa4f0b68ba40b80d7b1ddce3206d4857ac4a47c", mt2Root.Hex()) + }) +} + +func TestMerkleTree_AddUpdateGetWord(t *testing.T) { + mt, _ := newTestingMerkle(t) + + testData := []struct { + key byte + initialVal byte + updatedVal byte + }{ + {1, 2, 7}, + {3, 4, 8}, + {5, 6, 9}, + } + + for _, td := range testData { + err := mt.TryUpdate([]byte{td.key}, 1, []Byte32{{td.initialVal}}) + assert.NoError(t, err) + + node, err := mt.GetLeafNode([]byte{td.key}) + assert.NoError(t, err) + assert.Equal(t, 1, len(node.ValuePreimage)) + assert.Equal(t, (&Byte32{td.initialVal})[:], node.ValuePreimage[0][:]) + } + + for _, td := range testData { + err := mt.TryUpdate([]byte{td.key}, 1, []Byte32{{td.updatedVal}}) + assert.NoError(t, err) + + node, err := mt.GetLeafNode([]byte{td.key}) + assert.NoError(t, err) + assert.Equal(t, 1, len(node.ValuePreimage)) + assert.Equal(t, (&Byte32{td.updatedVal})[:], node.ValuePreimage[0][:]) + } + + _, err := mt.GetLeafNode([]byte{100}) + assert.Equal(t, ErrKeyNotFound, err) } -// makeTestSecureTrie creates a large enough secure trie for testing. -func makeTestZkTrie() (*ZktrieDatabase, *ZkTrie, map[string][]byte) { - // Create an empty trie - triedb := NewZktrieDatabase(rawdb.NewMemoryDatabase()) - trie, _ := NewZkTrie(common.Hash{}, triedb) - - // Fill it with some arbitrary data - content := make(map[string][]byte) - for i := byte(0); i < 255; i++ { - // Map the same data under multiple keys - key, val := common.LeftPadBytes([]byte{1, i}, 32), bytes.Repeat([]byte{i}, 32) - content[string(key)] = val - trie.Update(key, val) - - key, val = common.LeftPadBytes([]byte{2, i}, 32), bytes.Repeat([]byte{i}, 32) - content[string(key)] = val - trie.Update(key, val) - - // Add some other data to inflate the trie - for j := byte(3); j < 13; j++ { - key, val = common.LeftPadBytes([]byte{j, i}, 32), bytes.Repeat([]byte{j, i}, 16) - content[string(key)] = val - trie.Update(key, val) +func TestMerkleTree_Deletion(t *testing.T) { + t.Run("Check root consistency", func(t *testing.T) { + var err error + mt, _ := newTestingMerkle(t) + hashes := make([]*Hash, 7) + hashes[0], err = mt.Root() + assert.NoError(t, err) + + for i := 0; i < 6; i++ { + err := mt.TryUpdate([]byte{byte(i)}, 1, []Byte32{{byte(i)}}) + assert.NoError(t, err) + hashes[i+1], err = mt.Root() + assert.NoError(t, err) + } + + for i := 5; i >= 0; i-- { + err := mt.TryDelete([]byte{byte(i)}) + assert.NoError(t, err) + root, err := mt.Root() + assert.NoError(t, err) + assert.Equal(t, hashes[i], root, i) } + }) +} + +func TestZkTrieImpl_Add(t *testing.T) { + k1 := NewByte32FromBytes([]byte{1}) + k2 := NewByte32FromBytes([]byte{2}) + k3 := NewByte32FromBytes([]byte{3}) + + kvMap := map[*Byte32]*Byte32{ + k1: NewByte32FromBytes([]byte{1}), + k2: NewByte32FromBytes([]byte{2}), + k3: NewByte32FromBytes([]byte{3}), } - trie.Commit(false) - // Return the generated trie - return triedb, trie, content + t.Run("Add 1 and 2 in different orders", func(t *testing.T) { + orders := [][]*Byte32{ + {k1, k2}, + {k2, k1}, + } + + roots := make([]*Hash, len(orders)) + for i, order := range orders { + mt, _ := newTestingMerkle(t) + for _, key := range order { + value := kvMap[key] + err := mt.TryUpdate(key.Bytes(), 1, []Byte32{*value}) + assert.NoError(t, err) + } + var err error + roots[i], err = mt.Root() + assert.NoError(t, err) + } + + assert.Equal(t, "254d2db0dc83bbd21708e2af65597e14bce405b38867cedea74a5e3b3be4271a", roots[0].Hex()) + assert.Equal(t, roots[0], roots[1]) + }) + + t.Run("Add 1, 2, 3 in different orders", func(t *testing.T) { + orders := [][]*Byte32{ + {k1, k2, k3}, + {k1, k3, k2}, + {k2, k1, k3}, + {k2, k3, k1}, + {k3, k1, k2}, + {k3, k2, k1}, + } + + roots := make([]*Hash, len(orders)) + for i, order := range orders { + mt, _ := newTestingMerkle(t) + for _, key := range order { + value := kvMap[key] + err := mt.TryUpdate(key.Bytes(), 1, []Byte32{*value}) + assert.NoError(t, err) + } + var err error + roots[i], err = mt.Root() + assert.NoError(t, err) + } + + for i := 1; i < len(roots); i++ { + assert.Equal(t, "0274b9caacecfaaaffa25a00c1c17bd91b9a0fc590aedc06ef22c8d2ba7c76a7", roots[0].Hex()) + assert.Equal(t, roots[0], roots[i]) + } + }) } -func TestZktrieDelete(t *testing.T) { - t.Skip("var-len kv not supported") - trie := newEmptyZkTrie() - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, +func TestZkTrieImpl_Update(t *testing.T) { + k1 := []byte{1} + k2 := []byte{2} + k3 := []byte{3} + + t.Run("Update 1", func(t *testing.T) { + mt1, _ := newTestingMerkle(t) + err := mt1.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + root1, err := mt1.Root() + assert.NoError(t, err) + + mt2, _ := newTestingMerkle(t) + err = mt2.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{2})}) + assert.NoError(t, err) + err = mt2.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + root2, err := mt2.Root() + assert.NoError(t, err) + + assert.Equal(t, root1, root2) + }) + + t.Run("Update 2", func(t *testing.T) { + mt1, _ := newTestingMerkle(t) + err := mt1.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + err = mt1.TryUpdate(k2, 1, []Byte32{*NewByte32FromBytes([]byte{2})}) + assert.NoError(t, err) + root1, err := mt1.Root() + assert.NoError(t, err) + + mt2, _ := newTestingMerkle(t) + err = mt2.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + err = mt2.TryUpdate(k2, 1, []Byte32{*NewByte32FromBytes([]byte{3})}) + assert.NoError(t, err) + err = mt2.TryUpdate(k2, 1, []Byte32{*NewByte32FromBytes([]byte{2})}) + assert.NoError(t, err) + root2, err := mt2.Root() + assert.NoError(t, err) + + assert.Equal(t, root1, root2) + }) + + t.Run("Update 1, 2, 3", func(t *testing.T) { + mt1, _ := newTestingMerkle(t) + mt2, _ := newTestingMerkle(t) + keys := [][]byte{k1, k2, k3} + for i, key := range keys { + err := mt1.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{byte(i)})}) + assert.NoError(t, err) + } + for i, key := range keys { + err := mt2.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{byte(i + 3)})}) + assert.NoError(t, err) + } + for i, key := range keys { + err := mt1.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{byte(i + 6)})}) + assert.NoError(t, err) + err = mt2.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{byte(i + 6)})}) + assert.NoError(t, err) + } + + root1, err := mt1.Root() + assert.NoError(t, err) + root2, err := mt2.Root() + assert.NoError(t, err) + + assert.Equal(t, root1, root2) + }) + + t.Run("Update same value", func(t *testing.T) { + mt, _ := newTestingMerkle(t) + keys := [][]byte{k1, k2, k3} + for _, key := range keys { + err := mt.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + err = mt.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + node, err := mt.GetLeafNode(key) + assert.NoError(t, err) + assert.Equal(t, 1, len(node.ValuePreimage)) + assert.Equal(t, NewByte32FromBytes([]byte{1}).Bytes(), node.ValuePreimage[0][:]) + } + }) + + t.Run("Update non-existent word", func(t *testing.T) { + mt, _ := newTestingMerkle(t) + err := mt.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + node, err := mt.GetLeafNode(k1) + assert.NoError(t, err) + assert.Equal(t, 1, len(node.ValuePreimage)) + assert.Equal(t, NewByte32FromBytes([]byte{1}).Bytes(), node.ValuePreimage[0][:]) + }) +} + +func TestZkTrieImpl_Delete(t *testing.T) { + k1 := []byte{1} + k2 := []byte{2} + k3 := []byte{3} + k4 := []byte{4} + + t.Run("Test deletion leads to empty tree", func(t *testing.T) { + emptyMT, _ := newTestingMerkle(t) + emptyMTRoot, err := emptyMT.Root() + assert.NoError(t, err) + + mt1, _ := newTestingMerkle(t) + err = mt1.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + err = mt1.TryDelete(k1) + assert.NoError(t, err) + mt1Root, err := mt1.Root() + assert.NoError(t, err) + assert.Equal(t, HashZero, *mt1Root) + assert.Equal(t, emptyMTRoot, mt1Root) + + keys := [][]byte{k1, k2, k3, k4} + mt2, _ := newTestingMerkle(t) + for _, key := range keys { + err := mt2.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + } + for _, key := range keys { + err := mt2.TryDelete(key) + assert.NoError(t, err) + } + mt2Root, err := mt2.Root() + assert.NoError(t, err) + assert.Equal(t, HashZero, *mt2Root) + assert.Equal(t, emptyMTRoot, mt2Root) + + mt3, _ := newTestingMerkle(t) + for _, key := range keys { + err := mt3.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + } + for i := len(keys) - 1; i >= 0; i-- { + err := mt3.TryDelete(keys[i]) + assert.NoError(t, err) + } + mt3Root, err := mt3.Root() + assert.NoError(t, err) + assert.Equal(t, HashZero, *mt3Root) + assert.Equal(t, emptyMTRoot, mt3Root) + }) + + t.Run("Test equivalent trees after deletion", func(t *testing.T) { + keys := [][]byte{k1, k2, k3, k4} + + mt1, _ := newTestingMerkle(t) + for i, key := range keys { + err := mt1.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{byte(i + 1)})}) + assert.NoError(t, err) + } + err := mt1.TryDelete(k1) + assert.NoError(t, err) + err = mt1.TryDelete(k2) + assert.NoError(t, err) + + mt2, _ := newTestingMerkle(t) + err = mt2.TryUpdate(k3, 1, []Byte32{*NewByte32FromBytes([]byte{byte(3)})}) + assert.NoError(t, err) + err = mt2.TryUpdate(k4, 1, []Byte32{*NewByte32FromBytes([]byte{byte(4)})}) + assert.NoError(t, err) + + mt1Root, err := mt1.Root() + assert.NoError(t, err) + mt2Root, err := mt2.Root() + assert.NoError(t, err) + + assert.Equal(t, mt1Root, mt2Root) + + mt3, _ := newTestingMerkle(t) + for i, key := range keys { + err := mt3.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes([]byte{byte(i + 1)})}) + assert.NoError(t, err) + } + err = mt3.TryDelete(k1) + assert.NoError(t, err) + err = mt3.TryDelete(k3) + assert.NoError(t, err) + mt4, _ := newTestingMerkle(t) + err = mt4.TryUpdate(k2, 1, []Byte32{*NewByte32FromBytes([]byte{2})}) + assert.NoError(t, err) + err = mt4.TryUpdate(k4, 1, []Byte32{*NewByte32FromBytes([]byte{4})}) + assert.NoError(t, err) + + mt3Root, err := mt3.Root() + assert.NoError(t, err) + mt4Root, err := mt4.Root() + assert.NoError(t, err) + + assert.Equal(t, mt3Root, mt4Root) + }) + + t.Run("Test repeat deletion", func(t *testing.T) { + mt, _ := newTestingMerkle(t) + err := mt.TryUpdate(k1, 1, []Byte32{*NewByte32FromBytes([]byte{1})}) + assert.NoError(t, err) + err = mt.TryDelete(k1) + assert.NoError(t, err) + err = mt.TryDelete(k1) + assert.NoError(t, err) + }) + + t.Run("Test deletion of non-existent node", func(t *testing.T) { + mt, _ := newTestingMerkle(t) + err := mt.TryDelete(k1) + assert.NoError(t, err) + }) +} + +func TestMerkleTree_BuildAndVerifyZkTrieProof(t *testing.T) { + zkTrie, _ := newTestingMerkle(t) + + testData := []struct { + key *big.Int + value byte + }{ + {big.NewInt(1), 2}, + {big.NewInt(3), 4}, + {big.NewInt(5), 6}, + {big.NewInt(7), 8}, + {big.NewInt(9), 10}, } - for _, val := range vals { - if val.v != "" { - trie.Update([]byte(val.k), []byte(val.v)) - } else { - trie.Delete([]byte(val.k)) + + nonExistentKey := big.NewInt(11) + + for _, td := range testData { + err := zkTrie.TryUpdate([]byte{byte(td.key.Int64())}, 1, []Byte32{{td.value}}) + assert.NoError(t, err) + } + _, err := zkTrie.Root() + assert.NoError(t, err) + + t.Run("Test with existent key", func(t *testing.T) { + for _, td := range testData { + + node, err := zkTrie.GetLeafNode([]byte{byte(td.key.Int64())}) + assert.NoError(t, err) + assert.Equal(t, 1, len(node.ValuePreimage)) + assert.Equal(t, (&Byte32{td.value})[:], node.ValuePreimage[0][:]) + proof, node, err := BuildZkTrieProof(zkTrie.rootKey, td.key, 10, zkTrie.GetNodeByHash) + assert.NoError(t, err) + + valid := VerifyProofZkTrie(zkTrie.rootKey, proof, node) + assert.True(t, valid) } + }) + + t.Run("Test with non-existent key", func(t *testing.T) { + proof, node, err := BuildZkTrieProof(zkTrie.rootKey, nonExistentKey, 10, zkTrie.GetNodeByHash) + assert.NoError(t, err) + assert.False(t, proof.Existence) + valid := VerifyProofZkTrie(zkTrie.rootKey, proof, node) + assert.True(t, valid) + nodeAnother, err := zkTrie.GetLeafNode([]byte{byte(big.NewInt(1).Int64())}) + assert.NoError(t, err) + valid = VerifyProofZkTrie(zkTrie.rootKey, proof, nodeAnother) + assert.False(t, valid) + + hash, err := proof.Verify(node.nodeHash) + assert.NoError(t, err) + assert.Equal(t, hash[:], zkTrie.rootKey[:]) + }) +} + +func TestMerkleTree_GraphViz(t *testing.T) { + mt, _ := newTestingMerkle(t) + + var buffer bytes.Buffer + err := mt.GraphViz(&buffer, nil) + assert.NoError(t, err) + assert.Equal(t, "--------\nGraphViz of the ZkTrie with RootHash 0\ndigraph hierarchy {\nnode [fontname=Monospace,fontsize=10,shape=box]\n}\nEnd of GraphViz of the ZkTrie with RootHash 0\n--------\n", buffer.String()) + buffer.Reset() + + key1 := []byte{1} //0b1 + err = mt.TryUpdate(key1, 1, []Byte32{{1}}) + assert.NoError(t, err) + key2 := []byte{3} //0b11 + err = mt.TryUpdate(key2, 1, []Byte32{{3}}) + assert.NoError(t, err) + + err = mt.GraphViz(&buffer, nil) + assert.NoError(t, err) + assert.Equal(t, "--------\nGraphViz of the ZkTrie with RootHash 1210085283654691963881487672127167617844540538182653450104829037534096200821\ndigraph hierarchy {\nnode [fontname=Monospace,fontsize=10,shape=box]\n\"12100852...\" -> {\"95649672...\" \"20807384...\"}\n\"95649672...\" [style=filled];\n\"20807384...\" [style=filled];\n}\nEnd of GraphViz of the ZkTrie with RootHash 1210085283654691963881487672127167617844540538182653450104829037534096200821\n--------\n", buffer.String()) + buffer.Reset() +} + +func TestZkTrie_GetUpdateDelete(t *testing.T) { + mt, _ := newTestingMerkle(t) + val, err := mt.TryGet([]byte("key")) + assert.NoError(t, err) + assert.Nil(t, val) + assert.Equal(t, common.Hash{}, mt.Hash()) + + err = mt.TryUpdate([]byte("key"), 1, []Byte32{{1}}) + assert.NoError(t, err) + expected := common.HexToHash("0x0b9402772b5bfa4c7caaaeb14f489ae201d536c430c3fc29abb0fde923cd1df4") + assert.Equal(t, expected, mt.Hash()) + + val, err = mt.TryGet([]byte("key")) + assert.NoError(t, err) + assert.Equal(t, (&Byte32{1}).Bytes(), val) + + err = mt.TryDelete([]byte("key")) + assert.NoError(t, err) + assert.Equal(t, common.Hash{}, mt.Hash()) + + val, err = mt.TryGet([]byte("key")) + assert.NoError(t, err) + assert.Nil(t, val) +} + +func TestZkTrie_Copy(t *testing.T) { + mt, _ := newTestingMerkle(t) + + mt.TryUpdate([]byte("key"), 1, []Byte32{{1}}) + + copyTrie := mt.Copy() + val, err := copyTrie.TryGet([]byte("key")) + assert.NoError(t, err) + assert.Equal(t, (&Byte32{1}).Bytes(), val) +} + +func TestZkTrie_ProveAndProveWithDeletion(t *testing.T) { + mt, _ := newTestingMerkle(t) + + keys := []string{"key1", "key2", "key3", "key4", "key5"} + for i, keyStr := range keys { + key := make([]byte, 32) + copy(key, []byte(keyStr)) + + err := mt.TryUpdate(key, uint32(i+1), []Byte32{{byte(uint32(i + 1))}}) + assert.NoError(t, err) + + writeNode := func(n *Node) error { + return nil + } + + k, err := ToSecureKey(key) + assert.NoError(t, err) + + for j := 0; j <= i; j++ { + err = mt.ProveWithDeletion(NewHashFromBigInt(k).Bytes(), uint(j), writeNode, nil) + assert.NoError(t, err) + } + } +} + +func newHashFromHex(h string) (*Hash, error) { + return NewHashFromCheckedBytes(common.FromHex(h)) +} + +func TestHashParsers(t *testing.T) { + h0 := NewHashFromBigInt(big.NewInt(0)) + assert.Equal(t, "0", h0.String()) + h1 := NewHashFromBigInt(big.NewInt(1)) + assert.Equal(t, "1", h1.String()) + h10 := NewHashFromBigInt(big.NewInt(10)) + assert.Equal(t, "10", h10.String()) + + h7l := NewHashFromBigInt(big.NewInt(1234567)) + assert.Equal(t, "1234567", h7l.String()) + h8l := NewHashFromBigInt(big.NewInt(12345678)) + assert.Equal(t, "12345678...", h8l.String()) + + b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll + assert.True(t, ok) + h := NewHashFromBigInt(b) + assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll + assert.Equal(t, "49322979...", h.String()) + assert.Equal(t, "0ae794eb9c3d8bbb9002e993fc2ed301dcbd2af5508ed072c375e861f1aa5b26", h.Hex()) + + b1, err := NewBigIntFromHashBytes(b.Bytes()) + assert.Nil(t, err) + assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String()) + + b2, err := NewHashFromCheckedBytes(b.Bytes()) + assert.Nil(t, err) + assert.Equal(t, b.String(), b2.BigInt().String()) + + h2, err := newHashFromHex(h.Hex()) + assert.Nil(t, err) + assert.Equal(t, h, h2) + _, err = newHashFromHex("0x12") + assert.NotNil(t, err) + + // check limits + a := new(big.Int).Sub(constants.Q, big.NewInt(1)) + testHashParsers(t, a) + a = big.NewInt(int64(1)) + testHashParsers(t, a) +} + +func testHashParsers(t *testing.T, a *big.Int) { + h := NewHashFromBigInt(a) + assert.Equal(t, a, h.BigInt()) + hFromBytes, err := NewHashFromCheckedBytes(h.Bytes()) + assert.Nil(t, err) + assert.Equal(t, h, hFromBytes) + assert.Equal(t, a, hFromBytes.BigInt()) + assert.Equal(t, a.String(), hFromBytes.BigInt().String()) + hFromHex, err := newHashFromHex(h.Hex()) + assert.Nil(t, err) + assert.Equal(t, h, hFromHex) + + aBIFromHBytes, err := NewBigIntFromHashBytes(h.Bytes()) + assert.Nil(t, err) + assert.Equal(t, a, aBIFromHBytes) + assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String()) +} + +func TestMerkleTree_AddUpdateGetWord_2(t *testing.T) { + mt, _ := newTestingMerkle(t) + err := mt.TryUpdate([]byte{1}, 1, []Byte32{{2}}) + assert.Nil(t, err) + err = mt.TryUpdate([]byte{3}, 1, []Byte32{{4}}) + assert.Nil(t, err) + err = mt.TryUpdate([]byte{5}, 1, []Byte32{{6}}) + assert.Nil(t, err) + + mt.GetLeafNode([]byte{1}) + node, err := mt.GetLeafNode([]byte{1}) + assert.Nil(t, err) + assert.Equal(t, len(node.ValuePreimage), 1) + assert.Equal(t, (&Byte32{2})[:], node.ValuePreimage[0][:]) + node, err = mt.GetLeafNode([]byte{3}) + assert.Nil(t, err) + assert.Equal(t, len(node.ValuePreimage), 1) + assert.Equal(t, (&Byte32{4})[:], node.ValuePreimage[0][:]) + node, err = mt.GetLeafNode([]byte{5}) + assert.Nil(t, err) + assert.Equal(t, len(node.ValuePreimage), 1) + assert.Equal(t, (&Byte32{6})[:], node.ValuePreimage[0][:]) + + err = mt.TryUpdate([]byte{1}, 1, []Byte32{{7}}) + assert.Nil(t, err) + err = mt.TryUpdate([]byte{3}, 1, []Byte32{{8}}) + assert.Nil(t, err) + err = mt.TryUpdate([]byte{5}, 1, []Byte32{{9}}) + assert.Nil(t, err) + + node, err = mt.GetLeafNode([]byte{1}) + assert.Nil(t, err) + assert.Equal(t, len(node.ValuePreimage), 1) + assert.Equal(t, (&Byte32{7})[:], node.ValuePreimage[0][:]) + node, err = mt.GetLeafNode([]byte{3}) + assert.Nil(t, err) + assert.Equal(t, len(node.ValuePreimage), 1) + assert.Equal(t, (&Byte32{8})[:], node.ValuePreimage[0][:]) + node, err = mt.GetLeafNode([]byte{5}) + assert.Nil(t, err) + assert.Equal(t, len(node.ValuePreimage), 1) + assert.Equal(t, (&Byte32{9})[:], node.ValuePreimage[0][:]) + _, err = mt.GetLeafNode([]byte{100}) + assert.Equal(t, ErrKeyNotFound, err) +} + +func TestMerkleTree_UpdateAccount(t *testing.T) { + mt, _ := newTestingMerkle(t) + + acc1 := &types.StateAccount{ + Nonce: 1, + Balance: big.NewInt(10000000), + Root: common.HexToHash("22fb59aa5410ed465267023713ab42554c250f394901455a3366e223d5f7d147"), + KeccakCodeHash: common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), + PoseidonCodeHash: common.HexToHash("0c0a77f6e063b4b62eb7d9ed6f427cf687d8d0071d751850cfe5d136bc60d3ab").Bytes(), + CodeSize: 0, + } + value, flag := acc1.MarshalFields() + accValue := []Byte32{} + for _, v := range value { + accValue = append(accValue, *NewByte32FromBytes(v[:])) + } + err := mt.TryUpdate(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes(), flag, accValue) + assert.Nil(t, err) + + acc2 := &types.StateAccount{ + Nonce: 5, + Balance: big.NewInt(50000000), + Root: common.HexToHash("0"), + KeccakCodeHash: common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), + PoseidonCodeHash: common.HexToHash("05d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), + CodeSize: 5, } - hash := trie.Hash() - exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") - if hash != exp { - t.Errorf("expected %x got %x", exp, hash) + value, flag = acc2.MarshalFields() + accValue = []Byte32{} + for _, v := range value { + accValue = append(accValue, *NewByte32FromBytes(v[:])) } + err = mt.TryUpdate(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes(), flag, accValue) + assert.Nil(t, err) + + bt, err := mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) + assert.Nil(t, err) + + acc, err := types.UnmarshalStateAccount(bt) + assert.Nil(t, err) + assert.Equal(t, acc1.Nonce, acc.Nonce) + assert.Equal(t, acc1.Balance.Uint64(), acc.Balance.Uint64()) + assert.Equal(t, acc1.Root.Bytes(), acc.Root.Bytes()) + assert.Equal(t, acc1.KeccakCodeHash, acc.KeccakCodeHash) + assert.Equal(t, acc1.PoseidonCodeHash, acc.PoseidonCodeHash) + assert.Equal(t, acc1.CodeSize, acc.CodeSize) + + bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) + assert.Nil(t, err) + + acc, err = types.UnmarshalStateAccount(bt) + assert.Nil(t, err) + assert.Equal(t, acc2.Nonce, acc.Nonce) + assert.Equal(t, acc2.Balance.Uint64(), acc.Balance.Uint64()) + assert.Equal(t, acc2.Root.Bytes(), acc.Root.Bytes()) + assert.Equal(t, acc2.KeccakCodeHash, acc.KeccakCodeHash) + assert.Equal(t, acc2.PoseidonCodeHash, acc.PoseidonCodeHash) + assert.Equal(t, acc2.CodeSize, acc.CodeSize) + + bt, err = mt.TryGet(common.HexToAddress("0x8dE13967F19410A7991D63c2c0179feBFDA0c261").Bytes()) + assert.Nil(t, err) + assert.Nil(t, bt) + + err = mt.TryDelete(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) + assert.Nil(t, err) + + bt, err = mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) + assert.Nil(t, err) + assert.Nil(t, bt) + + err = mt.TryDelete(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) + assert.Nil(t, err) + + bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) + assert.Nil(t, err) + assert.Nil(t, bt) +} + +func TestDecodeSMTProof(t *testing.T) { + node, err := DecodeSMTProof(magicSMTBytes) + assert.NoError(t, err) + assert.Nil(t, node) + + k1 := NewHashFromBytes([]byte{1, 2, 3, 4, 5}) + k2 := NewHashFromBytes([]byte{6, 7, 8, 9, 0}) + origNode := NewParentNode(NodeTypeBranch_0, k1, k2) + node, err = DecodeSMTProof(origNode.Value()) + assert.NoError(t, err) + assert.Equal(t, origNode.Value(), node.Value()) } func TestZktrieGetKey(t *testing.T) { - trie := newEmptyZkTrie() + trie, _ := newTestingMerkle(t) key := []byte("0a1b2c3d4e5f6g7h8i9j0a1b2c3d4e5f") value := []byte("9j8i7h6g5f4e3d2c1b0a9j8i7h6g5f4e") - trie.Update(key, value) + trie.TryUpdate(key, 1, []Byte32{*NewByte32FromBytes(value)}) - kPreimage := zkt.NewByte32FromBytesPaddingZero(key) + kPreimage := NewByte32FromBytesPaddingZero(key) kHash, err := kPreimage.Hash() assert.Nil(t, err) - - if !bytes.Equal(trie.Get(key), value) { - t.Errorf("Get did not return bar") - } if k := trie.GetKey(kHash.Bytes()); !bytes.Equal(k, key) { t.Errorf("GetKey returned %q, want %q", k, key) } @@ -124,7 +765,7 @@ func TestZktrieGetKey(t *testing.T) { func TestZkTrieConcurrency(t *testing.T) { // Create an initial trie and copy if for concurrent access - _, trie, _ := makeTestZkTrie() + trie, _ := newTestingMerkle(t) threads := runtime.NumCPU() tries := make([]*ZkTrie, threads) @@ -141,15 +782,15 @@ func TestZkTrieConcurrency(t *testing.T) { for j := byte(0); j < 255; j++ { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), bytes.Repeat([]byte{j}, 32) - tries[index].Update(key, val) + tries[index].TryUpdate(key, 1, []Byte32{*NewByte32FromBytes(val)}) key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), bytes.Repeat([]byte{j}, 32) - tries[index].Update(key, val) + tries[index].TryUpdate(key, 1, []Byte32{*NewByte32FromBytes(val)}) // Add some other data to inflate the trie for k := byte(3); k < 13; k++ { key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), bytes.Repeat([]byte{k, j}, 16) - tries[index].Update(key, val) + tries[index].TryUpdate(key, 1, []Byte32{*NewByte32FromBytes(val)}) } } tries[index].Commit(false) @@ -159,109 +800,23 @@ func TestZkTrieConcurrency(t *testing.T) { pend.Wait() } -func tempDBZK(b *testing.B) (string, *Database) { - dir, err := ioutil.TempDir("", "zktrie-bench") - assert.NoError(b, err) - - diskdb, err := rawdb.NewLevelDBDatabase(dir, 256, 0, "", false) - assert.NoError(b, err) - config := &Config{ - Preimages: true, - HashDB: &hashdb.Config{CleanCacheSize: 256}, - IsUsingZktrie: true, - } - return dir, NewDatabase(diskdb, config) -} - -const benchElemCountZk = 10000 - -func BenchmarkZkTrieGet(b *testing.B) { - dir, tmpdb := tempDBZK(b) - zkTrie, _ := NewZkTrie(common.Hash{}, NewZktrieDatabaseFromTriedb(tmpdb)) - defer func() { - ldb := zkTrie.db.db.diskdb - ldb.Close() - os.RemoveAll(dir) - }() - - k := make([]byte, 32) - for i := 0; i < benchElemCountZk; i++ { - binary.LittleEndian.PutUint64(k, uint64(i)) - - err := zkTrie.TryUpdate(k, k) - assert.NoError(b, err) - } - - zkTrie.db.db.Commit(common.Hash{}, true) - b.ResetTimer() - for i := 0; i < b.N; i++ { - binary.LittleEndian.PutUint64(k, uint64(i)) - _, err := zkTrie.TryGet(k) - assert.NoError(b, err) - } - b.StopTimer() -} - -func BenchmarkZkTrieUpdate(b *testing.B) { - dir, tmpdb := tempDBZK(b) - zkTrie, _ := NewZkTrie(common.Hash{}, NewZktrieDatabaseFromTriedb(tmpdb)) - defer func() { - ldb := zkTrie.db.db.diskdb - ldb.Close() - os.RemoveAll(dir) - }() - - k := make([]byte, 32) - v := make([]byte, 32) - b.ReportAllocs() - - for i := 0; i < benchElemCountZk; i++ { - binary.LittleEndian.PutUint64(k, uint64(i)) - err := zkTrie.TryUpdate(k, k) - assert.NoError(b, err) - } - binary.LittleEndian.PutUint64(k, benchElemCountZk/2) - - //zkTrie.Commit(false) - zkTrie.db.db.Commit(common.Hash{}, true) - b.ResetTimer() - for i := 0; i < b.N; i++ { - binary.LittleEndian.PutUint64(k, uint64(i)) - binary.LittleEndian.PutUint64(v, 0xffffffff+uint64(i)) - err := zkTrie.TryUpdate(k, v) - assert.NoError(b, err) - } - b.StopTimer() -} - func TestZkTrieDelete(t *testing.T) { - key := make([]byte, 32) - value := make([]byte, 32) - trie1 := newEmptyZkTrie() + trie1, _ := newTestingMerkle(t) var count int = 6 var hashes []common.Hash hashes = append(hashes, trie1.Hash()) for i := 0; i < count; i++ { - binary.LittleEndian.PutUint64(key, uint64(i)) - binary.LittleEndian.PutUint64(value, uint64(i)) - err := trie1.TryUpdate(key, value) + err := trie1.TryUpdate([]byte{byte(i)}, 1, []Byte32{{byte(i)}}) assert.NoError(t, err) hashes = append(hashes, trie1.Hash()) } - // binary.LittleEndian.PutUint64(key, uint64(0xffffff)) - // err := trie1.TryDelete(key) - // assert.Equal(t, err, zktrie.ErrKeyNotFound) - - trie1.Commit(false) - for i := count - 1; i >= 0; i-- { - binary.LittleEndian.PutUint64(key, uint64(i)) - v, err := trie1.TryGet(key) + v, err := trie1.TryGet([]byte{byte(i)}) assert.NoError(t, err) assert.NotEmpty(t, v) - err = trie1.TryDelete(key) + err = trie1.TryDelete([]byte{byte(i)}) assert.NoError(t, err) hash := trie1.Hash() assert.Equal(t, hashes[i].Hex(), hash.Hex())