Skip to content

Commit

Permalink
feat: performance improvement and benchmark (#5)
Browse files Browse the repository at this point in the history
* add benchmark

* clippy

* wrap arc

* fix
  • Loading branch information
lightsing authored Sep 20, 2024
1 parent 74b91ce commit d749dca
Show file tree
Hide file tree
Showing 10 changed files with 667 additions and 50 deletions.
334 changes: 334 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ default-features = false
features = ["scroll", "scroll-poseidon-codehash"]
optional = true


[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
ctor = "0.2"
rand = "0.8"
rand = { version = "0.8", features = ["small_rng"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
zktrie = { git = "https://github.com/scroll-tech/zktrie.git", branch = "main", features = ["rs_zktrie"] }
zktrie_rust = { git = "https://github.com/scroll-tech/zktrie.git", branch = "main" }

[[bench]]
name = "node"
harness = false

[[bench]]
name = "trie"
harness = false

[features]
default = ["bn254", "hashbrown"]

Expand Down
88 changes: 88 additions & 0 deletions benches/node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#![allow(missing_docs)]
use criterion::{criterion_group, criterion_main, Criterion};
use poseidon_bn254::{hash_with_domain, Fr, PrimeField};
use rand::prelude::*;
use zktrie::HashField;
use zktrie_ng::trie::NodeType;
use zktrie_ng::{
hash::{poseidon::Poseidon, HashScheme},
trie::Node,
};
use zktrie_rust::hash::AsHash;

type OldNode = zktrie_rust::types::Node<AsHash<HashField>>;

fn bench_parse_node_inner(c: &mut Criterion, name: &str, node_bytes: Vec<u8>) {
let mut group = c.benchmark_group(name);
group.bench_with_input("zktrie-ng", &node_bytes, |b, node_bytes| {
b.iter(|| {
let node = Node::<Poseidon>::try_from(node_bytes.as_slice()).unwrap();
*node.get_or_calculate_node_hash().unwrap()
});
});
group.bench_with_input("zktrie", &node_bytes, |b, node_bytes| {
b.iter(|| {
OldNode::new_node_from_bytes(node_bytes)
.unwrap()
.calc_node_hash()
.unwrap()
.node_hash()
.unwrap()
});
});
group.finish();
}

fn bench_parse_node(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(42);

let account_leaf = {
let key: [u8; 20] = rng.gen();
let values: [[u8; 32]; 5] = rng.gen();
Node::<Poseidon>::new_leaf(
Poseidon::hash_bytes(&key).unwrap(),
values.to_vec(),
0b11111,
None,
)
.unwrap()
};
bench_parse_node_inner(c, "Parse Account Node", account_leaf.canonical_value(false));

let storage_leaf = {
let key: [u8; 32] = rng.gen();
let values: [[u8; 32]; 1] = rng.gen();
Node::<Poseidon>::new_leaf(
Poseidon::hash_bytes(&key).unwrap(),
values.to_vec(),
0b1,
None,
)
.unwrap()
};

bench_parse_node_inner(c, "Parse Storage Node", storage_leaf.canonical_value(false));

let branch_node = Node::<Poseidon>::new_branch(
NodeType::BranchLTRT,
*account_leaf.get_or_calculate_node_hash().unwrap(),
*storage_leaf.get_or_calculate_node_hash().unwrap(),
);

bench_parse_node_inner(c, "Parse Branch Node", branch_node.canonical_value(false));
}

fn poseidon_hash_scheme(a: &[u8; 32], b: &[u8; 32], domain: &[u8; 32]) -> Option<[u8; 32]> {
let a = Fr::from_repr_vartime(*a)?;
let b = Fr::from_repr_vartime(*b)?;
let domain = Fr::from_repr_vartime(*domain)?;
Some(hash_with_domain(&[a, b], domain).to_repr())
}

fn criterion_benchmark(c: &mut Criterion) {
zktrie::init_hash_scheme_simple(poseidon_hash_scheme);
bench_parse_node(c);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
148 changes: 148 additions & 0 deletions benches/trie.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#![allow(missing_docs)]

use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use poseidon_bn254::{hash_with_domain, Fr, PrimeField};
use rand::prelude::*;
use std::hint::black_box;
use zktrie::HashField;
use zktrie_ng::db::HashMapDb;
use zktrie_ng::hash::key_hasher::NoCacheHasher;
use zktrie_ng::{
hash::{poseidon::Poseidon, HashScheme},
trie::ZkTrie,
};
use zktrie_rust::{db::SimpleDb, hash::AsHash, types::TrieHashScheme};

type NodeOld = zktrie_rust::types::Node<AsHash<HashField>>;
type TrieOld =
zktrie_rust::raw::ZkTrieImpl<AsHash<HashField>, SimpleDb, { Poseidon::TRIE_MAX_LEVELS }>;

fn bench_trie_update(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(42);
let mut group = c.benchmark_group("Trie Update");

let k: [u8; 20] = rng.gen();
let values: [[u8; 32]; 5] = rng.gen();
let values = values.to_vec();

group.bench_with_input("zktrie-ng", &(k, values.clone()), |b, (k, values)| {
b.iter_batched(
|| {
let trie = ZkTrie::default();
(trie, k, values.clone())
},
|(mut trie, k, values)| {
trie.raw_update(k, values, black_box(0b11111)).unwrap();
},
BatchSize::SmallInput,
);
});

group.bench_with_input("zktrie", &(k, values), |b, (k, values)| {
b.iter_batched(
|| {
let trie = TrieOld::new_zktrie_impl(SimpleDb::new()).unwrap();
(trie, k, values.clone())
},
|(mut trie, k, values)| {
let key = NodeOld::hash_bytes(k).unwrap();
trie.try_update(&key, black_box(0b11111), values).unwrap();
},
BatchSize::SmallInput,
);
});
}

fn bench_trie_operation(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(42);

let mut trie = ZkTrie::default();
let mut trie_old = TrieOld::new_zktrie_impl(SimpleDb::new()).unwrap();

let mut keys = vec![];

for _ in 0..100 {
let k: [u8; 20] = rng.gen();
let values: [[u8; 32]; 5] = rng.gen();
let values = values.to_vec();

trie.raw_update(k, values.clone(), 0b11111).unwrap();
let key = NodeOld::hash_bytes(&k).unwrap();
trie_old.try_update(&key, 0b11111, values).unwrap();
keys.push((Poseidon::hash_bytes(&k).unwrap(), key));
}

trie.commit().unwrap();
trie_old.prepare_root().unwrap();
trie_old.commit().unwrap();

let mut group = c.benchmark_group("Trie Get");
keys.shuffle(&mut rng);
group.bench_with_input("zktrie-ng", &(&trie, &keys[..10]), |b, (trie, keys)| {
b.iter(|| {
keys.iter()
.map(|(key, _)| trie.get_node_by_key(key).unwrap())
.collect::<Vec<_>>()
});
});
group.bench_with_input("zktrie", &(&trie_old, &keys[..10]), |b, (trie, keys)| {
b.iter(|| {
keys.iter()
.map(|(_, key)| trie.try_get(key).unwrap())
.collect::<Vec<_>>()
});
});
group.finish();

let mut group = c.benchmark_group("Trie Delete");
keys.shuffle(&mut rng);
group.bench_with_input("zktrie-ng", &(&trie, &keys[..10]), |b, (trie, keys)| {
b.iter_batched(
|| {
let root = *trie.root().unwrap_ref();
let db = HashMapDb::from_map(false, trie.db().inner().clone());
let trie = ZkTrie::<Poseidon>::new_with_root(db, NoCacheHasher, root).unwrap();
(trie, keys)
},
|(mut trie, keys)| {
for (key, _) in keys.iter() {
trie.delete_by_node_key(*key).unwrap();
}
},
BatchSize::SmallInput,
)
});
group.bench_with_input("zktrie", &(&trie_old, &keys[..10]), |b, (trie, keys)| {
b.iter_batched(
|| {
let root = trie.root();
let db = trie.get_db().clone();
let trie = TrieOld::new_zktrie_impl_with_root(db, root).unwrap();
(trie, keys)
},
|(mut trie, keys)| {
for (_, key) in keys.iter() {
trie.try_delete(key).unwrap();
}
},
BatchSize::SmallInput,
)
});
group.finish();
}

fn poseidon_hash_scheme(a: &[u8; 32], b: &[u8; 32], domain: &[u8; 32]) -> Option<[u8; 32]> {
let a = Fr::from_repr_vartime(*a)?;
let b = Fr::from_repr_vartime(*b)?;
let domain = Fr::from_repr_vartime(*domain)?;
Some(hash_with_domain(&[a, b], domain).to_repr())
}

fn criterion_benchmark(c: &mut Criterion) {
zktrie::init_hash_scheme_simple(poseidon_hash_scheme);
bench_trie_update(c);
bench_trie_operation(c);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
//!
//! trie.raw_update(&[1u8; 32], vec![[1u8; 32]], 1).unwrap();
//!
//! let values: [[u8; 32]; 1] = trie.get(&[1u8; 32]).unwrap();
//! let values: [[u8; 32]; 1] = trie.get(&[1u8; 32]).unwrap().unwrap();
//! assert_eq!(values[0], [1u8; 32]);
//!
//! // zkTrie is lazy, won't update the backend database until `commit` is called.
Expand Down
4 changes: 2 additions & 2 deletions src/scroll_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
//!
//! let trie_account = Account::from_revm_account_with_storage_root(account, storage_root);
//!
//! trie.update(address.as_ref(), trie_account).unwrap();
//! trie.update(address, trie_account).unwrap();
//!
//! let account: Account = trie.get(address.as_ref()).unwrap();
//! let account: Account = trie.get(address).unwrap().unwrap();
//!
//! assert_eq!(trie_account, account);
//! ```
Expand Down
43 changes: 28 additions & 15 deletions src/trie/node/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl LeafNode {

/// Get the value preimages stored in a leaf node.
#[inline]
pub fn into_value_preimages(self) -> Arc<[[u8; 32]]> {
pub fn into_value_preimages(self) -> Box<[[u8; 32]]> {
self.value_preimages
}

Expand All @@ -120,8 +120,9 @@ impl LeafNode {

/// Get the `value_hash`
#[inline]
pub fn value_hash<H: HashScheme>(&self) -> &ZkHash {
&self.value_hash
pub fn get_or_calc_value_hash<H: HashScheme>(&self) -> Result<&ZkHash, H::Error> {
self.value_hash
.get_or_try_init(|| H::hash_bytes_array(&self.value_preimages, self.compress_flags))
}
}

Expand Down Expand Up @@ -166,6 +167,12 @@ impl BranchNode {
&self.child_right
}

/// Into the parts
#[inline]
pub fn as_parts(&self) -> (NodeType, &LazyNodeHash, &LazyNodeHash) {
(self.node_type, &self.child_left, &self.child_right)
}

/// Into the parts
#[inline]
pub fn into_parts(self) -> (NodeType, LazyNodeHash, LazyNodeHash) {
Expand Down Expand Up @@ -203,11 +210,11 @@ impl<H: HashScheme> Node<H> {
) -> Self {
Node {
node_hash: Arc::new(OnceCell::new()),
data: NodeKind::Branch(BranchNode {
data: NodeKind::Branch(Arc::new(BranchNode {
node_type,
child_left: child_left.into(),
child_right: child_right.into(),
}),
})),
_hash_scheme: std::marker::PhantomData,
}
}
Expand All @@ -219,17 +226,17 @@ impl<H: HashScheme> Node<H> {
compress_flags: u32,
node_key_preimage: Option<[u8; 32]>,
) -> Result<Self, H::Error> {
let value_hash = H::hash_bytes_array(&value_preimages, compress_flags)?;
let node_hash = H::hash(Leaf as u64, [node_key, value_hash])?;
// let value_hash = H::hash_bytes_array(&value_preimages, compress_flags)?;
// let node_hash = H::hash(Leaf as u64, [node_key, value_hash])?;
Ok(Node {
node_hash: Arc::new(OnceCell::with_value(node_hash)),
data: NodeKind::Leaf(LeafNode {
node_hash: Arc::new(OnceCell::new()),
data: NodeKind::Leaf(Arc::new(LeafNode {
node_key,
node_key_preimage,
value_preimages: Arc::from(value_preimages.into_boxed_slice()),
value_preimages: value_preimages.into_boxed_slice(),
compress_flags,
value_hash,
}),
value_hash: OnceCell::new(),
})),
_hash_scheme: std::marker::PhantomData,
})
}
Expand All @@ -244,7 +251,13 @@ impl<H: HashScheme> Node<H> {
#[inline]
pub fn get_or_calculate_node_hash(&self) -> Result<&ZkHash, H::Error> {
match self.data {
NodeKind::Empty | NodeKind::Leaf(_) => Ok(unsafe { self.node_hash.get_unchecked() }),
NodeKind::Empty => Ok(unsafe { self.node_hash.get_unchecked() }),
NodeKind::Leaf(ref leaf) => {
let value_hash = leaf.get_or_calc_value_hash::<H>()?;
Ok(self
.node_hash
.get_or_try_init(|| H::hash(Leaf as u64, [leaf.node_key, *value_hash]))?)
}
NodeKind::Branch(ref branch) => {
let left = branch.child_left.unwrap_ref();
let right = branch.child_right.unwrap_ref();
Expand Down Expand Up @@ -319,7 +332,7 @@ impl<H: HashScheme> Node<H> {

/// Try into a leaf node.
#[inline]
pub fn into_leaf(self) -> Option<LeafNode> {
pub fn into_leaf(self) -> Option<Arc<LeafNode>> {
match self.data {
NodeKind::Leaf(leaf) => Some(leaf),
_ => None,
Expand All @@ -328,7 +341,7 @@ impl<H: HashScheme> Node<H> {

/// Try into a branch node.
#[inline]
pub fn into_branch(self) -> Option<BranchNode> {
pub fn into_branch(self) -> Option<Arc<BranchNode>> {
match self.data {
NodeKind::Branch(branch) => Some(branch),
_ => None,
Expand Down
Loading

0 comments on commit d749dca

Please sign in to comment.