From 29d972db8790fe60da7bf20de58a7e8d496a28cd Mon Sep 17 00:00:00 2001 From: lightsing Date: Fri, 13 Sep 2024 14:59:38 +0800 Subject: [PATCH] add gc to trie --- src/db/kv/btree_map.rs | 32 ++++++++++++------- src/db/kv/hash_map.rs | 32 ++++++++++++------- src/db/kv/mod.rs | 16 ++++++++++ src/db/kv/sled.rs | 33 ++++++++++++------- src/lib.rs | 2 +- src/trie/zktrie/imp.rs | 69 ++++++++++++++++++++++++++++++++++++++++ src/trie/zktrie/mod.rs | 6 ++++ src/trie/zktrie/tests.rs | 11 ++++++- 8 files changed, 163 insertions(+), 38 deletions(-) diff --git a/src/db/kv/btree_map.rs b/src/db/kv/btree_map.rs index 39318d0..c66e6cb 100644 --- a/src/db/kv/btree_map.rs +++ b/src/db/kv/btree_map.rs @@ -33,18 +33,6 @@ impl BTreeMapDb { Self { gc_enabled, db } } - /// Enable or disable garbage collection. - #[inline] - pub fn set_gc_enabled(&mut self, gc_enabled: bool) { - self.gc_enabled = gc_enabled; - } - - /// Check if garbage collection is enabled. - #[inline] - pub fn is_gc_enabled(&self) -> bool { - self.gc_enabled - } - /// Get the inner `BTreeMap`. pub fn inner(&self) -> &BTreeMap, Box<[u8]>> { &self.db @@ -77,6 +65,10 @@ impl KVDatabase for BTreeMapDb { Ok(self.db.get(k)) } + fn set_gc_enabled(&mut self, gc_enabled: bool) { + self.gc_enabled = gc_enabled; + } + fn gc_enabled(&self) -> bool { self.gc_enabled } @@ -90,6 +82,22 @@ impl KVDatabase for BTreeMapDb { Ok(()) } + fn retain(&mut self, mut f: F) -> Result<(), Self::Error> + where + F: FnMut(&[u8], &[u8]) -> bool, + { + let mut removed = 0; + self.db.retain(|k, v| { + let keep = f(k, v); + if !keep { + removed += 1; + } + keep + }); + trace!("{} key-value pairs removed", removed); + Ok(()) + } + fn extend, Box<[u8]>)>>( &mut self, other: T, diff --git a/src/db/kv/hash_map.rs b/src/db/kv/hash_map.rs index c7ac527..8bb591e 100644 --- a/src/db/kv/hash_map.rs +++ b/src/db/kv/hash_map.rs @@ -31,18 +31,6 @@ impl HashMapDb { Self { gc_enabled, db } } - /// Enable or disable garbage collection. - #[inline] - pub fn set_gc_enabled(&mut self, gc_enabled: bool) { - self.gc_enabled = gc_enabled; - } - - /// Check if garbage collection is enabled. - #[inline] - pub fn is_gc_enabled(&self) -> bool { - self.gc_enabled - } - /// Get the inner [`HashMap`](std::collections::HashMap). pub fn inner(&self) -> &HashMap, Box<[u8]>> { &self.db @@ -75,6 +63,10 @@ impl KVDatabase for HashMapDb { Ok(self.db.get(k)) } + fn set_gc_enabled(&mut self, gc_enabled: bool) { + self.gc_enabled = gc_enabled; + } + fn gc_enabled(&self) -> bool { self.gc_enabled } @@ -88,6 +80,22 @@ impl KVDatabase for HashMapDb { Ok(()) } + fn retain(&mut self, mut f: F) -> Result<(), Self::Error> + where + F: FnMut(&[u8], &[u8]) -> bool, + { + let mut removed = 0; + self.db.retain(|k, v| { + let keep = f(k, v); + if !keep { + removed += 1; + } + keep + }); + trace!("{} key-value pairs removed", removed); + Ok(()) + } + fn extend, Box<[u8]>)>>( &mut self, other: T, diff --git a/src/db/kv/mod.rs b/src/db/kv/mod.rs index 5d312f9..3e3e7a9 100644 --- a/src/db/kv/mod.rs +++ b/src/db/kv/mod.rs @@ -37,6 +37,9 @@ pub trait KVDatabase { /// Returns `Ok(None)` if the key is not present. fn get(&self, k: &[u8]) -> Result>, Self::Error>; + /// Check if garbage collection is enabled. + fn set_gc_enabled(&mut self, _gc_enabled: bool) {} + /// Check if garbage collection is enabled. fn gc_enabled(&self) -> bool { false @@ -60,6 +63,19 @@ pub trait KVDatabase { Ok(()) } + /// Retain only the key-value pairs that satisfy the predicate. + /// + /// # Note + /// + /// Same as [`KVDatabase::remove`], this method is best-effort and should not be relied on + /// to determine if the key was present or not. + fn retain(&mut self, _f: F) -> Result<(), Self::Error> + where + F: FnMut(&[u8], &[u8]) -> bool, + { + Ok(()) + } + /// Extend the database with the key-value pairs from the iterator. fn extend, Box<[u8]>)>>( &mut self, diff --git a/src/db/kv/sled.rs b/src/db/kv/sled.rs index 357f2e8..652fd68 100644 --- a/src/db/kv/sled.rs +++ b/src/db/kv/sled.rs @@ -41,18 +41,6 @@ impl SledDb { Self { gc_enabled, db } } - /// Enable or disable garbage collection. - #[inline] - pub fn set_gc_enabled(&mut self, gc_enabled: bool) { - self.gc_enabled = gc_enabled; - } - - /// Check if garbage collection is enabled. - #[inline] - pub fn is_gc_enabled(&self) -> bool { - self.gc_enabled - } - /// Get the inner [`sled::Tree`] pub fn inner(&self) -> &sled::Tree { &self.db @@ -83,6 +71,10 @@ impl KVDatabase for SledDb { self.db.get(k) } + fn set_gc_enabled(&mut self, gc_enabled: bool) { + self.gc_enabled = gc_enabled; + } + fn gc_enabled(&self) -> bool { self.gc_enabled } @@ -96,6 +88,23 @@ impl KVDatabase for SledDb { Ok(()) } + fn retain(&mut self, mut f: F) -> Result<(), Self::Error> + where + F: FnMut(&[u8], &[u8]) -> bool, + { + let mut removed = 0; + let mut batch = Batch::default(); + for entry in self.db.iter() { + let (k, v) = entry?; + if !f(k.as_ref(), v.as_ref()) { + batch.remove(k); + removed += 1; + } + } + trace!("{} key-value pairs removed", removed); + self.db.apply_batch(batch) + } + fn extend, Box<[u8]>)>>( &mut self, other: T, diff --git a/src/lib.rs b/src/lib.rs index 8537168..223da5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,7 @@ //! // HashMap as backend kv database and NoCacheHasher as key hasher. //! type ZkTrie = trie::ZkTrie; //! -//! let mut trie = ZkTrie::new(HashMapDb::new(), NoCacheHasher); +//! let mut trie = ZkTrie::new(HashMapDb::default(), NoCacheHasher); //! // or this is default mode //! // let mut trie = ZkTrie::default(); //! diff --git a/src/trie/zktrie/imp.rs b/src/trie/zktrie/imp.rs index 0a1c84f..ef7e97a 100644 --- a/src/trie/zktrie/imp.rs +++ b/src/trie/zktrie/imp.rs @@ -135,6 +135,10 @@ impl> ZkTrie { /// Garbage collect the trie pub fn gc(&mut self) -> Result<(), H, Db> { + if !self.db.gc_enabled() { + warn!("garbage collection is disabled"); + return Ok(()); + } let is_dirty = self.is_dirty(); let mut removed = 0; self.gc_nodes @@ -162,6 +166,42 @@ impl> ZkTrie { Ok(()) } + /// Run full garbage collection + pub fn full_gc(&mut self) -> Result<(), H, Db> { + if self.is_dirty() { + warn!("dirty nodes found, commit before run full_gc"); + return Ok(()); + } + let gc_enabled = self.db.gc_enabled(); + self.db.set_gc_enabled(true); + + // traverse the trie and collect all nodes + let mut nodes = HashSet::new(); + for node in self.iter() { + let node = node?; + nodes.insert( + *node + .get_or_calculate_node_hash() + .map_err(ZkTrieError::Hash)?, + ); + } + + self.db + .retain(|k, _| nodes.contains(k)) + .map_err(ZkTrieError::Db)?; + self.db.set_gc_enabled(gc_enabled); + + Ok(()) + } + + /// Get an iterator of the trie + pub fn iter(&self) -> ZkTrieIterator { + ZkTrieIterator { + trie: self, + stack: vec![self.root.clone()], + } + } + /// Get a node from the trie by node hash #[instrument(level = "trace", skip(self, node_hash), ret)] pub fn get_node_by_hash(&self, node_hash: impl Into) -> Result, H, Db> { @@ -503,6 +543,35 @@ impl> ZkTrie { } } +impl<'a, H: HashScheme, Db: KVDatabase, K: KeyHasher> Debug for ZkTrieIterator<'a, H, Db, K> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ZkTrieIterator") + .field("trie", &self.trie) + .finish() + } +} + +impl<'a, H: HashScheme, Db: KVDatabase, K: KeyHasher> Iterator for ZkTrieIterator<'a, H, Db, K> { + type Item = Result, H, Db>; + + fn next(&mut self) -> Option { + if let Some(node_hash) = self.stack.pop() { + return match self.trie.get_node_by_hash(node_hash) { + Ok(node) => { + if node.is_branch() { + let branch = node.as_branch().expect("infalible"); + self.stack.push(branch.child_left().clone()); + self.stack.push(branch.child_right().clone()); + } + Some(Ok(node)) + } + Err(e) => Some(Err(e)), + }; + } + None + } +} + #[inline(always)] fn get_path(node_key: &ZkHash, level: usize) -> bool { node_key.as_slice()[HASH_SIZE - level / 8 - 1] & (1 << (level % 8)) != 0 diff --git a/src/trie/zktrie/mod.rs b/src/trie/zktrie/mod.rs index 7581015..4369067 100644 --- a/src/trie/zktrie/mod.rs +++ b/src/trie/zktrie/mod.rs @@ -26,6 +26,12 @@ pub struct ZkTrie { _hash_scheme: std::marker::PhantomData, } +/// An iterator over the zkTrie. +pub struct ZkTrieIterator<'a, H, Db, K> { + trie: &'a ZkTrie, + stack: Vec, +} + /// Errors that can occur when using a zkTrie. #[derive(Debug, thiserror::Error)] pub enum ZkTrieError { diff --git a/src/trie/zktrie/tests.rs b/src/trie/zktrie/tests.rs index a8df169..19a4760 100644 --- a/src/trie/zktrie/tests.rs +++ b/src/trie/zktrie/tests.rs @@ -3,6 +3,7 @@ use crate::hash::poseidon::tests::gen_random_bytes; use rand::random; use rand::seq::SliceRandom; use std::fmt::Display; +use std::hash::Hash; use zktrie::HashField; use zktrie_rust::{db::SimpleDb, hash::AsHash, types::TrieHashScheme}; @@ -66,6 +67,14 @@ fn test_random() { trie.commit().unwrap(); } + trie.full_gc().unwrap(); + + for (k, _) in keys.iter() { + let node_key = >::hash(&NoCacheHasher, k).unwrap(); + // full gc didn't delete anything unexpected + trie.get_node_by_key(&node_key).unwrap(); + } + assert_eq!(old_trie.root().as_ref(), trie.root.unwrap_ref().as_slice()); for (k, old_key) in keys.choose_multiple(&mut rand::thread_rng(), 10) { @@ -82,7 +91,7 @@ fn test_random() { // println!("New:"); // println!("{}", trie); - trie.gc().unwrap(); + trie.full_gc().unwrap(); assert_eq!(old_trie.root().as_ref(), trie.root.unwrap_ref().as_slice()); }