From 9004d7013afde40373f5064cdfa4f60e7c44d184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:15:01 -0500 Subject: [PATCH] refactor: define auxiliary functions for the batching of sumcheck claims (Arecibo backport) (#273) * `snark.rs`: factor out batch evaluation Sumcheck (#106) * Factor batch eval * Comments * refactor: Remove nedless pass-by-value instances, turn on clippy (#122) * refactor: Remove nedless pass-by-value instances, turn on the corresponding clippy lint - Updated function parameters in `ppsnark.rs` and `snark.rs` from `Vec` to `&[G::Scalar]` (introduced in #106), - Modified the `prove` function in `ppsnark.rs` to also convert `T_row`, `W_row`, `T_col`, and `W_col` from `Vec` to slices (`&[G::Scalar]`). - Enhanced the `xclippy` alias in `.cargo/config` by adding `-Wclippy::checked_conversions`, `-Wclippy::needless_pass_by_value`, and `-Wclippy::unnecessary_mut_passed` and reorganizing its elements. * `PolyEval{Instance, Witness}::pad` Accept `Vec` rather that `&[T]` to avoid copies Co-authored-by: Francois Garillot --------- Co-authored-by: Adrian Hamelink --------- Co-authored-by: Adrian Hamelink Co-authored-by: Adrian Hamelink --- src/spartan/mod.rs | 23 ++-- src/spartan/ppsnark.rs | 22 ++-- src/spartan/snark.rs | 280 ++++++++++++++++++++++++----------------- 3 files changed, 188 insertions(+), 137 deletions(-) diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 8bec5131f..3e7c514e8 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -33,16 +33,13 @@ pub struct PolyEvalWitness { } impl PolyEvalWitness { - fn pad(W: &[PolyEvalWitness]) -> Vec> { + fn pad(mut W: Vec>) -> Vec> { // determine the maximum size if let Some(n) = W.iter().map(|w| w.p.len()).max() { - W.iter() - .map(|w| { - let mut p = vec![E::Scalar::ZERO; n]; - p[..w.p.len()].copy_from_slice(&w.p); - PolyEvalWitness { p } - }) - .collect() + W.iter_mut().for_each(|w| { + w.p.resize(n, E::Scalar::ZERO); + }); + W } else { Vec::new() } @@ -94,14 +91,14 @@ pub struct PolyEvalInstance { } impl PolyEvalInstance { - fn pad(U: &[PolyEvalInstance]) -> Vec> { + fn pad(U: Vec>) -> Vec> { // determine the maximum size if let Some(ell) = U.iter().map(|u| u.x.len()).max() { - U.iter() - .map(|u| { + U.into_iter() + .map(|mut u| { let mut x = vec![E::Scalar::ZERO; ell - u.x.len()]; - x.extend(u.x.clone()); - PolyEvalInstance { c: u.c, x, e: u.e } + x.append(&mut u.x); + PolyEvalInstance { x, ..u } }) .collect() } else { diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index 3b21608f0..44dc30e42 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -281,11 +281,11 @@ impl MemorySumcheckInstance { pub fn new( ck: &CommitmentKey, r: &E::Scalar, - T_row: Vec, - W_row: Vec, + T_row: &[E::Scalar], + W_row: &[E::Scalar], ts_row: Vec, - T_col: Vec, - W_col: Vec, + T_col: &[E::Scalar], + W_col: &[E::Scalar], ts_col: Vec, transcript: &mut E::TE, ) -> Result<(Self, [Commitment; 4], [Vec; 4]), NovaError> { @@ -362,8 +362,8 @@ impl MemorySumcheckInstance { ((t_plus_r_inv_row, w_plus_r_inv_row), (t_plus_r_row, w_plus_r_row)), ((t_plus_r_inv_col, w_plus_r_inv_col), (t_plus_r_col, w_plus_r_col)), ) = rayon::join( - || helper(&T_row, &W_row, &ts_row, r), - || helper(&T_col, &W_col, &ts_col, r), + || helper(T_row, W_row, &ts_row, r), + || helper(T_col, W_col, &ts_col, r), ); let t_plus_r_inv_row = t_plus_r_inv_row?; @@ -1068,7 +1068,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let u: PolyEvalInstance = PolyEvalInstance::batch(&comm_vec, &tau_coords, &eval_vec, &c); // we now need to prove three claims - // (1) 0 = \sum_x poly_tau(x) * (poly_Az(x) * poly_Bz(x) - poly_uCz_E(x)), and eval_Az_at_tau + r * eval_Az_at_tau + r^2 * eval_Cz_at_tau = (Az+r*Bz+r^2*Cz)(tau) + // (1) 0 = \sum_x poly_tau(x) * (poly_Az(x) * poly_Bz(x) - poly_uCz_E(x)), and eval_Az_at_tau + r * eval_Bz_at_tau + r^2 * eval_Cz_at_tau = (Az+r*Bz+r^2*Cz)(tau) // (2) eval_Az_at_tau + c * eval_Bz_at_tau + c^2 * eval_Cz_at_tau = \sum_y L_row(y) * (val_A(y) + c * val_B(y) + c^2 * val_C(y)) * L_col(y) // (3) L_row(i) = eq(tau, row(i)) and L_col(i) = z(col(i)) let gamma = transcript.squeeze(b"g")?; @@ -1139,11 +1139,11 @@ impl> RelaxedR1CSSNARKTrait for Relax MemorySumcheckInstance::new( ck, &r, - T_row, - W_row, + &T_row, + &W_row, pk.S_repr.ts_row.clone(), - T_col, - W_col, + &T_col, + &W_col, pk.S_repr.ts_col.clone(), &mut transcript, ) diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index 4ef833bb2..104c5dfd9 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -287,69 +287,18 @@ impl> RelaxedR1CSSNARKTrait for Relax let (w_vec, u_vec): (Vec>, Vec>) = w_u_vec.into_iter().unzip(); - let w_vec_padded = PolyEvalWitness::pad(&w_vec); // pad the polynomials to be of the same size - let u_vec_padded = PolyEvalInstance::pad(&u_vec); // pad the evaluation points - - // generate a challenge - let rho = transcript.squeeze(b"r")?; - let num_claims = w_vec_padded.len(); - let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = u_vec_padded - .iter() - .zip(powers_of_rho.iter()) - .map(|(u, p)| u.e * p) - .sum(); - - let mut polys_left: Vec> = w_vec_padded - .iter() - .map(|w| MultilinearPolynomial::new(w.p.clone())) - .collect(); - let mut polys_right: Vec> = u_vec_padded - .iter() - .map(|u| MultilinearPolynomial::new(EqPolynomial::new(u.x.clone()).evals())) - .collect(); - - let num_rounds_z = u_vec_padded[0].x.len(); - let comb_func = |poly_A_comp: &E::Scalar, poly_B_comp: &E::Scalar| -> E::Scalar { - *poly_A_comp * *poly_B_comp - }; - let (sc_proof_batch, r_z, claims_batch) = SumcheckProof::prove_quad_batch( - &claim_batch_joint, - num_rounds_z, - &mut polys_left, - &mut polys_right, - &powers_of_rho, - comb_func, - &mut transcript, - )?; - - let (claims_batch_left, _): (Vec, Vec) = claims_batch; - - transcript.absorb(b"l", &claims_batch_left.as_slice()); - // we now combine evaluation claims at the same point rz into one - let gamma = transcript.squeeze(b"g")?; - let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = u_vec_padded - .iter() - .zip(powers_of_gamma.iter()) - .map(|(u, g_i)| u.c * *g_i) - .fold(Commitment::::default(), |acc, item| acc + item); - let poly_joint = PolyEvalWitness::weighted_sum(&w_vec_padded, &powers_of_gamma); - let eval_joint = claims_batch_left - .iter() - .zip(powers_of_gamma.iter()) - .map(|(e, g_i)| *e * *g_i) - .sum(); + let (batched_u, batched_w, sc_proof_batch, claims_batch_left) = + batch_eval_prove(u_vec, w_vec, &mut transcript)?; let eval_arg = EE::prove( ck, &pk.pk_ee, &mut transcript, - &comm_joint, - &poly_joint.p, - &r_z, - &eval_joint, + &batched_u.c, + &batched_w.p, + &batched_u.x, + &batched_u.e, )?; Ok(RelaxedR1CSSNARK { @@ -484,70 +433,175 @@ impl> RelaxedR1CSSNARKTrait for Relax }, ]; - let u_vec_padded = PolyEvalInstance::pad(&u_vec); // pad the evaluation points - - // generate a challenge - let rho = transcript.squeeze(b"r")?; - let num_claims = u_vec.len(); - let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = u_vec - .iter() - .zip(powers_of_rho.iter()) - .map(|(u, p)| u.e * p) - .sum(); - - let num_rounds_z = u_vec_padded[0].x.len(); - let (claim_batch_final, r_z) = - self - .sc_proof_batch - .verify(claim_batch_joint, num_rounds_z, 2, &mut transcript)?; - - let claim_batch_final_expected = { - let poly_rz = EqPolynomial::new(r_z.clone()); - let evals = u_vec_padded - .iter() - .map(|u| poly_rz.evaluate(&u.x)) - .collect::>(); - - evals - .iter() - .zip(self.evals_batch.iter()) - .zip(powers_of_rho.iter()) - .map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i) - .sum() - }; - - if claim_batch_final != claim_batch_final_expected { - return Err(NovaError::InvalidSumcheckProof); - } - - transcript.absorb(b"l", &self.evals_batch.as_slice()); - - // we now combine evaluation claims at the same point rz into one - let gamma = transcript.squeeze(b"g")?; - let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = u_vec_padded - .iter() - .zip(powers_of_gamma.iter()) - .map(|(u, g_i)| u.c * *g_i) - .fold(Commitment::::default(), |acc, item| acc + item); - let eval_joint = self - .evals_batch - .iter() - .zip(powers_of_gamma.iter()) - .map(|(e, g_i)| *e * *g_i) - .sum(); + let batched_u = batch_eval_verify( + u_vec, + &mut transcript, + &self.sc_proof_batch, + &self.evals_batch, + )?; // verify EE::verify( &vk.vk_ee, &mut transcript, - &comm_joint, - &r_z, - &eval_joint, + &batched_u.c, + &batched_u.x, + &batched_u.e, &self.eval_arg, )?; Ok(()) } } + +/// Proves a batch of polynomial evaluation claims using Sumcheck +/// reducing them to a single claim at the same point. +fn batch_eval_prove( + u_vec: Vec>, + w_vec: Vec>, + transcript: &mut E::TE, +) -> Result< + ( + PolyEvalInstance, + PolyEvalWitness, + SumcheckProof, + Vec, + ), + NovaError, +> { + assert_eq!(u_vec.len(), w_vec.len()); + + let w_vec_padded = PolyEvalWitness::pad(w_vec); // pad the polynomials to be of the same size + let u_vec_padded = PolyEvalInstance::pad(u_vec); // pad the evaluation points + + // generate a challenge + let rho = transcript.squeeze(b"r")?; + let num_claims = w_vec_padded.len(); + let powers_of_rho = powers::(&rho, num_claims); + let claim_batch_joint = u_vec_padded + .iter() + .zip(powers_of_rho.iter()) + .map(|(u, p)| u.e * p) + .sum(); + + let mut polys_left: Vec> = w_vec_padded + .iter() + .map(|w| MultilinearPolynomial::new(w.p.clone())) + .collect(); + let mut polys_right: Vec> = u_vec_padded + .iter() + .map(|u| MultilinearPolynomial::new(EqPolynomial::new(u.x.clone()).evals())) + .collect(); + + let num_rounds_z = u_vec_padded[0].x.len(); + let comb_func = + |poly_A_comp: &E::Scalar, poly_B_comp: &E::Scalar| -> E::Scalar { *poly_A_comp * *poly_B_comp }; + let (sc_proof_batch, r_z, claims_batch) = SumcheckProof::prove_quad_batch( + &claim_batch_joint, + num_rounds_z, + &mut polys_left, + &mut polys_right, + &powers_of_rho, + comb_func, + transcript, + )?; + + let (claims_batch_left, _): (Vec, Vec) = claims_batch; + + transcript.absorb(b"l", &claims_batch_left.as_slice()); + + // we now combine evaluation claims at the same point rz into one + let gamma = transcript.squeeze(b"g")?; + let powers_of_gamma: Vec = powers::(&gamma, num_claims); + let comm_joint = u_vec_padded + .iter() + .zip(powers_of_gamma.iter()) + .map(|(u, g_i)| u.c * *g_i) + .fold(Commitment::::default(), |acc, item| acc + item); + let poly_joint = PolyEvalWitness::weighted_sum(&w_vec_padded, &powers_of_gamma); + let eval_joint = claims_batch_left + .iter() + .zip(powers_of_gamma.iter()) + .map(|(e, g_i)| *e * *g_i) + .sum(); + + Ok(( + PolyEvalInstance:: { + c: comm_joint, + x: r_z, + e: eval_joint, + }, + poly_joint, + sc_proof_batch, + claims_batch_left, + )) +} + +/// Verifies a batch of polynomial evaluation claims using Sumcheck +/// reducing them to a single claim at the same point. +fn batch_eval_verify( + u_vec: Vec>, + transcript: &mut E::TE, + sc_proof_batch: &SumcheckProof, + evals_batch: &[E::Scalar], +) -> Result, NovaError> { + assert_eq!(evals_batch.len(), evals_batch.len()); + + let u_vec_padded = PolyEvalInstance::pad(u_vec); // pad the evaluation points + + // generate a challenge + let rho = transcript.squeeze(b"r")?; + let num_claims: usize = u_vec_padded.len(); + let powers_of_rho = powers::(&rho, num_claims); + let claim_batch_joint = u_vec_padded + .iter() + .zip(powers_of_rho.iter()) + .map(|(u, p)| u.e * p) + .sum(); + + let num_rounds_z = u_vec_padded[0].x.len(); + + let (claim_batch_final, r_z) = + sc_proof_batch.verify(claim_batch_joint, num_rounds_z, 2, transcript)?; + + let claim_batch_final_expected = { + let poly_rz = EqPolynomial::new(r_z.clone()); + let evals = u_vec_padded + .iter() + .map(|u| poly_rz.evaluate(&u.x)) + .collect::>(); + + evals + .iter() + .zip(evals_batch.iter()) + .zip(powers_of_rho.iter()) + .map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i) + .sum() + }; + + if claim_batch_final != claim_batch_final_expected { + return Err(NovaError::InvalidSumcheckProof); + } + + transcript.absorb(b"l", &evals_batch); + + // we now combine evaluation claims at the same point rz into one + let gamma = transcript.squeeze(b"g")?; + let powers_of_gamma: Vec = powers::(&gamma, num_claims); + let comm_joint = u_vec_padded + .iter() + .zip(powers_of_gamma.iter()) + .map(|(u, g_i)| u.c * *g_i) + .fold(Commitment::::default(), |acc, item| acc + item); + let eval_joint = evals_batch + .iter() + .zip(powers_of_gamma.iter()) + .map(|(e, g_i)| *e * *g_i) + .sum(); + + Ok(PolyEvalInstance:: { + c: comm_joint, + x: r_z, + e: eval_joint, + }) +}