Skip to content

Commit

Permalink
Refactor p[1|2]_affines::add() to use rayon, remove threadpool
Browse files Browse the repository at this point in the history
…dependency and unnecessary/unused `no-threads` feature
  • Loading branch information
nazar-pc committed Feb 9, 2024
1 parent ece0c6e commit 20eb420
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 103 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ jobs:
echo '--- test portable'
echo
cargo test --release --features=portable
echo '--- test no-threads'
echo
cargo test --release --features=no-threads
echo '--- test serde-secret'
echo
cargo test --release --features=serde-secret
Expand Down
6 changes: 0 additions & 6 deletions bindings/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ portable = []
# Enable ADX even if the host CPU doesn't support it.
# Binary can be executed on Broadwell+ and Ryzen+ systems.
force-adx = []
# Suppress multi-threading.
# Engaged on wasm32 target architecture automatically.
no-threads = []
# Add support for serializing SecretKey, not suitable for production.
serde-secret = ["serde"]

Expand All @@ -49,9 +46,6 @@ zeroize = { version = "^1.1", features = ["zeroize_derive"] }
rayon = "1.8.1"
serde = { version = "1.0.152", optional = true }

[target.'cfg(not(any(target_arch="wasm32", target_os="none", target_os="unknown", target_os="uefi")))'.dependencies]
threadpool = "^1.8.1"

[dev-dependencies]
rand = "0.8"
rand_chacha = "0.3"
Expand Down
3 changes: 0 additions & 3 deletions bindings/rust/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ fn main() {

if !target_no_std {
println!("cargo:rustc-cfg=feature=\"std\"");
if target_arch.eq("wasm32") || target_os.eq("unknown") {
println!("cargo:rustc-cfg=feature=\"no-threads\"");
}
}
println!("cargo:rerun-if-env-changed=BLST_TEST_NO_STD");

Expand Down
69 changes: 1 addition & 68 deletions bindings/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,78 +19,11 @@ use rayon::prelude::*;
use zeroize::Zeroize;

#[cfg(feature = "std")]
use std::sync::{atomic::*, mpsc::channel, Arc, Mutex};
use std::sync::{atomic::*, mpsc::channel, Mutex};

#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};

trait ThreadPoolExt {
fn joined_execute<'any, F>(&self, job: F)
where
F: FnOnce() + Send + 'any;
}

#[cfg(all(not(feature = "no-threads"), feature = "std"))]
mod mt {
use super::*;
use core::mem::transmute;
use std::sync::{Mutex, Once};
use threadpool::ThreadPool;

pub fn da_pool() -> ThreadPool {
static INIT: Once = Once::new();
static mut POOL: *const Mutex<ThreadPool> =
0 as *const Mutex<ThreadPool>;

INIT.call_once(|| {
let pool = Mutex::new(ThreadPool::default());
unsafe { POOL = transmute(Box::new(pool)) };
});
unsafe { (*POOL).lock().unwrap().clone() }
}

type Thunk<'any> = Box<dyn FnOnce() + Send + 'any>;

impl ThreadPoolExt for ThreadPool {
fn joined_execute<'scope, F>(&self, job: F)
where
F: FnOnce() + Send + 'scope,
{
// Bypass 'lifetime limitations by brute force. It works,
// because we explicitly join the threads...
self.execute(unsafe {
transmute::<Thunk<'scope>, Thunk<'static>>(Box::new(job))
})
}
}
}

#[cfg(all(feature = "no-threads", feature = "std"))]
mod mt {
use super::*;

pub struct EmptyPool {}

pub fn da_pool() -> EmptyPool {
EmptyPool {}
}

impl EmptyPool {
pub fn max_count(&self) -> usize {
1
}
}

impl ThreadPoolExt for EmptyPool {
fn joined_execute<'scope, F>(&self, job: F)
where
F: FnOnce() + Send + 'scope,
{
job()
}
}
}

include!("bindings.rs");

impl PartialEq for blst_p1 {
Expand Down
48 changes: 25 additions & 23 deletions bindings/rust/src/pippenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,29 +259,26 @@ macro_rules! pippenger_mult_impl {
pub fn add(&self) -> $point {
let npoints = self.points.len();

let pool = mt::da_pool();
let ncpus = pool.max_count();
let ncpus = rayon::current_num_threads();
if ncpus < 2 || npoints < 384 {
let p: [*const _; 2] = [&self.points[0], ptr::null()];
let p = [self.points.as_ptr(), ptr::null()];
let mut ret = <$point>::default();
unsafe { $add(&mut ret, &p[0], npoints) };
return ret;
}

let (tx, rx) = channel();
let counter = Arc::new(AtomicUsize::new(0));
let ret = Mutex::new(None::<$point>);
let counter = AtomicUsize::new(0);
let nchunks = (npoints + 255) / 256;
let chunk = npoints / nchunks + 1;

let n_workers = core::cmp::min(ncpus, nchunks);
for _ in 0..n_workers {
let tx = tx.clone();
let counter = counter.clone();

pool.joined_execute(move || {
rayon::scope(|scope| {
let ret = &ret;
scope.spawn_broadcast(move |_scope, _ctx| {
let mut processed = 0;
let mut acc = <$point>::default();
let mut chunk = chunk;
let mut p: [*const _; 2] = [ptr::null(), ptr::null()];
let mut p = [ptr::null(), ptr::null()];

loop {
let work =
Expand All @@ -298,19 +295,24 @@ macro_rules! pippenger_mult_impl {
$add(t.as_mut_ptr(), &p[0], chunk);
$add_or_double(&mut acc, &acc, t.as_ptr());
};
processed += 1;
}
tx.send(acc).expect("disaster");
});
}

let mut ret = rx.recv().unwrap();
for _ in 1..n_workers {
unsafe {
$add_or_double(&mut ret, &ret, &rx.recv().unwrap())
};
}
if processed > 0 {
let mut ret = ret.lock().unwrap();
match ret.as_mut() {
Some(ret) => {
unsafe { $add_or_double(ret, ret, &acc) };
}
None => {
ret.replace(acc);
}
}
}
})
});

ret
let mut ret = ret.lock().unwrap();
ret.take().unwrap()
}
}

Expand Down

0 comments on commit 20eb420

Please sign in to comment.