From 49d0e145881a97d0c3163b9a40cde658cf1672dd Mon Sep 17 00:00:00 2001 From: Avi Dessauer Date: Fri, 8 Mar 2024 18:56:44 -0600 Subject: [PATCH] Store to database --- src/lib.rs | 76 ++++++++++++++++++++++++++++++++++++-------- src/stored.rs | 41 +++++++++++++++++++----- src/stored/merkle.rs | 8 ++--- 3 files changed, 100 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b9300a3..25df319 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,7 @@ use sha2::{Digest, Sha256}; pub use stored::Store; use stored::{ merkle::{Snapshot, SnapshotBuilder}, - Node, NodeHash, + DatabaseSet, Node, NodeHash, }; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -258,7 +258,7 @@ impl Branch { } } -impl Branch> { +impl Branch> { fn new_at_branch( word_idx: usize, branch_word_or_prefix: u32, @@ -424,7 +424,42 @@ pub struct Transaction { pub current_root: TrieRoot, } -impl, V: Debug + AsRef<[u8]>> Transaction { +impl<'a, Db: DatabaseSet, V: Clone + AsRef<[u8]>> Transaction, V> { + /// Write modified nodes to the database and return the root hash. + /// Calling this method will write all modified nodes to the database. + /// Calling this method again will rewrite the nodes to the database. + /// + /// Caching writes is the responsibility of the `DatabaseSet` implementation. + pub fn commit(&self) -> Result { + let store_modified_branch = + &mut |hash: &NodeHash, branch: &Branch>, left: NodeHash, right: NodeHash| { + let branch = Branch { + left, + right, + mask: branch.mask, + prior_word: branch.prior_word, + prefix: branch.prefix.clone(), + }; + + self.data_store + .db + .set(*hash, Node::Branch(branch)) + .map_err(|e| e.into()) + }; + + let store_modified_leaf = &mut |hash: &NodeHash, leaf: &Leaf| { + self.data_store + .db + .set(*hash, Node::Leaf(leaf.clone())) + .map_err(|e| e.into()) + }; + + let root_hash = self.calc_root_hash_inner(store_modified_branch, store_modified_leaf)?; + Ok(root_hash) + } +} + +impl, V: AsRef<[u8]>> Transaction { pub fn new(root: TrieRoot, data_store: S) -> Self { Transaction { current_root: root, @@ -433,29 +468,44 @@ impl, V: Debug + AsRef<[u8]>> Transaction { } /// TODO a version of this that writes to the database. - pub fn calc_root_hash(&self) -> Result { - let mut on_modified_leaf = |_: &_, _: &_| {}; - let mut on_modified_branch = |_: &_, _: &_| {}; - + pub fn calc_root_hash_inner( + &self, + on_modified_branch: &mut impl FnMut( + &NodeHash, + &Branch>, + NodeHash, + NodeHash, + ) -> Result<(), String>, + on_modified_leaf: &mut impl FnMut(&NodeHash, &Leaf) -> Result<(), String>, + ) -> Result { let root_hash = match &self.current_root { TrieRoot::Empty => return Ok([0; 32]), TrieRoot::Node(node_ref) => Self::calc_root_hash_node( &self.data_store, node_ref, - &mut on_modified_leaf, - &mut on_modified_branch, + on_modified_leaf, + on_modified_branch, )?, }; Ok(root_hash) } + pub fn calc_root_hash(&self) -> Result { + self.calc_root_hash_inner(&mut |_, _, _, _| Ok(()), &mut |_, _| Ok(())) + } + /// TODO use this to store nodes in the data base fn calc_root_hash_node( data_store: &S, node_ref: &NodeRef, - on_modified_leaf: &mut impl FnMut(&NodeHash, &Leaf), - on_modified_branch: &mut impl FnMut(&NodeHash, &Branch>), + on_modified_leaf: &mut impl FnMut(&NodeHash, &Leaf) -> Result<(), String>, + on_modified_branch: &mut impl FnMut( + &NodeHash, + &Branch>, + NodeHash, + NodeHash, + ) -> Result<(), String>, ) -> Result { // TODO use a stack instead of recursion match node_ref { @@ -474,13 +524,13 @@ impl, V: Debug + AsRef<[u8]>> Transaction { )?; let hash = branch.hash_branch(&left, &right); - on_modified_branch(&hash, branch); + on_modified_branch(&hash, branch, left, right)?; Ok(hash) } NodeRef::ModLeaf(leaf) => { let hash = leaf.hash_leaf(); - on_modified_leaf(&hash, leaf); + on_modified_leaf(&hash, leaf)?; Ok(hash) } NodeRef::Stored(stored_idx) => { diff --git a/src/stored.rs b/src/stored.rs index a5cf4f0..e47985a 100644 --- a/src/stored.rs +++ b/src/stored.rs @@ -1,5 +1,6 @@ pub mod merkle; +use std::cell::RefCell; use std::hash::Hash; use alloc::{collections::BTreeMap, fmt::Debug, string::String}; @@ -18,10 +19,20 @@ pub trait Store { fn get_node(&self, hash_idx: Idx) -> Result, &Leaf>, Self::Error>; } -pub trait Database { - type Error: Into + Debug; +pub trait DatabaseGet { + type GetError: Into + Debug; + + fn get(&self, hash: &NodeHash) -> Result, Leaf>, Self::GetError>; +} + +pub trait DatabaseSet: DatabaseGet { + type SetError: Into + Debug; - fn get(&self, hash: &NodeHash) -> Result, Leaf>, Self::Error>; + fn set( + &self, + hash: NodeHash, + node: Node, Leaf>, + ) -> Result<(), Self::GetError>; } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -47,24 +58,38 @@ pub type NodeHash = [u8; 32]; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct MemoryDb { - leaves: BTreeMap, Leaf>>, + leaves: RefCell, Leaf>>>, } impl MemoryDb { pub fn empty() -> Self { Self { - leaves: BTreeMap::new(), + leaves: RefCell::default(), } } } -impl Database for MemoryDb { - type Error = Error; +impl DatabaseGet for MemoryDb { + type GetError = Error; - fn get(&self, hash_idx: &NodeHash) -> Result, Leaf>, Self::Error> { + fn get(&self, hash_idx: &NodeHash) -> Result, Leaf>, Self::GetError> { self.leaves + .borrow() .get(hash_idx) .cloned() .ok_or(Error::NodeNotFound) } } + +impl DatabaseSet for MemoryDb { + type SetError = Error; + + fn set( + &self, + hash_idx: NodeHash, + node: Node, Leaf>, + ) -> Result<(), Self::SetError> { + self.leaves.borrow_mut().insert(hash_idx, node); + Ok(()) + } +} diff --git a/src/stored/merkle.rs b/src/stored/merkle.rs index cdb0ae3..194b2dd 100644 --- a/src/stored/merkle.rs +++ b/src/stored/merkle.rs @@ -5,7 +5,7 @@ use bumpalo::Bump; use crate::{Branch, Leaf}; -use super::{Database, Error, Idx, Node, NodeHash, Store}; +use super::{DatabaseGet, Error, Idx, Node, NodeHash, Store}; /// A snapshot of the merkle trie /// @@ -104,7 +104,7 @@ impl> Store for Snapshot { #[derive(Clone, Debug)] pub struct SnapshotBuilder<'a, Db, V> { - db: Db, + pub db: Db, bump: &'a Bump, /// The root of the trie is always at index 0 @@ -113,7 +113,7 @@ pub struct SnapshotBuilder<'a, Db, V> { type NodeHashMaybeNode<'a, V> = (NodeHash, Option, &'a Leaf>>); -impl<'a, Db: Database, V: Clone> Store for SnapshotBuilder<'a, Db, V> { +impl<'a, Db: DatabaseGet, V: Clone> Store for SnapshotBuilder<'a, Db, V> { type Error = Error; fn get_unvisted_hash(&self, hash_idx: Idx) -> Result<&NodeHash, Self::Error> { @@ -203,7 +203,7 @@ impl<'a, Db, V: Clone> SnapshotBuilder<'a, Db, V> { Error, > where - Db: Database, + Db: DatabaseGet, { let Ok(node) = db.get(hash) else { return Err(Error::NodeNotFound);