From ece0c6ee9051356aee6b78d17ed8a6190bc479ac Mon Sep 17 00:00:00 2001 From: Nazar Mokrynskyi Date: Sat, 10 Feb 2024 00:29:52 +0200 Subject: [PATCH] Refactor `p[1|2]_affines::mult()` to use `rayon` --- bindings/rust/src/pippenger.rs | 37 +++++++++++++--------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/bindings/rust/src/pippenger.rs b/bindings/rust/src/pippenger.rs index c7fb798c..e9adb370 100644 --- a/bindings/rust/src/pippenger.rs +++ b/bindings/rust/src/pippenger.rs @@ -128,12 +128,10 @@ macro_rules! pippenger_mult_impl { panic!("scalars length mismatch"); } - let pool = mt::da_pool(); - let ncpus = pool.max_count(); + let ncpus = rayon::current_num_threads(); if ncpus < 2 || npoints < 32 { - let p: [*const $point_affine; 2] = - [&self.points[0], ptr::null()]; - let s: [*const u8; 2] = [&scalars[0], ptr::null()]; + let p = [self.points.as_ptr(), ptr::null()]; + let s = [scalars.as_ptr(), ptr::null()]; let mut scratch = vec![0u64; unsafe { $scratch_sizeof(npoints) / 8 }]; @@ -141,11 +139,11 @@ macro_rules! pippenger_mult_impl { unsafe { $multi_scalar_mult( &mut ret, - &p[0], + p.as_ptr(), npoints, - &s[0], + s.as_ptr(), nbits, - &mut scratch[0], + scratch.as_mut_ptr(), ); } return ret; @@ -184,22 +182,15 @@ macro_rules! pippenger_mult_impl { let points = &self.points[..]; let sz = unsafe { $scratch_sizeof(0) / 8 }; - let mut row_sync: Vec = Vec::with_capacity(ny); - row_sync.resize_with(ny, Default::default); - let row_sync = Arc::new(row_sync); - let counter = Arc::new(AtomicUsize::new(0)); + let mut row_sync = Vec::with_capacity(ny); + row_sync.resize_with(ny, AtomicUsize::default); + let counter = AtomicUsize::new(0); let (tx, rx) = channel(); - let n_workers = core::cmp::min(ncpus, total); - for _ in 0..n_workers { - let tx = tx.clone(); - let counter = counter.clone(); - let row_sync = row_sync.clone(); - - pool.joined_execute(move || { + rayon::scope(|scope| { + scope.spawn_broadcast(move |_scope, _ctx| { let mut scratch = vec![0u64; sz << (window - 1)]; - let mut p: [*const $point_affine; 2] = - [ptr::null(), ptr::null()]; - let mut s: [*const u8; 2] = [ptr::null(), ptr::null()]; + let mut p = [ptr::null(), ptr::null()]; + let mut s = [ptr::null(), ptr::null()]; loop { let work = counter.fetch_add(1, Ordering::Relaxed); @@ -231,7 +222,7 @@ macro_rules! pippenger_mult_impl { } } }); - } + }); let mut ret = <$point>::default(); let mut rows = vec![false; ny];