Skip to content

Commit

Permalink
add gc to trie
Browse files Browse the repository at this point in the history
  • Loading branch information
lightsing committed Sep 13, 2024
1 parent 1f7cf66 commit 29d972d
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 38 deletions.
32 changes: 20 additions & 12 deletions src/db/kv/btree_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,6 @@ impl BTreeMapDb {
Self { gc_enabled, db }
}

/// Enable or disable garbage collection.
#[inline]
pub fn set_gc_enabled(&mut self, gc_enabled: bool) {
self.gc_enabled = gc_enabled;
}

/// Check if garbage collection is enabled.
#[inline]
pub fn is_gc_enabled(&self) -> bool {
self.gc_enabled
}

/// Get the inner `BTreeMap`.
pub fn inner(&self) -> &BTreeMap<Box<[u8]>, Box<[u8]>> {
&self.db
Expand Down Expand Up @@ -77,6 +65,10 @@ impl KVDatabase for BTreeMapDb {
Ok(self.db.get(k))
}

fn set_gc_enabled(&mut self, gc_enabled: bool) {
self.gc_enabled = gc_enabled;
}

fn gc_enabled(&self) -> bool {
self.gc_enabled
}
Expand All @@ -90,6 +82,22 @@ impl KVDatabase for BTreeMapDb {
Ok(())
}

fn retain<F>(&mut self, mut f: F) -> Result<(), Self::Error>
where
F: FnMut(&[u8], &[u8]) -> bool,
{
let mut removed = 0;
self.db.retain(|k, v| {
let keep = f(k, v);
if !keep {
removed += 1;
}
keep
});
trace!("{} key-value pairs removed", removed);
Ok(())
}

fn extend<T: IntoIterator<Item = (Box<[u8]>, Box<[u8]>)>>(
&mut self,
other: T,
Expand Down
32 changes: 20 additions & 12 deletions src/db/kv/hash_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@ impl HashMapDb {
Self { gc_enabled, db }
}

/// Enable or disable garbage collection.
#[inline]
pub fn set_gc_enabled(&mut self, gc_enabled: bool) {
self.gc_enabled = gc_enabled;
}

/// Check if garbage collection is enabled.
#[inline]
pub fn is_gc_enabled(&self) -> bool {
self.gc_enabled
}

/// Get the inner [`HashMap`](std::collections::HashMap).
pub fn inner(&self) -> &HashMap<Box<[u8]>, Box<[u8]>> {
&self.db
Expand Down Expand Up @@ -75,6 +63,10 @@ impl KVDatabase for HashMapDb {
Ok(self.db.get(k))
}

fn set_gc_enabled(&mut self, gc_enabled: bool) {
self.gc_enabled = gc_enabled;
}

fn gc_enabled(&self) -> bool {
self.gc_enabled
}
Expand All @@ -88,6 +80,22 @@ impl KVDatabase for HashMapDb {
Ok(())
}

fn retain<F>(&mut self, mut f: F) -> Result<(), Self::Error>
where
F: FnMut(&[u8], &[u8]) -> bool,
{
let mut removed = 0;
self.db.retain(|k, v| {
let keep = f(k, v);
if !keep {
removed += 1;
}
keep
});
trace!("{} key-value pairs removed", removed);
Ok(())
}

fn extend<T: IntoIterator<Item = (Box<[u8]>, Box<[u8]>)>>(
&mut self,
other: T,
Expand Down
16 changes: 16 additions & 0 deletions src/db/kv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub trait KVDatabase {
/// Returns `Ok(None)` if the key is not present.
fn get(&self, k: &[u8]) -> Result<Option<impl AsRef<[u8]>>, Self::Error>;

/// Check if garbage collection is enabled.
fn set_gc_enabled(&mut self, _gc_enabled: bool) {}

/// Check if garbage collection is enabled.
fn gc_enabled(&self) -> bool {
false
Expand All @@ -60,6 +63,19 @@ pub trait KVDatabase {
Ok(())
}

/// Retain only the key-value pairs that satisfy the predicate.
///
/// # Note
///
/// Same as [`KVDatabase::remove`], this method is best-effort and should not be relied on
/// to determine if the key was present or not.
fn retain<F>(&mut self, _f: F) -> Result<(), Self::Error>
where
F: FnMut(&[u8], &[u8]) -> bool,
{
Ok(())
}

/// Extend the database with the key-value pairs from the iterator.
fn extend<T: IntoIterator<Item = (Box<[u8]>, Box<[u8]>)>>(
&mut self,
Expand Down
33 changes: 21 additions & 12 deletions src/db/kv/sled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,6 @@ impl SledDb {
Self { gc_enabled, db }
}

/// Enable or disable garbage collection.
#[inline]
pub fn set_gc_enabled(&mut self, gc_enabled: bool) {
self.gc_enabled = gc_enabled;
}

/// Check if garbage collection is enabled.
#[inline]
pub fn is_gc_enabled(&self) -> bool {
self.gc_enabled
}

/// Get the inner [`sled::Tree`]
pub fn inner(&self) -> &sled::Tree {
&self.db
Expand Down Expand Up @@ -83,6 +71,10 @@ impl KVDatabase for SledDb {
self.db.get(k)
}

fn set_gc_enabled(&mut self, gc_enabled: bool) {
self.gc_enabled = gc_enabled;
}

fn gc_enabled(&self) -> bool {
self.gc_enabled
}
Expand All @@ -96,6 +88,23 @@ impl KVDatabase for SledDb {
Ok(())
}

fn retain<F>(&mut self, mut f: F) -> Result<(), Self::Error>
where
F: FnMut(&[u8], &[u8]) -> bool,
{
let mut removed = 0;
let mut batch = Batch::default();
for entry in self.db.iter() {
let (k, v) = entry?;
if !f(k.as_ref(), v.as_ref()) {
batch.remove(k);
removed += 1;
}
}
trace!("{} key-value pairs removed", removed);
self.db.apply_batch(batch)
}

fn extend<T: IntoIterator<Item = (Box<[u8]>, Box<[u8]>)>>(
&mut self,
other: T,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
//! // HashMap as backend kv database and NoCacheHasher as key hasher.
//! type ZkTrie = trie::ZkTrie<Poseidon, HashMapDb, NoCacheHasher>;
//!
//! let mut trie = ZkTrie::new(HashMapDb::new(), NoCacheHasher);
//! let mut trie = ZkTrie::new(HashMapDb::default(), NoCacheHasher);
//! // or this is default mode
//! // let mut trie = ZkTrie::default();
//!
Expand Down
69 changes: 69 additions & 0 deletions src/trie/zktrie/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ impl<H: HashScheme, Db: KVDatabase, K: KeyHasher<H>> ZkTrie<H, Db, K> {

/// Garbage collect the trie
pub fn gc(&mut self) -> Result<(), H, Db> {
if !self.db.gc_enabled() {
warn!("garbage collection is disabled");
return Ok(());
}
let is_dirty = self.is_dirty();
let mut removed = 0;
self.gc_nodes
Expand Down Expand Up @@ -162,6 +166,42 @@ impl<H: HashScheme, Db: KVDatabase, K: KeyHasher<H>> ZkTrie<H, Db, K> {
Ok(())
}

/// Run full garbage collection
pub fn full_gc(&mut self) -> Result<(), H, Db> {
if self.is_dirty() {
warn!("dirty nodes found, commit before run full_gc");
return Ok(());
}
let gc_enabled = self.db.gc_enabled();
self.db.set_gc_enabled(true);

// traverse the trie and collect all nodes
let mut nodes = HashSet::new();
for node in self.iter() {
let node = node?;
nodes.insert(
*node
.get_or_calculate_node_hash()
.map_err(ZkTrieError::Hash)?,
);
}

self.db
.retain(|k, _| nodes.contains(k))
.map_err(ZkTrieError::Db)?;
self.db.set_gc_enabled(gc_enabled);

Ok(())
}

/// Get an iterator of the trie
pub fn iter(&self) -> ZkTrieIterator<H, Db, K> {
ZkTrieIterator {
trie: self,
stack: vec![self.root.clone()],
}
}

/// Get a node from the trie by node hash
#[instrument(level = "trace", skip(self, node_hash), ret)]
pub fn get_node_by_hash(&self, node_hash: impl Into<LazyNodeHash>) -> Result<Node<H>, H, Db> {
Expand Down Expand Up @@ -503,6 +543,35 @@ impl<H: HashScheme, Db: KVDatabase, K: KeyHasher<H>> ZkTrie<H, Db, K> {
}
}

impl<'a, H: HashScheme, Db: KVDatabase, K: KeyHasher<H>> Debug for ZkTrieIterator<'a, H, Db, K> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZkTrieIterator")
.field("trie", &self.trie)
.finish()
}
}

impl<'a, H: HashScheme, Db: KVDatabase, K: KeyHasher<H>> Iterator for ZkTrieIterator<'a, H, Db, K> {
type Item = Result<Node<H>, H, Db>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(node_hash) = self.stack.pop() {
return match self.trie.get_node_by_hash(node_hash) {
Ok(node) => {
if node.is_branch() {
let branch = node.as_branch().expect("infalible");
self.stack.push(branch.child_left().clone());
self.stack.push(branch.child_right().clone());
}
Some(Ok(node))
}
Err(e) => Some(Err(e)),
};
}
None
}
}

#[inline(always)]
fn get_path(node_key: &ZkHash, level: usize) -> bool {
node_key.as_slice()[HASH_SIZE - level / 8 - 1] & (1 << (level % 8)) != 0
Expand Down
6 changes: 6 additions & 0 deletions src/trie/zktrie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ pub struct ZkTrie<H = Poseidon, Db = HashMapDb, K = NoCacheHasher> {
_hash_scheme: std::marker::PhantomData<H>,
}

/// An iterator over the zkTrie.
pub struct ZkTrieIterator<'a, H, Db, K> {
trie: &'a ZkTrie<H, Db, K>,
stack: Vec<LazyNodeHash>,
}

/// Errors that can occur when using a zkTrie.
#[derive(Debug, thiserror::Error)]
pub enum ZkTrieError<HashErr, DbErr> {
Expand Down
11 changes: 10 additions & 1 deletion src/trie/zktrie/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::hash::poseidon::tests::gen_random_bytes;
use rand::random;
use rand::seq::SliceRandom;
use std::fmt::Display;
use std::hash::Hash;
use zktrie::HashField;
use zktrie_rust::{db::SimpleDb, hash::AsHash, types::TrieHashScheme};

Expand Down Expand Up @@ -66,6 +67,14 @@ fn test_random() {
trie.commit().unwrap();
}

trie.full_gc().unwrap();

for (k, _) in keys.iter() {
let node_key = <NoCacheHasher as KeyHasher<Poseidon>>::hash(&NoCacheHasher, k).unwrap();
// full gc didn't delete anything unexpected
trie.get_node_by_key(&node_key).unwrap();
}

assert_eq!(old_trie.root().as_ref(), trie.root.unwrap_ref().as_slice());

for (k, old_key) in keys.choose_multiple(&mut rand::thread_rng(), 10) {
Expand All @@ -82,7 +91,7 @@ fn test_random() {
// println!("New:");
// println!("{}", trie);

trie.gc().unwrap();
trie.full_gc().unwrap();

assert_eq!(old_trie.root().as_ref(), trie.root.unwrap_ref().as_slice());
}
Expand Down

0 comments on commit 29d972d

Please sign in to comment.