Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper ZK treatment in plonky2 #1625

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions plonky2/src/batch_fri/oracle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(not(feature = "std"))]
use alloc::{format, vec::Vec};
use alloc::{format, vec, vec::Vec};

use itertools::Itertools;
use plonky2_field::extension::Extendable;
Expand All @@ -19,6 +19,7 @@ use crate::hash::batch_merkle_tree::BatchMerkleTree;
use crate::hash::hash_types::RichField;
use crate::iop::challenger::Challenger;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkOracle;
use crate::timed;
use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
Expand Down Expand Up @@ -151,9 +152,15 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
// where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum.
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
// There are usually two batches for the openings at `zeta` and `g * zeta`.
// The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`.
for FriBatchInfo { point, polynomials } in &instance.batches {
for (idx, FriBatchInfo { point, polynomials }) in instance.batches.iter().enumerate() {
let is_zk = fri_params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.sum();
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
// Collect the coefficients of all the polynomials in `polynomials`.
let polys_coeff = polynomials.iter().map(|fri_poly| {
let polys_coeff = polynomials[..last_poly].iter().map(|fri_poly| {
&oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index]
});
let composition_poly = timed!(
Expand All @@ -165,6 +172,28 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
quotient.coeffs.push(F::Extension::ZERO); // pad back to power of two
alpha.shift_poly(&mut final_poly);
final_poly += quotient;

if is_zk && idx == 0 {
let degree = 1 << degree_bits[i];
let mut composition_poly = PolynomialCoeffs::empty();
polynomials[last_poly..]
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.enumerate()
.for_each(|(i, fri_poly)| {
let mut cur_coeffs = oracles[fri_poly.oracle_index].polynomials
[fri_poly.polynomial_index]
.coeffs
.clone();
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; degree * i]);
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; 2 * degree - cur_coeffs.len()]);
composition_poly += PolynomialCoeffs { coeffs: cur_coeffs };
});

alpha.shift_poly(&mut final_poly);
final_poly += composition_poly.to_extension();
}
}

assert_eq!(final_poly.len(), 1 << degree_bits[i]);
Expand Down
8 changes: 6 additions & 2 deletions plonky2/src/batch_fri/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pub fn batch_fri_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>,
fri_params: &FriParams,
timing: &mut TimingTree,
) -> FriProof<F, C::Hasher, D> {
let n = lde_polynomial_coeffs.len();
assert_eq!(lde_polynomial_values[0].len(), n);
let mut n = lde_polynomial_coeffs.len();
assert_eq!(lde_polynomial_values[0].len(), lde_polynomial_coeffs.len());
// The polynomial vectors should be sorted by degree, from largest to smallest, with no duplicate degrees.
assert!(lde_polynomial_values
.windows(2)
Expand All @@ -49,6 +49,10 @@ pub fn batch_fri_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>,
}
assert_eq!(cur_poly_index, lde_polynomial_values.len());

if fri_params.hiding {
n /= 2;
}

// Commit phase
let (trees, final_coeffs) = timed!(
timing,
Expand Down
48 changes: 43 additions & 5 deletions plonky2/src/batch_fri/recursive_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use alloc::{format, vec::Vec};

use itertools::Itertools;
use plonky2_field::types::Field;

use crate::field::extension::Extendable;
use crate::fri::proof::{
Expand All @@ -15,6 +16,7 @@ use crate::iop::ext_target::{flatten_target, ExtensionTarget};
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::config::{AlgebraicHasher, GenericConfig};
use crate::plonk::plonk_common::PlonkOracle;
use crate::util::reducing::ReducingFactorTarget;
use crate::with_context;

Expand Down Expand Up @@ -62,7 +64,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
PrecomputedReducedOpeningsTarget::from_os_and_alpha(
opn,
challenges.fri_alpha,
self
self,
params.hiding,
)
);
precomputed_reduced_evals.push(pre);
Expand Down Expand Up @@ -165,13 +168,20 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut alpha = ReducingFactorTarget::new(alpha);
let mut sum = self.zero_extension();

for (batch, reduced_openings) in instance[index]
for (idx, (batch, reduced_openings)) in instance[index]
.batches
.iter()
.zip(&precomputed_reduced_evals.reduced_openings_at_point)
.enumerate()
{
let FriBatchInfoTarget { point, polynomials } = batch;
let evals = polynomials
let is_zk = params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.sum();
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
let evals = polynomials[..last_poly]
.iter()
.map(|p| {
let poly_blinding = instance[index].oracles[p.oracle_index].blinding;
Expand All @@ -184,6 +194,30 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let denominator = self.sub_extension(subgroup_x, *point);
sum = alpha.shift(sum, self);
sum = self.div_add_extension(numerator, denominator, sum);

if is_zk && idx == 0 {
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, p)| {
let poly_blinding = instance[index].oracles[p.oracle_index].blinding;
let salted = params.hiding && poly_blinding;
let eval = proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted);
sum = alpha.shift(sum, self);
let val = self
.constant_extension(F::Extension::from_canonical_u32((i == 0) as u32));
let power =
self.exp_power_of_2_extension(subgroup_x, i * params.degree_bits);
let pi =
self.constant_extension(F::Extension::from_canonical_u32(i as u32));
let power = self.mul_extension(power, pi);
let shift_val = self.add_extension(val, power);

let eval_extension = eval.to_ext_target(self.zero());
let tmp = self.mul_extension(eval_extension, shift_val);
sum = self.add_extension(sum, tmp);
});
}
}

sum
Expand All @@ -210,7 +244,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
Self::assert_noncanonical_indices_ok(&params.config);
let mut x_index_bits = self.low_bits(x_index, n, F::BITS);

let cap_index =
let initial_cap_index =
self.le_sum(x_index_bits[x_index_bits.len() - params.config.cap_height..].iter());
with_context!(
self,
Expand All @@ -221,7 +255,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&x_index_bits,
&round_proof.initial_trees_proof,
initial_merkle_caps,
cap_index
initial_cap_index
)
);

Expand Down Expand Up @@ -252,6 +286,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
);
batch_index += 1;

let cap_index = self.le_sum(
x_index_bits[x_index_bits.len() + params.hiding as usize - params.config.cap_height..]
.iter(),
);
for (i, &arity_bits) in params.reduction_arity_bits.iter().enumerate() {
let evals = &round_proof.steps[i].evals;

Expand Down
31 changes: 28 additions & 3 deletions plonky2/src/batch_fri/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::hash::hash_types::RichField;
use crate::hash::merkle_proofs::{verify_batch_merkle_proof_to_cap, verify_merkle_proof_to_cap};
use crate::hash::merkle_tree::MerkleCap;
use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::plonk_common::PlonkOracle;
use crate::util::reducing::ReducingFactor;
use crate::util::reverse_bits;

Expand Down Expand Up @@ -46,7 +47,8 @@ pub fn verify_batch_fri_proof<

let mut precomputed_reduced_evals = Vec::with_capacity(openings.len());
for opn in openings {
let pre = PrecomputedReducedOpenings::from_os_and_alpha(opn, challenges.fri_alpha);
let pre =
PrecomputedReducedOpenings::from_os_and_alpha(opn, challenges.fri_alpha, params.hiding);
precomputed_reduced_evals.push(pre);
}
let degree_bits = degree_bits
Expand Down Expand Up @@ -123,13 +125,20 @@ fn batch_fri_combine_initial<
let mut alpha = ReducingFactor::new(alpha);
let mut sum = F::Extension::ZERO;

for (batch, reduced_openings) in instances[index]
for (idx, (batch, reduced_openings)) in instances[index]
.batches
.iter()
.zip(&precomputed_reduced_evals.reduced_openings_at_point)
.enumerate()
{
let FriBatchInfo { point, polynomials } = batch;
let evals = polynomials
let is_zk = params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.sum();
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
let evals = polynomials[..last_poly]
.iter()
.map(|p| {
let poly_blinding = instances[index].oracles[p.oracle_index].blinding;
Expand All @@ -142,6 +151,22 @@ fn batch_fri_combine_initial<
let denominator = subgroup_x - *point;
sum = alpha.shift(sum);
sum += numerator / denominator;

if is_zk && idx == 0 {
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, p)| {
let poly_blinding = instances[index].oracles[p.oracle_index].blinding;
let salted = params.hiding && poly_blinding;
let eval = proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted);
sum = alpha.shift(sum);
let shift_val = F::Extension::from_canonical_usize((i == 0) as usize)
+ subgroup_x.exp_power_of_2(i * params.degree_bits)
* F::Extension::from_canonical_usize(i);
sum += F::Extension::from_basefield(eval) * shift_val;
});
}
}

sum
Expand Down
5 changes: 3 additions & 2 deletions plonky2/src/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ impl FriConfig {
self.rate_bits,
self.cap_height,
self.num_query_rounds,
hiding,
);
FriParams {
config: self.clone(),
Expand Down Expand Up @@ -87,7 +88,7 @@ pub struct FriParams {

impl FriParams {
pub fn total_arities(&self) -> usize {
self.reduction_arity_bits.iter().sum()
self.reduction_arity_bits.iter().sum::<usize>()
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
}

pub(crate) fn max_arity_bits(&self) -> Option<usize> {
Expand All @@ -103,7 +104,7 @@ impl FriParams {
}

pub fn final_poly_bits(&self) -> usize {
self.degree_bits - self.total_arities()
self.degree_bits + self.hiding as usize - self.total_arities()
}

pub fn final_poly_len(&self) -> usize {
Expand Down
36 changes: 33 additions & 3 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(not(feature = "std"))]
use alloc::{format, vec::Vec};
use alloc::{format, vec, vec::Vec};

use itertools::Itertools;
use plonky2_field::types::Field;
Expand All @@ -17,6 +17,7 @@ use crate::hash::hash_types::RichField;
use crate::hash::merkle_tree::MerkleTree;
use crate::iop::challenger::Challenger;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkOracle;
use crate::timed;
use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
Expand Down Expand Up @@ -194,9 +195,16 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
// where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum.
// There are usually two batches for the openings at `zeta` and `g * zeta`.
// The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`.
for FriBatchInfo { point, polynomials } in &instance.batches {
let is_zk = fri_params.hiding;

for (idx, FriBatchInfo { point, polynomials }) in instance.batches.iter().enumerate() {
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
.sum();
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
// Collect the coefficients of all the polynomials in `polynomials`.
let polys_coeff = polynomials.iter().map(|fri_poly| {
let polys_coeff = polynomials[..last_poly].iter().map(|fri_poly| {
&oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index]
});
let composition_poly = timed!(
Expand All @@ -208,6 +216,28 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
quotient.coeffs.push(F::Extension::ZERO); // pad back to power of two
alpha.shift_poly(&mut final_poly);
final_poly += quotient;

if is_zk && idx == 0 {
let degree = 1 << oracles[0].degree_log;
let mut composition_poly = PolynomialCoeffs::empty();
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, fri_poly)| {
let mut cur_coeffs = oracles[fri_poly.oracle_index].polynomials
[fri_poly.polynomial_index]
.coeffs
.clone();
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; degree * i]);
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; 2 * degree - cur_coeffs.len()]);
composition_poly += PolynomialCoeffs { coeffs: cur_coeffs };
});

alpha.shift_poly(&mut final_poly);
final_poly += composition_poly.to_extension();
}
}

let lde_final_poly = final_poly.lde(fri_params.config.rate_bits);
Expand Down
Loading