From 610cac0331ca683e1b667f9e59666bd4a0902427 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Thu, 11 Jul 2024 22:17:06 -0700 Subject: [PATCH 01/44] HyperKZG verifier circuit: get started --- jolt-core/Cargo.toml | 5 + jolt-core/src/circuits/mod.rs | 1 + .../poly/commitment/commitment_scheme.rs | 22 ++++ .../src/circuits/poly/commitment/hyperkzg.rs | 124 ++++++++++++++++++ jolt-core/src/circuits/poly/commitment/mod.rs | 2 + jolt-core/src/circuits/poly/mod.rs | 1 + jolt-core/src/lib.rs | 1 + 7 files changed, 156 insertions(+) create mode 100644 jolt-core/src/circuits/mod.rs create mode 100644 jolt-core/src/circuits/poly/commitment/commitment_scheme.rs create mode 100644 jolt-core/src/circuits/poly/commitment/hyperkzg.rs create mode 100644 jolt-core/src/circuits/poly/commitment/mod.rs create mode 100644 jolt-core/src/circuits/poly/mod.rs diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index 218243092..cc114d1ae 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -20,9 +20,13 @@ license-file = "LICENSE" keywords = ["SNARK", "cryptography", "proofs"] [dependencies] +ark-bls12-381 = "0.4.0" ark-bn254 = "0.4.0" +ark-crypto-primitives = { version = "0.4.0", default-features = false, features = ["snark", "sponge"] } ark-ec = { version = "0.4.2", default-features = false } ark-ff = { version = "0.4.2", default-features = false } +ark-groth16 = { version = "0.4.0" } +ark-relations = { version = "0.4.0", default-features = false } ark-serialize = { version = "0.4.2", default-features = false, features = [ "derive", ] } @@ -45,6 +49,7 @@ rayon = { version = "^1.8.0", optional = true } rgb = "0.8.37" serde = { version = "1.0.*", default-features = false } sha3 = "0.10.8" +#sigma0-polymath = { git = "https://github.com/sigma0-xyz/polymath", default-features = false, features = ["std", "parallel"] } smallvec = "1.13.1" strum = "0.25.0" strum_macros = "0.25.2" diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs new file mode 100644 index 000000000..0fc158571 --- /dev/null +++ b/jolt-core/src/circuits/mod.rs @@ -0,0 +1 @@ +pub mod poly; diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs new file mode 100644 index 000000000..bc5862c5a --- /dev/null +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -0,0 +1,22 @@ +// use crate::utils::errors::ProofVerifyError; +// use crate::utils::transcript::ProofTranscript; +// +// pub trait CommitmentScheme { +// fn verify( +// proof: &Self::Proof, +// setup: &Self::Setup, +// transcript: &mut ProofTranscript, +// opening_point: &[Self::Field], // point at which the polynomial is evaluated +// opening: &Self::Field, // evaluation \widetilde{Z}(r) +// commitment: &Self::Commitment, +// ) -> Result<(), ProofVerifyError>; +// +// fn batch_verify( +// batch_proof: &Self::BatchedProof, +// setup: &Self::Setup, +// opening_point: &[Self::Field], +// openings: &[Self::Field], +// commitments: &[&Self::Commitment], +// transcript: &mut ProofTranscript, +// ) -> Result<(), ProofVerifyError>; +// } diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs new file mode 100644 index 000000000..6cbe13264 --- /dev/null +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -0,0 +1,124 @@ +use ark_bn254::{Bn254, Fr}; +use ark_ff::Field; +// We'll use these interfaces to construct our circuit. +use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; + +use crate::poly::commitment::hyperkzg::{HyperKZGCommitment, HyperKZGProof, HyperKZGVerifierKey}; + +#[derive(Default)] +struct HyperKZGVerifierCircuit { + _f: std::marker::PhantomData, + // TODO fill in +} + +impl HyperKZGVerifierCircuit { + pub(crate) fn public_inputs( + &self, + vk: &HyperKZGVerifierKey, + comm: &HyperKZGCommitment, + point: &Vec, + eval: &Fr, + proof: &HyperKZGProof, + ) -> Vec { + // TODO fill in + vec![] + } +} + +impl ConstraintSynthesizer for HyperKZGVerifierCircuit { + fn generate_constraints(self, cs: ConstraintSystemRef) -> Result<(), SynthesisError> { + // TODO fill in + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use ark_bn254::{Bn254, Fr}; + use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; + use rand_core::SeedableRng; + + use crate::poly::commitment::hyperkzg::{ + HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, + }; + use crate::poly::dense_mlpoly::DensePolynomial; + use crate::utils::errors::ProofVerifyError; + use crate::utils::transcript::ProofTranscript; + + use super::*; + + #[test] + fn test_hyperkzg_eval() { + type Groth16 = ark_groth16::Groth16; + + // Test with poly(X1, X2) = 1 + X1 + X2 + X1*X2 + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); + let srs = HyperKZGSRS::setup(&mut rng, 3); + let (pk, vk): (HyperKZGProverKey, HyperKZGVerifierKey) = srs.trim(3); + + // poly is in eval. representation; evaluated at [(0,0), (0,1), (1,0), (1,1)] + let poly = DensePolynomial::new(vec![Fr::from(1), Fr::from(2), Fr::from(2), Fr::from(4)]); + + let (cpk, cvk) = { + let circuit = HyperKZGVerifierCircuit::default(); + + Groth16::setup(circuit, &mut rng).unwrap() + }; + let pvk = Groth16::process_vk(&cvk).unwrap(); + + let C = HyperKZG::commit(&pk, &poly).unwrap(); + + let mut test_inner = |point: Vec, eval: Fr| -> Result<(), ProofVerifyError> { + let mut tr = ProofTranscript::new(b"TestEval"); + let proof = HyperKZG::open(&pk, &poly, &point, &eval, &mut tr).unwrap(); + let mut tr = ProofTranscript::new(b"TestEval"); + HyperKZG::verify(&vk, &C, &point, &eval, &proof, &mut tr)?; + + // Create an instance of our circuit (with the + // witness) + let verifier_circuit = HyperKZGVerifierCircuit::default(); + let instance = verifier_circuit.public_inputs(&vk, &C, &point, &eval, &proof); + + // Create a groth16 proof with our parameters. + let proof = Groth16::prove(&cpk, verifier_circuit, &mut rng) + .map_err(|e| ProofVerifyError::InternalError)?; + let result = Groth16::verify_with_processed_vk(&pvk, &instance, &proof); + match result { + Ok(true) => Ok(()), + Ok(false) => Err(ProofVerifyError::InternalError), + Err(_) => Err(ProofVerifyError::InternalError), + } + }; + + // Call the prover with a (point, eval) pair. + // The prover does not recompute so it may produce a proof, but it should not verify + let point = vec![Fr::from(0), Fr::from(0)]; + let eval = Fr::from(1); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![Fr::from(0), Fr::from(1)]; + let eval = Fr::from(2); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![Fr::from(1), Fr::from(1)]; + let eval = Fr::from(4); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![Fr::from(0), Fr::from(2)]; + let eval = Fr::from(3); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![Fr::from(2), Fr::from(2)]; + let eval = Fr::from(9); + assert!(test_inner(point, eval).is_ok()); + + // Try a couple incorrect evaluations and expect failure + let point = vec![Fr::from(2), Fr::from(2)]; + let eval = Fr::from(50); + assert!(test_inner(point, eval).is_err()); + + let point = vec![Fr::from(0), Fr::from(2)]; + let eval = Fr::from(4); + assert!(test_inner(point, eval).is_err()); + } +} diff --git a/jolt-core/src/circuits/poly/commitment/mod.rs b/jolt-core/src/circuits/poly/commitment/mod.rs new file mode 100644 index 000000000..0b1c33650 --- /dev/null +++ b/jolt-core/src/circuits/poly/commitment/mod.rs @@ -0,0 +1,2 @@ +mod commitment_scheme; +pub mod hyperkzg; diff --git a/jolt-core/src/circuits/poly/mod.rs b/jolt-core/src/circuits/poly/mod.rs new file mode 100644 index 000000000..1f9123814 --- /dev/null +++ b/jolt-core/src/circuits/poly/mod.rs @@ -0,0 +1 @@ +pub mod commitment; diff --git a/jolt-core/src/lib.rs b/jolt-core/src/lib.rs index 965f9f090..69cd9a629 100644 --- a/jolt-core/src/lib.rs +++ b/jolt-core/src/lib.rs @@ -16,6 +16,7 @@ pub mod benches; #[cfg(feature = "host")] pub mod host; +pub mod circuits; pub mod field; pub mod jolt; pub mod lasso; From 39c14de48c7fc308136ed636ba3722764f343a61 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Thu, 25 Jul 2024 10:37:23 -0700 Subject: [PATCH 02/44] WIP: HyperKZGVerifierGadget Also, add PairingGadget trait. --- jolt-core/Cargo.toml | 7 +- jolt-core/src/circuits/mod.rs | 1 + jolt-core/src/circuits/pairing/mod.rs | 73 +++++++ .../poly/commitment/commitment_scheme.rs | 53 +++--- .../src/circuits/poly/commitment/hyperkzg.rs | 180 +++++++++++++++--- 5 files changed, 264 insertions(+), 50 deletions(-) create mode 100644 jolt-core/src/circuits/pairing/mod.rs diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index cc114d1ae..ba7648c83 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -22,10 +22,10 @@ keywords = ["SNARK", "cryptography", "proofs"] [dependencies] ark-bls12-381 = "0.4.0" ark-bn254 = "0.4.0" -ark-crypto-primitives = { version = "0.4.0", default-features = false, features = ["snark", "sponge"] } +ark-crypto-primitives = { version = "0.4.0", default-features = false, features = ["snark", "sponge", "r1cs"] } ark-ec = { version = "0.4.2", default-features = false } ark-ff = { version = "0.4.2", default-features = false } -ark-groth16 = { version = "0.4.0" } +ark-r1cs-std = { version = "0.4.0" } ark-relations = { version = "0.4.0", default-features = false } ark-serialize = { version = "0.4.2", default-features = false, features = [ "derive", @@ -49,7 +49,6 @@ rayon = { version = "^1.8.0", optional = true } rgb = "0.8.37" serde = { version = "1.0.*", default-features = false } sha3 = "0.10.8" -#sigma0-polymath = { git = "https://github.com/sigma0-xyz/polymath", default-features = false, features = ["std", "parallel"] } smallvec = "1.13.1" strum = "0.25.0" strum_macros = "0.25.2" @@ -76,8 +75,10 @@ tokio = { version = "1.38.0", optional = true } [dev-dependencies] +ark-groth16 = { version = "0.4.0" } criterion = { version = "0.5.1", features = ["html_reports"] } iai-callgrind = "0.10.2" +#sigma0-polymath = { git = "https://github.com/sigma0-xyz/polymath" } [build-dependencies] common = { path = "../common" } diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 0fc158571..0d8caf690 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -1 +1,2 @@ +mod pairing; pub mod poly; diff --git a/jolt-core/src/circuits/pairing/mod.rs b/jolt-core/src/circuits/pairing/mod.rs new file mode 100644 index 000000000..d24727022 --- /dev/null +++ b/jolt-core/src/circuits/pairing/mod.rs @@ -0,0 +1,73 @@ +use ark_ec::pairing::Pairing; +use ark_ff::PrimeField; +use ark_r1cs_std::prelude::*; +use ark_relations::r1cs::SynthesisError; +use ark_std::fmt::Debug; + +/// Specifies the constraints for computing a pairing in the bilinear group +/// `E`. +pub trait PairingGadget { + /// A variable representing an element of `G1`. + /// This is the R1CS equivalent of `E::G1Projective`. + type G1Var: CurveVar; + + /// A variable representing an element of `G2`. + /// This is the R1CS equivalent of `E::G2Projective`. + type G2Var: CurveVar; + + /// A variable representing an element of `GT`. + /// This is the R1CS equivalent of `E::GT`. + type GTVar: FieldVar; + + /// A variable representing cached precomputation that can speed up + /// pairings computations. This is the R1CS equivalent of + /// `E::G1Prepared`. + type G1PreparedVar: ToBytesGadget + + AllocVar + + Clone + + Debug; + /// A variable representing cached precomputation that can speed up + /// pairings computations. This is the R1CS equivalent of + /// `E::G2Prepared`. + type G2PreparedVar: ToBytesGadget + + AllocVar + + Clone + + Debug; + + /// Computes a multi-miller loop between elements + /// of `p` and `q`. + fn miller_loop( + p: &[Self::G1PreparedVar], + q: &[Self::G2PreparedVar], + ) -> Result; + + /// Computes a final exponentiation over `p`. + fn final_exponentiation(p: &Self::GTVar) -> Result; + + /// Computes a pairing over `p` and `q`. + #[tracing::instrument(target = "r1cs")] + fn pairing( + p: Self::G1PreparedVar, + q: Self::G2PreparedVar, + ) -> Result { + let tmp = Self::miller_loop(&[p], &[q])?; + Self::final_exponentiation(&tmp) + } + + /// Computes a product of pairings over the elements in `p` and `q`. + #[must_use] + #[tracing::instrument(target = "r1cs")] + fn multi_pairing( + p: &[Self::G1PreparedVar], + q: &[Self::G2PreparedVar], + ) -> Result { + let miller_result = Self::miller_loop(p, q)?; + Self::final_exponentiation(&miller_result) + } + + /// Performs the precomputation to generate `Self::G1PreparedVar`. + fn prepare_g1(q: &Self::G1Var) -> Result; + + /// Performs the precomputation to generate `Self::G2PreparedVar`. + fn prepare_g2(q: &Self::G2Var) -> Result; +} diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs index bc5862c5a..abbee6cf7 100644 --- a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -1,22 +1,31 @@ -// use crate::utils::errors::ProofVerifyError; -// use crate::utils::transcript::ProofTranscript; -// -// pub trait CommitmentScheme { -// fn verify( -// proof: &Self::Proof, -// setup: &Self::Setup, -// transcript: &mut ProofTranscript, -// opening_point: &[Self::Field], // point at which the polynomial is evaluated -// opening: &Self::Field, // evaluation \widetilde{Z}(r) -// commitment: &Self::Commitment, -// ) -> Result<(), ProofVerifyError>; -// -// fn batch_verify( -// batch_proof: &Self::BatchedProof, -// setup: &Self::Setup, -// opening_point: &[Self::Field], -// openings: &[Self::Field], -// commitments: &[&Self::Commitment], -// transcript: &mut ProofTranscript, -// ) -> Result<(), ProofVerifyError>; -// } +use ark_crypto_primitives::sponge::constraints::SpongeWithGadget; +use ark_ec::pairing::Pairing; +use ark_ff::PrimeField; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::prelude::*; +use ark_relations::r1cs::SynthesisError; + +use crate::poly::commitment::commitment_scheme::CommitmentScheme; + +pub trait CommitmentVerifierGadget< + F: PrimeField, + ConstraintF: PrimeField, + C: CommitmentScheme, +> +{ + type VerifyingKeyVar: AllocVar + Clone; + type ProofVar: AllocVar + Clone; + type CommitmentVar: AllocVar + Clone; + + // type Field: FieldVar; // TODO replace FpVar with Field: FieldVar + type TranscriptGadget: SpongeWithGadget + Clone; // TODO requires F: PrimeField, we want to generalize to JoltField + + fn verify( + proof: &Self::ProofVar, + vk: &Self::VerifyingKeyVar, + transcript: &mut Self::TranscriptGadget, + opening_point: &[FpVar], + opening: &FpVar, + commitment: &Self::CommitmentVar, + ) -> Result, SynthesisError>; +} diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 6cbe13264..581939e77 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,9 +1,133 @@ -use ark_bn254::{Bn254, Fr}; -use ark_ff::Field; -// We'll use these interfaces to construct our circuit. -use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; +use std::borrow::Borrow; + +use crate::circuits::pairing::PairingGadget; +use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; +use crate::field::JoltField; +use crate::poly::commitment::hyperkzg::{ + HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, +}; +use ark_bn254::{Bn254, Fr as BN254Fr}; +use ark_crypto_primitives::sponge::poseidon::PoseidonSponge; +use ark_ec::pairing::Pairing; +use ark_ff::{Field, PrimeField}; +use ark_r1cs_std::boolean::Boolean; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::pairing::PairingVar; +use ark_r1cs_std::prelude::*; +use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::marker::PhantomData; + +#[derive(Clone)] +pub struct HyperKZGProofVar> { + _e: PhantomData, + _p: PhantomData

, + _constraint_f: PhantomData, + // TODO fill in +} + +impl AllocVar, ConstraintF> + for HyperKZGProofVar +where + E: Pairing, + ConstraintF: PrimeField, + P: PairingGadget, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + todo!() + } +} + +#[derive(Clone)] +pub struct HyperKZGCommitmentVar< + E: Pairing, + ConstraintF: PrimeField, + P: PairingGadget, +> { + _e: PhantomData, + _p: PhantomData

, + _constraint_f: PhantomData, + // TODO fill in +} + +impl AllocVar, ConstraintF> + for HyperKZGCommitmentVar +where + E: Pairing, + ConstraintF: PrimeField, + P: PairingGadget, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + todo!() + } +} -use crate::poly::commitment::hyperkzg::{HyperKZGCommitment, HyperKZGProof, HyperKZGVerifierKey}; +#[derive(Clone)] +pub struct HyperKZGVerifierKeyVar< + E: Pairing, + ConstraintF: PrimeField, + P: PairingGadget, +> { + _e: PhantomData, + _p: PhantomData

, + _constraint_f: PhantomData, + // TODO fill in +} + +impl, ConstraintF: PrimeField> + AllocVar<(HyperKZGProverKey, HyperKZGVerifierKey), ConstraintF> + for HyperKZGVerifierKeyVar +{ + fn new_variable, HyperKZGVerifierKey)>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + todo!() + } +} + +pub struct HyperKZGVerifierGadget +where + E: Pairing, + P: PairingGadget, +{ + _e: PhantomData, + _p: PhantomData

, + _constraint_f: PhantomData, +} + +impl CommitmentVerifierGadget> + for HyperKZGVerifierGadget +where + E: Pairing, + P: PairingGadget + Clone, + ConstraintF: PrimeField, + F: PrimeField + JoltField, +{ + type VerifyingKeyVar = HyperKZGVerifierKeyVar; + type ProofVar = HyperKZGProofVar; + type CommitmentVar = HyperKZGCommitmentVar; + type TranscriptGadget = PoseidonSponge; + + fn verify( + proof: &Self::ProofVar, + vk: &Self::VerifyingKeyVar, + transcript: &mut Self::TranscriptGadget, + opening_point: &[FpVar], + opening: &FpVar, + commitment: &Self::CommitmentVar, + ) -> Result, SynthesisError> { + todo!() + } +} #[derive(Default)] struct HyperKZGVerifierCircuit { @@ -16,8 +140,8 @@ impl HyperKZGVerifierCircuit { &self, vk: &HyperKZGVerifierKey, comm: &HyperKZGCommitment, - point: &Vec, - eval: &Fr, + point: &Vec, + eval: &BN254Fr, proof: &HyperKZGProof, ) -> Vec { // TODO fill in @@ -34,7 +158,8 @@ impl ConstraintSynthesizer for HyperKZGVerifierCircuit { #[cfg(test)] mod tests { - use ark_bn254::{Bn254, Fr}; + use ark_bls12_381::Bls12_381; + use ark_bn254::{Bn254, Fr as BN254Fr}; use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; use rand_core::SeedableRng; @@ -49,7 +174,7 @@ mod tests { #[test] fn test_hyperkzg_eval() { - type Groth16 = ark_groth16::Groth16; + type Groth16 = ark_groth16::Groth16; // Test with poly(X1, X2) = 1 + X1 + X2 + X1*X2 let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); @@ -57,7 +182,12 @@ mod tests { let (pk, vk): (HyperKZGProverKey, HyperKZGVerifierKey) = srs.trim(3); // poly is in eval. representation; evaluated at [(0,0), (0,1), (1,0), (1,1)] - let poly = DensePolynomial::new(vec![Fr::from(1), Fr::from(2), Fr::from(2), Fr::from(4)]); + let poly = DensePolynomial::new(vec![ + BN254Fr::from(1), + BN254Fr::from(2), + BN254Fr::from(2), + BN254Fr::from(4), + ]); let (cpk, cvk) = { let circuit = HyperKZGVerifierCircuit::default(); @@ -68,7 +198,7 @@ mod tests { let C = HyperKZG::commit(&pk, &poly).unwrap(); - let mut test_inner = |point: Vec, eval: Fr| -> Result<(), ProofVerifyError> { + let mut test_inner = |point: Vec, eval: BN254Fr| -> Result<(), ProofVerifyError> { let mut tr = ProofTranscript::new(b"TestEval"); let proof = HyperKZG::open(&pk, &poly, &point, &eval, &mut tr).unwrap(); let mut tr = ProofTranscript::new(b"TestEval"); @@ -92,33 +222,33 @@ mod tests { // Call the prover with a (point, eval) pair. // The prover does not recompute so it may produce a proof, but it should not verify - let point = vec![Fr::from(0), Fr::from(0)]; - let eval = Fr::from(1); + let point = vec![BN254Fr::from(0), BN254Fr::from(0)]; + let eval = BN254Fr::from(1); assert!(test_inner(point, eval).is_ok()); - let point = vec![Fr::from(0), Fr::from(1)]; - let eval = Fr::from(2); + let point = vec![BN254Fr::from(0), BN254Fr::from(1)]; + let eval = BN254Fr::from(2); assert!(test_inner(point, eval).is_ok()); - let point = vec![Fr::from(1), Fr::from(1)]; - let eval = Fr::from(4); + let point = vec![BN254Fr::from(1), BN254Fr::from(1)]; + let eval = BN254Fr::from(4); assert!(test_inner(point, eval).is_ok()); - let point = vec![Fr::from(0), Fr::from(2)]; - let eval = Fr::from(3); + let point = vec![BN254Fr::from(0), BN254Fr::from(2)]; + let eval = BN254Fr::from(3); assert!(test_inner(point, eval).is_ok()); - let point = vec![Fr::from(2), Fr::from(2)]; - let eval = Fr::from(9); + let point = vec![BN254Fr::from(2), BN254Fr::from(2)]; + let eval = BN254Fr::from(9); assert!(test_inner(point, eval).is_ok()); // Try a couple incorrect evaluations and expect failure - let point = vec![Fr::from(2), Fr::from(2)]; - let eval = Fr::from(50); + let point = vec![BN254Fr::from(2), BN254Fr::from(2)]; + let eval = BN254Fr::from(50); assert!(test_inner(point, eval).is_err()); - let point = vec![Fr::from(0), Fr::from(2)]; - let eval = Fr::from(4); + let point = vec![BN254Fr::from(0), BN254Fr::from(2)]; + let eval = BN254Fr::from(4); assert!(test_inner(point, eval).is_err()); } } From d948df5fc0db345ae87dfe7f4f04ce9302097e3f Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 28 Jul 2024 15:51:28 -0700 Subject: [PATCH 03/44] WIP: PairingGadget --- jolt-core/Cargo.toml | 2 + jolt-core/src/circuits/fields/fp2.rs | 27 + jolt-core/src/circuits/fields/mod.rs | 3 + .../circuits/fields/quadratic_extension.rs | 589 +++++++++++ jolt-core/src/circuits/groups/curves/mod.rs | 1 + .../curves/short_weierstrass/bls12/mod.rs | 314 ++++++ .../groups/curves/short_weierstrass/mod.rs | 975 ++++++++++++++++++ .../short_weierstrass/non_zero_affine.rs | 399 +++++++ jolt-core/src/circuits/groups/mod.rs | 1 + jolt-core/src/circuits/mod.rs | 4 +- jolt-core/src/circuits/pairing/bls12/mod.rs | 178 ++++ jolt-core/src/circuits/pairing/mod.rs | 70 ++ 12 files changed, 2562 insertions(+), 1 deletion(-) create mode 100644 jolt-core/src/circuits/fields/fp2.rs create mode 100644 jolt-core/src/circuits/fields/mod.rs create mode 100644 jolt-core/src/circuits/fields/quadratic_extension.rs create mode 100644 jolt-core/src/circuits/groups/curves/mod.rs create mode 100644 jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs create mode 100644 jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs create mode 100644 jolt-core/src/circuits/groups/curves/short_weierstrass/non_zero_affine.rs create mode 100644 jolt-core/src/circuits/groups/mod.rs create mode 100644 jolt-core/src/circuits/pairing/bls12/mod.rs diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index ba7648c83..b511c8df7 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -33,6 +33,7 @@ ark-serialize = { version = "0.4.2", default-features = false, features = [ ark-std = { version = "0.4.0" } binius-field = { git = "https://gitlab.com/UlvetannaOSS/binius", package = "binius_field"} clap = { version = "4.3.10", features = ["derive"] } +derivative = { version = "2" } enum_dispatch = "0.3.12" fixedbitset = "0.5.0" itertools = "0.10.0" @@ -76,6 +77,7 @@ tokio = { version = "1.38.0", optional = true } [dev-dependencies] ark-groth16 = { version = "0.4.0" } +ark-test-curves = { version = "0.4.0", default-features = false, features = ["bls12_381_curve", "mnt6_753"] } criterion = { version = "0.5.1", features = ["html_reports"] } iai-callgrind = "0.10.2" #sigma0-polymath = { git = "https://github.com/sigma0-xyz/polymath" } diff --git a/jolt-core/src/circuits/fields/fp2.rs b/jolt-core/src/circuits/fields/fp2.rs new file mode 100644 index 000000000..87050befe --- /dev/null +++ b/jolt-core/src/circuits/fields/fp2.rs @@ -0,0 +1,27 @@ +use crate::circuits::fields::quadratic_extension::*; +use ark_ff::fields::{Fp2Config, Fp2ConfigWrapper, QuadExtConfig}; +use ark_ff::PrimeField; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; + +/// A quadratic extension field constructed over a prime field. +/// This is the R1CS equivalent of `ark_ff::Fp2

`. +pub type Fp2Var = QuadExtVar< + NonNativeFieldVar<

::Fp, ConstraintF>, + ConstraintF, + Fp2ConfigWrapper

, +>; + +impl QuadExtVarConfig, ConstraintF> + for Fp2ConfigWrapper

+where + P: Fp2Config, + ConstraintF: PrimeField, +{ + fn mul_base_field_var_by_frob_coeff( + fe: &mut NonNativeFieldVar, + power: usize, + ) { + *fe *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + } +} diff --git a/jolt-core/src/circuits/fields/mod.rs b/jolt-core/src/circuits/fields/mod.rs new file mode 100644 index 000000000..90f54f974 --- /dev/null +++ b/jolt-core/src/circuits/fields/mod.rs @@ -0,0 +1,3 @@ +// pub mod fp12; +pub mod fp2; +pub mod quadratic_extension; diff --git a/jolt-core/src/circuits/fields/quadratic_extension.rs b/jolt-core/src/circuits/fields/quadratic_extension.rs new file mode 100644 index 000000000..40be6e238 --- /dev/null +++ b/jolt-core/src/circuits/fields/quadratic_extension.rs @@ -0,0 +1,589 @@ +use ark_ff::{ + fields::{Field, QuadExtConfig, QuadExtField}, + PrimeField, Zero, +}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::vec::Vec; +use core::{borrow::Borrow, marker::PhantomData}; + +use ark_r1cs_std::{ + fields::{fp::FpVar, FieldOpsBounds, FieldVar}, + impl_bounded_ops, + prelude::*, + ToConstraintFieldGadget, +}; +use derivative::Derivative; + +/// This struct is the `R1CS` equivalent of the quadratic extension field type +/// in `ark-ff`, i.e. `ark_ff::QuadExtField`. +#[derive(Derivative)] +#[derivative(Debug(bound = "BF: core::fmt::Debug"), Clone(bound = "BF: Clone"))] +#[must_use] +pub struct QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// The zero-th coefficient of this field element. + pub c0: BF, + /// The first coefficient of this field element. + pub c1: BF, + #[derivative(Debug = "ignore")] + _params: PhantomData

, + #[derivative(Debug = "ignore")] + _constraint_f: PhantomData, +} + +/// This trait describes parameters that are used to implement arithmetic for +/// `QuadExtVar`. +pub trait QuadExtVarConfig: QuadExtConfig +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'a> &'a BF: FieldOpsBounds<'a, Self::BaseField, BF>, +{ + /// Multiply the base field of the `QuadExtVar` by the appropriate Frobenius + /// coefficient. This is equivalent to + /// `Self::mul_base_field_by_frob_coeff(power)`. + fn mul_base_field_var_by_frob_coeff(fe: &mut BF, power: usize); +} + +impl QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// Constructs a `QuadExtVar` from the underlying coefficients. + pub fn new(c0: BF, c1: BF) -> Self { + Self { + c0, + c1, + _params: PhantomData, + _constraint_f: PhantomData, + } + } + + /// Multiplies a variable of the base field by the quadratic nonresidue + /// `P::NONRESIDUE` that is used to construct the extension field. + #[inline] + pub fn mul_base_field_by_nonresidue(fe: &BF) -> Result { + Ok(fe * P::NONRESIDUE) + } + + /// Multiplies `self` by a constant from the base field. + #[inline] + pub fn mul_by_base_field_constant(&self, fe: P::BaseField) -> Self { + let c0 = self.c0.clone() * fe; + let c1 = self.c1.clone() * fe; + QuadExtVar::new(c0, c1) + } + + /// Sets `self = self.mul_by_base_field_constant(fe)`. + #[inline] + pub fn mul_assign_by_base_field_constant(&mut self, fe: P::BaseField) { + *self = (&*self).mul_by_base_field_constant(fe); + } + + /// This is only to be used when the element is *known* to be in the + /// cyclotomic subgroup. + #[inline] + pub fn unitary_inverse(&self) -> Result { + Ok(Self::new(self.c0.clone(), self.c1.negate()?)) + } + + /// This is only to be used when the element is *known* to be in the + /// cyclotomic subgroup. + #[inline] + #[tracing::instrument(target = "r1cs", skip(exponent))] + pub fn cyclotomic_exp(&self, exponent: impl AsRef<[u64]>) -> Result + where + Self: FieldVar, ConstraintF>, + { + let mut res = Self::one(); + let self_inverse = self.unitary_inverse()?; + + let mut found_nonzero = false; + let naf = ark_ff::biginteger::arithmetic::find_naf(exponent.as_ref()); + + for &value in naf.iter().rev() { + if found_nonzero { + res.square_in_place()?; + } + + if value != 0 { + found_nonzero = true; + + if value > 0 { + res *= self; + } else { + res *= &self_inverse; + } + } + } + + Ok(res) + } +} + +impl R1CSVar for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + type Value = QuadExtField

; + + fn cs(&self) -> ConstraintSystemRef { + [&self.c0, &self.c1].cs() + } + + #[inline] + fn value(&self) -> Result { + match (self.c0.value(), self.c1.value()) { + (Ok(c0), Ok(c1)) => Ok(QuadExtField::new(c0, c1)), + (..) => Err(SynthesisError::AssignmentMissing), + } + } +} + +impl From> for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn from(other: Boolean) -> Self { + let c0 = BF::from(other); + let c1 = BF::zero(); + Self::new(c0, c1) + } +} + +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, QuadExtField

, QuadExtVar> + for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ +} +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, QuadExtField

, QuadExtVar> + for &'a QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, + P: QuadExtVarConfig, +{ +} + +impl FieldVar, ConstraintF> for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn zero() -> Self { + let c0 = BF::zero(); + let c1 = BF::zero(); + Self::new(c0, c1) + } + + fn one() -> Self { + let c0 = BF::one(); + let c1 = BF::zero(); + Self::new(c0, c1) + } + + fn constant(other: QuadExtField

) -> Self { + let c0 = BF::constant(other.c0); + let c1 = BF::constant(other.c1); + Self::new(c0, c1) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn double(&self) -> Result { + let c0 = self.c0.double()?; + let c1 = self.c1.double()?; + Ok(Self::new(c0, c1)) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn negate(&self) -> Result { + let mut result = self.clone(); + result.c0.negate_in_place()?; + result.c1.negate_in_place()?; + Ok(result) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn square(&self) -> Result { + // From Libsnark/fp2_gadget.tcc + // Complex multiplication for Fp2: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + + // v0 = c0 - c1 + let mut v0 = &self.c0 - &self.c1; + // v3 = c0 - beta * c1 + let v3 = &self.c0 - &Self::mul_base_field_by_nonresidue(&self.c1)?; + // v2 = c0 * c1 + let v2 = &self.c0 * &self.c1; + + // v0 = (v0 * v3) + v2 + v0 *= &v3; + v0 += &v2; + + let c0 = &v0 + &Self::mul_base_field_by_nonresidue(&v2)?; + let c1 = v2.double()?; + + Ok(Self::new(c0, c1)) + } + + #[tracing::instrument(target = "r1cs")] + fn mul_equals(&self, other: &Self, result: &Self) -> Result<(), SynthesisError> { + // Karatsuba multiplication for Fp2: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // result.c0 = v0 + non_residue * v1 + // result.c1 = (A.c0 + A.c1) * (B.c0 + B.c1) - v0 - v1 + // Enforced with 3 constraints: + // A.c1 * B.c1 = v1 + // A.c0 * B.c0 = result.c0 - non_residue * v1 + // (A.c0+A.c1)*(B.c0+B.c1) = result.c1 + result.c0 + (1 - non_residue) * v1 + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + // Compute v1 + let v1 = &self.c1 * &other.c1; + + // Perform second check + let non_residue_times_v1 = Self::mul_base_field_by_nonresidue(&v1)?; + let rhs = &result.c0 - &non_residue_times_v1; + self.c0.mul_equals(&other.c0, &rhs)?; + + // Last check + let a0_plus_a1 = &self.c0 + &self.c1; + let b0_plus_b1 = &other.c0 + &other.c1; + let one_minus_non_residue_v1 = &v1 - &non_residue_times_v1; + + let tmp = &(&result.c1 + &result.c0) + &one_minus_non_residue_v1; + a0_plus_a1.mul_equals(&b0_plus_b1, &tmp)?; + + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn inverse(&self) -> Result { + let mode = if self.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let inverse = Self::new_variable( + self.cs(), + || { + self.value() + .map(|f| f.inverse().unwrap_or_else(QuadExtField::zero)) + }, + mode, + )?; + self.mul_equals(&inverse, &Self::one())?; + Ok(inverse) + } + + #[tracing::instrument(target = "r1cs")] + fn frobenius_map(&self, power: usize) -> Result { + let mut result = self.clone(); + result.c0.frobenius_map_in_place(power)?; + result.c1.frobenius_map_in_place(power)?; + P::mul_base_field_var_by_frob_coeff(&mut result.c1, power); + Ok(result) + } +} + +impl_bounded_ops!( + QuadExtVar, + QuadExtField

, + Add, + add, + AddAssign, + add_assign, + |this: &'a QuadExtVar, other: &'a QuadExtVar| { + let c0 = &this.c0 + &other.c0; + let c1 = &this.c1 + &other.c1; + QuadExtVar::new(c0, c1) + }, + |this: &'a QuadExtVar, other: QuadExtField

| { + this + QuadExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: QuadExtVarConfig), + for <'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF> +); +impl_bounded_ops!( + QuadExtVar, + QuadExtField

, + Sub, + sub, + SubAssign, + sub_assign, + |this: &'a QuadExtVar, other: &'a QuadExtVar| { + let c0 = &this.c0 - &other.c0; + let c1 = &this.c1 - &other.c1; + QuadExtVar::new(c0, c1) + }, + |this: &'a QuadExtVar, other: QuadExtField

| { + this - QuadExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: QuadExtVarConfig), + for <'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF> +); +impl_bounded_ops!( + QuadExtVar, + QuadExtField

, + Mul, + mul, + MulAssign, + mul_assign, + |this: &'a QuadExtVar, other: &'a QuadExtVar| { + // Karatsuba multiplication for Fp2: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // result.c0 = v0 + non_residue * v1 + // result.c1 = (A.c0 + A.c1) * (B.c0 + B.c1) - v0 - v1 + // Enforced with 3 constraints: + // A.c1 * B.c1 = v1 + // A.c0 * B.c0 = result.c0 - non_residue * v1 + // (A.c0+A.c1)*(B.c0+B.c1) = result.c1 + result.c0 + (1 - non_residue) * v1 + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + let mut result = this.clone(); + let v0 = &this.c0 * &other.c0; + let v1 = &this.c1 * &other.c1; + + result.c1 += &this.c0; + result.c1 *= &other.c0 + &other.c1; + result.c1 -= &v0; + result.c1 -= &v1; + result.c0 = v0 + &QuadExtVar::::mul_base_field_by_nonresidue(&v1).unwrap(); + result + }, + |this: &'a QuadExtVar, other: QuadExtField

| { + this * QuadExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: QuadExtVarConfig), + for <'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF> +); + +impl EqGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let b0 = self.c0.is_eq(&other.c0)?; + let b1 = self.c1.is_eq(&other.c1)?; + b0.and(&b1) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.c0.conditional_enforce_equal(&other.c0, condition)?; + self.c1.conditional_enforce_equal(&other.c1, condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +impl ToBitsGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bits_le()?; + let mut c1 = self.c1.to_bits_le()?; + c0.append(&mut c1); + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bits_le()?; + let mut c1 = self.c1.to_non_unique_bits_le()?; + c0.append(&mut c1); + Ok(c0) + } +} + +impl ToBytesGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bytes()?; + let mut c1 = self.c1.to_bytes()?; + c0.append(&mut c1); + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bytes()?; + let mut c1 = self.c1.to_non_unique_bytes()?; + c0.append(&mut c1); + Ok(c0) + } +} + +impl ToConstraintFieldGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + BF: ToConstraintFieldGadget, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let mut res = Vec::new(); + + res.extend_from_slice(&self.c0.to_constraint_field()?); + res.extend_from_slice(&self.c1.to_constraint_field()?); + + Ok(res) + } +} + +impl CondSelectGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, + P: QuadExtVarConfig, +{ + #[inline] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let c0 = BF::conditionally_select(cond, &true_value.c0, &false_value.c0)?; + let c1 = BF::conditionally_select(cond, &true_value.c1, &false_value.c1)?; + Ok(Self::new(c0, c1)) + } +} + +impl TwoBitLookupGadget for QuadExtVar +where + BF: FieldVar + + TwoBitLookupGadget, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + type TableConstant = QuadExtField

; + + #[tracing::instrument(target = "r1cs")] + fn two_bit_lookup( + b: &[Boolean], + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c0 = BF::two_bit_lookup(b, &c0s)?; + let c1 = BF::two_bit_lookup(b, &c1s)?; + Ok(Self::new(c0, c1)) + } +} + +impl ThreeBitCondNegLookupGadget for QuadExtVar +where + BF: FieldVar + + ThreeBitCondNegLookupGadget, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + type TableConstant = QuadExtField

; + + #[tracing::instrument(target = "r1cs")] + fn three_bit_cond_neg_lookup( + b: &[Boolean], + b0b1: &Boolean, + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c0 = BF::three_bit_cond_neg_lookup(b, b0b1, &c0s)?; + let c1 = BF::three_bit_cond_neg_lookup(b, b0b1, &c1s)?; + Ok(Self::new(c0, c1)) + } +} + +impl AllocVar, ConstraintF> for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let (c0, c1) = match f() { + Ok(fe) => (Ok(fe.borrow().c0), Ok(fe.borrow().c1)), + Err(_) => ( + Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), + ), + }; + + let c0 = BF::new_variable(ark_relations::ns!(cs, "c0"), || c0, mode)?; + let c1 = BF::new_variable(ark_relations::ns!(cs, "c1"), || c1, mode)?; + Ok(Self::new(c0, c1)) + } +} diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs new file mode 100644 index 000000000..dbd7e00e0 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -0,0 +1 @@ +pub mod short_weierstrass; diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs new file mode 100644 index 000000000..84b4fc25f --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs @@ -0,0 +1,314 @@ +use ark_ec::{ + bls12::{Bls12Config, G1Prepared, G2Prepared, TwistType}, + short_weierstrass::Affine as GroupAffine, +}; +use ark_ff::{BitIteratorBE, Field, One}; +use ark_relations::r1cs::{Namespace, SynthesisError}; + +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; +use ark_r1cs_std::prelude::*; +use ark_r1cs_std::{ + fields::{fp::FpVar, FieldVar}, + R1CSVar, +}; +use core::fmt::Debug; +use derivative::Derivative; + +use crate::circuits::fields::fp2::Fp2Var; +use crate::circuits::groups::curves::short_weierstrass::*; +use ark_std::vec::Vec; + +/// Represents a projective point in G1. +pub type G1Var = ProjectiveVar< +

::G1Config, + ConstraintF, + NonNativeFieldVar<

::Fp, ConstraintF>, +>; + +/// Represents an affine point on G1. Should be used only for comparison and +/// when a canonical representation of a point is required, and not for +/// arithmetic. +pub type G1AffineVar = AffineVar< +

::G1Config, + ConstraintF, + NonNativeFieldVar<

::Fp, ConstraintF>, +>; + +/// Represents a projective point in G2. +pub type G2Var = + ProjectiveVar<

::G2Config, ConstraintF, Fp2G>; +/// Represents an affine point on G2. Should be used only for comparison and +/// when a canonical representation of a point is required, and not for +/// arithmetic. +pub type G2AffineVar = + AffineVar<

::G2Config, ConstraintF, Fp2G>; + +/// Represents the cached precomputation that can be performed on a G1 element +/// which enables speeding up pairing computation. +#[derive(Derivative)] +#[derivative( + Clone(bound = "G1Var: Clone"), + Debug(bound = "G1Var: Debug") +)] +pub struct G1PreparedVar( + pub AffineVar>, +); + +impl G1PreparedVar { + /// Returns the value assigned to `self` in the underlying constraint + /// system. + pub fn value(&self) -> Result, SynthesisError> { + let x = self.0.x.value()?; + let y = self.0.y.value()?; + let infinity = self.0.infinity.value()?; + let g = infinity + .then_some(GroupAffine::identity()) + .unwrap_or(GroupAffine::new(x, y)) + .into(); + Ok(g) + } + + /// Constructs `Self` from a `G1Var`. + pub fn from_group_var(q: &G1Var) -> Result { + let g = q.to_affine()?; + Ok(Self(g)) + } +} + +impl AllocVar, ConstraintF> + for G1PreparedVar +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let g1_prep = f().map(|b| b.borrow().0); + + let x = NonNativeFieldVar::new_variable( + ark_relations::ns!(cs, "x"), + || g1_prep.map(|g| g.x), + mode, + )?; + let y = NonNativeFieldVar::new_variable( + ark_relations::ns!(cs, "y"), + || g1_prep.map(|g| g.y), + mode, + )?; + let infinity = Boolean::new_variable( + ark_relations::ns!(cs, "inf"), + || g1_prep.map(|g| g.infinity), + mode, + )?; + let g = AffineVar::new(x, y, infinity); + Ok(Self(g)) + } +} + +impl ToBytesGadget + for G1PreparedVar +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut bytes = self.0.x.to_bytes()?; + let y_bytes = self.0.y.to_bytes()?; + let inf_bytes = self.0.infinity.to_bytes()?; + bytes.extend_from_slice(&y_bytes); + bytes.extend_from_slice(&inf_bytes); + Ok(bytes) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let mut bytes = self.0.x.to_non_unique_bytes()?; + let y_bytes = self.0.y.to_non_unique_bytes()?; + let inf_bytes = self.0.infinity.to_non_unique_bytes()?; + bytes.extend_from_slice(&y_bytes); + bytes.extend_from_slice(&inf_bytes); + Ok(bytes) + } +} + +type Fp2G = Fp2Var<

::Fp2Config, ConstraintF>; +type LCoeff = (Fp2G, Fp2G); +/// Represents the cached precomputation that can be performed on a G2 element +/// which enables speeding up pairing computation. +#[derive(Derivative)] +#[derivative( + Clone(bound = "Fp2Var: Clone"), + Debug(bound = "Fp2Var: Debug") +)] +pub struct G2PreparedVar { + #[doc(hidden)] + pub ell_coeffs: Vec>, +} + +impl AllocVar, ConstraintF> for G2PreparedVar +where + P: Bls12Config, + ConstraintF: PrimeField, +{ + #[tracing::instrument(target = "r1cs", skip(cs, f, mode))] + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let g2_prep = f().map(|b| { + let projective_coeffs = &b.borrow().ell_coeffs; + match P::TWIST_TYPE { + TwistType::M => { + let mut z_s = projective_coeffs + .iter() + .map(|(_, _, z)| *z) + .collect::>(); + ark_ff::fields::batch_inversion(&mut z_s); + projective_coeffs + .iter() + .zip(z_s) + .map(|((x, y, _), z_inv)| (*x * &z_inv, *y * &z_inv)) + .collect::>() + } + TwistType::D => { + let mut z_s = projective_coeffs + .iter() + .map(|(z, ..)| *z) + .collect::>(); + ark_ff::fields::batch_inversion(&mut z_s); + projective_coeffs + .iter() + .zip(z_s) + .map(|((_, x, y), z_inv)| (*x * &z_inv, *y * &z_inv)) + .collect::>() + } + } + }); + + let l = Vec::new_variable( + ark_relations::ns!(cs, "l"), + || { + g2_prep + .clone() + .map(|c| c.iter().map(|(l, _)| *l).collect::>()) + }, + mode, + )?; + let r = Vec::new_variable( + ark_relations::ns!(cs, "r"), + || g2_prep.map(|c| c.iter().map(|(_, r)| *r).collect::>()), + mode, + )?; + let ell_coeffs = l.into_iter().zip(r).collect(); + Ok(Self { ell_coeffs }) + } +} + +impl ToBytesGadget for G2PreparedVar +where + P: Bls12Config, + ConstraintF: PrimeField, +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut bytes = Vec::new(); + for coeffs in &self.ell_coeffs { + bytes.extend_from_slice(&coeffs.0.to_bytes()?); + bytes.extend_from_slice(&coeffs.1.to_bytes()?); + } + Ok(bytes) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let mut bytes = Vec::new(); + for coeffs in &self.ell_coeffs { + bytes.extend_from_slice(&coeffs.0.to_non_unique_bytes()?); + bytes.extend_from_slice(&coeffs.1.to_non_unique_bytes()?); + } + Ok(bytes) + } +} + +impl G2PreparedVar +where + P: Bls12Config, + ConstraintF: PrimeField, +{ + /// Constructs `Self` from a `G2Var`. + #[tracing::instrument(target = "r1cs")] + pub fn from_group_var(q: &G2Var) -> Result { + let q = q.to_affine()?; + let two_inv = P::Fp::one().double().inverse().unwrap(); + // Enforce that `q` is not the point at infinity. + q.infinity.enforce_not_equal(&Boolean::Constant(true))?; + let mut ell_coeffs = vec![]; + let mut r = q.clone(); + + for i in BitIteratorBE::new(P::X).skip(1) { + ell_coeffs.push(Self::double(&mut r, &two_inv)?); + + if i { + ell_coeffs.push(Self::add(&mut r, &q)?); + } + } + + Ok(Self { ell_coeffs }) + } + + #[tracing::instrument(target = "r1cs")] + fn double( + r: &mut G2AffineVar, + two_inv: &P::Fp, + ) -> Result, SynthesisError> { + let a = r.y.inverse()?; + let mut b = r.x.square()?; + let b_tmp = b.clone(); + b.mul_assign_by_base_field_constant(*two_inv); + b += &b_tmp; + + let c = &a * &b; + let d = r.x.double()?; + let x3 = c.square()? - &d; + let e = &c * &r.x - &r.y; + let c_x3 = &c * &x3; + let y3 = &e - &c_x3; + let mut f = c; + f.negate_in_place()?; + r.x = x3; + r.y = y3; + match P::TWIST_TYPE { + TwistType::M => Ok((e, f)), + TwistType::D => Ok((f, e)), + } + } + + #[tracing::instrument(target = "r1cs")] + fn add( + r: &mut G2AffineVar, + q: &G2AffineVar, + ) -> Result, SynthesisError> { + let a = (&q.x - &r.x).inverse()?; + let b = &q.y - &r.y; + let c = &a * &b; + let d = &r.x + &q.x; + let x3 = c.square()? - &d; + + let e = (&r.x - &x3) * &c; + let y3 = e - &r.y; + let g = &c * &r.x - &r.y; + let mut f = c; + f.negate_in_place()?; + r.x = x3; + r.y = y3; + match P::TWIST_TYPE { + TwistType::M => Ok((g, f)), + TwistType::D => Ok((f, g)), + } + } +} diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs new file mode 100644 index 000000000..f27170ec3 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs @@ -0,0 +1,975 @@ +use ark_ec::{ + short_weierstrass::{ + Affine as SWAffine, Projective as SWProjective, SWCurveConfig as SWModelParameters, + }, + AffineRepr, CurveGroup, +}; +use ark_ff::{BigInteger, BitIteratorBE, Field, One, PrimeField, Zero}; +use ark_r1cs_std::impl_bounded_ops; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul}; +use derivative::Derivative; +use non_zero_affine::NonZeroAffineVar; + +use ark_r1cs_std::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; + +use ark_std::vec::Vec; +use binius_field::PackedField; + +/// This module provides a generic implementation of G1 and G2 for +/// the [\[BLS12]\]() family of bilinear groups. +pub mod bls12; + +/// This module provides a generic implementation of elliptic curve operations +/// for points on short-weierstrass curves in affine coordinates that **are +/// not** equal to zero. +/// +/// Note: this module is **unsafe** in general: it can synthesize unsatisfiable +/// or underconstrained constraint systems when a represented point _is_ equal +/// to zero. The [ProjectiveVar] gadget is the recommended way of working with +/// elliptic curve points. +pub mod non_zero_affine; +/// An implementation of arithmetic for Short Weierstrass curves that relies on +/// the complete formulae derived in the paper of +/// [[Renes, Costello, Batina 2015]](). +#[derive(Derivative)] +#[derivative(Debug, Clone)] +#[must_use] +pub struct ProjectiveVar< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +> where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + /// The x-coordinate. + pub x: F, + /// The y-coordinate. + pub y: F, + /// The z-coordinate. + pub z: F, + #[derivative(Debug = "ignore")] + _params: PhantomData

, + #[derivative(Debug = "ignore")] + _constraint_f: PhantomData, +} + +/// An affine representation of a curve point. +#[derive(Derivative)] +#[derivative(Debug, Clone)] +#[must_use] +pub struct AffineVar< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +> where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + /// The x-coordinate. + pub x: F, + /// The y-coordinate. + pub y: F, + /// Is `self` the point at infinity. + pub infinity: Boolean, + #[derivative(Debug = "ignore")] + _params: PhantomData

, + #[derivative(Debug = "ignore")] + _constraint_f: PhantomData, +} + +impl AffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn new(x: F, y: F, infinity: Boolean) -> Self { + Self { + x, + y, + infinity, + _params: PhantomData, + _constraint_f: PhantomData, + } + } + + /// Returns the value assigned to `self` in the underlying + /// constraint system. + pub fn value(&self) -> Result, SynthesisError> { + Ok(match self.infinity.value()? { + true => SWAffine::identity(), + false => SWAffine::new(self.x.value()?, self.y.value()?), + }) + } +} + +impl ToConstraintFieldGadget for AffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, + F: ToConstraintFieldGadget, +{ + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let mut res = Vec::>::new(); + + res.extend_from_slice(&self.x.to_constraint_field()?); + res.extend_from_slice(&self.y.to_constraint_field()?); + res.extend_from_slice(&self.infinity.to_constraint_field()?); + + Ok(res) + } +} + +impl R1CSVar for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + type Value = SWProjective

; + + fn cs(&self) -> ConstraintSystemRef { + self.x.cs().or(self.y.cs()).or(self.z.cs()) + } + + fn value(&self) -> Result { + let (x, y, z) = (self.x.value()?, self.y.value()?, self.z.value()?); + let result = if let Some(z_inv) = z.inverse() { + SWAffine::new(x * &z_inv, y * &z_inv) + } else { + SWAffine::identity() + }; + Ok(result.into()) + } +} + +impl ProjectiveVar +where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +{ + /// Constructs `Self` from an `(x, y, z)` coordinate triple. + pub fn new(x: F, y: F, z: F) -> Self { + Self { + x, + y, + z, + _params: PhantomData, + _constraint_f: PhantomData, + } + } + + /// Convert this point into affine form. + #[tracing::instrument(target = "r1cs")] + pub fn to_affine(&self) -> Result, SynthesisError> { + if self.is_constant() { + let point = self.value()?.into_affine(); + let x = F::new_constant(ConstraintSystemRef::None, point.x)?; + let y = F::new_constant(ConstraintSystemRef::None, point.y)?; + let infinity = Boolean::constant(point.infinity); + Ok(AffineVar::new(x, y, infinity)) + } else { + let cs = self.cs(); + let infinity = self.is_zero()?; + let zero_x = F::zero(); + let zero_y = F::one(); + // Allocate a variable whose value is either `self.z.inverse()` if the inverse + // exists, and is zero otherwise. + let z_inv = F::new_witness(ark_relations::ns!(cs, "z_inverse"), || { + Ok(self.z.value()?.inverse().unwrap_or_else(P::BaseField::zero)) + })?; + // The inverse exists if `!self.is_zero()`. + // This means that `z_inv * self.z = 1` if `self.is_not_zero()`, and + // `z_inv * self.z = 0` if `self.is_zero()`. + // + // Thus, `z_inv * self.z = !self.is_zero()`. + z_inv.mul_equals(&self.z, &F::from(infinity.not()))?; + + let non_zero_x = &self.x * &z_inv; + let non_zero_y = &self.y * &z_inv; + + let x = infinity.select(&zero_x, &non_zero_x)?; + let y = infinity.select(&zero_y, &non_zero_y)?; + + Ok(AffineVar::new(x, y, infinity)) + } + } + + /// Allocates a new variable without performing an on-curve check, which is + /// useful if the variable is known to be on the curve (eg., if the point + /// is a constant or is a public input). + #[tracing::instrument(target = "r1cs", skip(cs, f))] + pub fn new_variable_omit_on_curve_check( + cs: impl Into>, + f: impl FnOnce() -> Result, SynthesisError>, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + + let (x, y, z) = match f() { + Ok(ge) => { + let ge = ge.into_affine(); + if ge.is_zero() { + ( + Ok(P::BaseField::zero()), + Ok(P::BaseField::one()), + Ok(P::BaseField::zero()), + ) + } else { + (Ok(ge.x), Ok(ge.y), Ok(P::BaseField::one())) + } + } + _ => ( + Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), + ), + }; + + let x = F::new_variable(ark_relations::ns!(cs, "x"), || x, mode)?; + let y = F::new_variable(ark_relations::ns!(cs, "y"), || y, mode)?; + let z = F::new_variable(ark_relations::ns!(cs, "z"), || z, mode)?; + + Ok(Self::new(x, y, z)) + } + + /// Mixed addition, which is useful when `other = (x2, y2)` is known to have + /// z = 1. + #[tracing::instrument(target = "r1cs", skip(self, other))] + pub(crate) fn add_mixed( + &self, + other: &NonZeroAffineVar, + ) -> Result { + // Complete mixed addition formula from Renes-Costello-Batina 2015 + // Algorithm 2 + // (https://eprint.iacr.org/2015/1060). + // Below, comments at the end of a line denote the corresponding + // step(s) of the algorithm + // + // Adapted from code in + // https://github.com/RustCrypto/elliptic-curves/blob/master/p256/src/arithmetic/projective.rs + let three_b = P::COEFF_B.double() + &P::COEFF_B; + let (x1, y1, z1) = (&self.x, &self.y, &self.z); + let (x2, y2) = (&other.x, &other.y); + + let xx = x1 * x2; // 1 + let yy = y1 * y2; // 2 + let xy_pairs = ((x1 + y1) * &(x2 + y2)) - (&xx + &yy); // 4, 5, 6, 7, 8 + let xz_pairs = (x2 * z1) + x1; // 8, 9 + let yz_pairs = (y2 * z1) + y1; // 10, 11 + + let axz = mul_by_coeff_a::(&xz_pairs); // 12 + + let bz3_part = &axz + z1 * three_b; // 13, 14 + + let yy_m_bz3 = &yy - &bz3_part; // 15 + let yy_p_bz3 = &yy + &bz3_part; // 16 + + let azz = mul_by_coeff_a::(z1); // 20 + let xx3_p_azz = xx.double().unwrap() + &xx + &azz; // 18, 19, 22 + + let bxz3 = &xz_pairs * three_b; // 21 + let b3_xz_pairs = mul_by_coeff_a::(&(&xx - &azz)) + &bxz3; // 23, 24, 25 + + let x = (&yy_m_bz3 * &xy_pairs) - &yz_pairs * &b3_xz_pairs; // 28,29, 30 + let y = (&yy_p_bz3 * &yy_m_bz3) + &xx3_p_azz * b3_xz_pairs; // 17, 26, 27 + let z = (&yy_p_bz3 * &yz_pairs) + xy_pairs * xx3_p_azz; // 31, 32, 33 + + Ok(ProjectiveVar::new(x, y, z)) + } + + /// Computes a scalar multiplication with a little-endian scalar of size + /// `P::ScalarField::MODULUS_BITS`. + #[tracing::instrument( + target = "r1cs", + skip(self, mul_result, multiple_of_power_of_two, bits) + )] + fn fixed_scalar_mul_le( + &self, + mul_result: &mut Self, + multiple_of_power_of_two: &mut NonZeroAffineVar, + bits: &[&Boolean], + ) -> Result<(), SynthesisError> { + let scalar_modulus_bits = ::MODULUS_BIT_SIZE as usize; + + assert!(scalar_modulus_bits >= bits.len()); + let split_len = ark_std::cmp::min(scalar_modulus_bits - 2, bits.len()); + let (affine_bits, proj_bits) = bits.split_at(split_len); + // Computes the standard little-endian double-and-add algorithm + // (Algorithm 3.26, Guide to Elliptic Curve Cryptography) + // + // We rely on *incomplete* affine formulae for partially computing this. + // However, we avoid exceptional edge cases because we partition the scalar + // into two chunks: one guaranteed to be less than p - 2, and the rest. + // We only use incomplete formulae for the first chunk, which means we avoid + // exceptions: + // + // `add_unchecked(a, b)` is incomplete when either `b.is_zero()`, or when + // `b = ±a`. During scalar multiplication, we don't hit either case: + // * `b = ±a`: `b = accumulator = k * a`, where `2 <= k < p - 1`. This implies + // that `k != p ± 1`, and so `b != (p ± 1) * a`. Because the group is finite, + // this in turn means that `b != ±a`, as required. + // * `a` or `b` is zero: for `a`, we handle the zero case after the loop; for + // `b`, notice that it is monotonically increasing, and furthermore, equals `k + // * a`, where `k != p = 0 mod p`. + + // Unlike normal double-and-add, here we start off with a non-zero + // `accumulator`, because `NonZeroAffineVar::add_unchecked` doesn't + // support addition with `zero`. In more detail, we initialize + // `accumulator` to be the initial value of `multiple_of_power_of_two`. + // This ensures that all unchecked additions of `accumulator` with later + // values of `multiple_of_power_of_two` are safe. However, to do this + // correctly, we need to perform two steps: + // * We must skip the LSB, and instead proceed assuming that it was 1. Later, we + // will conditionally subtract the initial value of `accumulator`: if LSB == + // 0: subtract initial_acc_value; else, subtract 0. + // * Because we are assuming the first bit, we must double + // `multiple_of_power_of_two`. + + let mut accumulator = multiple_of_power_of_two.clone(); + let initial_acc_value = accumulator.into_projective(); + + // The powers start at 2 (instead of 1) because we're skipping the first bit. + multiple_of_power_of_two.double_in_place()?; + + // As mentioned, we will skip the LSB, and will later handle it via a + // conditional subtraction. + for bit in affine_bits.iter().skip(1) { + if bit.is_constant() { + if *bit == &Boolean::TRUE { + accumulator = accumulator.add_unchecked(&multiple_of_power_of_two)?; + } + } else { + let temp = accumulator.add_unchecked(&multiple_of_power_of_two)?; + accumulator = bit.select(&temp, &accumulator)?; + } + multiple_of_power_of_two.double_in_place()?; + } + // Perform conditional subtraction: + + // We can convert to projective safely because the result is guaranteed to be + // non-zero by the condition on `affine_bits.len()`, and by the fact + // that `accumulator` is non-zero + let result = accumulator.into_projective(); + // If bits[0] is 0, then we have to subtract `self`; else, we subtract zero. + let subtrahend = bits[0].select(&Self::zero(), &initial_acc_value)?; + *mul_result += result - subtrahend; + + // Now, let's finish off the rest of the bits using our complete formulae + for bit in proj_bits { + if bit.is_constant() { + if *bit == &Boolean::TRUE { + *mul_result += &multiple_of_power_of_two.into_projective(); + } + } else { + let temp = &*mul_result + &multiple_of_power_of_two.into_projective(); + *mul_result = bit.select(&temp, &mul_result)?; + } + multiple_of_power_of_two.double_in_place()?; + } + Ok(()) + } +} + +impl CurveVar, ConstraintF> for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn zero() -> Self { + Self::new(F::zero(), F::one(), F::zero()) + } + + fn is_zero(&self) -> Result, SynthesisError> { + self.z.is_zero() + } + + fn constant(g: SWProjective

) -> Self { + let cs = ConstraintSystemRef::None; + Self::new_variable_omit_on_curve_check(cs, || Ok(g), AllocationMode::Constant).unwrap() + } + + #[tracing::instrument(target = "r1cs", skip(cs, f))] + fn new_variable_omit_prime_order_check( + cs: impl Into>, + f: impl FnOnce() -> Result, SynthesisError>, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + // Curve equation in projective form: + // E: Y² * Z = X³ + aX * Z² + bZ³ + // + // This can be re-written as + // E: Y² * Z - bZ³ = X³ + aX * Z² + // E: Z * (Y² - bZ²) = X * (X² + aZ²) + // so, compute X², Y², Z², + // compute temp = X * (X² + aZ²) + // check Z.mul_equals((Y² - bZ²), temp) + // + // A total of 5 multiplications + + let g = Self::new_variable_omit_on_curve_check(cs, f, mode)?; + + if mode != AllocationMode::Constant { + // Perform on-curve check. + let b = P::COEFF_B; + let a = P::COEFF_A; + + let x2 = g.x.square()?; + let y2 = g.y.square()?; + let z2 = g.z.square()?; + let t = &g.x * (x2 + &z2 * a); + + g.z.mul_equals(&(y2 - z2 * b), &t)?; + } + Ok(g) + } + + /// Enforce that `self` is in the prime-order subgroup. + /// + /// Does so by multiplying by the prime order, and checking that the result + /// is unchanged. + // TODO: at the moment this doesn't work, because the addition and doubling + // formulae are incomplete for even-order points. + #[tracing::instrument(target = "r1cs")] + fn enforce_prime_order(&self) -> Result<(), SynthesisError> { + unimplemented!("cannot enforce prime order"); + // let r_minus_1 = (-P::ScalarField::one()).into_bigint(); + + // let mut result = Self::zero(); + // for b in BitIteratorBE::without_leading_zeros(r_minus_1) { + // result.double_in_place()?; + + // if b { + // result += self; + // } + // } + // self.negate()?.enforce_equal(&result)?; + // Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn double_in_place(&mut self) -> Result<(), SynthesisError> { + // Complete doubling formula from Renes-Costello-Batina 2015 + // Algorithm 3 + // (https://eprint.iacr.org/2015/1060). + // Below, comments at the end of a line denote the corresponding + // step(s) of the algorithm + // + // Adapted from code in + // https://github.com/RustCrypto/elliptic-curves/blob/master/p256/src/arithmetic/projective.rs + let three_b = P::COEFF_B.double() + &P::COEFF_B; + + let xx = self.x.square()?; // 1 + let yy = self.y.square()?; // 2 + let zz = self.z.square()?; // 3 + let xy2 = (&self.x * &self.y).double()?; // 4, 5 + let xz2 = (&self.x * &self.z).double()?; // 6, 7 + + let axz2 = mul_by_coeff_a::(&xz2); // 8 + + let bzz3_part = &axz2 + &zz * three_b; // 9, 10 + let yy_m_bzz3 = &yy - &bzz3_part; // 11 + let yy_p_bzz3 = &yy + &bzz3_part; // 12 + let y_frag = yy_p_bzz3 * &yy_m_bzz3; // 13 + let x_frag = yy_m_bzz3 * &xy2; // 14 + + let bxz3 = xz2 * three_b; // 15 + let azz = mul_by_coeff_a::(&zz); // 16 + let b3_xz_pairs = mul_by_coeff_a::(&(&xx - &azz)) + &bxz3; // 15, 16, 17, 18, 19 + let xx3_p_azz = (xx.double()? + &xx + &azz) * &b3_xz_pairs; // 23, 24, 25 + + let y = y_frag + &xx3_p_azz; // 26, 27 + let yz2 = (&self.y * &self.z).double()?; // 28, 29 + let x = x_frag - &(b3_xz_pairs * &yz2); // 30, 31 + let z = (yz2 * &yy).double()?.double()?; // 32, 33, 34 + self.x = x; + self.y = y; + self.z = z; + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn negate(&self) -> Result { + Ok(Self::new(self.x.clone(), self.y.negate()?, self.z.clone())) + } + + /// Computes `bits * self`, where `bits` is a little-endian + /// `Boolean` representation of a scalar. + #[tracing::instrument(target = "r1cs", skip(bits))] + fn scalar_mul_le<'a>( + &self, + bits: impl Iterator>, + ) -> Result { + if self.is_constant() { + if self.value().unwrap().is_zero() { + return Ok(self.clone()); + } + } + let self_affine = self.to_affine()?; + let (x, y, infinity) = (self_affine.x, self_affine.y, self_affine.infinity); + // We first handle the non-zero case, and then later + // will conditionally select zero if `self` was zero. + let non_zero_self = NonZeroAffineVar::new(x, y); + + let mut bits = bits.collect::>(); + if bits.len() == 0 { + return Ok(Self::zero()); + } + // Remove unnecessary constant zeros in the most-significant positions. + bits = bits + .into_iter() + // We iterate from the MSB down. + .rev() + // Skip leading zeros, if they are constants. + .skip_while(|b| b.is_constant() && (b.value().unwrap() == false)) + .collect(); + // After collecting we are in big-endian form; we have to reverse to get back to + // little-endian. + bits.reverse(); + + let scalar_modulus_bits = ::MODULUS_BIT_SIZE; + let mut mul_result = Self::zero(); + let mut power_of_two_times_self = non_zero_self; + // We chunk up `bits` into `p`-sized chunks. + for bits in bits.chunks(scalar_modulus_bits as usize) { + self.fixed_scalar_mul_le(&mut mul_result, &mut power_of_two_times_self, bits)?; + } + + // The foregoing algorithm relies on incomplete addition, and so does not + // work when the input (`self`) is zero. We hence have to perform + // a check to ensure that if the input is zero, then so is the output. + // The cost of this check should be less than the benefit of using + // mixed addition in almost all cases. + infinity.select(&Self::zero(), &mul_result) + } + + #[tracing::instrument(target = "r1cs", skip(scalar_bits_with_bases))] + fn precomputed_base_scalar_mul_le<'a, I, B>( + &mut self, + scalar_bits_with_bases: I, + ) -> Result<(), SynthesisError> + where + I: Iterator)>, + B: Borrow>, + { + // We just ignore the provided bases and use the faster scalar multiplication. + let (bits, bases): (Vec<_>, Vec<_>) = scalar_bits_with_bases + .map(|(b, c)| (b.borrow().clone(), *c)) + .unzip(); + let base = bases[0]; + *self = Self::constant(base).scalar_mul_le(bits.iter())?; + Ok(()) + } +} + +impl ToConstraintFieldGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, + F: ToConstraintFieldGadget, +{ + fn to_constraint_field(&self) -> Result>, SynthesisError> { + self.to_affine()?.to_constraint_field() + } +} + +fn mul_by_coeff_a< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +>( + f: &F, +) -> F +where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + if !P::COEFF_A.is_zero() { + f * P::COEFF_A + } else { + F::zero() + } +} + +impl_bounded_ops!( + ProjectiveVar, + SWProjective

, + Add, + add, + AddAssign, + add_assign, + |mut this: &'a ProjectiveVar, mut other: &'a ProjectiveVar| { + // Implement complete addition for Short Weierstrass curves, following + // the complete addition formula from Renes-Costello-Batina 2015 + // (https://eprint.iacr.org/2015/1060). + // + // We special case handling of constants to get better constraint weight. + if this.is_constant() { + // we'll just act like `other` is constant. + core::mem::swap(&mut this, &mut other); + } + + if other.is_constant() { + // The value should exist because `other` is a constant. + let other = other.value().unwrap(); + if other.is_zero() { + // this + 0 = this + this.clone() + } else { + // We'll use mixed addition to add non-zero constants. + let x = F::constant(other.x); + let y = F::constant(other.y); + this.add_mixed(&NonZeroAffineVar::new(x, y)).unwrap() + } + } else { + // Complete addition formula from Renes-Costello-Batina 2015 + // Algorithm 1 + // (https://eprint.iacr.org/2015/1060). + // Below, comments at the end of a line denote the corresponding + // step(s) of the algorithm + // + // Adapted from code in + // https://github.com/RustCrypto/elliptic-curves/blob/master/p256/src/arithmetic/projective.rs + let three_b = P::COEFF_B.double() + &P::COEFF_B; + let (x1, y1, z1) = (&this.x, &this.y, &this.z); + let (x2, y2, z2) = (&other.x, &other.y, &other.z); + + let xx = x1 * x2; // 1 + let yy = y1 * y2; // 2 + let zz = z1 * z2; // 3 + let xy_pairs = ((x1 + y1) * &(x2 + y2)) - (&xx + &yy); // 4, 5, 6, 7, 8 + let xz_pairs = ((x1 + z1) * &(x2 + z2)) - (&xx + &zz); // 9, 10, 11, 12, 13 + let yz_pairs = ((y1 + z1) * &(y2 + z2)) - (&yy + &zz); // 14, 15, 16, 17, 18 + + let axz = mul_by_coeff_a::(&xz_pairs); // 19 + + let bzz3_part = &axz + &zz * three_b; // 20, 21 + + let yy_m_bzz3 = &yy - &bzz3_part; // 22 + let yy_p_bzz3 = &yy + &bzz3_part; // 23 + + let azz = mul_by_coeff_a::(&zz); + let xx3_p_azz = xx.double().unwrap() + &xx + &azz; // 25, 26, 27, 29 + + let bxz3 = &xz_pairs * three_b; // 28 + let b3_xz_pairs = mul_by_coeff_a::(&(&xx - &azz)) + &bxz3; // 30, 31, 32 + + let x = (&yy_m_bzz3 * &xy_pairs) - &yz_pairs * &b3_xz_pairs; // 35, 39, 40 + let y = (&yy_p_bzz3 * &yy_m_bzz3) + &xx3_p_azz * b3_xz_pairs; // 24, 36, 37, 38 + let z = (&yy_p_bzz3 * &yz_pairs) + xy_pairs * xx3_p_azz; // 41, 42, 43 + + ProjectiveVar::new(x, y, z) + } + + }, + |this: &'a ProjectiveVar, other: SWProjective

| { + this + ProjectiveVar::constant(other) + }, + (ConstraintF: PrimeField, F: FieldVar, P: SWModelParameters), + for <'b> &'b F: FieldOpsBounds<'b, P::BaseField, F>, +); + +impl_bounded_ops!( + ProjectiveVar, + SWProjective

, + Sub, + sub, + SubAssign, + sub_assign, + |this: &'a ProjectiveVar, other: &'a ProjectiveVar| this + other.negate().unwrap(), + |this: &'a ProjectiveVar, other: SWProjective

| this - ProjectiveVar::constant(other), + (ConstraintF: PrimeField, F: FieldVar, P: SWModelParameters), + for <'b> &'b F: FieldOpsBounds<'b, P::BaseField, F> +); + +impl<'a, P, ConstraintF, F> GroupOpsBounds<'a, SWProjective

, ProjectiveVar> + for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'b> &'b F: FieldOpsBounds<'b, P::BaseField, F>, +{ +} + +impl<'a, P, ConstraintF, F> GroupOpsBounds<'a, SWProjective

, ProjectiveVar> + for &'a ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'b> &'b F: FieldOpsBounds<'b, P::BaseField, F>, +{ +} + +impl CondSelectGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let x = cond.select(&true_value.x, &false_value.x)?; + let y = cond.select(&true_value.y, &false_value.y)?; + let z = cond.select(&true_value.z, &false_value.z)?; + + Ok(Self::new(x, y, z)) + } +} + +impl EqGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; + let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; + let coordinates_equal = x_equal.and(&y_equal)?; + let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; + both_are_zero.or(&coordinates_equal) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; + let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; + let coordinates_equal = x_equal.and(&y_equal)?; + let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; + both_are_zero + .or(&coordinates_equal)? + .conditional_enforce_equal(&Boolean::Constant(true), condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +impl AllocVar, ConstraintF> for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + Self::new_variable( + cs, + || f().map(|b| SWProjective::from((*b.borrow()).clone())), + mode, + ) + } +} + +impl AllocVar, ConstraintF> for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let f = || Ok(*f()?.borrow()); + match mode { + AllocationMode::Constant => Self::new_variable_omit_prime_order_check(cs, f, mode), + AllocationMode::Input => Self::new_variable_omit_prime_order_check(cs, f, mode), + AllocationMode::Witness => { + // if cofactor.is_even(): + // divide until you've removed all even factors + // else: + // just directly use double and add. + let mut power_of_2: u32 = 0; + let mut cofactor = P::COFACTOR.to_vec(); + while cofactor[0] % 2 == 0 { + div2(&mut cofactor); + power_of_2 += 1; + } + + let cofactor_weight = BitIteratorBE::new(cofactor.as_slice()) + .filter(|b| *b) + .count(); + let modulus_minus_1 = (-P::ScalarField::one()).into_bigint(); // r - 1 + let modulus_minus_1_weight = + BitIteratorBE::new(modulus_minus_1).filter(|b| *b).count(); + + // We pick the most efficient method of performing the prime order check: + // If the cofactor has lower hamming weight than the scalar field's modulus, + // we first multiply by the inverse of the cofactor, and then, after allocating, + // multiply by the cofactor. This ensures the resulting point has no cofactors + // + // Else, we multiply by the scalar field's modulus and ensure that the result + // equals the identity. + + let (mut ge, iter) = if cofactor_weight < modulus_minus_1_weight { + let ge = Self::new_variable_omit_prime_order_check( + ark_relations::ns!(cs, "Witness without subgroup check with cofactor mul"), + || f().map(|g| g.borrow().into_affine().mul_by_cofactor_inv().into()), + mode, + )?; + ( + ge, + BitIteratorBE::without_leading_zeros(cofactor.as_slice()), + ) + } else { + let ge = Self::new_variable_omit_prime_order_check( + ark_relations::ns!(cs, "Witness without subgroup check with `r` check"), + || { + f().map(|g| { + let g = g.into_affine(); + let mut power_of_two = P::ScalarField::one().into_bigint(); + power_of_two.muln(power_of_2); + let power_of_two_inv = P::ScalarField::from_bigint(power_of_two) + .and_then(|n| n.inverse()) + .unwrap(); + g.mul(power_of_two_inv) + }) + }, + mode, + )?; + + ( + ge, + BitIteratorBE::without_leading_zeros(modulus_minus_1.as_ref()), + ) + }; + // Remove the even part of the cofactor + for _ in 0..power_of_2 { + ge.double_in_place()?; + } + + let mut result = Self::zero(); + for b in iter { + result.double_in_place()?; + + if b { + result += &ge + } + } + if cofactor_weight < modulus_minus_1_weight { + Ok(result) + } else { + ge.enforce_equal(&ge)?; + Ok(ge) + } + } + } + } +} + +#[inline] +fn div2(limbs: &mut [u64]) { + let mut t = 0; + for i in limbs.iter_mut().rev() { + let t2 = *i << 63; + *i >>= 1; + *i |= t; + t = t2; + } +} + +impl ToBitsGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bits = g.x.to_bits_le()?; + let y_bits = g.y.to_bits_le()?; + bits.extend_from_slice(&y_bits); + bits.push(g.infinity); + Ok(bits) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bits = g.x.to_non_unique_bits_le()?; + let y_bits = g.y.to_non_unique_bits_le()?; + bits.extend_from_slice(&y_bits); + bits.push(g.infinity); + Ok(bits) + } +} + +impl ToBytesGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bytes = g.x.to_bytes()?; + let y_bytes = g.y.to_bytes()?; + let inf_bytes = g.infinity.to_bytes()?; + bytes.extend_from_slice(&y_bytes); + bytes.extend_from_slice(&inf_bytes); + Ok(bytes) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bytes = g.x.to_non_unique_bytes()?; + let y_bytes = g.y.to_non_unique_bytes()?; + let inf_bytes = g.infinity.to_non_unique_bytes()?; + bytes.extend_from_slice(&y_bytes); + bytes.extend_from_slice(&inf_bytes); + Ok(bytes) + } +} diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/non_zero_affine.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/non_zero_affine.rs new file mode 100644 index 000000000..120760fb2 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/non_zero_affine.rs @@ -0,0 +1,399 @@ +use super::*; +use ark_ec::Group; +use ark_std::ops::Add; +use derivative::Derivative; + +/// An affine representation of a prime order curve point that is guaranteed +/// to *not* be the point at infinity. +#[derive(Derivative)] +#[derivative(Debug, Clone)] +#[must_use] +pub struct NonZeroAffineVar< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +> where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + /// The x-coordinate. + pub x: F, + /// The y-coordinate. + pub y: F, + #[derivative(Debug = "ignore")] + _params: PhantomData

, + #[derivative(Debug = "ignore")] + _constraint_f: PhantomData, +} + +impl NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + pub fn new(x: F, y: F) -> Self { + Self { + x, + y, + _params: PhantomData, + _constraint_f: PhantomData, + } + } + + /// Converts self into a non-zero projective point. + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn into_projective(&self) -> ProjectiveVar { + ProjectiveVar::new(self.x.clone(), self.y.clone(), F::one()) + } + + /// Performs an addition without checking that other != ±self. + #[tracing::instrument(target = "r1cs", skip(self, other))] + pub fn add_unchecked(&self, other: &Self) -> Result { + if [self, other].is_constant() { + let result = self.value()?.add(other.value()?).into_affine(); + Ok(Self::new(F::constant(result.x), F::constant(result.y))) + } else { + let (x1, y1) = (&self.x, &self.y); + let (x2, y2) = (&other.x, &other.y); + // Then, + // slope lambda := (y2 - y1)/(x2 - x1); + // x3 = lambda^2 - x1 - x2; + // y3 = lambda * (x1 - x3) - y1 + let numerator = y2 - y1; + let denominator = x2 - x1; + // It's okay to use `unchecked` here, because the precondition of + // `add_unchecked` is that self != ±other, which means that + // `numerator` and `denominator` are both non-zero. + let lambda = numerator.mul_by_inverse_unchecked(&denominator)?; + let x3 = lambda.square()? - x1 - x2; + let y3 = lambda * &(x1 - &x3) - y1; + Ok(Self::new(x3, y3)) + } + } + + /// Doubles `self`. As this is a prime order curve point, + /// the output is guaranteed to not be the point at infinity. + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn double(&self) -> Result { + if [self].is_constant() { + let result = SWProjective::

::from(self.value()?) + .double() + .into_affine(); + // Panic if the result is zero. + assert!(!result.is_zero()); + Ok(Self::new(F::constant(result.x), F::constant(result.y))) + } else { + let (x1, y1) = (&self.x, &self.y); + let x1_sqr = x1.square()?; + // Then, + // tangent lambda := (3 * x1^2 + a) / (2 * y1); + // x3 = lambda^2 - 2x1 + // y3 = lambda * (x1 - x3) - y1 + let numerator = x1_sqr.double()? + &x1_sqr + P::COEFF_A; + let denominator = y1.double()?; + // It's okay to use `unchecked` here, because the precondition of `double` is + // that self != zero. + let lambda = numerator.mul_by_inverse_unchecked(&denominator)?; + let x3 = lambda.square()? - x1.double()?; + let y3 = lambda * &(x1 - &x3) - y1; + Ok(Self::new(x3, y3)) + } + } + + /// Computes `(self + other) + self`. This method requires only 5 + /// constraints, less than the 7 required when computing via + /// `self.double() + other`. + /// + /// This follows the formulae from [\[ELM03\]](https://arxiv.org/abs/math/0208038). + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn double_and_add_unchecked(&self, other: &Self) -> Result { + if [self].is_constant() || other.is_constant() { + self.double()?.add_unchecked(other) + } else { + // It's okay to use `unchecked` the precondition is that `self != ±other` (i.e. + // same logic as in `add_unchecked`) + let (x1, y1) = (&self.x, &self.y); + let (x2, y2) = (&other.x, &other.y); + + // Calculate self + other: + // slope lambda := (y2 - y1)/(x2 - x1); + // x3 = lambda^2 - x1 - x2; + // y3 = lambda * (x1 - x3) - y1 + let numerator = y2 - y1; + let denominator = x2 - x1; + let lambda_1 = numerator.mul_by_inverse_unchecked(&denominator)?; + + let x3 = lambda_1.square()? - x1 - x2; + + // Calculate final addition slope: + let lambda_2 = + (lambda_1 + y1.double()?.mul_by_inverse_unchecked(&(&x3 - x1))?).negate()?; + + let x4 = lambda_2.square()? - x1 - x3; + let y4 = lambda_2 * &(x1 - &x4) - y1; + Ok(Self::new(x4, y4)) + } + } + + /// Doubles `self` in place. + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn double_in_place(&mut self) -> Result<(), SynthesisError> { + *self = self.double()?; + Ok(()) + } +} + +impl R1CSVar for NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + type Value = SWAffine

; + + fn cs(&self) -> ConstraintSystemRef { + self.x.cs().or(self.y.cs()) + } + + fn value(&self) -> Result, SynthesisError> { + Ok(SWAffine::new(self.x.value()?, self.y.value()?)) + } +} + +impl CondSelectGadget for NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let x = cond.select(&true_value.x, &false_value.x)?; + let y = cond.select(&true_value.y, &false_value.y)?; + + Ok(Self::new(x, y)) + } +} + +impl EqGadget for NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let x_equal = self.x.is_eq(&other.x)?; + let y_equal = self.y.is_eq(&other.y)?; + x_equal.and(&y_equal) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let x_equal = self.x.is_eq(&other.x)?; + let y_equal = self.y.is_eq(&other.y)?; + let coordinates_equal = x_equal.and(&y_equal)?; + coordinates_equal.conditional_enforce_equal(&Boolean::Constant(true), condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> { + self.x.enforce_equal(&other.x)?; + self.y.enforce_equal(&other.y)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +#[cfg(test)] +mod test_non_zero_affine { + use ark_ec::{models::short_weierstrass::SWCurveConfig, CurveGroup}; + use ark_r1cs_std::{ + alloc::AllocVar, + eq::EqGadget, + fields::fp::{AllocatedFp, FpVar}, + groups::{ + curves::short_weierstrass::{non_zero_affine::NonZeroAffineVar, ProjectiveVar}, + CurveVar, + }, + R1CSVar, + }; + use ark_relations::r1cs::ConstraintSystem; + use ark_std::{vec::Vec, One}; + use ark_test_curves::bls12_381::{g1::Config as G1Config, Fq}; + + #[test] + fn correctness_test_1() { + let cs = ConstraintSystem::::new_ref(); + + let x = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.x)).unwrap(), + ); + let y = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.y)).unwrap(), + ); + + // The following code uses `double` and `add` (`add_unchecked`) to compute + // (1 + 2 + ... + 2^9) G + + let sum_a = { + let mut a = ProjectiveVar::>::new( + x.clone(), + y.clone(), + FpVar::Constant(Fq::one()), + ); + + let mut double_sequence = Vec::new(); + double_sequence.push(a.clone()); + + for _ in 1..10 { + a = a.double().unwrap(); + double_sequence.push(a.clone()); + } + + let mut sum = double_sequence[0].clone(); + for elem in double_sequence.iter().skip(1) { + sum = sum + elem; + } + + let sum = sum.value().unwrap().into_affine(); + (sum.x, sum.y) + }; + + let sum_b = { + let mut a = NonZeroAffineVar::>::new(x, y); + + let mut double_sequence = Vec::new(); + double_sequence.push(a.clone()); + + for _ in 1..10 { + a = a.double().unwrap(); + double_sequence.push(a.clone()); + } + + let mut sum = double_sequence[0].clone(); + for elem in double_sequence.iter().skip(1) { + sum = sum.add_unchecked(&elem).unwrap(); + } + + (sum.x.value().unwrap(), sum.y.value().unwrap()) + }; + + assert_eq!(sum_a.0, sum_b.0); + assert_eq!(sum_a.1, sum_b.1); + } + + #[test] + fn correctness_test_2() { + let cs = ConstraintSystem::::new_ref(); + + let x = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.x)).unwrap(), + ); + let y = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.y)).unwrap(), + ); + + // The following code tests `double_and_add`. + let sum_a = { + let a = ProjectiveVar::>::new( + x.clone(), + y.clone(), + FpVar::Constant(Fq::one()), + ); + + let mut cur = a.clone(); + cur.double_in_place().unwrap(); + for _ in 1..10 { + cur.double_in_place().unwrap(); + cur = cur + &a; + } + + let sum = cur.value().unwrap().into_affine(); + (sum.x, sum.y) + }; + + let sum_b = { + let a = NonZeroAffineVar::>::new(x, y); + + let mut cur = a.double().unwrap(); + for _ in 1..10 { + cur = cur.double_and_add_unchecked(&a).unwrap(); + } + + (cur.x.value().unwrap(), cur.y.value().unwrap()) + }; + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(sum_a.0, sum_b.0); + assert_eq!(sum_a.1, sum_b.1); + } + + #[test] + fn correctness_test_eq() { + let cs = ConstraintSystem::::new_ref(); + + let x = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.x)).unwrap(), + ); + let y = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.y)).unwrap(), + ); + + let a = NonZeroAffineVar::>::new(x, y); + + let n = 10; + + let a_multiples: Vec>> = + std::iter::successors(Some(a.clone()), |acc| Some(acc.add_unchecked(&a).unwrap())) + .take(n) + .collect(); + + let all_equal: Vec>> = (0..n / 2) + .map(|i| { + a_multiples[i] + .add_unchecked(&a_multiples[n - i - 1]) + .unwrap() + }) + .collect(); + + for i in 0..n - 1 { + a_multiples[i] + .enforce_not_equal(&a_multiples[i + 1]) + .unwrap(); + } + for i in 0..all_equal.len() - 1 { + all_equal[i].enforce_equal(&all_equal[i + 1]).unwrap(); + } + + assert!(cs.is_satisfied().unwrap()); + } +} diff --git a/jolt-core/src/circuits/groups/mod.rs b/jolt-core/src/circuits/groups/mod.rs new file mode 100644 index 000000000..26b097205 --- /dev/null +++ b/jolt-core/src/circuits/groups/mod.rs @@ -0,0 +1 @@ +pub mod curves; diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 0d8caf690..b933bbbda 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -1,2 +1,4 @@ -mod pairing; +pub mod fields; +pub mod groups; +pub mod pairing; pub mod poly; diff --git a/jolt-core/src/circuits/pairing/bls12/mod.rs b/jolt-core/src/circuits/pairing/bls12/mod.rs new file mode 100644 index 000000000..0b3b51c97 --- /dev/null +++ b/jolt-core/src/circuits/pairing/bls12/mod.rs @@ -0,0 +1,178 @@ +use super::PairingGadget; +use crate::circuits::fields::fp2::Fp2Var; +use crate::circuits::groups::curves::short_weierstrass::bls12::{ + G1AffineVar, G1PreparedVar, G1Var, G2PreparedVar, G2Var, +}; +use ark_ec::bls12::{Bls12, Bls12Config, TwistType}; +use ark_ff::{BitIteratorBE, PrimeField}; +use ark_r1cs_std::fields::fp12::Fp12Var; +// use crate::circuits::fields::fp12::Fp12Var; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; +use ark_r1cs_std::prelude::FieldVar; +use ark_relations::r1cs::SynthesisError; +use ark_std::marker::PhantomData; + +/// Specifies the constraints for computing a pairing in a BLS12 bilinear group. +pub struct PairingVar(PhantomData<(P, ConstraintF)>) +where + P: Bls12Config, + ConstraintF: PrimeField; + +type Fp2V = Fp2Var<

::Fp2Config, ConstraintF>; + +impl PairingVar +where + P: Bls12Config, + ConstraintF: PrimeField, +{ + // Evaluate the line function at point p. + #[tracing::instrument(target = "r1cs")] + fn ell( + f: &mut Fp12Var, + coeffs: &(Fp2V, Fp2V), + p: &G1AffineVar, + ) -> Result<(), SynthesisError> { + let zero = NonNativeFieldVar::::zero(); + + match P::TWIST_TYPE { + TwistType::M => { + let c0 = coeffs.0.clone(); + let mut c1 = coeffs.1.clone(); + let c2 = Fp2V::::new(p.y.clone(), zero); + + c1.c0 *= &p.x; + c1.c1 *= &p.x; + *f = f.mul_by_014(&c0, &c1, &c2)?; + Ok(()) + } + TwistType::D => { + let c0 = Fp2V::::new(p.y.clone(), zero); + let mut c1 = coeffs.0.clone(); + let c2 = coeffs.1.clone(); + + c1.c0 *= &p.x; + c1.c1 *= &p.x; + *f = f.mul_by_034(&c0, &c1, &c2)?; + Ok(()) + } + } + } + + #[tracing::instrument(target = "r1cs")] + fn exp_by_x(f: &Fp12Var) -> Result, SynthesisError> { + let mut result = f.optimized_cyclotomic_exp(P::X)?; + if P::X_IS_NEGATIVE { + result = result.unitary_inverse()?; + } + Ok(result) + } +} + +impl PairingGadget, ConstraintF> + for PairingVar +{ + type G1Var = G1Var; + type G2Var = G2Var; + type GTVar = Fp12Var; + type G1PreparedVar = G1PreparedVar; + type G2PreparedVar = G2PreparedVar; + + #[tracing::instrument(target = "r1cs")] + fn miller_loop( + ps: &[Self::G1PreparedVar], + qs: &[Self::G2PreparedVar], + ) -> Result { + let mut pairs = vec![]; + for (p, q) in ps.iter().zip(qs.iter()) { + pairs.push((p, q.ell_coeffs.iter())); + } + let mut f = Self::GTVar::one(); + + for i in BitIteratorBE::new(P::X).skip(1) { + f.square_in_place()?; + + for &mut (p, ref mut coeffs) in pairs.iter_mut() { + Self::ell(&mut f, coeffs.next().unwrap(), &p.0)?; + } + + if i { + for &mut (p, ref mut coeffs) in pairs.iter_mut() { + Self::ell(&mut f, &coeffs.next().unwrap(), &p.0)?; + } + } + } + + if P::X_IS_NEGATIVE { + f = f.unitary_inverse()?; + } + + Ok(f) + } + + #[tracing::instrument(target = "r1cs")] + fn final_exponentiation(f: &Self::GTVar) -> Result { + // Computing the final exponentation following + // https://eprint.iacr.org/2016/130.pdf. + // We don't use their "faster" formula because it is difficult to make + // it work for curves with odd `P::X`. + // Hence we implement the slower algorithm from Table 1 below. + + let f1 = f.unitary_inverse()?; + + f.inverse().and_then(|mut f2| { + // f2 = f^(-1); + // r = f^(p^6 - 1) + let mut r = f1; + r *= &f2; + + // f2 = f^(p^6 - 1) + f2 = r.clone(); + // r = f^((p^6 - 1)(p^2)) + r.frobenius_map_in_place(2)?; + + // r = f^((p^6 - 1)(p^2) + (p^6 - 1)) + // r = f^((p^6 - 1)(p^2 + 1)) + r *= &f2; + + // Hard part of the final exponentation is below: + // From https://eprint.iacr.org/2016/130.pdf, Table 1 + let mut y0 = r.cyclotomic_square()?; + y0 = y0.unitary_inverse()?; + + let mut y5 = Self::exp_by_x(&r)?; + + let mut y1 = y5.cyclotomic_square()?; + let mut y3 = y0 * &y5; + y0 = Self::exp_by_x(&y3)?; + let y2 = Self::exp_by_x(&y0)?; + let mut y4 = Self::exp_by_x(&y2)?; + y4 *= &y1; + y1 = Self::exp_by_x(&y4)?; + y3 = y3.unitary_inverse()?; + y1 *= &y3; + y1 *= &r; + y3 = r.clone(); + y3 = y3.unitary_inverse()?; + y0 *= &r; + y0.frobenius_map_in_place(3)?; + y4 *= &y3; + y4.frobenius_map_in_place(1)?; + y5 *= &y2; + y5.frobenius_map_in_place(2)?; + y5 *= &y0; + y5 *= &y4; + y5 *= &y1; + Ok(y5) + }) + } + + #[tracing::instrument(target = "r1cs")] + fn prepare_g1(p: &Self::G1Var) -> Result { + Self::G1PreparedVar::from_group_var(p) + } + + #[tracing::instrument(target = "r1cs")] + fn prepare_g2(q: &Self::G2Var) -> Result { + Self::G2PreparedVar::from_group_var(q) + } +} diff --git a/jolt-core/src/circuits/pairing/mod.rs b/jolt-core/src/circuits/pairing/mod.rs index d24727022..d5f5ba30a 100644 --- a/jolt-core/src/circuits/pairing/mod.rs +++ b/jolt-core/src/circuits/pairing/mod.rs @@ -1,3 +1,5 @@ +pub mod bls12; + use ark_ec::pairing::Pairing; use ark_ff::PrimeField; use ark_r1cs_std::prelude::*; @@ -71,3 +73,71 @@ pub trait PairingGadget { /// Performs the precomputation to generate `Self::G2PreparedVar`. fn prepare_g2(q: &Self::G2Var) -> Result; } + +#[cfg(test)] +mod tests { + use ark_bls12_381::Bls12_381; + use ark_bn254::Bn254; + use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; + use ark_ec::pairing::Pairing; + use ark_ec::Group; + use ark_ff::{Field, PrimeField}; + use ark_groth16::Groth16; + use ark_r1cs_std::pairing::bls12::PairingVar; + use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; + use ark_std::marker::PhantomData; + use ark_std::rand::Rng; + use ark_std::test_rng; + use rand_core::{RngCore, SeedableRng}; + + struct PairingCheckCircuit { + _constraint_f: PhantomData, + r: Option, + r_g2: Option, + } + + impl ConstraintSynthesizer + for PairingCheckCircuit + { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + // TODO use PairingVar to generate constraints + // let r_g1 = PairingVar::::new_input(cs.clone(), || { + // Ok(E::G1Projective::prime_subgroup_generator()) + // })?; + Ok(()) + } + } + + #[test] + fn test_pairing_check_circuit() { + let c = PairingCheckCircuit:: { + _constraint_f: PhantomData, + r: None, + r_g2: None, + }; + + // This is not cryptographically safe, use + // `OsRng` (for example) in production software. + let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(test_rng().next_u64()); + + let (pk, vk) = Groth16::::setup(c, &mut rng).unwrap(); + + let pvk = Groth16::::process_vk(&vk).unwrap(); + + let r = rng.gen(); + let r_g2 = ::G2::generator() * &r; + + let c = PairingCheckCircuit:: { + _constraint_f: PhantomData, + r: Some(r), + r_g2: Some(r_g2.into()), + }; + + let proof = Groth16::::prove(&pk, c, &mut rng).unwrap(); + + assert!(Groth16::::verify_with_processed_vk(&pvk, &[], &proof).unwrap()); + } +} From 2143514a8b9c3e3fa1b5c1c76f8524ba19eaada5 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 28 Jul 2024 17:01:23 -0700 Subject: [PATCH 04/44] WIP: PairingGadget: it compiles! --- .../src/circuits/fields/cubic_extension.rs | 604 ++++++++++++++++++ jolt-core/src/circuits/fields/fp12.rs | 189 ++++++ jolt-core/src/circuits/fields/fp6_3over2.rs | 105 +++ jolt-core/src/circuits/fields/mod.rs | 4 +- .../circuits/fields/quadratic_extension.rs | 5 +- jolt-core/src/circuits/pairing/bls12/mod.rs | 12 +- jolt-core/src/circuits/pairing/mod.rs | 8 +- 7 files changed, 915 insertions(+), 12 deletions(-) create mode 100644 jolt-core/src/circuits/fields/cubic_extension.rs create mode 100644 jolt-core/src/circuits/fields/fp12.rs create mode 100644 jolt-core/src/circuits/fields/fp6_3over2.rs diff --git a/jolt-core/src/circuits/fields/cubic_extension.rs b/jolt-core/src/circuits/fields/cubic_extension.rs new file mode 100644 index 000000000..c66572c5e --- /dev/null +++ b/jolt-core/src/circuits/fields/cubic_extension.rs @@ -0,0 +1,604 @@ +use ark_ff::{ + fields::{CubicExtField, Field}, + CubicExtConfig, PrimeField, Zero, +}; +use ark_r1cs_std::{ + fields::{fp::FpVar, FieldOpsBounds, FieldVar}, + impl_bounded_ops, + prelude::*, + ToConstraintFieldGadget, +}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::vec::Vec; +use core::{borrow::Borrow, marker::PhantomData}; +use derivative::Derivative; + +/// This struct is the `R1CS` equivalent of the cubic extension field type +/// in `ark-ff`, i.e. `ark_ff::CubicExtField`. +#[derive(Derivative)] +#[derivative(Debug(bound = "BF: core::fmt::Debug"), Clone(bound = "BF: Clone"))] +#[must_use] +pub struct CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// The zero-th coefficient of this field element. + pub c0: BF, + /// The first coefficient of this field element. + pub c1: BF, + /// The second coefficient of this field element. + pub c2: BF, + #[derivative(Debug = "ignore")] + _params: PhantomData<(P, ConstraintF)>, +} + +/// This trait describes parameters that are used to implement arithmetic for +/// `CubicExtVar`. +pub trait CubicExtVarConfig: CubicExtConfig +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'a> &'a BF: FieldOpsBounds<'a, Self::BaseField, BF>, +{ + /// Multiply the base field of the `CubicExtVar` by the appropriate + /// Frobenius coefficient. This is equivalent to + /// `Self::mul_base_field_by_frob_coeff(c1, c2, power)`. + fn mul_base_field_vars_by_frob_coeff(c1: &mut BF, c2: &mut BF, power: usize); +} + +impl CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// Constructs a `CubicExtVar` from the underlying coefficients. + #[inline] + pub fn new(c0: BF, c1: BF, c2: BF) -> Self { + let _params = PhantomData; + Self { + c0, + c1, + c2, + _params, + } + } + + /// Multiplies a variable of the base field by the cubic nonresidue + /// `P::NONRESIDUE` that is used to construct the extension field. + #[inline] + pub fn mul_base_field_by_nonresidue(fe: &BF) -> Result { + Ok(fe * P::NONRESIDUE) + } + + /// Multiplies `self` by a constant from the base field. + #[inline] + pub fn mul_by_base_field_constant(&self, fe: P::BaseField) -> Self { + let c0 = &self.c0 * fe; + let c1 = &self.c1 * fe; + let c2 = &self.c2 * fe; + Self::new(c0, c1, c2) + } + + /// Sets `self = self.mul_by_base_field_constant(fe)`. + #[inline] + pub fn mul_assign_by_base_field_constant(&mut self, fe: P::BaseField) { + *self = (&*self).mul_by_base_field_constant(fe); + } +} + +impl R1CSVar for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, + P: CubicExtVarConfig, +{ + type Value = CubicExtField

; + + fn cs(&self) -> ConstraintSystemRef { + [&self.c0, &self.c1, &self.c2].cs() + } + + #[inline] + fn value(&self) -> Result { + match (self.c0.value(), self.c1.value(), self.c2.value()) { + (Ok(c0), Ok(c1), Ok(c2)) => Ok(CubicExtField::new(c0, c1, c2)), + (..) => Err(SynthesisError::AssignmentMissing), + } + } +} + +impl From> for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn from(other: Boolean) -> Self { + let c0 = BF::from(other); + let c1 = BF::zero(); + let c2 = BF::zero(); + Self::new(c0, c1, c2) + } +} + +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, CubicExtField

, CubicExtVar> + for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ +} +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, CubicExtField

, CubicExtVar> + for &'a CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ +} + +impl FieldVar, ConstraintF> for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn zero() -> Self { + let c0 = BF::zero(); + let c1 = BF::zero(); + let c2 = BF::zero(); + Self::new(c0, c1, c2) + } + + fn one() -> Self { + let c0 = BF::one(); + let c1 = BF::zero(); + let c2 = BF::zero(); + Self::new(c0, c1, c2) + } + + fn constant(other: CubicExtField

) -> Self { + let c0 = BF::constant(other.c0); + let c1 = BF::constant(other.c1); + let c2 = BF::constant(other.c2); + Self::new(c0, c1, c2) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn double(&self) -> Result { + let c0 = self.c0.double()?; + let c1 = self.c1.double()?; + let c2 = self.c2.double()?; + Ok(Self::new(c0, c1, c2)) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn negate(&self) -> Result { + let mut result = self.clone(); + result.c0.negate_in_place()?; + result.c1.negate_in_place()?; + result.c2.negate_in_place()?; + Ok(result) + } + + /// Use the Chung-Hasan asymmetric squaring formula. + /// + /// (Devegili OhEig Scott Dahab --- Multiplication and Squaring on + /// Abstract Pairing-Friendly + /// Fields.pdf; Section 4 (CH-SQR2)) + #[inline] + #[tracing::instrument(target = "r1cs")] + fn square(&self) -> Result { + let a = self.c0.clone(); + let b = self.c1.clone(); + let c = self.c2.clone(); + + let s0 = a.square()?; + let ab = &a * &b; + let s1 = ab.double()?; + let s2 = (&a - &b + &c).square()?; + let s3 = (&b * &c).double()?; + let s4 = c.square()?; + + let c0 = Self::mul_base_field_by_nonresidue(&s3)? + &s0; + let c1 = Self::mul_base_field_by_nonresidue(&s4)? + &s1; + let c2 = s1 + &s2 + &s3 - &s0 - &s4; + + Ok(Self::new(c0, c1, c2)) + } + + #[tracing::instrument(target = "r1cs")] + fn mul_equals(&self, other: &Self, result: &Self) -> Result<(), SynthesisError> { + // Karatsuba multiplication for cubic extensions: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // v2 = A.c2 * B.c2 + // result.c0 = v0 + β((a1 + a2)(b1 + b2) − v1 − v2) + // result.c1 = (a0 + a1)(b0 + b1) − v0 − v1 + βv2 + // result.c2 = (a0 + a2)(b0 + b2) − v0 + v1 − v2, + // We enforce this with six constraints: + // + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // v2 = A.c2 * B.c2 + // + // result.c0 - v0 + \beta*(v1 + v2) = β(a1 + a2)(b1 + b2)) + // result.c1 + v0 + v1 - βv2 = (a0 + a1)(b0 + b1) + // result.c2 + v0 - v1 + v2 = (a0 + a2)(b0 + b2) + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + // + // This implementation adapted from + // https://github.com/ZencashOfficial/ginger-lib/blob/development/r1cs/gadgets/std/src/fields/fp3.rs + let v0 = &self.c0 * &other.c0; + let v1 = &self.c1 * &other.c1; + let v2 = &self.c2 * &other.c2; + + // Check c0 + let nr_a1_plus_a2 = (&self.c1 + &self.c2) * P::NONRESIDUE; + let b1_plus_b2 = &other.c1 + &other.c2; + let nr_v1 = &v1 * P::NONRESIDUE; + let nr_v2 = &v2 * P::NONRESIDUE; + let to_check = &result.c0 - &v0 + &nr_v1 + &nr_v2; + nr_a1_plus_a2.mul_equals(&b1_plus_b2, &to_check)?; + + // Check c1 + let a0_plus_a1 = &self.c0 + &self.c1; + let b0_plus_b1 = &other.c0 + &other.c1; + let to_check = &result.c1 - &nr_v2 + &v0 + &v1; + a0_plus_a1.mul_equals(&b0_plus_b1, &to_check)?; + + // Check c2 + let a0_plus_a2 = &self.c0 + &self.c2; + let b0_plus_b2 = &other.c0 + &other.c2; + let to_check = &result.c2 + &v0 - &v1 + &v2; + a0_plus_a2.mul_equals(&b0_plus_b2, &to_check)?; + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn inverse(&self) -> Result { + let mode = if self.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let inverse = Self::new_variable( + self.cs(), + || { + self.value() + .map(|f| f.inverse().unwrap_or_else(CubicExtField::zero)) + }, + mode, + )?; + self.mul_equals(&inverse, &Self::one())?; + Ok(inverse) + } + + #[tracing::instrument(target = "r1cs")] + fn frobenius_map(&self, power: usize) -> Result { + let mut result = self.clone(); + result.c0.frobenius_map_in_place(power)?; + result.c1.frobenius_map_in_place(power)?; + result.c2.frobenius_map_in_place(power)?; + + P::mul_base_field_vars_by_frob_coeff(&mut result.c1, &mut result.c2, power); + Ok(result) + } +} + +impl_bounded_ops!( + CubicExtVar, + CubicExtField

, + Add, + add, + AddAssign, + add_assign, + |this: &'a CubicExtVar, other: &'a CubicExtVar| { + let c0 = &this.c0 + &other.c0; + let c1 = &this.c1 + &other.c1; + let c2 = &this.c2 + &other.c2; + CubicExtVar::new(c0, c1, c2) + }, + |this: &'a CubicExtVar, other: CubicExtField

| { + this + CubicExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: CubicExtVarConfig), + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +); +impl_bounded_ops!( + CubicExtVar, + CubicExtField

, + Sub, + sub, + SubAssign, + sub_assign, + |this: &'a CubicExtVar, other: &'a CubicExtVar| { + let c0 = &this.c0 - &other.c0; + let c1 = &this.c1 - &other.c1; + let c2 = &this.c2 - &other.c2; + CubicExtVar::new(c0, c1, c2) + }, + |this: &'a CubicExtVar, other: CubicExtField

| { + this - CubicExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: CubicExtVarConfig), + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +); +impl_bounded_ops!( + CubicExtVar, + CubicExtField

, + Mul, + mul, + MulAssign, + mul_assign, + |this: &'a CubicExtVar, other: &'a CubicExtVar| { + // Karatsuba multiplication for cubic extensions: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // v2 = A.c2 * B.c2 + // result.c0 = v0 + β((a1 + a2)(b1 + b2) − v1 − v2) + // result.c1 = (a0 + a1)(b0 + b1) − v0 − v1 + βv2 + // result.c2 = (a0 + a2)(b0 + b2) − v0 + v1 − v2, + // + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + let v0 = &this.c0 * &other.c0; + let v1 = &this.c1 * &other.c1; + let v2 = &this.c2 * &other.c2; + let c0 = + (((&this.c1 + &this.c2) * (&other.c1 + &other.c2) - &v1 - &v2) * P::NONRESIDUE) + &v0 ; + let c1 = + (&this.c0 + &this.c1) * (&other.c0 + &other.c1) - &v0 - &v1 + (&v2 * P::NONRESIDUE); + let c2 = + (&this.c0 + &this.c2) * (&other.c0 + &other.c2) - &v0 + &v1 - &v2; + + CubicExtVar::new(c0, c1, c2) + }, + |this: &'a CubicExtVar, other: CubicExtField

| { + this * CubicExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: CubicExtVarConfig), + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +); + +impl EqGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let b0 = self.c0.is_eq(&other.c0)?; + let b1 = self.c1.is_eq(&other.c1)?; + let b2 = self.c2.is_eq(&other.c2)?; + b0.and(&b1)?.and(&b2) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.c0.conditional_enforce_equal(&other.c0, condition)?; + self.c1.conditional_enforce_equal(&other.c1, condition)?; + self.c2.conditional_enforce_equal(&other.c2, condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +impl ToBitsGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bits_le()?; + let mut c1 = self.c1.to_bits_le()?; + let mut c2 = self.c2.to_bits_le()?; + c0.append(&mut c1); + c0.append(&mut c2); + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bits_le()?; + let mut c1 = self.c1.to_non_unique_bits_le()?; + let mut c2 = self.c2.to_non_unique_bits_le()?; + c0.append(&mut c1); + c0.append(&mut c2); + Ok(c0) + } +} + +impl ToBytesGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bytes()?; + let mut c1 = self.c1.to_bytes()?; + let mut c2 = self.c2.to_bytes()?; + c0.append(&mut c1); + c0.append(&mut c2); + + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bytes()?; + let mut c1 = self.c1.to_non_unique_bytes()?; + let mut c2 = self.c2.to_non_unique_bytes()?; + + c0.append(&mut c1); + c0.append(&mut c2); + + Ok(c0) + } +} + +impl ToConstraintFieldGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + BF: ToConstraintFieldGadget, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let mut res = Vec::new(); + + res.extend_from_slice(&self.c0.to_constraint_field()?); + res.extend_from_slice(&self.c1.to_constraint_field()?); + res.extend_from_slice(&self.c2.to_constraint_field()?); + + Ok(res) + } +} + +impl CondSelectGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let c0 = BF::conditionally_select(cond, &true_value.c0, &false_value.c0)?; + let c1 = BF::conditionally_select(cond, &true_value.c1, &false_value.c1)?; + let c2 = BF::conditionally_select(cond, &true_value.c2, &false_value.c2)?; + Ok(Self::new(c0, c1, c2)) + } +} + +impl TwoBitLookupGadget for CubicExtVar +where + BF: FieldVar + + TwoBitLookupGadget, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + type TableConstant = CubicExtField

; + + #[tracing::instrument(target = "r1cs")] + fn two_bit_lookup( + b: &[Boolean], + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c2s = c.iter().map(|f| f.c2).collect::>(); + let c0 = BF::two_bit_lookup(b, &c0s)?; + let c1 = BF::two_bit_lookup(b, &c1s)?; + let c2 = BF::two_bit_lookup(b, &c2s)?; + Ok(Self::new(c0, c1, c2)) + } +} + +impl ThreeBitCondNegLookupGadget + for CubicExtVar +where + BF: FieldVar + + ThreeBitCondNegLookupGadget, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + type TableConstant = CubicExtField

; + + #[tracing::instrument(target = "r1cs")] + fn three_bit_cond_neg_lookup( + b: &[Boolean], + b0b1: &Boolean, + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c2s = c.iter().map(|f| f.c2).collect::>(); + let c0 = BF::three_bit_cond_neg_lookup(b, b0b1, &c0s)?; + let c1 = BF::three_bit_cond_neg_lookup(b, b0b1, &c1s)?; + let c2 = BF::three_bit_cond_neg_lookup(b, b0b1, &c2s)?; + Ok(Self::new(c0, c1, c2)) + } +} + +impl AllocVar, ConstraintF> for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + + use SynthesisError::*; + let (c0, c1, c2) = match f() { + Ok(fe) => (Ok(fe.borrow().c0), Ok(fe.borrow().c1), Ok(fe.borrow().c2)), + Err(_) => ( + Err(AssignmentMissing), + Err(AssignmentMissing), + Err(AssignmentMissing), + ), + }; + + let c0 = BF::new_variable(ark_relations::ns!(cs, "c0"), || c0, mode)?; + let c1 = BF::new_variable(ark_relations::ns!(cs, "c1"), || c1, mode)?; + let c2 = BF::new_variable(ark_relations::ns!(cs, "c2"), || c2, mode)?; + Ok(Self::new(c0, c1, c2)) + } +} diff --git a/jolt-core/src/circuits/fields/fp12.rs b/jolt-core/src/circuits/fields/fp12.rs new file mode 100644 index 000000000..017deefe1 --- /dev/null +++ b/jolt-core/src/circuits/fields/fp12.rs @@ -0,0 +1,189 @@ +use crate::circuits::fields::{fp2::Fp2Var, fp6_3over2::Fp6Var, quadratic_extension::*}; +use ark_ff::{ + fields::{fp12_2over3over2::*, Field}, + fp6_3over2::Fp6Config, + PrimeField, QuadExtConfig, +}; +use ark_r1cs_std::fields::FieldVar; +use ark_relations::r1cs::SynthesisError; + +/// A degree-12 extension field constructed as the tower of a +/// quadratic extension over a cubic extension over a quadratic extension field. +/// This is the R1CS equivalent of `ark_ff::fp12_2over3over2::Fp12

`. +pub type Fp12Var = QuadExtVar< + Fp6Var<

::Fp6Config, ConstraintF>, + ConstraintF, + Fp12ConfigWrapper

, +>; + +type Fp2Config

= <

::Fp6Config as Fp6Config>::Fp2Config; + +impl QuadExtVarConfig, ConstraintF> + for Fp12ConfigWrapper

+where + P: Fp12Config, + ConstraintF: PrimeField, +{ + fn mul_base_field_var_by_frob_coeff(fe: &mut Fp6Var, power: usize) { + fe.c0 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + fe.c1 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + fe.c2 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + } +} + +impl Fp12Var +where + P: Fp12Config, + ConstraintF: PrimeField, +{ + /// Multiplies by a sparse element of the form `(c0 = (c0, c1, 0), c1 = (0, + /// d1, 0))`. + #[inline] + pub fn mul_by_014( + &self, + c0: &Fp2Var, ConstraintF>, + c1: &Fp2Var, ConstraintF>, + d1: &Fp2Var, ConstraintF>, + ) -> Result { + let v0 = self.c0.mul_by_c0_c1_0(&c0, &c1)?; + let v1 = self.c1.mul_by_0_c1_0(&d1)?; + let new_c0 = Self::mul_base_field_by_nonresidue(&v1)? + &v0; + + let new_c1 = (&self.c0 + &self.c1).mul_by_c0_c1_0(&c0, &(c1 + d1))? - &v0 - &v1; + Ok(Self::new(new_c0, new_c1)) + } + + /// Multiplies by a sparse element of the form `(c0 = (c0, 0, 0), c1 = (d0, + /// d1, 0))`. + #[inline] + pub fn mul_by_034( + &self, + c0: &Fp2Var, ConstraintF>, + d0: &Fp2Var, ConstraintF>, + d1: &Fp2Var, ConstraintF>, + ) -> Result { + let a0 = &self.c0.c0 * c0; + let a1 = &self.c0.c1 * c0; + let a2 = &self.c0.c2 * c0; + let a = Fp6Var::new(a0, a1, a2); + let b = self.c1.mul_by_c0_c1_0(&d0, &d1)?; + + let c0 = c0 + d0; + let c1 = d1; + let e = (&self.c0 + &self.c1).mul_by_c0_c1_0(&c0, &c1)?; + let new_c1 = e - (&a + &b); + let new_c0 = Self::mul_base_field_by_nonresidue(&b)? + &a; + + Ok(Self::new(new_c0, new_c1)) + } + + /// Squares `self` when `self` is in the cyclotomic subgroup. + pub fn cyclotomic_square(&self) -> Result { + if characteristic_square_mod_6_is_one(Fp12::

::characteristic()) { + let fp2_nr = ::NONRESIDUE; + + let z0 = &self.c0.c0; + let z4 = &self.c0.c1; + let z3 = &self.c0.c2; + let z2 = &self.c1.c0; + let z1 = &self.c1.c1; + let z5 = &self.c1.c2; + + // t0 + t1*y = (z0 + z1*y)^2 = a^2 + let tmp = z0 * z1; + let t0 = { + let tmp1 = z0 + z1; + let tmp2 = z1 * fp2_nr + z0; + let tmp4 = &tmp * fp2_nr + &tmp; + tmp1 * tmp2 - tmp4 + }; + let t1 = tmp.double()?; + + // t2 + t3*y = (z2 + z3*y)^2 = b^2 + let tmp = z2 * z3; + let t2 = { + // (z2 + &z3) * &(z2 + &(fp2_nr * &z3)) - &tmp - &(tmp * &fp2_nr); + let tmp1 = z2 + z3; + let tmp2 = z3 * fp2_nr + z2; + let tmp4 = &tmp * fp2_nr + &tmp; + tmp1 * tmp2 - tmp4 + }; + let t3 = tmp.double()?; + + // t4 + t5*y = (z4 + z5*y)^2 = c^2 + let tmp = z4 * z5; + let t4 = { + // (z4 + &z5) * &(z4 + &(fp2_nr * &z5)) - &tmp - &(tmp * &fp2_nr); + let tmp1 = z4 + z5; + let tmp2 = (z5 * fp2_nr) + z4; + let tmp4 = (&tmp * fp2_nr) + &tmp; + (tmp1 * tmp2) - tmp4 + }; + let t5 = tmp.double()?; + + // for A + + // z0 = 3 * t0 - 2 * z0 + let c0_c0 = (&t0 - z0).double()? + &t0; + + // z1 = 3 * t1 + 2 * z1 + let c1_c1 = (&t1 + z1).double()? + &t1; + + // for B + + // z2 = 3 * (xi * t5) + 2 * z2 + let c1_c0 = { + let tmp = &t5 * fp2_nr; + (z2 + &tmp).double()? + &tmp + }; + + // z3 = 3 * t4 - 2 * z3 + let c0_c2 = (&t4 - z3).double()? + &t4; + + // for C + + // z4 = 3 * t2 - 2 * z4 + let c0_c1 = (&t2 - z4).double()? + &t2; + + // z5 = 3 * t3 + 2 * z5 + let c1_c2 = (&t3 + z5).double()? + &t3; + let c0 = Fp6Var::new(c0_c0, c0_c1, c0_c2); + let c1 = Fp6Var::new(c1_c0, c1_c1, c1_c2); + + Ok(Self::new(c0, c1)) + } else { + self.square() + } + } + + /// Like `Self::cyclotomic_exp`, but additionally uses cyclotomic squaring. + pub fn optimized_cyclotomic_exp( + &self, + exponent: impl AsRef<[u64]>, + ) -> Result { + use ark_ff::biginteger::arithmetic::find_naf; + let mut res = Self::one(); + let self_inverse = self.unitary_inverse()?; + + let mut found_nonzero = false; + let naf = find_naf(exponent.as_ref()); + + for &value in naf.iter().rev() { + if found_nonzero { + res = res.cyclotomic_square()?; + } + + if value != 0 { + found_nonzero = true; + + if value > 0 { + res *= self; + } else { + res *= &self_inverse; + } + } + } + + Ok(res) + } +} diff --git a/jolt-core/src/circuits/fields/fp6_3over2.rs b/jolt-core/src/circuits/fields/fp6_3over2.rs new file mode 100644 index 000000000..5d600c93e --- /dev/null +++ b/jolt-core/src/circuits/fields/fp6_3over2.rs @@ -0,0 +1,105 @@ +use crate::circuits::fields::{cubic_extension::*, fp2::*}; +use ark_ff::{ + fields::{fp6_3over2::*, Fp2}, + CubicExtConfig, PrimeField, +}; +use ark_relations::r1cs::SynthesisError; +use ark_std::ops::MulAssign; + +/// A sextic extension field constructed as the tower of a +/// cubic extension over a quadratic extension field. +/// This is the R1CS equivalent of `ark_ff::fp6_3over3::Fp6

`. +pub type Fp6Var = + CubicExtVar::Fp2Config, ConstraintF>, ConstraintF, Fp6ConfigWrapper

>; + +impl CubicExtVarConfig, ConstraintF> + for Fp6ConfigWrapper

+where + P: Fp6Config, + ConstraintF: PrimeField, +{ + fn mul_base_field_vars_by_frob_coeff( + c1: &mut Fp2Var, + c2: &mut Fp2Var, + power: usize, + ) { + *c1 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + *c2 *= Self::FROBENIUS_COEFF_C2[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + } +} + +impl Fp6Var +where + P: Fp6Config, + ConstraintF: PrimeField, +{ + /// Multiplies `self` by a sparse element which has `c0 == c2 == zero`. + pub fn mul_by_0_c1_0( + &self, + c1: &Fp2Var, + ) -> Result { + // Karatsuba multiplication + // v0 = a0 * b0 = 0 + + // v1 = a1 * b1 + let v1 = &self.c1 * c1; + + // v2 = a2 * b2 = 0 + + let a1_plus_a2 = &self.c1 + &self.c2; + let b1_plus_b2 = c1.clone(); + + let a0_plus_a1 = &self.c0 + &self.c1; + + // c0 = (NONRESIDUE * ((a1 + a2)*(b1 + b2) - v1 - v2)) + v0 + // = NONRESIDUE * ((a1 + a2) * b1 - v1) + let c0 = &(a1_plus_a2 * &b1_plus_b2 - &v1) * P::NONRESIDUE; + + // c1 = (a0 + a1) * (b0 + b1) - v0 - v1 + NONRESIDUE * v2 + // = (a0 + a1) * b1 - v1 + let c1 = a0_plus_a1 * c1 - &v1; + // c2 = (a0 + a2) * (b0 + b2) - v0 - v2 + v1 + // = v1 + let c2 = v1; + Ok(Self::new(c0, c1, c2)) + } + + /// Multiplies `self` by a sparse element which has `c2 == zero`. + pub fn mul_by_c0_c1_0( + &self, + c0: &Fp2Var, + c1: &Fp2Var, + ) -> Result { + let v0 = &self.c0 * c0; + let v1 = &self.c1 * c1; + // v2 = 0. + + let a1_plus_a2 = &self.c1 + &self.c2; + let a0_plus_a1 = &self.c0 + &self.c1; + let a0_plus_a2 = &self.c0 + &self.c2; + + let b1_plus_b2 = c1.clone(); + let b0_plus_b1 = c0 + c1; + let b0_plus_b2 = c0.clone(); + + let c0 = (&a1_plus_a2 * &b1_plus_b2 - &v1) * P::NONRESIDUE + &v0; + + let c1 = a0_plus_a1 * &b0_plus_b1 - &v0 - &v1; + + let c2 = a0_plus_a2 * &b0_plus_b2 - &v0 + &v1; + + Ok(Self::new(c0, c1, c2)) + } +} + +impl MulAssign> for Fp6Var +where + P: Fp6Config, + ConstraintF: PrimeField, +{ + fn mul_assign(&mut self, other: Fp2) { + self.c0 *= other; + self.c1 *= other; + self.c2 *= other; + } +} diff --git a/jolt-core/src/circuits/fields/mod.rs b/jolt-core/src/circuits/fields/mod.rs index 90f54f974..2e88f53b8 100644 --- a/jolt-core/src/circuits/fields/mod.rs +++ b/jolt-core/src/circuits/fields/mod.rs @@ -1,3 +1,5 @@ -// pub mod fp12; +mod cubic_extension; +pub mod fp12; pub mod fp2; +mod fp6_3over2; pub mod quadratic_extension; diff --git a/jolt-core/src/circuits/fields/quadratic_extension.rs b/jolt-core/src/circuits/fields/quadratic_extension.rs index 40be6e238..84c4028cb 100644 --- a/jolt-core/src/circuits/fields/quadratic_extension.rs +++ b/jolt-core/src/circuits/fields/quadratic_extension.rs @@ -31,9 +31,7 @@ where /// The first coefficient of this field element. pub c1: BF, #[derivative(Debug = "ignore")] - _params: PhantomData

, - #[derivative(Debug = "ignore")] - _constraint_f: PhantomData, + _params: PhantomData<(P, ConstraintF)>, } /// This trait describes parameters that are used to implement arithmetic for @@ -63,7 +61,6 @@ where c0, c1, _params: PhantomData, - _constraint_f: PhantomData, } } diff --git a/jolt-core/src/circuits/pairing/bls12/mod.rs b/jolt-core/src/circuits/pairing/bls12/mod.rs index 0b3b51c97..ca30b9b76 100644 --- a/jolt-core/src/circuits/pairing/bls12/mod.rs +++ b/jolt-core/src/circuits/pairing/bls12/mod.rs @@ -5,8 +5,8 @@ use crate::circuits::groups::curves::short_weierstrass::bls12::{ }; use ark_ec::bls12::{Bls12, Bls12Config, TwistType}; use ark_ff::{BitIteratorBE, PrimeField}; -use ark_r1cs_std::fields::fp12::Fp12Var; -// use crate::circuits::fields::fp12::Fp12Var; +// use ark_r1cs_std::fields::fp12::Fp12Var; +use crate::circuits::fields::fp12::Fp12Var; use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; use ark_r1cs_std::prelude::FieldVar; use ark_relations::r1cs::SynthesisError; @@ -28,7 +28,7 @@ where // Evaluate the line function at point p. #[tracing::instrument(target = "r1cs")] fn ell( - f: &mut Fp12Var, + f: &mut Fp12Var, coeffs: &(Fp2V, Fp2V), p: &G1AffineVar, ) -> Result<(), SynthesisError> { @@ -59,7 +59,9 @@ where } #[tracing::instrument(target = "r1cs")] - fn exp_by_x(f: &Fp12Var) -> Result, SynthesisError> { + fn exp_by_x( + f: &Fp12Var, + ) -> Result, SynthesisError> { let mut result = f.optimized_cyclotomic_exp(P::X)?; if P::X_IS_NEGATIVE { result = result.unitary_inverse()?; @@ -73,7 +75,7 @@ impl PairingGadget, Constraint { type G1Var = G1Var; type G2Var = G2Var; - type GTVar = Fp12Var; + type GTVar = Fp12Var; type G1PreparedVar = G1PreparedVar; type G2PreparedVar = G2PreparedVar; diff --git a/jolt-core/src/circuits/pairing/mod.rs b/jolt-core/src/circuits/pairing/mod.rs index d5f5ba30a..818c24ff9 100644 --- a/jolt-core/src/circuits/pairing/mod.rs +++ b/jolt-core/src/circuits/pairing/mod.rs @@ -11,11 +11,15 @@ use ark_std::fmt::Debug; pub trait PairingGadget { /// A variable representing an element of `G1`. /// This is the R1CS equivalent of `E::G1Projective`. - type G1Var: CurveVar; + type G1Var: CurveVar + + AllocVar + + AllocVar; /// A variable representing an element of `G2`. /// This is the R1CS equivalent of `E::G2Projective`. - type G2Var: CurveVar; + type G2Var: CurveVar + + AllocVar + + AllocVar; /// A variable representing an element of `GT`. /// This is the R1CS equivalent of `E::GT`. From 99313853113eb69a05e79bfb2d5660338d8c8e0a Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 28 Jul 2024 19:59:49 -0700 Subject: [PATCH 05/44] WIP: PairingGadget: test circuit compiles and runs (extremely slowly) We don't know how many constraints yet. --- jolt-core/Cargo.toml | 1 + jolt-core/src/circuits/pairing/bls12/mod.rs | 12 ++- jolt-core/src/circuits/pairing/mod.rs | 82 ++++++++++++++++----- 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index b511c8df7..182b319fb 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -103,6 +103,7 @@ default = [ "rayon", ] host = ["dep:reqwest", "dep:tokio"] +print-trace = [ "ark-std/print-trace" ] [target.'cfg(not(target_arch = "wasm32"))'.dependencies] memory-stats = "1.0.0" diff --git a/jolt-core/src/circuits/pairing/bls12/mod.rs b/jolt-core/src/circuits/pairing/bls12/mod.rs index ca30b9b76..819590835 100644 --- a/jolt-core/src/circuits/pairing/bls12/mod.rs +++ b/jolt-core/src/circuits/pairing/bls12/mod.rs @@ -1,26 +1,24 @@ -use super::PairingGadget; +use crate::circuits::fields::fp12::Fp12Var; use crate::circuits::fields::fp2::Fp2Var; use crate::circuits::groups::curves::short_weierstrass::bls12::{ G1AffineVar, G1PreparedVar, G1Var, G2PreparedVar, G2Var, }; use ark_ec::bls12::{Bls12, Bls12Config, TwistType}; use ark_ff::{BitIteratorBE, PrimeField}; -// use ark_r1cs_std::fields::fp12::Fp12Var; -use crate::circuits::fields::fp12::Fp12Var; use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; use ark_r1cs_std::prelude::FieldVar; use ark_relations::r1cs::SynthesisError; use ark_std::marker::PhantomData; /// Specifies the constraints for computing a pairing in a BLS12 bilinear group. -pub struct PairingVar(PhantomData<(P, ConstraintF)>) +pub struct PairingGadget(PhantomData<(P, ConstraintF)>) where P: Bls12Config, ConstraintF: PrimeField; type Fp2V = Fp2Var<

::Fp2Config, ConstraintF>; -impl PairingVar +impl PairingGadget where P: Bls12Config, ConstraintF: PrimeField, @@ -70,8 +68,8 @@ where } } -impl PairingGadget, ConstraintF> - for PairingVar +impl super::PairingGadget, ConstraintF> + for PairingGadget { type G1Var = G1Var; type G2Var = G2Var; diff --git a/jolt-core/src/circuits/pairing/mod.rs b/jolt-core/src/circuits/pairing/mod.rs index 818c24ff9..add6b2402 100644 --- a/jolt-core/src/circuits/pairing/mod.rs +++ b/jolt-core/src/circuits/pairing/mod.rs @@ -80,68 +80,116 @@ pub trait PairingGadget { #[cfg(test)] mod tests { + use super::*; use ark_bls12_381::Bls12_381; use ark_bn254::Bn254; use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; use ark_ec::pairing::Pairing; use ark_ec::Group; - use ark_ff::{Field, PrimeField}; + use ark_ff::PrimeField; use ark_groth16::Groth16; - use ark_r1cs_std::pairing::bls12::PairingVar; + use ark_r1cs_std::prelude::*; use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; use ark_std::marker::PhantomData; use ark_std::rand::Rng; - use ark_std::test_rng; + use ark_std::{end_timer, start_timer, test_rng}; use rand_core::{RngCore, SeedableRng}; - struct PairingCheckCircuit { - _constraint_f: PhantomData, + struct PairingCheckCircuit + where + E: Pairing, + ConstraintF: PrimeField, + P: PairingGadget, + { r: Option, - r_g2: Option, + r_g2: Option, + _params: PhantomData<(ConstraintF, P)>, } - impl ConstraintSynthesizer - for PairingCheckCircuit + impl ConstraintSynthesizer + for PairingCheckCircuit + where + E: Pairing, + ConstraintF: PrimeField, + P: PairingGadget, { fn generate_constraints( self, cs: ConstraintSystemRef, ) -> Result<(), SynthesisError> { // TODO use PairingVar to generate constraints - // let r_g1 = PairingVar::::new_input(cs.clone(), || { - // Ok(E::G1Projective::prime_subgroup_generator()) - // })?; - Ok(()) + + let r_g1 = P::G1Var::new_witness(cs.clone(), || { + Ok(E::G1::generator() * self.r.ok_or(SynthesisError::AssignmentMissing)?) + })?; + let r_g2 = P::G2Var::new_witness(cs.clone(), || { + Ok(self.r_g2.ok_or(SynthesisError::AssignmentMissing)?) + })?; + + let r_g1_prepared = P::prepare_g1(&r_g1)?; + let r_g2_prepared = P::prepare_g2(&r_g2)?; + + let one_g2_prepared = P::G2PreparedVar::new_constant( + cs.clone(), + &E::G2Prepared::from(E::G2::generator()), + )?; + let minus_one_g1_prepared = P::G1PreparedVar::new_constant( + cs.clone(), + &E::G1Prepared::from(-E::G1::generator()), + )?; + + let result = P::multi_pairing( + &[r_g1_prepared, minus_one_g1_prepared], + &[one_g2_prepared, r_g2_prepared], + )?; + + result.enforce_equal(&P::GTVar::one()) } } #[test] fn test_pairing_check_circuit() { - let c = PairingCheckCircuit:: { - _constraint_f: PhantomData, + type DemoCircuit = PairingCheckCircuit< + Bls12_381, + ark_bn254::Fr, + bls12::PairingGadget, + >; + + let c = DemoCircuit { r: None, r_g2: None, + _params: PhantomData, }; // This is not cryptographically safe, use // `OsRng` (for example) in production software. let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(test_rng().next_u64()); + let setup_timer = start_timer!(|| "Groth16::setup"); let (pk, vk) = Groth16::::setup(c, &mut rng).unwrap(); + end_timer!(setup_timer); + let process_vk_timer = start_timer!(|| "Groth16::process_vk"); let pvk = Groth16::::process_vk(&vk).unwrap(); + end_timer!(process_vk_timer); let r = rng.gen(); let r_g2 = ::G2::generator() * &r; - let c = PairingCheckCircuit:: { - _constraint_f: PhantomData, + let c = DemoCircuit { r: Some(r), r_g2: Some(r_g2.into()), + _params: PhantomData, }; + let prove_timer = start_timer!(|| "Groth16::prove"); let proof = Groth16::::prove(&pk, c, &mut rng).unwrap(); + end_timer!(prove_timer); + + let verify_timer = start_timer!(|| "Groth16::verify"); + let verify_result = Groth16::::verify_with_processed_vk(&pvk, &[], &proof); + end_timer!(verify_timer); - assert!(Groth16::::verify_with_processed_vk(&pvk, &[], &proof).unwrap()); + assert!(verify_result.unwrap()); } } From d2f71718368452ecd95c2e13a0e80c66e7e9fac3 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 28 Jul 2024 21:58:02 -0700 Subject: [PATCH 06/44] print out the number of constraints as they get synthesized --- jolt-core/src/circuits/pairing/mod.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/jolt-core/src/circuits/pairing/mod.rs b/jolt-core/src/circuits/pairing/mod.rs index add6b2402..1a3944d5e 100644 --- a/jolt-core/src/circuits/pairing/mod.rs +++ b/jolt-core/src/circuits/pairing/mod.rs @@ -117,31 +117,41 @@ mod tests { self, cs: ConstraintSystemRef, ) -> Result<(), SynthesisError> { - // TODO use PairingVar to generate constraints + dbg!(cs.num_constraints()); let r_g1 = P::G1Var::new_witness(cs.clone(), || { Ok(E::G1::generator() * self.r.ok_or(SynthesisError::AssignmentMissing)?) })?; + dbg!(cs.num_constraints()); + + let r_g1_prepared = P::prepare_g1(&r_g1)?; + dbg!(cs.num_constraints()); + + let minus_one_g1_prepared = P::G1PreparedVar::new_constant( + cs.clone(), + &E::G1Prepared::from(-E::G1::generator()), + )?; + dbg!(cs.num_constraints()); + let r_g2 = P::G2Var::new_witness(cs.clone(), || { Ok(self.r_g2.ok_or(SynthesisError::AssignmentMissing)?) })?; + dbg!(cs.num_constraints()); - let r_g1_prepared = P::prepare_g1(&r_g1)?; let r_g2_prepared = P::prepare_g2(&r_g2)?; + dbg!(cs.num_constraints()); let one_g2_prepared = P::G2PreparedVar::new_constant( cs.clone(), &E::G2Prepared::from(E::G2::generator()), )?; - let minus_one_g1_prepared = P::G1PreparedVar::new_constant( - cs.clone(), - &E::G1Prepared::from(-E::G1::generator()), - )?; + dbg!(cs.num_constraints()); let result = P::multi_pairing( &[r_g1_prepared, minus_one_g1_prepared], &[one_g2_prepared, r_g2_prepared], )?; + dbg!(cs.num_constraints()); result.enforce_equal(&P::GTVar::one()) } From 925c4ec0912d2ded96316c1cd516834aa2b3a1b3 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Mon, 29 Jul 2024 10:52:06 -0700 Subject: [PATCH 07/44] ignore the heavy test --- jolt-core/src/circuits/pairing/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/jolt-core/src/circuits/pairing/mod.rs b/jolt-core/src/circuits/pairing/mod.rs index 1a3944d5e..36bec9d44 100644 --- a/jolt-core/src/circuits/pairing/mod.rs +++ b/jolt-core/src/circuits/pairing/mod.rs @@ -158,6 +158,7 @@ mod tests { } #[test] + #[ignore] fn test_pairing_check_circuit() { type DemoCircuit = PairingCheckCircuit< Bls12_381, From a73c4458ae8c864d98de12179604402885ae8408 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Fri, 2 Aug 2024 10:49:30 -0700 Subject: [PATCH 08/44] WIP: DelayedPairingCircuit --- jolt-core/src/circuits/groups/curves/mod.rs | 162 ++++++++++++++++++ .../groups/curves/short_weierstrass/bn254.rs | 7 + .../groups/curves/short_weierstrass/mod.rs | 3 + 3 files changed, 172 insertions(+) create mode 100644 jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index dbd7e00e0..30c491d27 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -1 +1,163 @@ +use crate::circuits::pairing::PairingGadget; + pub mod short_weierstrass; + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; + use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; + use ark_bls12_381::Bls12_381; + use ark_bn254::{Bn254, Fq, Fr}; + use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; + use ark_crypto_primitives::sponge::Absorb; + use ark_ec::bn::G1Projective; + use ark_ec::pairing::Pairing; + use ark_ec::short_weierstrass::{Projective, SWCurveConfig}; + use ark_ec::{CurveGroup, Group}; + use ark_ff::{PrimeField, ToConstraintField}; + use ark_groth16::Groth16; + use ark_r1cs_std::fields::fp::FpVar; + use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; + use ark_r1cs_std::prelude::*; + use ark_r1cs_std::ToConstraintFieldGadget; + use ark_relations::ns; + use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; + use ark_serialize::CanonicalSerialize; + use ark_std::marker::PhantomData; + use ark_std::rand::Rng; + use ark_std::{end_timer, start_timer, test_rng, UniformRand}; + use itertools::Itertools; + use rand_core::{RngCore, SeedableRng}; + use std::sync::{Arc, RwLock}; + + struct DelayedPairingCircuit + where + E: Pairing, + G1Var: CurveVar, + { + _params: PhantomData, + + // witness values + w_g1: [Option; 3], + d: Option, + + // public inputs + r_g1: Arc>>, + } + + impl ConstraintSynthesizer for DelayedPairingCircuit + where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, + { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + dbg!(cs.num_constraints()); + + let d = FpVar::new_witness(ns!(cs, "d"), || { + self.d.ok_or(SynthesisError::AssignmentMissing) + })?; + dbg!(cs.num_constraints()); + + let w_g1 = (0..3) + .map(|i| { + G1Var::new_witness(ns!(cs, "w_g1"), || { + self.w_g1[i].ok_or(SynthesisError::AssignmentMissing) + }) + }) + .collect::, _>>()?; + dbg!(cs.num_constraints()); + + let d_square = d.square()?; + let d_to_k = [FpVar::one(), d, d_square]; + dbg!(cs.num_constraints()); + + let r_g1 = (0..3) + .map(|k| { + w_g1[k] + .clone() + .scalar_mul_le(d_to_k[k].to_bits_le()?.iter()) + }) + .collect::, _>>()? + .iter() + .fold(G1Var::zero(), |acc, x| acc + x); + dbg!(cs.num_constraints()); + + let r_g1_opt = r_g1.value().ok(); + + let mut r_value_opt = self.r_g1.write().unwrap(); + *r_value_opt = r_g1_opt.clone(); + drop(r_value_opt); + + let cf_vec = r_g1.to_constraint_field()?; + + for cf in cf_vec.iter() { + let cf_input = FpVar::new_input(ns!(cs, "r_g1_input"), || cf.value())?; + cf_input.enforce_equal(&cf)?; + } + + dbg!(cs.num_constraints()); + + Ok(()) + } + } + + #[test] + fn test_delayed_pairing_circuit() { + type DemoCircuit = DelayedPairingCircuit; + + let circuit = DemoCircuit { + _params: PhantomData, + w_g1: [None; 3], + d: None, + r_g1: Arc::new(RwLock::new(None)), + }; + + // This is not cryptographically safe, use + // `OsRng` (for example) in production software. + let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(test_rng().next_u64()); + + let setup_timer = start_timer!(|| "Groth16::setup"); + let (pk, vk) = Groth16::::setup(circuit, &mut rng).unwrap(); + end_timer!(setup_timer); + + let process_vk_timer = start_timer!(|| "Groth16::process_vk"); + let pvk = Groth16::::process_vk(&vk).unwrap(); + end_timer!(process_vk_timer); + + let r_g1_lock = Arc::new(RwLock::new(None)); + let c_init_values = DemoCircuit { + _params: PhantomData, + w_g1: [Some(rng.gen()); 3], + d: Some(rng.gen()), + r_g1: r_g1_lock.clone(), + }; + + let prove_timer = start_timer!(|| "Groth16::prove"); + let proof = Groth16::::prove(&pk, c_init_values, &mut rng).unwrap(); + end_timer!(prove_timer); + + let r_g1_opt_read = r_g1_lock.read().unwrap(); + let r_g1 = dbg!(*r_g1_opt_read).unwrap(); + + let public_input = get_public_input(&r_g1); + + let verify_timer = start_timer!(|| "Groth16::verify"); + let verify_result = Groth16::::verify_with_processed_vk(&pvk, &public_input, &proof); + end_timer!(verify_timer); + + assert!(verify_result.unwrap()); + } + + fn get_public_input(g1: &ark_bn254::G1Projective) -> Vec { + G1Var::constant(g1.clone()) + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + } +} diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs new file mode 100644 index 000000000..3ca4a9838 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs @@ -0,0 +1,7 @@ +use crate::circuits::groups::curves::short_weierstrass::ProjectiveVar; +use ark_bn254::{Bn254, Fq, Fr}; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; + +pub type FBaseVar = NonNativeFieldVar; + +pub type G1Var = ProjectiveVar; diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs index f27170ec3..735e044a5 100644 --- a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs @@ -20,6 +20,8 @@ use binius_field::PackedField; /// the [\[BLS12]\]() family of bilinear groups. pub mod bls12; +pub mod bn254; + /// This module provides a generic implementation of elliptic curve operations /// for points on short-weierstrass curves in affine coordinates that **are /// not** equal to zero. @@ -29,6 +31,7 @@ pub mod bls12; /// to zero. The [ProjectiveVar] gadget is the recommended way of working with /// elliptic curve points. pub mod non_zero_affine; + /// An implementation of arithmetic for Short Weierstrass curves that relies on /// the complete formulae derived in the paper of /// [[Renes, Costello, Batina 2015]](). From 2f495c03f4c8b82d834fe4e48b21fb52527fe3f8 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sat, 3 Aug 2024 11:17:48 -0700 Subject: [PATCH 09/44] WIP: DelayedPairingCircuit --- jolt-core/src/circuits/groups/curves/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 30c491d27..c38a60c2c 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -75,7 +75,7 @@ mod tests { let d_to_k = [FpVar::one(), d, d_square]; dbg!(cs.num_constraints()); - let r_g1 = (0..3) + let r_g1 = (1..3) .map(|k| { w_g1[k] .clone() @@ -83,7 +83,7 @@ mod tests { }) .collect::, _>>()? .iter() - .fold(G1Var::zero(), |acc, x| acc + x); + .fold(w_g1[0].clone(), |acc, x| acc + x); dbg!(cs.num_constraints()); let r_g1_opt = r_g1.value().ok(); From 8e6b3d46b65e3d5fd2c67b915682327206979264 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sat, 3 Aug 2024 11:19:01 -0700 Subject: [PATCH 10/44] WIP: LoadedSNARK --- jolt-core/src/circuits/mod.rs | 140 ++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index b933bbbda..d7ab5bcc2 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -1,4 +1,144 @@ +use ark_crypto_primitives::snark::SNARK; +use ark_ec::pairing::Pairing; +use ark_ec::{AffineRepr, VariableBaseMSM}; +use ark_ff::PrimeField; +use ark_relations::r1cs::ConstraintSynthesizer; +use ark_serialize::CanonicalSerialize; +use ark_std::rand::{CryptoRng, RngCore}; +use ark_std::{One, Zero}; +use std::ops::Neg; + pub mod fields; pub mod groups; pub mod pairing; pub mod poly; + +/// Describes G1 elements to be used in a multi-pairing. +/// The verifier is responsible for ensuring that the sum of the pairings is zero. +/// The verifier needs to use appropriate G2 elements from the verification key or the proof +/// (depending on the protocol). +pub struct DelayedPairingDef { + /// Left pairing G1 element offset in the public input. + pub l_g1_offset: usize, + /// Right pairing G1 element offset in the public input. This element is, by convention, always used + /// in the multi-pairing computation with coefficient `-1`. + pub r_g1_offset: usize, +} + +/// Describes a block of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. +/// It's the verifiers responsibility to ensure that the sum is zero. +pub struct DelayedMSMDef { + /// Length is the number of G1 elements in the MSM. + pub length: usize, + /// MSM G1 elements offset in the public input. G1 elements are stored as sequences of scalar field elements + /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). + /// The offset is in the number of scalar field elements in the public input before the G1 elements block. + pub g1_offset: usize, + /// MSM scalars offset in the public input. The scalar at index `length-1` is, by convention, always `-1`, + /// so we can save one public input element. + /// The offset is in the number of scalar field elements in the public input before the scalars block. + pub scalar_offset: usize, +} + +pub struct LoadedSNARKProof +where + E: Pairing, + S: SNARK, +{ + pub snark_proof: S::Proof, + /// Delayed pairing G1 elements in the public input. + pub delayed_pairings: Vec, + /// Delayed MSM G1 and scalar blocks in the public input. + pub delayed_msms: Vec, +} + +pub trait LoadedSNARK +where + E: Pairing, + S: SNARK, +{ + type Circuit: ConstraintSynthesizer; + + fn prove( + circuit_pk: &S::ProvingKey, + circuit: Self::Circuit, + rng: &mut R, + ) -> Result, S::Error>; + + fn msm_inputs( + msm_defs: &[DelayedMSMDef], + public_input: &[E::ScalarField], + ) -> Result, Vec)>, S::Error> { + msm_defs + .iter() + .map(|msm_def| { + let g1_offset = msm_def.g1_offset; + let msm_length = msm_def.length; + let g1s = Self::g1_elements(public_input, g1_offset, msm_length); + let scalars = [ + &public_input[msm_def.scalar_offset..msm_def.scalar_offset + msm_length - 1], + &[-E::ScalarField::one()], + ] + .concat(); + Ok((g1s, scalars)) + }) + .collect() + } + + fn g1_elements( + public_input: &[::ScalarField], + g1_offset: usize, + length: usize, + ) -> Vec<::G1Affine> { + let g1_element_size = g1_affine_size_in_scalar_field_elements::(); + public_input[g1_offset..g1_offset + length * g1_element_size] + .chunks(g1_element_size) + .map(|chunk| g1_affine_from_scalar_field::(chunk)) + .collect() + } + + fn pairing_inputs( + pvk: &S::ProcessedVerifyingKey, + public_input: &[E::ScalarField], + proof: &LoadedSNARKProof, + ) -> Result, Vec)>, S::Error>; + + fn verify( + pvk: &S::ProcessedVerifyingKey, + public_input: &[E::ScalarField], + proof: &LoadedSNARKProof, + ) -> Result { + let r = S::verify_with_processed_vk(pvk, public_input, &proof.snark_proof)?; + if !r { + return Ok(false); + } + + let msms = Self::msm_inputs(&proof.delayed_msms, public_input)?; + for (g1s, scalars) in msms { + assert_eq!(g1s.len(), scalars.len()); + let r = E::G1::msm_unchecked(&g1s, &scalars); + if !r.is_zero() { + return Ok(false); + } + } + + let pairings = Self::pairing_inputs(pvk, public_input, &proof)?; + for (g1s, g2s) in pairings { + assert_eq!(g1s.len(), g2s.len()); + let r = E::multi_pairing(&g1s, &g2s); + if !r.is_zero() { + return Ok(false); + } + } + + Ok(true) + } +} + +fn g1_affine_size_in_scalar_field_elements() -> usize { + todo!() +} + +fn g1_affine_from_scalar_field(_s: &[E::ScalarField]) -> E::G1Affine { + todo!() +} From 5fa5e4821d6db65f381a0444e3ad173c9c7f5436 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sat, 3 Aug 2024 12:13:24 -0700 Subject: [PATCH 11/44] WIP: LoadedSNARK --- jolt-core/src/circuits/mod.rs | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index d7ab5bcc2..39d98d02b 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -40,6 +40,8 @@ pub struct DelayedMSMDef { pub scalar_offset: usize, } +// TODO move delayed pairing and msm defs to `LoadedSNARKVerifierKey`: the layout is known ahead of time. + pub struct LoadedSNARKProof where E: Pairing, @@ -74,6 +76,7 @@ where .map(|msm_def| { let g1_offset = msm_def.g1_offset; let msm_length = msm_def.length; + assert!(msm_length > 1); let g1s = Self::g1_elements(public_input, g1_offset, msm_length); let scalars = [ &public_input[msm_def.scalar_offset..msm_def.scalar_offset + msm_length - 1], @@ -101,7 +104,28 @@ where pvk: &S::ProcessedVerifyingKey, public_input: &[E::ScalarField], proof: &LoadedSNARKProof, - ) -> Result, Vec)>, S::Error>; + ) -> Result, Vec)>, S::Error> { + let g1_vectors = proof + .delayed_pairings + .iter() + .map(|pairing_def| { + let l_g1 = Self::g1_elements(public_input, pairing_def.l_g1_offset, 1)[0]; + let r_g1 = Self::g1_elements(public_input, pairing_def.r_g1_offset, 1)[0]; + + vec![l_g1.into_group(), -r_g1.into_group()] + }) + .collect::>(); + Ok(g1_vectors + .into_iter() + .zip(Self::g2_elements(pvk, public_input, proof)) + .collect()) + } + + fn g2_elements( + pvk: &::ScalarField>>::ProcessedVerifyingKey, + public_input: &[::ScalarField], + proof: &LoadedSNARKProof, + ) -> Vec>; fn verify( pvk: &S::ProcessedVerifyingKey, From f1ed212b6ad8e6bd7d034afa681f58d4310a8055 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sat, 3 Aug 2024 16:30:52 -0700 Subject: [PATCH 12/44] WIP: LoadedSNARK trait --- jolt-core/src/circuits/mod.rs | 91 +++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 39d98d02b..9addab8cd 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -1,12 +1,22 @@ use ark_crypto_primitives::snark::SNARK; -use ark_ec::pairing::Pairing; -use ark_ec::{AffineRepr, VariableBaseMSM}; +use ark_ec::{ + pairing::Pairing, + short_weierstrass::{Affine, SWCurveConfig}, + AffineRepr, CurveConfig, VariableBaseMSM, +}; use ark_ff::PrimeField; +use ark_r1cs_std::{ + fields::nonnative::params::{get_params, OptimizationType}, + fields::nonnative::AllocatedNonNativeFieldVar, +}; use ark_relations::r1cs::ConstraintSynthesizer; use ark_serialize::CanonicalSerialize; -use ark_std::rand::{CryptoRng, RngCore}; -use ark_std::{One, Zero}; -use std::ops::Neg; +use ark_std::{ + iterable::Iterable, + ops::Neg, + rand::{CryptoRng, RngCore}, + One, Zero, +}; pub mod fields; pub mod groups; @@ -40,23 +50,22 @@ pub struct DelayedMSMDef { pub scalar_offset: usize, } -// TODO move delayed pairing and msm defs to `LoadedSNARKVerifierKey`: the layout is known ahead of time. - -pub struct LoadedSNARKProof +pub struct LoadedSNARKVerifyingKey where E: Pairing, S: SNARK, { - pub snark_proof: S::Proof, + pub snark_pvk: S::ProcessedVerifyingKey, /// Delayed pairing G1 elements in the public input. pub delayed_pairings: Vec, /// Delayed MSM G1 and scalar blocks in the public input. pub delayed_msms: Vec, } -pub trait LoadedSNARK +pub trait LoadedSNARK where - E: Pairing, + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, S: SNARK, { type Circuit: ConstraintSynthesizer; @@ -65,7 +74,7 @@ where circuit_pk: &S::ProvingKey, circuit: Self::Circuit, rng: &mut R, - ) -> Result, S::Error>; + ) -> Result; fn msm_inputs( msm_defs: &[DelayedMSMDef], @@ -89,55 +98,55 @@ where } fn g1_elements( - public_input: &[::ScalarField], + public_input: &[E::ScalarField], g1_offset: usize, length: usize, - ) -> Vec<::G1Affine> { + ) -> Vec { let g1_element_size = g1_affine_size_in_scalar_field_elements::(); public_input[g1_offset..g1_offset + length * g1_element_size] .chunks(g1_element_size) - .map(|chunk| g1_affine_from_scalar_field::(chunk)) + .map(|chunk| g1_affine_from_scalar_field::(chunk)) .collect() } fn pairing_inputs( - pvk: &S::ProcessedVerifyingKey, + vk: &LoadedSNARKVerifyingKey, public_input: &[E::ScalarField], - proof: &LoadedSNARKProof, + proof: &S::Proof, ) -> Result, Vec)>, S::Error> { - let g1_vectors = proof + let g1_vectors = vk .delayed_pairings .iter() .map(|pairing_def| { let l_g1 = Self::g1_elements(public_input, pairing_def.l_g1_offset, 1)[0]; let r_g1 = Self::g1_elements(public_input, pairing_def.r_g1_offset, 1)[0]; - vec![l_g1.into_group(), -r_g1.into_group()] + vec![l_g1.into(), (-r_g1).into()] }) - .collect::>(); + .collect::>>(); Ok(g1_vectors .into_iter() - .zip(Self::g2_elements(pvk, public_input, proof)) + .zip(Self::g2_elements(vk, public_input, proof)) .collect()) } fn g2_elements( - pvk: &::ScalarField>>::ProcessedVerifyingKey, + vk: &LoadedSNARKVerifyingKey, public_input: &[::ScalarField], - proof: &LoadedSNARKProof, + proof: &S::Proof, ) -> Vec>; fn verify( - pvk: &S::ProcessedVerifyingKey, + vk: &LoadedSNARKVerifyingKey, public_input: &[E::ScalarField], - proof: &LoadedSNARKProof, + proof: &S::Proof, ) -> Result { - let r = S::verify_with_processed_vk(pvk, public_input, &proof.snark_proof)?; + let r = S::verify_with_processed_vk(&vk.snark_pvk, public_input, proof)?; if !r { return Ok(false); } - let msms = Self::msm_inputs(&proof.delayed_msms, public_input)?; + let msms = Self::msm_inputs(&vk.delayed_msms, public_input)?; for (g1s, scalars) in msms { assert_eq!(g1s.len(), scalars.len()); let r = E::G1::msm_unchecked(&g1s, &scalars); @@ -146,7 +155,7 @@ where } } - let pairings = Self::pairing_inputs(pvk, public_input, &proof)?; + let pairings = Self::pairing_inputs(vk, public_input, &proof)?; for (g1s, g2s) in pairings { assert_eq!(g1s.len(), g2s.len()); let r = E::multi_pairing(&g1s, &g2s); @@ -160,9 +169,29 @@ where } fn g1_affine_size_in_scalar_field_elements() -> usize { - todo!() + let params = get_params( + E::BaseField::MODULUS_BIT_SIZE as usize, + E::ScalarField::MODULUS_BIT_SIZE as usize, + OptimizationType::Weight, + ); + params.num_limbs * 2 + 1 } -fn g1_affine_from_scalar_field(_s: &[E::ScalarField]) -> E::G1Affine { - todo!() +fn g1_affine_from_scalar_field(s: &[E::ScalarField]) -> E::G1Affine +where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, +{ + let base_field_size_in_limbs = (s.len() - 1) / 2; + let x = AllocatedNonNativeFieldVar::::limbs_to_value( + s[..base_field_size_in_limbs].to_vec(), + OptimizationType::Weight, + ); + let y = AllocatedNonNativeFieldVar::::limbs_to_value( + s[base_field_size_in_limbs..s.len() - 1].to_vec(), + OptimizationType::Weight, + ); + let infinity = !s[s.len() - 1].is_zero(); + + Affine { x, y, infinity } } From f1e7f393aa16b3ff285ba09b7ceabc4fb773e821 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sat, 3 Aug 2024 18:58:48 -0700 Subject: [PATCH 13/44] WIP: LoadedSNARK trait --- jolt-core/src/circuits/mod.rs | 74 ++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 9addab8cd..c0c70f055 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -10,11 +10,12 @@ use ark_r1cs_std::{ fields::nonnative::AllocatedNonNativeFieldVar, }; use ark_relations::r1cs::ConstraintSynthesizer; -use ark_serialize::CanonicalSerialize; +use ark_serialize::{CanonicalSerialize, SerializationError, Valid}; use ark_std::{ iterable::Iterable, ops::Neg, rand::{CryptoRng, RngCore}, + result::Result, One, Zero, }; @@ -62,6 +63,20 @@ where pub delayed_msms: Vec, } +#[derive(thiserror::Error, Debug)] +pub enum LoadedSNARKError +where + E: Pairing, + S: SNARK, +{ + /// Wraps `S::Error`. + #[error(transparent)] + SNARKError(S::Error), + /// Wraps `SerializationError`. + #[error(transparent)] + SerializationError(#[from] SerializationError), +} + pub trait LoadedSNARK where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, @@ -74,19 +89,23 @@ where circuit_pk: &S::ProvingKey, circuit: Self::Circuit, rng: &mut R, - ) -> Result; + ) -> Result>; fn msm_inputs( msm_defs: &[DelayedMSMDef], public_input: &[E::ScalarField], - ) -> Result, Vec)>, S::Error> { + ) -> Result, Vec)>, SerializationError> { msm_defs .iter() .map(|msm_def| { let g1_offset = msm_def.g1_offset; let msm_length = msm_def.length; - assert!(msm_length > 1); - let g1s = Self::g1_elements(public_input, g1_offset, msm_length); + assert!(msm_length > 1); // TODO make it a verifier key validity error + let g1s = Self::g1_elements(public_input, g1_offset, msm_length)?; + + if public_input.len() < msm_def.scalar_offset + msm_length - 1 { + return Err(SerializationError::InvalidData); + }; let scalars = [ &public_input[msm_def.scalar_offset..msm_def.scalar_offset + msm_length - 1], &[-E::ScalarField::one()], @@ -101,8 +120,12 @@ where public_input: &[E::ScalarField], g1_offset: usize, length: usize, - ) -> Vec { + ) -> Result, SerializationError> { let g1_element_size = g1_affine_size_in_scalar_field_elements::(); + if public_input.len() < g1_offset + length * g1_element_size { + return Err(SerializationError::InvalidData); + }; + public_input[g1_offset..g1_offset + length * g1_element_size] .chunks(g1_element_size) .map(|chunk| g1_affine_from_scalar_field::(chunk)) @@ -113,20 +136,20 @@ where vk: &LoadedSNARKVerifyingKey, public_input: &[E::ScalarField], proof: &S::Proof, - ) -> Result, Vec)>, S::Error> { + ) -> Result, Vec)>, SerializationError> { let g1_vectors = vk .delayed_pairings .iter() .map(|pairing_def| { - let l_g1 = Self::g1_elements(public_input, pairing_def.l_g1_offset, 1)[0]; - let r_g1 = Self::g1_elements(public_input, pairing_def.r_g1_offset, 1)[0]; + let l_g1 = Self::g1_elements(public_input, pairing_def.l_g1_offset, 1)?[0]; + let r_g1 = Self::g1_elements(public_input, pairing_def.r_g1_offset, 1)?[0]; - vec![l_g1.into(), (-r_g1).into()] + Ok(vec![l_g1.into(), (-r_g1).into()]) }) - .collect::>>(); - Ok(g1_vectors + .collect::>, SerializationError>>(); + Ok(g1_vectors? .into_iter() - .zip(Self::g2_elements(vk, public_input, proof)) + .zip(Self::g2_elements(vk, public_input, proof)?) .collect()) } @@ -134,14 +157,15 @@ where vk: &LoadedSNARKVerifyingKey, public_input: &[::ScalarField], proof: &S::Proof, - ) -> Vec>; + ) -> Result>, SerializationError>; fn verify( vk: &LoadedSNARKVerifyingKey, public_input: &[E::ScalarField], proof: &S::Proof, - ) -> Result { - let r = S::verify_with_processed_vk(&vk.snark_pvk, public_input, proof)?; + ) -> Result> { + let r = S::verify_with_processed_vk(&vk.snark_pvk, public_input, proof) + .map_err(|e| LoadedSNARKError::SNARKError(e))?; if !r { return Ok(false); } @@ -177,11 +201,18 @@ fn g1_affine_size_in_scalar_field_elements() -> usize { params.num_limbs * 2 + 1 } -fn g1_affine_from_scalar_field(s: &[E::ScalarField]) -> E::G1Affine +fn g1_affine_from_scalar_field( + s: &[E::ScalarField], +) -> Result where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, P: SWCurveConfig, { + let infinity = !s[s.len() - 1].is_zero(); + if infinity { + return Ok(E::G1Affine::zero()); + } + let base_field_size_in_limbs = (s.len() - 1) / 2; let x = AllocatedNonNativeFieldVar::::limbs_to_value( s[..base_field_size_in_limbs].to_vec(), @@ -191,7 +222,12 @@ where s[base_field_size_in_limbs..s.len() - 1].to_vec(), OptimizationType::Weight, ); - let infinity = !s[s.len() - 1].is_zero(); - Affine { x, y, infinity } + let affine = Affine { + x, + y, + infinity: false, + }; + affine.check()?; + Ok(affine) } From 3f4313ed89ef42cb7f00b4b831dc834af73ea0d3 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sat, 3 Aug 2024 19:23:59 -0700 Subject: [PATCH 14/44] OffloadedSNARK trait --- jolt-core/src/circuits/mod.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index c0c70f055..c3e0468a6 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -51,7 +51,7 @@ pub struct DelayedMSMDef { pub scalar_offset: usize, } -pub struct LoadedSNARKVerifyingKey +pub struct OffloadedSNARKVerifyingKey where E: Pairing, S: SNARK, @@ -64,7 +64,7 @@ where } #[derive(thiserror::Error, Debug)] -pub enum LoadedSNARKError +pub enum OffloadedSNARKError where E: Pairing, S: SNARK, @@ -77,7 +77,7 @@ where SerializationError(#[from] SerializationError), } -pub trait LoadedSNARK +pub trait OffloadedSNARK where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, P: SWCurveConfig, @@ -89,7 +89,7 @@ where circuit_pk: &S::ProvingKey, circuit: Self::Circuit, rng: &mut R, - ) -> Result>; + ) -> Result>; fn msm_inputs( msm_defs: &[DelayedMSMDef], @@ -133,7 +133,7 @@ where } fn pairing_inputs( - vk: &LoadedSNARKVerifyingKey, + vk: &OffloadedSNARKVerifyingKey, public_input: &[E::ScalarField], proof: &S::Proof, ) -> Result, Vec)>, SerializationError> { @@ -154,18 +154,18 @@ where } fn g2_elements( - vk: &LoadedSNARKVerifyingKey, + vk: &OffloadedSNARKVerifyingKey, public_input: &[::ScalarField], proof: &S::Proof, ) -> Result>, SerializationError>; fn verify( - vk: &LoadedSNARKVerifyingKey, + vk: &OffloadedSNARKVerifyingKey, public_input: &[E::ScalarField], proof: &S::Proof, - ) -> Result> { + ) -> Result> { let r = S::verify_with_processed_vk(&vk.snark_pvk, public_input, proof) - .map_err(|e| LoadedSNARKError::SNARKError(e))?; + .map_err(|e| OffloadedSNARKError::SNARKError(e))?; if !r { return Ok(false); } From ab84dca92ee41b1d286f6b16cca966db76451924 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 4 Aug 2024 11:05:26 -0700 Subject: [PATCH 15/44] WIP: implement OffloadedSNARK trait for a demo circuit --- jolt-core/src/circuits/groups/curves/mod.rs | 64 ++++++++++++++++++--- 1 file changed, 57 insertions(+), 7 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index c38a60c2c..40d1832c8 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -7,13 +7,14 @@ mod tests { use super::*; use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; + use crate::circuits::{OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey}; use ark_bls12_381::Bls12_381; use ark_bn254::{Bn254, Fq, Fr}; use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; use ark_crypto_primitives::sponge::Absorb; use ark_ec::bn::G1Projective; use ark_ec::pairing::Pairing; - use ark_ec::short_weierstrass::{Projective, SWCurveConfig}; + use ark_ec::short_weierstrass::{Affine, Projective, SWCurveConfig}; use ark_ec::{CurveGroup, Group}; use ark_ff::{PrimeField, ToConstraintField}; use ark_groth16::Groth16; @@ -23,13 +24,14 @@ mod tests { use ark_r1cs_std::ToConstraintFieldGadget; use ark_relations::ns; use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; - use ark_serialize::CanonicalSerialize; + use ark_serialize::{CanonicalSerialize, SerializationError}; use ark_std::marker::PhantomData; use ark_std::rand::Rng; + use ark_std::rc::Rc; + use ark_std::sync::RwLock; use ark_std::{end_timer, start_timer, test_rng, UniformRand}; use itertools::Itertools; - use rand_core::{RngCore, SeedableRng}; - use std::sync::{Arc, RwLock}; + use rand_core::{CryptoRng, RngCore, SeedableRng}; struct DelayedPairingCircuit where @@ -43,7 +45,8 @@ mod tests { d: Option, // public inputs - r_g1: Arc>>, + r_g1: Rc>>, + g1s: Rc>>>, } impl ConstraintSynthesizer for DelayedPairingCircuit @@ -105,15 +108,60 @@ mod tests { } } + struct DelayedPairingCircuitSNARK + where + E: Pairing, + P: SWCurveConfig, + S: SNARK, + G1Var: CurveVar, + { + _params: PhantomData<(E, P, S, G1Var)>, + } + + impl OffloadedSNARK for DelayedPairingCircuitSNARK + where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, + S: SNARK, + G1Var: CurveVar + ToConstraintFieldGadget, + { + type Circuit = DelayedPairingCircuit; + + fn prove( + circuit_pk: &S::ProvingKey, + circuit: Self::Circuit, + rng: &mut R, + ) -> Result> { + // TODO place the G1 elements into the public input + + let proof = S::prove(circuit_pk, circuit, rng) + .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + + Ok(proof) + } + + fn g2_elements( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[::ScalarField], + proof: &S::Proof, + ) -> Result>, SerializationError> { + // TODO get the G2 elements from the verifying key + Ok(vec![]) + } + } + #[test] fn test_delayed_pairing_circuit() { type DemoCircuit = DelayedPairingCircuit; + type DemoCircuitSNARK = DelayedPairingCircuitSNARK, G1Var>; + let circuit = DemoCircuit { _params: PhantomData, w_g1: [None; 3], d: None, - r_g1: Arc::new(RwLock::new(None)), + r_g1: Rc::new(RwLock::new(None)), + g1s: Rc::new(Default::default()), }; // This is not cryptographically safe, use @@ -128,12 +176,14 @@ mod tests { let pvk = Groth16::::process_vk(&vk).unwrap(); end_timer!(process_vk_timer); - let r_g1_lock = Arc::new(RwLock::new(None)); + let r_g1_lock = Rc::new(RwLock::new(None)); + let g1s = Rc::new(RwLock::new(None)); let c_init_values = DemoCircuit { _params: PhantomData, w_g1: [Some(rng.gen()); 3], d: Some(rng.gen()), r_g1: r_g1_lock.clone(), + g1s: g1s.clone(), }; let prove_timer = start_timer!(|| "Groth16::prove"); From 3bd319f8f87453087a46dfebcfbe9f3f827f12e6 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 4 Aug 2024 19:21:27 -0700 Subject: [PATCH 16/44] Offloaded circuit successfully verifies --- jolt-core/src/circuits/groups/curves/mod.rs | 162 +++++++++++++------- jolt-core/src/circuits/mod.rs | 49 ++++-- 2 files changed, 143 insertions(+), 68 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 40d1832c8..74ffdf97e 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -7,7 +7,10 @@ mod tests { use super::*; use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; - use crate::circuits::{OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey}; + use crate::circuits::{ + OffloadedData, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, + PublicInputRef, + }; use ark_bls12_381::Bls12_381; use ark_bn254::{Bn254, Fq, Fr}; use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; @@ -15,7 +18,7 @@ mod tests { use ark_ec::bn::G1Projective; use ark_ec::pairing::Pairing; use ark_ec::short_weierstrass::{Affine, Projective, SWCurveConfig}; - use ark_ec::{CurveGroup, Group}; + use ark_ec::{CurveGroup, Group, VariableBaseMSM}; use ark_ff::{PrimeField, ToConstraintField}; use ark_groth16::Groth16; use ark_r1cs_std::fields::fp::FpVar; @@ -25,11 +28,13 @@ mod tests { use ark_relations::ns; use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; use ark_serialize::{CanonicalSerialize, SerializationError}; + use ark_std::cell::OnceCell; use ark_std::marker::PhantomData; + use ark_std::ops::Deref; use ark_std::rand::Rng; use ark_std::rc::Rc; use ark_std::sync::RwLock; - use ark_std::{end_timer, start_timer, test_rng, UniformRand}; + use ark_std::{end_timer, start_timer, test_rng, One, UniformRand}; use itertools::Itertools; use rand_core::{CryptoRng, RngCore, SeedableRng}; @@ -45,8 +50,7 @@ mod tests { d: Option, // public inputs - r_g1: Rc>>, - g1s: Rc>>>, + offloaded_data: Rc>>, } impl ConstraintSynthesizer for DelayedPairingCircuit @@ -75,50 +79,101 @@ mod tests { dbg!(cs.num_constraints()); let d_square = d.square()?; - let d_to_k = [FpVar::one(), d, d_square]; + let d_k = [FpVar::one(), d, d_square]; dbg!(cs.num_constraints()); - let r_g1 = (1..3) - .map(|k| { - w_g1[k] - .clone() - .scalar_mul_le(d_to_k[k].to_bits_le()?.iter()) - }) - .collect::, _>>()? + // `None` in setup mode, `Some>` in proving mode. + let msm_g1_values = w_g1 .iter() - .fold(w_g1[0].clone(), |acc, x| acc + x); - dbg!(cs.num_constraints()); + .map(|g1| g1.value().ok().map(|g1| g1.into_affine())) + .collect::>>(); - let r_g1_opt = r_g1.value().ok(); + let d_k_values = d_k + .iter() + .map(|d| d.value().ok()) + .collect::>>(); + + let (full_msm_value, r_g1_value) = msm_g1_values + .clone() + .zip(d_k_values) + .map(|(g1s, d_k)| { + let r_g1 = E::G1::msm_unchecked(&g1s, &d_k); + let minus_one = -E::ScalarField::one(); + ( + ( + [g1s, vec![r_g1.into()]].concat(), + [d_k, vec![minus_one]].concat(), + ), + r_g1, + ) + }) + .unzip(); - let mut r_value_opt = self.r_g1.write().unwrap(); - *r_value_opt = r_g1_opt.clone(); - drop(r_value_opt); + let r_g1_var = G1Var::new_witness(ns!(cs, "r_g1"), || { + r_g1_value.ok_or(SynthesisError::AssignmentMissing) + })?; - let cf_vec = r_g1.to_constraint_field()?; + if let Some(msm_value) = full_msm_value { + self.offloaded_data + .set(OffloadedData { + msms: vec![msm_value], + pairings: vec![], + }) + .unwrap(); + }; - for cf in cf_vec.iter() { - let cf_input = FpVar::new_input(ns!(cs, "r_g1_input"), || cf.value())?; - cf_input.enforce_equal(&cf)?; + // write d_k to public_input + for x in d_k { + let d_k_input = FpVar::new_input(ns!(cs, "d_k"), || x.value())?; + d_k_input.enforce_equal(&x)?; } + dbg!(cs.num_constraints()); + // write w_g1 to public_input + for g1 in w_g1 { + let f_vec = g1.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "w_g1"), || f.value())?; + f_input.enforce_equal(f)?; + } + } + + // write r_g1 to public_input + { + let f_vec = r_g1_var.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "r_g1"), || f.value())?; + f_input.enforce_equal(f)?; + } + } dbg!(cs.num_constraints()); Ok(()) } } - struct DelayedPairingCircuitSNARK + impl PublicInputRef> for DelayedPairingCircuit + where + E: Pairing, + G1Var: CurveVar, + { + fn public_input_ref(&self) -> Rc>> { + self.offloaded_data.clone() + } + } + + struct DelayedPairingCircuitSNARK where E: Pairing, - P: SWCurveConfig, S: SNARK, G1Var: CurveVar, { - _params: PhantomData<(E, P, S, G1Var)>, + _params: PhantomData<(E, S, G1Var)>, } - impl OffloadedSNARK for DelayedPairingCircuitSNARK + impl OffloadedSNARK for DelayedPairingCircuitSNARK where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, P: SWCurveConfig, @@ -127,19 +182,6 @@ mod tests { { type Circuit = DelayedPairingCircuit; - fn prove( - circuit_pk: &S::ProvingKey, - circuit: Self::Circuit, - rng: &mut R, - ) -> Result> { - // TODO place the G1 elements into the public input - - let proof = S::prove(circuit_pk, circuit, rng) - .map_err(|e| OffloadedSNARKError::SNARKError(e))?; - - Ok(proof) - } - fn g2_elements( vk: &OffloadedSNARKVerifyingKey, public_input: &[::ScalarField], @@ -154,14 +196,13 @@ mod tests { fn test_delayed_pairing_circuit() { type DemoCircuit = DelayedPairingCircuit; - type DemoCircuitSNARK = DelayedPairingCircuitSNARK, G1Var>; + type DemoCircuitSNARK = DelayedPairingCircuitSNARK, G1Var>; let circuit = DemoCircuit { _params: PhantomData, w_g1: [None; 3], d: None, - r_g1: Rc::new(RwLock::new(None)), - g1s: Rc::new(Default::default()), + offloaded_data: Default::default(), }; // This is not cryptographically safe, use @@ -176,24 +217,21 @@ mod tests { let pvk = Groth16::::process_vk(&vk).unwrap(); end_timer!(process_vk_timer); - let r_g1_lock = Rc::new(RwLock::new(None)); - let g1s = Rc::new(RwLock::new(None)); let c_init_values = DemoCircuit { _params: PhantomData, w_g1: [Some(rng.gen()); 3], d: Some(rng.gen()), - r_g1: r_g1_lock.clone(), - g1s: g1s.clone(), + offloaded_data: Default::default(), }; + let data_ref = c_init_values.public_input_ref(); let prove_timer = start_timer!(|| "Groth16::prove"); let proof = Groth16::::prove(&pk, c_init_values, &mut rng).unwrap(); end_timer!(prove_timer); - let r_g1_opt_read = r_g1_lock.read().unwrap(); - let r_g1 = dbg!(*r_g1_opt_read).unwrap(); + let pi_data = dbg!(data_ref.get()).unwrap(); - let public_input = get_public_input(&r_g1); + let public_input = build_public_input(pi_data); let verify_timer = start_timer!(|| "Groth16::verify"); let verify_result = Groth16::::verify_with_processed_vk(&pvk, &public_input, &proof); @@ -202,12 +240,24 @@ mod tests { assert!(verify_result.unwrap()); } - fn get_public_input(g1: &ark_bn254::G1Projective) -> Vec { - G1Var::constant(g1.clone()) - .to_constraint_field() - .unwrap() + fn build_public_input(data: &OffloadedData) -> Vec { + let scalars = &data.msms[0].1; + + let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) + + let msm_g1_vec = data.msms[0] + .0 .iter() - .map(|x| x.value().unwrap()) - .collect::>() + .map(|&g1| { + G1Var::constant(g1.into()) + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + }) + .concat(); + + [scalar_vec, msm_g1_vec].concat() } } diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index c3e0468a6..503323e68 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -12,9 +12,11 @@ use ark_r1cs_std::{ use ark_relations::r1cs::ConstraintSynthesizer; use ark_serialize::{CanonicalSerialize, SerializationError, Valid}; use ark_std::{ + cell::OnceCell, iterable::Iterable, ops::Neg, rand::{CryptoRng, RngCore}, + rc::Rc, result::Result, One, Zero, }; @@ -29,11 +31,11 @@ pub mod poly; /// The verifier needs to use appropriate G2 elements from the verification key or the proof /// (depending on the protocol). pub struct DelayedPairingDef { - /// Left pairing G1 element offset in the public input. - pub l_g1_offset: usize, - /// Right pairing G1 element offset in the public input. This element is, by convention, always used - /// in the multi-pairing computation with coefficient `-1`. - pub r_g1_offset: usize, + /// Offsets of the G1 elements in the public input. The G1 elements are stored as sequences of scalar field elements + /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). + /// The offsets are in the number of scalar field elements in the public input before the G1 elements block. + /// The last element, by convention, is always used in the multi-pairing computation with coefficient `-1`. + pub g1_offsets: Vec, } /// Describes a block of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. @@ -63,6 +65,16 @@ where pub delayed_msms: Vec, } +#[derive(Debug)] +pub struct OffloadedData { + pub msms: Vec<(Vec, Vec)>, + pub pairings: Vec>, +} + +pub trait PublicInputRef { + fn public_input_ref(&self) -> Rc>; +} + #[derive(thiserror::Error, Debug)] pub enum OffloadedSNARKError where @@ -77,19 +89,21 @@ where SerializationError(#[from] SerializationError), } -pub trait OffloadedSNARK +pub trait OffloadedSNARK where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, P: SWCurveConfig, S: SNARK, { - type Circuit: ConstraintSynthesizer; + type Circuit: ConstraintSynthesizer + PublicInputRef>; fn prove( circuit_pk: &S::ProvingKey, circuit: Self::Circuit, rng: &mut R, - ) -> Result>; + ) -> Result> { + S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e)) + } fn msm_inputs( msm_defs: &[DelayedMSMDef], @@ -141,10 +155,21 @@ where .delayed_pairings .iter() .map(|pairing_def| { - let l_g1 = Self::g1_elements(public_input, pairing_def.l_g1_offset, 1)?[0]; - let r_g1 = Self::g1_elements(public_input, pairing_def.r_g1_offset, 1)?[0]; - - Ok(vec![l_g1.into(), (-r_g1).into()]) + let last_index = pairing_def.g1_offsets.len() - 1; + let g1s = pairing_def + .g1_offsets + .iter() + .enumerate() + .map(|(i, &offset)| { + let g1 = Self::g1_elements(public_input, offset, 1)?[0]; + if i == last_index { + Ok((-g1).into()) + } else { + Ok(g1.into()) + } + }) + .collect::, _>>(); + g1s }) .collect::>, SerializationError>>(); Ok(g1_vectors? From 18b9886d4dc80938ad806f081139ed699784cad3 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 6 Aug 2024 19:52:43 -0700 Subject: [PATCH 17/44] Offloaded SNARK takes shape --- jolt-core/src/circuits/groups/curves/mod.rs | 69 +++---- jolt-core/src/circuits/mod.rs | 199 +++++++++++++------- 2 files changed, 154 insertions(+), 114 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 74ffdf97e..16206790f 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -8,8 +8,8 @@ mod tests { use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; use crate::circuits::{ - OffloadedData, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, - PublicInputRef, + DelayedMSMDef, OffloadedData, OffloadedSNARK, OffloadedSNARKError, + OffloadedSNARKVerifyingKey, PublicInputRef, }; use ark_bls12_381::Bls12_381; use ark_bn254::{Bn254, Fq, Fr}; @@ -38,7 +38,7 @@ mod tests { use itertools::Itertools; use rand_core::{CryptoRng, RngCore, SeedableRng}; - struct DelayedPairingCircuit + struct DelayedOpsCircuit where E: Pairing, G1Var: CurveVar, @@ -53,7 +53,7 @@ mod tests { offloaded_data: Rc>>, } - impl ConstraintSynthesizer for DelayedPairingCircuit + impl ConstraintSynthesizer for DelayedOpsCircuit where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, @@ -117,7 +117,6 @@ mod tests { self.offloaded_data .set(OffloadedData { msms: vec![msm_value], - pairings: vec![], }) .unwrap(); }; @@ -154,17 +153,17 @@ mod tests { } } - impl PublicInputRef> for DelayedPairingCircuit + impl PublicInputRef for DelayedOpsCircuit where E: Pairing, - G1Var: CurveVar, + G1Var: CurveVar + ToConstraintFieldGadget, { fn public_input_ref(&self) -> Rc>> { self.offloaded_data.clone() } } - struct DelayedPairingCircuitSNARK + struct DelayedOpsCircuitSNARK where E: Pairing, S: SNARK, @@ -173,18 +172,27 @@ mod tests { _params: PhantomData<(E, S, G1Var)>, } - impl OffloadedSNARK for DelayedPairingCircuitSNARK + impl OffloadedSNARK for DelayedOpsCircuitSNARK where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, P: SWCurveConfig, S: SNARK, G1Var: CurveVar + ToConstraintFieldGadget, { - type Circuit = DelayedPairingCircuit; + type Circuit = DelayedOpsCircuit; + + fn offloaded_setup( + snark_vk: S::ProcessedVerifyingKey, + ) -> Result, OffloadedSNARKError> { + Ok(OffloadedSNARKVerifyingKey { + snark_pvk: snark_vk, + delayed_pairings: vec![], // TODO none yet + }) + } fn g2_elements( vk: &OffloadedSNARKVerifyingKey, - public_input: &[::ScalarField], + public_input: &[E::ScalarField], proof: &S::Proof, ) -> Result>, SerializationError> { // TODO get the G2 elements from the verifying key @@ -194,9 +202,9 @@ mod tests { #[test] fn test_delayed_pairing_circuit() { - type DemoCircuit = DelayedPairingCircuit; + type DemoCircuit = DelayedOpsCircuit; - type DemoCircuitSNARK = DelayedPairingCircuitSNARK, G1Var>; + type DemoSNARK = DelayedOpsCircuitSNARK, G1Var>; let circuit = DemoCircuit { _params: PhantomData, @@ -210,11 +218,12 @@ mod tests { let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(test_rng().next_u64()); let setup_timer = start_timer!(|| "Groth16::setup"); - let (pk, vk) = Groth16::::setup(circuit, &mut rng).unwrap(); + let (pk, vk) = DemoSNARK::setup(circuit, &mut rng).unwrap(); end_timer!(setup_timer); let process_vk_timer = start_timer!(|| "Groth16::process_vk"); - let pvk = Groth16::::process_vk(&vk).unwrap(); + // let pvk = DemoSNARK::process_vk(&vk).unwrap(); + let pvk = vk; end_timer!(process_vk_timer); let c_init_values = DemoCircuit { @@ -223,41 +232,15 @@ mod tests { d: Some(rng.gen()), offloaded_data: Default::default(), }; - let data_ref = c_init_values.public_input_ref(); let prove_timer = start_timer!(|| "Groth16::prove"); - let proof = Groth16::::prove(&pk, c_init_values, &mut rng).unwrap(); + let proof = DemoSNARK::prove(&pk, c_init_values, &mut rng).unwrap(); end_timer!(prove_timer); - let pi_data = dbg!(data_ref.get()).unwrap(); - - let public_input = build_public_input(pi_data); - let verify_timer = start_timer!(|| "Groth16::verify"); - let verify_result = Groth16::::verify_with_processed_vk(&pvk, &public_input, &proof); + let verify_result = DemoSNARK::verify_with_processed_vk(&pvk, &[], &proof); end_timer!(verify_timer); assert!(verify_result.unwrap()); } - - fn build_public_input(data: &OffloadedData) -> Vec { - let scalars = &data.msms[0].1; - - let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) - - let msm_g1_vec = data.msms[0] - .0 - .iter() - .map(|&g1| { - G1Var::constant(g1.into()) - .to_constraint_field() - .unwrap() - .iter() - .map(|x| x.value().unwrap()) - .collect::>() - }) - .concat(); - - [scalar_vec, msm_g1_vec].concat() - } } diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 503323e68..0b6ca18b2 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -5,12 +5,14 @@ use ark_ec::{ AffineRepr, CurveConfig, VariableBaseMSM, }; use ark_ff::PrimeField; +use ark_r1cs_std::prelude::*; use ark_r1cs_std::{ fields::nonnative::params::{get_params, OptimizationType}, fields::nonnative::AllocatedNonNativeFieldVar, + ToConstraintFieldGadget, }; use ark_relations::r1cs::ConstraintSynthesizer; -use ark_serialize::{CanonicalSerialize, SerializationError, Valid}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; use ark_std::{ cell::OnceCell, iterable::Iterable, @@ -20,6 +22,7 @@ use ark_std::{ result::Result, One, Zero, }; +use itertools::Itertools; pub mod fields; pub mod groups; @@ -30,6 +33,7 @@ pub mod poly; /// The verifier is responsible for ensuring that the sum of the pairings is zero. /// The verifier needs to use appropriate G2 elements from the verification key or the proof /// (depending on the protocol). +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] pub struct DelayedPairingDef { /// Offsets of the G1 elements in the public input. The G1 elements are stored as sequences of scalar field elements /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). @@ -40,6 +44,7 @@ pub struct DelayedPairingDef { /// Describes a block of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. /// It's the verifiers responsibility to ensure that the sum is zero. +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] pub struct DelayedMSMDef { /// Length is the number of G1 elements in the MSM. pub length: usize, @@ -53,81 +58,137 @@ pub struct DelayedMSMDef { pub scalar_offset: usize, } +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] pub struct OffloadedSNARKVerifyingKey where E: Pairing, S: SNARK, { pub snark_pvk: S::ProcessedVerifyingKey, - /// Delayed pairing G1 elements in the public input. pub delayed_pairings: Vec, - /// Delayed MSM G1 and scalar blocks in the public input. - pub delayed_msms: Vec, } -#[derive(Debug)] +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct OffloadedSNARKProof +where + E: Pairing, + S: SNARK, +{ + pub snark_proof: S::Proof, + pub offloaded_data: OffloadedData, +} + +#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] pub struct OffloadedData { pub msms: Vec<(Vec, Vec)>, - pub pairings: Vec>, } -pub trait PublicInputRef { - fn public_input_ref(&self) -> Rc>; +pub trait PublicInputRef +where + E: Pairing, +{ + fn public_input_ref(&self) -> Rc>>; } #[derive(thiserror::Error, Debug)] -pub enum OffloadedSNARKError +pub enum OffloadedSNARKError where - E: Pairing, - S: SNARK, + Err: 'static + ark_std::error::Error, { - /// Wraps `S::Error`. + /// Wraps `Err`. #[error(transparent)] - SNARKError(S::Error), + SNARKError(Err), /// Wraps `SerializationError`. #[error(transparent)] SerializationError(#[from] SerializationError), } -pub trait OffloadedSNARK +pub trait OffloadedSNARK where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, P: SWCurveConfig, S: SNARK, + G1Var: CurveVar + ToConstraintFieldGadget, { - type Circuit: ConstraintSynthesizer + PublicInputRef>; + type Circuit: ConstraintSynthesizer + PublicInputRef; + + fn setup, R: RngCore + CryptoRng>( + circuit: C, + rng: &mut R, + ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> + { + Self::circuit_specific_setup(circuit, rng) + } + + fn circuit_specific_setup, R: RngCore + CryptoRng>( + circuit: C, + rng: &mut R, + ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> + { + let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) + .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let vk = Self::offloaded_setup(snark_pvk)?; + Ok((pk, vk)) + } + + fn offloaded_setup( + snark_vk: S::ProcessedVerifyingKey, + ) -> Result, OffloadedSNARKError>; fn prove( circuit_pk: &S::ProvingKey, circuit: Self::Circuit, rng: &mut R, - ) -> Result> { - S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e)) + ) -> Result, OffloadedSNARKError> { + let public_input_ref = circuit.public_input_ref(); + let proof = + S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + Ok(OffloadedSNARKProof { + snark_proof: proof, + offloaded_data: public_input_ref.get().unwrap().clone(), + }) } - fn msm_inputs( - msm_defs: &[DelayedMSMDef], + fn verify( + vk: &OffloadedSNARKVerifyingKey, public_input: &[E::ScalarField], - ) -> Result, Vec)>, SerializationError> { - msm_defs - .iter() - .map(|msm_def| { - let g1_offset = msm_def.g1_offset; - let msm_length = msm_def.length; - assert!(msm_length > 1); // TODO make it a verifier key validity error - let g1s = Self::g1_elements(public_input, g1_offset, msm_length)?; + proof: &OffloadedSNARKProof, + ) -> Result> { + Self::verify_with_processed_vk(vk, public_input, proof) + } - if public_input.len() < msm_def.scalar_offset + msm_length - 1 { - return Err(SerializationError::InvalidData); - }; - let scalars = [ - &public_input[msm_def.scalar_offset..msm_def.scalar_offset + msm_length - 1], - &[-E::ScalarField::one()], - ] - .concat(); - Ok((g1s, scalars)) - }) - .collect() + fn verify_with_processed_vk( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &OffloadedSNARKProof, + ) -> Result> { + let public_input = build_public_input::(public_input, &proof.offloaded_data); + + let r = S::verify_with_processed_vk(&vk.snark_pvk, &public_input, &proof.snark_proof) + .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + if !r { + return Ok(false); + } + + for (g1s, scalars) in &proof.offloaded_data.msms { + assert_eq!(g1s.len(), scalars.len()); + let r = E::G1::msm_unchecked(&g1s, &scalars); + if !r.is_zero() { + return Ok(false); + } + } + + let pairings = Self::pairing_inputs(vk, &public_input, &proof.snark_proof)?; + for (g1s, g2s) in pairings { + assert_eq!(g1s.len(), g2s.len()); + let r = E::multi_pairing(&g1s, &g2s); + if !r.is_zero() { + return Ok(false); + } + } + + Ok(true) } fn g1_elements( @@ -179,42 +240,10 @@ where } fn g2_elements( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[::ScalarField], - proof: &S::Proof, - ) -> Result>, SerializationError>; - - fn verify( vk: &OffloadedSNARKVerifyingKey, public_input: &[E::ScalarField], proof: &S::Proof, - ) -> Result> { - let r = S::verify_with_processed_vk(&vk.snark_pvk, public_input, proof) - .map_err(|e| OffloadedSNARKError::SNARKError(e))?; - if !r { - return Ok(false); - } - - let msms = Self::msm_inputs(&vk.delayed_msms, public_input)?; - for (g1s, scalars) in msms { - assert_eq!(g1s.len(), scalars.len()); - let r = E::G1::msm_unchecked(&g1s, &scalars); - if !r.is_zero() { - return Ok(false); - } - } - - let pairings = Self::pairing_inputs(vk, public_input, &proof)?; - for (g1s, g2s) in pairings { - assert_eq!(g1s.len(), g2s.len()); - let r = E::multi_pairing(&g1s, &g2s); - if !r.is_zero() { - return Ok(false); - } - } - - Ok(true) - } + ) -> Result>, SerializationError>; } fn g1_affine_size_in_scalar_field_elements() -> usize { @@ -256,3 +285,31 @@ where affine.check()?; Ok(affine) } + +fn build_public_input( + public_input: &[E::ScalarField], + data: &OffloadedData, +) -> Vec +where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + let scalars = &data.msms[0].1; + + let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) + + let msm_g1_vec = data.msms[0] + .0 + .iter() + .map(|&g1| { + G1Var::constant(g1.into()) + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + }) + .concat(); + + [public_input.to_vec(), scalar_vec, msm_g1_vec].concat() +} From 6a5aa97b4918b08770d8eb4b621733f74aa6743b Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 6 Aug 2024 21:02:16 -0700 Subject: [PATCH 18/44] OffloadedSNARK: move to module `snark` --- jolt-core/src/circuits/groups/curves/mod.rs | 6 +- jolt-core/src/circuits/mod.rs | 311 -------------------- jolt-core/src/lib.rs | 1 + jolt-core/src/snark/mod.rs | 288 ++++++++++++++++++ 4 files changed, 292 insertions(+), 314 deletions(-) create mode 100644 jolt-core/src/snark/mod.rs diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 16206790f..01837a26b 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -7,9 +7,9 @@ mod tests { use super::*; use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; - use crate::circuits::{ - DelayedMSMDef, OffloadedData, OffloadedSNARK, OffloadedSNARKError, - OffloadedSNARKVerifyingKey, PublicInputRef, + use crate::snark::{ + OffloadedData, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, + PublicInputRef, }; use ark_bls12_381::Bls12_381; use ark_bn254::{Bn254, Fq, Fr}; diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 0b6ca18b2..b933bbbda 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -1,315 +1,4 @@ -use ark_crypto_primitives::snark::SNARK; -use ark_ec::{ - pairing::Pairing, - short_weierstrass::{Affine, SWCurveConfig}, - AffineRepr, CurveConfig, VariableBaseMSM, -}; -use ark_ff::PrimeField; -use ark_r1cs_std::prelude::*; -use ark_r1cs_std::{ - fields::nonnative::params::{get_params, OptimizationType}, - fields::nonnative::AllocatedNonNativeFieldVar, - ToConstraintFieldGadget, -}; -use ark_relations::r1cs::ConstraintSynthesizer; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; -use ark_std::{ - cell::OnceCell, - iterable::Iterable, - ops::Neg, - rand::{CryptoRng, RngCore}, - rc::Rc, - result::Result, - One, Zero, -}; -use itertools::Itertools; - pub mod fields; pub mod groups; pub mod pairing; pub mod poly; - -/// Describes G1 elements to be used in a multi-pairing. -/// The verifier is responsible for ensuring that the sum of the pairings is zero. -/// The verifier needs to use appropriate G2 elements from the verification key or the proof -/// (depending on the protocol). -#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct DelayedPairingDef { - /// Offsets of the G1 elements in the public input. The G1 elements are stored as sequences of scalar field elements - /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). - /// The offsets are in the number of scalar field elements in the public input before the G1 elements block. - /// The last element, by convention, is always used in the multi-pairing computation with coefficient `-1`. - pub g1_offsets: Vec, -} - -/// Describes a block of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. -/// It's the verifiers responsibility to ensure that the sum is zero. -#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct DelayedMSMDef { - /// Length is the number of G1 elements in the MSM. - pub length: usize, - /// MSM G1 elements offset in the public input. G1 elements are stored as sequences of scalar field elements - /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). - /// The offset is in the number of scalar field elements in the public input before the G1 elements block. - pub g1_offset: usize, - /// MSM scalars offset in the public input. The scalar at index `length-1` is, by convention, always `-1`, - /// so we can save one public input element. - /// The offset is in the number of scalar field elements in the public input before the scalars block. - pub scalar_offset: usize, -} - -#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct OffloadedSNARKVerifyingKey -where - E: Pairing, - S: SNARK, -{ - pub snark_pvk: S::ProcessedVerifyingKey, - pub delayed_pairings: Vec, -} - -#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct OffloadedSNARKProof -where - E: Pairing, - S: SNARK, -{ - pub snark_proof: S::Proof, - pub offloaded_data: OffloadedData, -} - -#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] -pub struct OffloadedData { - pub msms: Vec<(Vec, Vec)>, -} - -pub trait PublicInputRef -where - E: Pairing, -{ - fn public_input_ref(&self) -> Rc>>; -} - -#[derive(thiserror::Error, Debug)] -pub enum OffloadedSNARKError -where - Err: 'static + ark_std::error::Error, -{ - /// Wraps `Err`. - #[error(transparent)] - SNARKError(Err), - /// Wraps `SerializationError`. - #[error(transparent)] - SerializationError(#[from] SerializationError), -} - -pub trait OffloadedSNARK -where - E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, - P: SWCurveConfig, - S: SNARK, - G1Var: CurveVar + ToConstraintFieldGadget, -{ - type Circuit: ConstraintSynthesizer + PublicInputRef; - - fn setup, R: RngCore + CryptoRng>( - circuit: C, - rng: &mut R, - ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> - { - Self::circuit_specific_setup(circuit, rng) - } - - fn circuit_specific_setup, R: RngCore + CryptoRng>( - circuit: C, - rng: &mut R, - ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> - { - let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) - .map_err(|e| OffloadedSNARKError::SNARKError(e))?; - let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; - let vk = Self::offloaded_setup(snark_pvk)?; - Ok((pk, vk)) - } - - fn offloaded_setup( - snark_vk: S::ProcessedVerifyingKey, - ) -> Result, OffloadedSNARKError>; - - fn prove( - circuit_pk: &S::ProvingKey, - circuit: Self::Circuit, - rng: &mut R, - ) -> Result, OffloadedSNARKError> { - let public_input_ref = circuit.public_input_ref(); - let proof = - S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; - Ok(OffloadedSNARKProof { - snark_proof: proof, - offloaded_data: public_input_ref.get().unwrap().clone(), - }) - } - - fn verify( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &OffloadedSNARKProof, - ) -> Result> { - Self::verify_with_processed_vk(vk, public_input, proof) - } - - fn verify_with_processed_vk( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &OffloadedSNARKProof, - ) -> Result> { - let public_input = build_public_input::(public_input, &proof.offloaded_data); - - let r = S::verify_with_processed_vk(&vk.snark_pvk, &public_input, &proof.snark_proof) - .map_err(|e| OffloadedSNARKError::SNARKError(e))?; - if !r { - return Ok(false); - } - - for (g1s, scalars) in &proof.offloaded_data.msms { - assert_eq!(g1s.len(), scalars.len()); - let r = E::G1::msm_unchecked(&g1s, &scalars); - if !r.is_zero() { - return Ok(false); - } - } - - let pairings = Self::pairing_inputs(vk, &public_input, &proof.snark_proof)?; - for (g1s, g2s) in pairings { - assert_eq!(g1s.len(), g2s.len()); - let r = E::multi_pairing(&g1s, &g2s); - if !r.is_zero() { - return Ok(false); - } - } - - Ok(true) - } - - fn g1_elements( - public_input: &[E::ScalarField], - g1_offset: usize, - length: usize, - ) -> Result, SerializationError> { - let g1_element_size = g1_affine_size_in_scalar_field_elements::(); - if public_input.len() < g1_offset + length * g1_element_size { - return Err(SerializationError::InvalidData); - }; - - public_input[g1_offset..g1_offset + length * g1_element_size] - .chunks(g1_element_size) - .map(|chunk| g1_affine_from_scalar_field::(chunk)) - .collect() - } - - fn pairing_inputs( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &S::Proof, - ) -> Result, Vec)>, SerializationError> { - let g1_vectors = vk - .delayed_pairings - .iter() - .map(|pairing_def| { - let last_index = pairing_def.g1_offsets.len() - 1; - let g1s = pairing_def - .g1_offsets - .iter() - .enumerate() - .map(|(i, &offset)| { - let g1 = Self::g1_elements(public_input, offset, 1)?[0]; - if i == last_index { - Ok((-g1).into()) - } else { - Ok(g1.into()) - } - }) - .collect::, _>>(); - g1s - }) - .collect::>, SerializationError>>(); - Ok(g1_vectors? - .into_iter() - .zip(Self::g2_elements(vk, public_input, proof)?) - .collect()) - } - - fn g2_elements( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &S::Proof, - ) -> Result>, SerializationError>; -} - -fn g1_affine_size_in_scalar_field_elements() -> usize { - let params = get_params( - E::BaseField::MODULUS_BIT_SIZE as usize, - E::ScalarField::MODULUS_BIT_SIZE as usize, - OptimizationType::Weight, - ); - params.num_limbs * 2 + 1 -} - -fn g1_affine_from_scalar_field( - s: &[E::ScalarField], -) -> Result -where - E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, - P: SWCurveConfig, -{ - let infinity = !s[s.len() - 1].is_zero(); - if infinity { - return Ok(E::G1Affine::zero()); - } - - let base_field_size_in_limbs = (s.len() - 1) / 2; - let x = AllocatedNonNativeFieldVar::::limbs_to_value( - s[..base_field_size_in_limbs].to_vec(), - OptimizationType::Weight, - ); - let y = AllocatedNonNativeFieldVar::::limbs_to_value( - s[base_field_size_in_limbs..s.len() - 1].to_vec(), - OptimizationType::Weight, - ); - - let affine = Affine { - x, - y, - infinity: false, - }; - affine.check()?; - Ok(affine) -} - -fn build_public_input( - public_input: &[E::ScalarField], - data: &OffloadedData, -) -> Vec -where - E: Pairing, - G1Var: CurveVar + ToConstraintFieldGadget, -{ - let scalars = &data.msms[0].1; - - let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) - - let msm_g1_vec = data.msms[0] - .0 - .iter() - .map(|&g1| { - G1Var::constant(g1.into()) - .to_constraint_field() - .unwrap() - .iter() - .map(|x| x.value().unwrap()) - .collect::>() - }) - .concat(); - - [public_input.to_vec(), scalar_vec, msm_g1_vec].concat() -} diff --git a/jolt-core/src/lib.rs b/jolt-core/src/lib.rs index 69cd9a629..21c7a36cf 100644 --- a/jolt-core/src/lib.rs +++ b/jolt-core/src/lib.rs @@ -23,5 +23,6 @@ pub mod lasso; pub mod msm; pub mod poly; pub mod r1cs; +pub mod snark; pub mod subprotocols; pub mod utils; diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs new file mode 100644 index 000000000..61fdb192f --- /dev/null +++ b/jolt-core/src/snark/mod.rs @@ -0,0 +1,288 @@ +use ark_crypto_primitives::snark::SNARK; +use ark_ec::pairing::Pairing; +use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; +use ark_ec::{AffineRepr, VariableBaseMSM}; +use ark_ff::{PrimeField, Zero}; +use ark_r1cs_std::fields::nonnative::params::{get_params, OptimizationType}; +use ark_r1cs_std::fields::nonnative::AllocatedNonNativeFieldVar; +use ark_r1cs_std::groups::CurveVar; +use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; +use ark_relations::r1cs::ConstraintSynthesizer; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; +use itertools::Itertools; +use rand_core::{CryptoRng, RngCore}; +use std::cell::OnceCell; +use std::rc::Rc; + +/// Describes G1 elements to be used in a multi-pairing. +/// The verifier is responsible for ensuring that the sum of the pairings is zero. +/// The verifier needs to use appropriate G2 elements from the verification key or the proof +/// (depending on the protocol). +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct DelayedPairingDef { + /// Offsets of the G1 elements in the public input. The G1 elements are stored as sequences of scalar field elements + /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). + /// The offsets are in the number of scalar field elements in the public input before the G1 elements block. + /// The last element, by convention, is always used in the multi-pairing computation with coefficient `-1`. + pub g1_offsets: Vec, +} + +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct OffloadedSNARKVerifyingKey +where + E: Pairing, + S: SNARK, +{ + pub snark_pvk: S::ProcessedVerifyingKey, + pub delayed_pairings: Vec, +} + +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct OffloadedSNARKProof +where + E: Pairing, + S: SNARK, +{ + pub snark_proof: S::Proof, + pub offloaded_data: OffloadedData, +} + +#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +pub struct OffloadedData { + /// Blocks of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. + /// It's the verifiers responsibility to ensure that the sum is zero. + /// The scalar at index `length-1` is, by convention, always `-1`, so + /// we save one public input element per MSM. + pub msms: Vec<(Vec, Vec)>, +} + +pub trait PublicInputRef +where + E: Pairing, +{ + fn public_input_ref(&self) -> Rc>>; +} + +#[derive(thiserror::Error, Debug)] +pub enum OffloadedSNARKError +where + Err: 'static + ark_std::error::Error, +{ + /// Wraps `Err`. + #[error(transparent)] + SNARKError(Err), + /// Wraps `SerializationError`. + #[error(transparent)] + SerializationError(#[from] SerializationError), +} + +pub trait OffloadedSNARK +where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, + S: SNARK, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + type Circuit: ConstraintSynthesizer + PublicInputRef; + + fn setup, R: RngCore + CryptoRng>( + circuit: C, + rng: &mut R, + ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> + { + Self::circuit_specific_setup(circuit, rng) + } + + fn circuit_specific_setup, R: RngCore + CryptoRng>( + circuit: C, + rng: &mut R, + ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> + { + let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) + .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let vk = Self::offloaded_setup(snark_pvk)?; + Ok((pk, vk)) + } + + fn offloaded_setup( + snark_vk: S::ProcessedVerifyingKey, + ) -> Result, OffloadedSNARKError>; + + fn prove( + circuit_pk: &S::ProvingKey, + circuit: Self::Circuit, + rng: &mut R, + ) -> Result, OffloadedSNARKError> { + let public_input_ref = circuit.public_input_ref(); + let proof = + S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + Ok(OffloadedSNARKProof { + snark_proof: proof, + offloaded_data: public_input_ref.get().unwrap().clone(), + }) + } + + fn verify( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &OffloadedSNARKProof, + ) -> Result> { + Self::verify_with_processed_vk(vk, public_input, proof) + } + + fn verify_with_processed_vk( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &OffloadedSNARKProof, + ) -> Result> { + let public_input = build_public_input::(public_input, &proof.offloaded_data); + + let r = S::verify_with_processed_vk(&vk.snark_pvk, &public_input, &proof.snark_proof) + .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + if !r { + return Ok(false); + } + + for (g1s, scalars) in &proof.offloaded_data.msms { + assert_eq!(g1s.len(), scalars.len()); + let r = E::G1::msm_unchecked(&g1s, &scalars); + if !r.is_zero() { + return Ok(false); + } + } + + let pairings = Self::pairing_inputs(vk, &public_input, &proof.snark_proof)?; + for (g1s, g2s) in pairings { + assert_eq!(g1s.len(), g2s.len()); + let r = E::multi_pairing(&g1s, &g2s); + if !r.is_zero() { + return Ok(false); + } + } + + Ok(true) + } + + fn g1_elements( + public_input: &[E::ScalarField], + g1_offset: usize, + length: usize, + ) -> Result, SerializationError> { + let g1_element_size = g1_affine_size_in_scalar_field_elements::(); + if public_input.len() < g1_offset + length * g1_element_size { + return Err(SerializationError::InvalidData); + }; + + public_input[g1_offset..g1_offset + length * g1_element_size] + .chunks(g1_element_size) + .map(|chunk| g1_affine_from_scalar_field::(chunk)) + .collect() + } + + fn pairing_inputs( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &S::Proof, + ) -> Result, Vec)>, SerializationError> { + let g1_vectors = vk + .delayed_pairings + .iter() + .map(|pairing_def| { + let last_index = pairing_def.g1_offsets.len() - 1; + let g1s = pairing_def + .g1_offsets + .iter() + .enumerate() + .map(|(i, &offset)| { + let g1 = Self::g1_elements(public_input, offset, 1)?[0]; + if i == last_index { + Ok((-g1).into()) + } else { + Ok(g1.into()) + } + }) + .collect::, _>>(); + g1s + }) + .collect::>, SerializationError>>(); + Ok(g1_vectors? + .into_iter() + .zip(Self::g2_elements(vk, public_input, proof)?) + .collect()) + } + + fn g2_elements( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &S::Proof, + ) -> Result>, SerializationError>; +} + +fn g1_affine_size_in_scalar_field_elements() -> usize { + let params = get_params( + E::BaseField::MODULUS_BIT_SIZE as usize, + E::ScalarField::MODULUS_BIT_SIZE as usize, + OptimizationType::Weight, + ); + params.num_limbs * 2 + 1 +} + +fn g1_affine_from_scalar_field( + s: &[E::ScalarField], +) -> Result +where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, +{ + let infinity = !s[s.len() - 1].is_zero(); + if infinity { + return Ok(E::G1Affine::zero()); + } + + let base_field_size_in_limbs = (s.len() - 1) / 2; + let x = AllocatedNonNativeFieldVar::::limbs_to_value( + s[..base_field_size_in_limbs].to_vec(), + OptimizationType::Weight, + ); + let y = AllocatedNonNativeFieldVar::::limbs_to_value( + s[base_field_size_in_limbs..s.len() - 1].to_vec(), + OptimizationType::Weight, + ); + + let affine = Affine { + x, + y, + infinity: false, + }; + affine.check()?; + Ok(affine) +} + +fn build_public_input( + public_input: &[E::ScalarField], + data: &OffloadedData, +) -> Vec +where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + let scalars = &data.msms[0].1; + + let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) + + let msm_g1_vec = data.msms[0] + .0 + .iter() + .map(|&g1| { + G1Var::constant(g1.into()) + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + }) + .concat(); + + [public_input.to_vec(), scalar_vec, msm_g1_vec].concat() +} From 2020be1ea3ed572997a1e1416e4cd28d83fa731a Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 6 Aug 2024 21:31:58 -0700 Subject: [PATCH 19/44] OffloadedSNARK: cleanup --- jolt-core/src/circuits/groups/curves/mod.rs | 8 ++-- jolt-core/src/snark/mod.rs | 48 ++++++++++++--------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 01837a26b..0099ff23e 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -8,8 +8,8 @@ mod tests { use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; use crate::snark::{ - OffloadedData, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, - PublicInputRef, + OffloadedData, OffloadedDataRef, OffloadedSNARK, OffloadedSNARKError, + OffloadedSNARKVerifyingKey, }; use ark_bls12_381::Bls12_381; use ark_bn254::{Bn254, Fq, Fr}; @@ -153,12 +153,12 @@ mod tests { } } - impl PublicInputRef for DelayedOpsCircuit + impl OffloadedDataRef for DelayedOpsCircuit where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { - fn public_input_ref(&self) -> Rc>> { + fn offloaded_data_ref(&self) -> Rc>> { self.offloaded_data.clone() } } diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index 61fdb192f..d801ef5ec 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -19,7 +19,7 @@ use std::rc::Rc; /// The verifier needs to use appropriate G2 elements from the verification key or the proof /// (depending on the protocol). #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct DelayedPairingDef { +pub struct OffloadedPairingDef { /// Offsets of the G1 elements in the public input. The G1 elements are stored as sequences of scalar field elements /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). /// The offsets are in the number of scalar field elements in the public input before the G1 elements block. @@ -34,7 +34,7 @@ where S: SNARK, { pub snark_pvk: S::ProcessedVerifyingKey, - pub delayed_pairings: Vec, + pub delayed_pairings: Vec, } #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] @@ -56,11 +56,11 @@ pub struct OffloadedData { pub msms: Vec<(Vec, Vec)>, } -pub trait PublicInputRef +pub trait OffloadedDataRef where E: Pairing, { - fn public_input_ref(&self) -> Rc>>; + fn offloaded_data_ref(&self) -> Rc>>; } #[derive(thiserror::Error, Debug)] @@ -83,7 +83,7 @@ where S: SNARK, G1Var: CurveVar + ToConstraintFieldGadget, { - type Circuit: ConstraintSynthesizer + PublicInputRef; + type Circuit: ConstraintSynthesizer + OffloadedDataRef; fn setup, R: RngCore + CryptoRng>( circuit: C, @@ -114,12 +114,13 @@ where circuit: Self::Circuit, rng: &mut R, ) -> Result, OffloadedSNARKError> { - let public_input_ref = circuit.public_input_ref(); + // Get the "pointer" to the offloaded data. `S::prove` will populate it. + let offloaded_data_ref = circuit.offloaded_data_ref(); let proof = S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; Ok(OffloadedSNARKProof { snark_proof: proof, - offloaded_data: public_input_ref.get().unwrap().clone(), + offloaded_data: offloaded_data_ref.get().unwrap().clone(), }) } @@ -267,22 +268,29 @@ where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { - let scalars = &data.msms[0].1; - - let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) - - let msm_g1_vec = data.msms[0] - .0 + let appended_data = data + .msms .iter() - .map(|&g1| { - G1Var::constant(g1.into()) - .to_constraint_field() - .unwrap() + .map(|msm| { + let scalars = &msm.1; + let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) + + let msm_g1_vec = msm + .0 .iter() - .map(|x| x.value().unwrap()) - .collect::>() + .map(|&g1| { + G1Var::constant(g1.into()) + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + }) + .concat(); + + [scalar_vec, msm_g1_vec].concat() }) .concat(); - [public_input.to_vec(), scalar_vec, msm_g1_vec].concat() + [public_input.to_vec(), appended_data].concat() } From b940345a972d7cf3b47ba950d2b3a2c0f99a021a Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Thu, 8 Aug 2024 16:01:43 -0700 Subject: [PATCH 20/44] OffloadedSNARK: OffloadedMSM gadget --- jolt-core/src/circuits/groups/curves/mod.rs | 104 ++++++------------- jolt-core/src/circuits/mod.rs | 1 + jolt-core/src/circuits/offloaded/mod.rs | 109 ++++++++++++++++++++ jolt-core/src/snark/mod.rs | 39 +++++-- 4 files changed, 171 insertions(+), 82 deletions(-) create mode 100644 jolt-core/src/circuits/offloaded/mod.rs diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 0099ff23e..feda7e83a 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -7,8 +7,9 @@ mod tests { use super::*; use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; + use crate::circuits::offloaded::OffloadedMSMGadget; use crate::snark::{ - OffloadedData, OffloadedDataRef, OffloadedSNARK, OffloadedSNARKError, + OffloadedData, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, }; use ark_bls12_381::Bls12_381; @@ -29,6 +30,7 @@ mod tests { use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; use ark_serialize::{CanonicalSerialize, SerializationError}; use ark_std::cell::OnceCell; + use ark_std::cell::{Cell, RefCell}; use ark_std::marker::PhantomData; use ark_std::ops::Deref; use ark_std::rand::Rng; @@ -49,8 +51,14 @@ mod tests { w_g1: [Option; 3], d: Option, - // public inputs - offloaded_data: Rc>>, + // deferred fns to write offloaded data to public_input + deferred_fns: RefCell< + Vec< + Box< + dyn FnOnce() -> Result<(Vec, Vec), SynthesisError>, + >, + >, + >, } impl ConstraintSynthesizer for DelayedOpsCircuit @@ -82,84 +90,32 @@ mod tests { let d_k = [FpVar::one(), d, d_square]; dbg!(cs.num_constraints()); - // `None` in setup mode, `Some>` in proving mode. - let msm_g1_values = w_g1 - .iter() - .map(|g1| g1.value().ok().map(|g1| g1.into_affine())) - .collect::>>(); - - let d_k_values = d_k - .iter() - .map(|d| d.value().ok()) - .collect::>>(); - - let (full_msm_value, r_g1_value) = msm_g1_values - .clone() - .zip(d_k_values) - .map(|(g1s, d_k)| { - let r_g1 = E::G1::msm_unchecked(&g1s, &d_k); - let minus_one = -E::ScalarField::one(); - ( - ( - [g1s, vec![r_g1.into()]].concat(), - [d_k, vec![minus_one]].concat(), - ), - r_g1, - ) - }) - .unzip(); - - let r_g1_var = G1Var::new_witness(ns!(cs, "r_g1"), || { - r_g1_value.ok_or(SynthesisError::AssignmentMissing) - })?; - - if let Some(msm_value) = full_msm_value { - self.offloaded_data - .set(OffloadedData { - msms: vec![msm_value], - }) - .unwrap(); - }; - - // write d_k to public_input - for x in d_k { - let d_k_input = FpVar::new_input(ns!(cs, "d_k"), || x.value())?; - d_k_input.enforce_equal(&x)?; - } - dbg!(cs.num_constraints()); - - // write w_g1 to public_input - for g1 in w_g1 { - let f_vec = g1.to_constraint_field()?; - - for f in f_vec.iter() { - let f_input = FpVar::new_input(ns!(cs, "w_g1"), || f.value())?; - f_input.enforce_equal(f)?; - } - } - - // write r_g1 to public_input - { - let f_vec = r_g1_var.to_constraint_field()?; - - for f in f_vec.iter() { - let f_input = FpVar::new_input(ns!(cs, "r_g1"), || f.value())?; - f_input.enforce_equal(f)?; - } - } - dbg!(cs.num_constraints()); + let _ = OffloadedMSMGadget::msm( + &self, + ns!(cs, "msm").cs(), + w_g1.as_slice(), + d_k.as_slice(), + )?; Ok(()) } } - impl OffloadedDataRef for DelayedOpsCircuit + impl OffloadedDataCircuit for DelayedOpsCircuit where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { - fn offloaded_data_ref(&self) -> Rc>> { - self.offloaded_data.clone() + fn deferred_fns_ref( + &self, + ) -> &RefCell< + Vec< + Box< + dyn FnOnce() -> Result<(Vec, Vec), SynthesisError>, + >, + >, + > { + &self.deferred_fns } } @@ -210,7 +166,7 @@ mod tests { _params: PhantomData, w_g1: [None; 3], d: None, - offloaded_data: Default::default(), + deferred_fns: Default::default(), }; // This is not cryptographically safe, use @@ -230,7 +186,7 @@ mod tests { _params: PhantomData, w_g1: [Some(rng.gen()); 3], d: Some(rng.gen()), - offloaded_data: Default::default(), + deferred_fns: Default::default(), }; let prove_timer = start_timer!(|| "Groth16::prove"); diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index b933bbbda..5be66eec6 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -1,4 +1,5 @@ pub mod fields; pub mod groups; +pub mod offloaded; pub mod pairing; pub mod poly; diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs new file mode 100644 index 000000000..ae3167b08 --- /dev/null +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -0,0 +1,109 @@ +use crate::snark::OffloadedDataCircuit; +use ark_ec::pairing::Pairing; +use ark_ec::{CurveGroup, VariableBaseMSM}; +use ark_ff::{One, PrimeField}; +use ark_r1cs_std::alloc::AllocVar; +use ark_r1cs_std::eq::EqGadget; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::fields::FieldVar; +use ark_r1cs_std::groups::CurveVar; +use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; +use ark_relations::ns; +use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; +use std::marker::PhantomData; + +pub struct OffloadedMSMGadget +where + E: Pairing, + Circuit: OffloadedDataCircuit, + ConstraintF: PrimeField, + FVar: FieldVar + ToConstraintFieldGadget, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + _params: PhantomData<(E, ConstraintF, FVar, G1Var, Circuit)>, +} + +impl OffloadedMSMGadget +where + E: Pairing, + Circuit: OffloadedDataCircuit, + ConstraintF: PrimeField, + FVar: FieldVar + ToConstraintFieldGadget, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + pub fn msm( + circuit: &Circuit, + cs: ConstraintSystemRef, + g1s: &[G1Var], + scalars: &[FVar], + ) -> Result { + let g1_values = g1s + .iter() + .map(|g1| g1.value().ok().map(|g1| g1.into_affine())) + .collect::>>(); + + let scalar_values = scalars + .iter() + .map(|s| s.value().ok()) + .collect::>>(); + + let (full_msm_value, msm_g1_value) = g1_values + .zip(scalar_values) + .map(|(g1s, scalars)| { + let r_g1 = E::G1::msm_unchecked(&g1s, &scalars); + let minus_one = -ConstraintF::one(); + ( + ( + [g1s, vec![r_g1.into()]].concat(), + [scalars, vec![minus_one]].concat(), + ), + r_g1, + ) + }) + .unzip(); + + let msm_g1_var = G1Var::new_witness(ns!(cs, "msm_g1"), || { + msm_g1_value.ok_or(SynthesisError::AssignmentMissing) + })?; + + { + let g1s = g1s.to_vec(); + let scalars = scalars.to_vec(); + let msm_g1_var = msm_g1_var.clone(); + + circuit.defer_msm(move || { + // write scalars to public_input + for x in scalars { + let scalar_input = FVar::new_input(ns!(cs, "scalar"), || x.value())?; + scalar_input.enforce_equal(&x)?; + } + dbg!(cs.num_constraints()); + + // write g1s to public_input + for g1 in g1s { + let f_vec = g1.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; + f_input.enforce_equal(f)?; + } + } + + // write msm_g1 to public_input + { + let f_vec = msm_g1_var.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "msm_g1"), || f.value())?; + f_input.enforce_equal(f)?; + } + } + dbg!(cs.num_constraints()); + + full_msm_value.ok_or(SynthesisError::AssignmentMissing) + }) + }; + + Ok(msm_g1_var) + } +} diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index d801ef5ec..eead49508 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -7,12 +7,11 @@ use ark_r1cs_std::fields::nonnative::params::{get_params, OptimizationType}; use ark_r1cs_std::fields::nonnative::AllocatedNonNativeFieldVar; use ark_r1cs_std::groups::CurveVar; use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; -use ark_relations::r1cs::ConstraintSynthesizer; +use ark_relations::r1cs::{ConstraintSynthesizer, SynthesisError}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; +use ark_std::cell::RefCell; use itertools::Itertools; use rand_core::{CryptoRng, RngCore}; -use std::cell::OnceCell; -use std::rc::Rc; /// Describes G1 elements to be used in a multi-pairing. /// The verifier is responsible for ensuring that the sum of the pairings is zero. @@ -56,11 +55,32 @@ pub struct OffloadedData { pub msms: Vec<(Vec, Vec)>, } -pub trait OffloadedDataRef +pub trait OffloadedDataCircuit where E: Pairing, { - fn offloaded_data_ref(&self) -> Rc>>; + fn deferred_fns_ref( + &self, + ) -> &RefCell< + Vec Result<(Vec, Vec), SynthesisError>>>, + >; + + fn defer_msm( + &self, + f: impl FnOnce() -> Result<(Vec, Vec), SynthesisError> + 'static, + ) { + self.deferred_fns_ref().borrow_mut().push(Box::new(f)); + } + + fn run_deferred(&self) -> Result, SynthesisError> { + let deferred_fns = self.deferred_fns_ref().take(); + let msms = deferred_fns + .into_iter() + .map(|f| f()) + .collect::, _>>()?; + + Ok(OffloadedData { msms }) + } } #[derive(thiserror::Error, Debug)] @@ -74,6 +94,8 @@ where /// Wraps `SerializationError`. #[error(transparent)] SerializationError(#[from] SerializationError), + #[error(transparent)] + SynthesisError(#[from] SynthesisError), } pub trait OffloadedSNARK @@ -83,7 +105,7 @@ where S: SNARK, G1Var: CurveVar + ToConstraintFieldGadget, { - type Circuit: ConstraintSynthesizer + OffloadedDataRef; + type Circuit: ConstraintSynthesizer + OffloadedDataCircuit; fn setup, R: RngCore + CryptoRng>( circuit: C, @@ -115,12 +137,13 @@ where rng: &mut R, ) -> Result, OffloadedSNARKError> { // Get the "pointer" to the offloaded data. `S::prove` will populate it. - let offloaded_data_ref = circuit.offloaded_data_ref(); + let offloaded_data = circuit.run_deferred()?; + let proof = S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; Ok(OffloadedSNARKProof { snark_proof: proof, - offloaded_data: offloaded_data_ref.get().unwrap().clone(), + offloaded_data, }) } From 8d305c80cc11b7b176a4226d1f33b90204dfb408 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Fri, 9 Aug 2024 12:44:19 -0700 Subject: [PATCH 21/44] OffloadedSNARK: OffloadedMSM gadget works --- jolt-core/src/circuits/groups/curves/mod.rs | 35 ++---- jolt-core/src/circuits/offloaded/mod.rs | 12 +- jolt-core/src/snark/mod.rs | 129 +++++++++++++++----- 3 files changed, 117 insertions(+), 59 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index feda7e83a..eea50bfb9 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -9,7 +9,7 @@ mod tests { use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; use crate::circuits::offloaded::OffloadedMSMGadget; use crate::snark::{ - OffloadedData, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, + DeferredFnsRef, OffloadedData, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, }; use ark_bls12_381::Bls12_381; @@ -52,13 +52,7 @@ mod tests { d: Option, // deferred fns to write offloaded data to public_input - deferred_fns: RefCell< - Vec< - Box< - dyn FnOnce() -> Result<(Vec, Vec), SynthesisError>, - >, - >, - >, + deferred_fns_ref: DeferredFnsRef, } impl ConstraintSynthesizer for DelayedOpsCircuit @@ -90,12 +84,9 @@ mod tests { let d_k = [FpVar::one(), d, d_square]; dbg!(cs.num_constraints()); - let _ = OffloadedMSMGadget::msm( - &self, - ns!(cs, "msm").cs(), - w_g1.as_slice(), - d_k.as_slice(), - )?; + let _ = + OffloadedMSMGadget::msm(&self, ns!(cs, "msm"), w_g1.as_slice(), d_k.as_slice())?; + dbg!(cs.num_constraints()); Ok(()) } @@ -106,16 +97,8 @@ mod tests { E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { - fn deferred_fns_ref( - &self, - ) -> &RefCell< - Vec< - Box< - dyn FnOnce() -> Result<(Vec, Vec), SynthesisError>, - >, - >, - > { - &self.deferred_fns + fn deferred_fns_ref(&self) -> &DeferredFnsRef { + &self.deferred_fns_ref } } @@ -166,7 +149,7 @@ mod tests { _params: PhantomData, w_g1: [None; 3], d: None, - deferred_fns: Default::default(), + deferred_fns_ref: Default::default(), }; // This is not cryptographically safe, use @@ -186,7 +169,7 @@ mod tests { _params: PhantomData, w_g1: [Some(rng.gen()); 3], d: Some(rng.gen()), - deferred_fns: Default::default(), + deferred_fns_ref: Default::default(), }; let prove_timer = start_timer!(|| "Groth16::prove"); diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index ae3167b08..6ab7a81ae 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -9,7 +9,7 @@ use ark_r1cs_std::fields::FieldVar; use ark_r1cs_std::groups::CurveVar; use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; use ark_relations::ns; -use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use std::marker::PhantomData; pub struct OffloadedMSMGadget @@ -33,10 +33,13 @@ where { pub fn msm( circuit: &Circuit, - cs: ConstraintSystemRef, + cs: impl Into>, g1s: &[G1Var], scalars: &[FVar], ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let g1_values = g1s .iter() .map(|g1| g1.value().ok().map(|g1| g1.into_affine())) @@ -70,6 +73,8 @@ where let g1s = g1s.to_vec(); let scalars = scalars.to_vec(); let msm_g1_var = msm_g1_var.clone(); + let ns = ns!(cs, "deferred_msm"); + let cs = ns.cs(); circuit.defer_msm(move || { // write scalars to public_input @@ -100,9 +105,10 @@ where } dbg!(cs.num_constraints()); - full_msm_value.ok_or(SynthesisError::AssignmentMissing) + Ok(full_msm_value) }) }; + dbg!(cs.num_constraints()); Ok(msm_g1_var) } diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index eead49508..c25a17997 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -1,15 +1,20 @@ use ark_crypto_primitives::snark::SNARK; -use ark_ec::pairing::Pairing; -use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; -use ark_ec::{AffineRepr, VariableBaseMSM}; +use ark_ec::{ + pairing::Pairing, + short_weierstrass::{Affine, SWCurveConfig}, + AffineRepr, VariableBaseMSM, +}; use ark_ff::{PrimeField, Zero}; -use ark_r1cs_std::fields::nonnative::params::{get_params, OptimizationType}; -use ark_r1cs_std::fields::nonnative::AllocatedNonNativeFieldVar; -use ark_r1cs_std::groups::CurveVar; -use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; -use ark_relations::r1cs::{ConstraintSynthesizer, SynthesisError}; +use ark_r1cs_std::{ + fields::nonnative::params::{get_params, OptimizationType}, + fields::nonnative::AllocatedNonNativeFieldVar, + groups::CurveVar, + R1CSVar, ToConstraintFieldGadget, +}; +use ark_relations::r1cs; +use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; -use ark_std::cell::RefCell; +use ark_std::{cell::OnceCell, cell::RefCell, marker::PhantomData, rc::Rc}; use itertools::Itertools; use rand_core::{CryptoRng, RngCore}; @@ -55,32 +60,35 @@ pub struct OffloadedData { pub msms: Vec<(Vec, Vec)>, } +pub type DeferredFn = + dyn FnOnce() -> Result, Vec)>, SynthesisError>; + +pub type DeferredFnsRef = Rc< + RefCell< + Vec< + Box< + dyn FnOnce() -> Result< + Option<(Vec, Vec)>, + SynthesisError, + >, + >, + >, + >, +>; + pub trait OffloadedDataCircuit where E: Pairing, { - fn deferred_fns_ref( - &self, - ) -> &RefCell< - Vec Result<(Vec, Vec), SynthesisError>>>, - >; + fn deferred_fns_ref(&self) -> &DeferredFnsRef; fn defer_msm( &self, - f: impl FnOnce() -> Result<(Vec, Vec), SynthesisError> + 'static, + f: impl FnOnce() -> Result, Vec)>, SynthesisError> + + 'static, ) { self.deferred_fns_ref().borrow_mut().push(Box::new(f)); } - - fn run_deferred(&self) -> Result, SynthesisError> { - let deferred_fns = self.deferred_fns_ref().take(); - let msms = deferred_fns - .into_iter() - .map(|f| f()) - .collect::, _>>()?; - - Ok(OffloadedData { msms }) - } } #[derive(thiserror::Error, Debug)] @@ -98,6 +106,53 @@ where SynthesisError(#[from] SynthesisError), } +struct WrappedCircuit +where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, + C: ConstraintSynthesizer + OffloadedDataCircuit, +{ + _params: PhantomData<(E, P)>, + circuit: C, + offloaded_data_ref: Rc>>, +} + +fn run_deferred( + deferred_fns: Vec< + Box< + dyn FnOnce() -> Result, Vec)>, SynthesisError>, + >, + >, +) -> Result>, SynthesisError> { + let msms = deferred_fns + .into_iter() + .map(|f| f()) + .collect::>, _>>()?; + + Ok(msms.map(|msms| OffloadedData { msms })) +} + +impl ConstraintSynthesizer for WrappedCircuit +where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, + C: ConstraintSynthesizer + OffloadedDataCircuit, +{ + fn generate_constraints(self, cs: ConstraintSystemRef) -> r1cs::Result<()> { + let deferred_fns_ref = self.circuit.deferred_fns_ref().clone(); + + let offloaded_data_ref = self.offloaded_data_ref.clone(); + + self.circuit.generate_constraints(cs)?; + + if let Some(offloaded_data) = run_deferred::(deferred_fns_ref.take())? { + offloaded_data_ref.set(offloaded_data).unwrap(); + }; + + Ok(()) + } +} + pub trait OffloadedSNARK where E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, @@ -107,11 +162,16 @@ where { type Circuit: ConstraintSynthesizer + OffloadedDataCircuit; - fn setup, R: RngCore + CryptoRng>( - circuit: C, + fn setup( + circuit: Self::Circuit, rng: &mut R, ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> { + let circuit = WrappedCircuit { + _params: PhantomData, + circuit, + offloaded_data_ref: Default::default(), + }; Self::circuit_specific_setup(circuit, rng) } @@ -122,8 +182,11 @@ where { let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let vk = Self::offloaded_setup(snark_pvk)?; + Ok((pk, vk)) } @@ -136,14 +199,20 @@ where circuit: Self::Circuit, rng: &mut R, ) -> Result, OffloadedSNARKError> { - // Get the "pointer" to the offloaded data. `S::prove` will populate it. - let offloaded_data = circuit.run_deferred()?; + let circuit = WrappedCircuit { + _params: PhantomData, + circuit, + offloaded_data_ref: Default::default(), + }; + + let offloaded_data_ref = circuit.offloaded_data_ref.clone(); let proof = S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + Ok(OffloadedSNARKProof { snark_proof: proof, - offloaded_data, + offloaded_data: offloaded_data_ref.get().unwrap().clone(), }) } From 517c865ec3ff00041d93933e5235a40ea2d5836e Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Fri, 9 Aug 2024 17:43:10 -0700 Subject: [PATCH 22/44] WIP: HyperKZG gadget --- jolt-core/src/circuits/groups/curves/mod.rs | 3 - .../poly/commitment/commitment_scheme.rs | 15 +- .../src/circuits/poly/commitment/hyperkzg.rs | 227 ++++++++++-------- jolt-core/src/poly/commitment/hyperkzg.rs | 2 +- 4 files changed, 131 insertions(+), 116 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index eea50bfb9..96c4e628f 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -1,5 +1,3 @@ -use crate::circuits::pairing::PairingGadget; - pub mod short_weierstrass; #[cfg(test)] @@ -12,7 +10,6 @@ mod tests { DeferredFnsRef, OffloadedData, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, }; - use ark_bls12_381::Bls12_381; use ark_bn254::{Bn254, Fq, Fr}; use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; use ark_crypto_primitives::sponge::Absorb; diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs index abbee6cf7..1f0170b5b 100644 --- a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -7,25 +7,22 @@ use ark_relations::r1cs::SynthesisError; use crate::poly::commitment::commitment_scheme::CommitmentScheme; -pub trait CommitmentVerifierGadget< - F: PrimeField, +pub trait CommitmentVerifierGadget +where ConstraintF: PrimeField, - C: CommitmentScheme, -> + C: CommitmentScheme, { type VerifyingKeyVar: AllocVar + Clone; type ProofVar: AllocVar + Clone; type CommitmentVar: AllocVar + Clone; - - // type Field: FieldVar; // TODO replace FpVar with Field: FieldVar - type TranscriptGadget: SpongeWithGadget + Clone; // TODO requires F: PrimeField, we want to generalize to JoltField + type TranscriptGadget: SpongeWithGadget + Clone; fn verify( proof: &Self::ProofVar, vk: &Self::VerifyingKeyVar, transcript: &mut Self::TranscriptGadget, - opening_point: &[FpVar], - opening: &FpVar, + opening_point: &[FpVar], + opening: &FpVar, commitment: &Self::CommitmentVar, ) -> Result, SynthesisError>; } diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 581939e77..159b72621 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,12 +1,10 @@ use std::borrow::Borrow; -use crate::circuits::pairing::PairingGadget; use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; use crate::field::JoltField; use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, }; -use ark_bn254::{Bn254, Fr as BN254Fr}; use ark_crypto_primitives::sponge::poseidon::PoseidonSponge; use ark_ec::pairing::Pairing; use ark_ff::{Field, PrimeField}; @@ -18,19 +16,19 @@ use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, use ark_std::marker::PhantomData; #[derive(Clone)] -pub struct HyperKZGProofVar> { - _e: PhantomData, - _p: PhantomData

, - _constraint_f: PhantomData, +pub struct HyperKZGProofVar +where + E: Pairing, + ConstraintF: PrimeField, +{ + _params: PhantomData<(E, ConstraintF)>, // TODO fill in } -impl AllocVar, ConstraintF> - for HyperKZGProofVar +impl AllocVar, ConstraintF> for HyperKZGProofVar where E: Pairing, ConstraintF: PrimeField, - P: PairingGadget, { fn new_variable>>( cs: impl Into>, @@ -42,23 +40,20 @@ where } #[derive(Clone)] -pub struct HyperKZGCommitmentVar< +pub struct HyperKZGCommitmentVar +where E: Pairing, ConstraintF: PrimeField, - P: PairingGadget, -> { - _e: PhantomData, - _p: PhantomData

, - _constraint_f: PhantomData, +{ + _params: PhantomData<(E, ConstraintF)>, // TODO fill in } -impl AllocVar, ConstraintF> - for HyperKZGCommitmentVar +impl AllocVar, ConstraintF> + for HyperKZGCommitmentVar where E: Pairing, ConstraintF: PrimeField, - P: PairingGadget, { fn new_variable>>( cs: impl Into>, @@ -70,20 +65,20 @@ where } #[derive(Clone)] -pub struct HyperKZGVerifierKeyVar< +pub struct HyperKZGVerifierKeyVar +where E: Pairing, ConstraintF: PrimeField, - P: PairingGadget, -> { - _e: PhantomData, - _p: PhantomData

, - _constraint_f: PhantomData, +{ + _params: PhantomData<(E, ConstraintF)>, // TODO fill in } -impl, ConstraintF: PrimeField> - AllocVar<(HyperKZGProverKey, HyperKZGVerifierKey), ConstraintF> - for HyperKZGVerifierKeyVar +impl AllocVar<(HyperKZGProverKey, HyperKZGVerifierKey), ConstraintF> + for HyperKZGVerifierKeyVar +where + E: Pairing, + ConstraintF: PrimeField, { fn new_variable, HyperKZGVerifierKey)>>( cs: impl Into>, @@ -94,120 +89,146 @@ impl, ConstraintF: PrimeField> } } -pub struct HyperKZGVerifierGadget +pub struct HyperKZGVerifierGadget where E: Pairing, - P: PairingGadget, + ConstraintF: PrimeField + JoltField, { - _e: PhantomData, - _p: PhantomData

, - _constraint_f: PhantomData, + _params: PhantomData<(E, ConstraintF)>, } -impl CommitmentVerifierGadget> - for HyperKZGVerifierGadget +impl CommitmentVerifierGadget> + for HyperKZGVerifierGadget where - E: Pairing, - P: PairingGadget + Clone, - ConstraintF: PrimeField, - F: PrimeField + JoltField, + E: Pairing, + ConstraintF: PrimeField + JoltField, { - type VerifyingKeyVar = HyperKZGVerifierKeyVar; - type ProofVar = HyperKZGProofVar; - type CommitmentVar = HyperKZGCommitmentVar; - type TranscriptGadget = PoseidonSponge; + type VerifyingKeyVar = HyperKZGVerifierKeyVar; + type ProofVar = HyperKZGProofVar; + type CommitmentVar = HyperKZGCommitmentVar; + type TranscriptGadget = PoseidonSponge; fn verify( proof: &Self::ProofVar, vk: &Self::VerifyingKeyVar, transcript: &mut Self::TranscriptGadget, - opening_point: &[FpVar], - opening: &FpVar, + opening_point: &[FpVar], + opening: &FpVar, commitment: &Self::CommitmentVar, ) -> Result, SynthesisError> { todo!() } } -#[derive(Default)] -struct HyperKZGVerifierCircuit { - _f: std::marker::PhantomData, - // TODO fill in -} - -impl HyperKZGVerifierCircuit { - pub(crate) fn public_inputs( - &self, - vk: &HyperKZGVerifierKey, - comm: &HyperKZGCommitment, - point: &Vec, - eval: &BN254Fr, - proof: &HyperKZGProof, - ) -> Vec { - // TODO fill in - vec![] - } -} - -impl ConstraintSynthesizer for HyperKZGVerifierCircuit { - fn generate_constraints(self, cs: ConstraintSystemRef) -> Result<(), SynthesisError> { - // TODO fill in - Ok(()) - } -} - #[cfg(test)] mod tests { - use ark_bls12_381::Bls12_381; - use ark_bn254::{Bn254, Fr as BN254Fr}; - use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; - use rand_core::SeedableRng; - use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, }; use crate::poly::dense_mlpoly::DensePolynomial; use crate::utils::errors::ProofVerifyError; use crate::utils::transcript::ProofTranscript; + use ark_bn254::{Bn254, Fr}; + use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; + use ark_r1cs_std::ToConstraintFieldGadget; + use ark_relations::ns; + use ark_std::rand::Rng; + use rand_core::{CryptoRng, RngCore, SeedableRng}; use super::*; + struct HyperKZGVerifierCircuit + where + E: Pairing, + { + pcs_vk: Option>, + commitment: Option>, + point: Option>, + eval: Option, + pcs_proof: Option>, + } + + impl HyperKZGVerifierCircuit + where + E: Pairing, + { + pub(crate) fn public_inputs(&self) -> Vec { + Boolean::::TRUE + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + } + } + + impl ConstraintSynthesizer for HyperKZGVerifierCircuit + where + E: Pairing, + { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + // TODO fill in + + let _ = Boolean::new_input(ns!(cs, "verification_result"), || Ok(true))?; + + Ok(()) + } + } + #[test] fn test_hyperkzg_eval() { - type Groth16 = ark_groth16::Groth16; + type Groth16 = ark_groth16::Groth16; // Test with poly(X1, X2) = 1 + X1 + X2 + X1*X2 let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); let srs = HyperKZGSRS::setup(&mut rng, 3); - let (pk, vk): (HyperKZGProverKey, HyperKZGVerifierKey) = srs.trim(3); + let (pcs_pk, pcs_vk): (HyperKZGProverKey, HyperKZGVerifierKey) = srs.trim(3); // poly is in eval. representation; evaluated at [(0,0), (0,1), (1,0), (1,1)] let poly = DensePolynomial::new(vec![ - BN254Fr::from(1), - BN254Fr::from(2), - BN254Fr::from(2), - BN254Fr::from(4), + ark_bn254::Fr::from(1), + ark_bn254::Fr::from(2), + ark_bn254::Fr::from(2), + ark_bn254::Fr::from(4), ]); let (cpk, cvk) = { - let circuit = HyperKZGVerifierCircuit::default(); + let circuit = HyperKZGVerifierCircuit:: { + pcs_vk: None, + commitment: None, + point: None, + eval: None, + pcs_proof: None, + }; Groth16::setup(circuit, &mut rng).unwrap() }; let pvk = Groth16::process_vk(&cvk).unwrap(); - let C = HyperKZG::commit(&pk, &poly).unwrap(); + let C = HyperKZG::commit(&pcs_pk, &poly).unwrap(); - let mut test_inner = |point: Vec, eval: BN254Fr| -> Result<(), ProofVerifyError> { + let test_inner = |point: Vec, eval: Fr| -> Result<(), ProofVerifyError> { let mut tr = ProofTranscript::new(b"TestEval"); - let proof = HyperKZG::open(&pk, &poly, &point, &eval, &mut tr).unwrap(); + let hkzg_proof = HyperKZG::open(&pcs_pk, &poly, &point, &eval, &mut tr).unwrap(); let mut tr = ProofTranscript::new(b"TestEval"); - HyperKZG::verify(&vk, &C, &point, &eval, &proof, &mut tr)?; + HyperKZG::verify(&pcs_vk, &C, &point, &eval, &hkzg_proof, &mut tr)?; // Create an instance of our circuit (with the // witness) - let verifier_circuit = HyperKZGVerifierCircuit::default(); - let instance = verifier_circuit.public_inputs(&vk, &C, &point, &eval, &proof); + let verifier_circuit = HyperKZGVerifierCircuit:: { + pcs_vk: Some(pcs_vk.clone()), + commitment: Some(C.clone()), + point: Some(point.clone()), + eval: Some(eval), + pcs_proof: Some(hkzg_proof), + }; + let instance = verifier_circuit.public_inputs(); + + let mut rng = + ark_std::rand::rngs::StdRng::seed_from_u64(ark_std::test_rng().next_u64()); // Create a groth16 proof with our parameters. let proof = Groth16::prove(&cpk, verifier_circuit, &mut rng) @@ -222,33 +243,33 @@ mod tests { // Call the prover with a (point, eval) pair. // The prover does not recompute so it may produce a proof, but it should not verify - let point = vec![BN254Fr::from(0), BN254Fr::from(0)]; - let eval = BN254Fr::from(1); + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(0)]; + let eval = ark_bn254::Fr::from(1); assert!(test_inner(point, eval).is_ok()); - let point = vec![BN254Fr::from(0), BN254Fr::from(1)]; - let eval = BN254Fr::from(2); + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(1)]; + let eval = ark_bn254::Fr::from(2); assert!(test_inner(point, eval).is_ok()); - let point = vec![BN254Fr::from(1), BN254Fr::from(1)]; - let eval = BN254Fr::from(4); + let point = vec![ark_bn254::Fr::from(1), ark_bn254::Fr::from(1)]; + let eval = ark_bn254::Fr::from(4); assert!(test_inner(point, eval).is_ok()); - let point = vec![BN254Fr::from(0), BN254Fr::from(2)]; - let eval = BN254Fr::from(3); + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(3); assert!(test_inner(point, eval).is_ok()); - let point = vec![BN254Fr::from(2), BN254Fr::from(2)]; - let eval = BN254Fr::from(9); + let point = vec![ark_bn254::Fr::from(2), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(9); assert!(test_inner(point, eval).is_ok()); // Try a couple incorrect evaluations and expect failure - let point = vec![BN254Fr::from(2), BN254Fr::from(2)]; - let eval = BN254Fr::from(50); + let point = vec![ark_bn254::Fr::from(2), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(50); assert!(test_inner(point, eval).is_err()); - let point = vec![BN254Fr::from(0), BN254Fr::from(2)]; - let eval = BN254Fr::from(4); + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(4); assert!(test_inner(point, eval).is_err()); } } diff --git a/jolt-core/src/poly/commitment/hyperkzg.rs b/jolt-core/src/poly/commitment/hyperkzg.rs index 11e814b48..ad6bd0905 100644 --- a/jolt-core/src/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/poly/commitment/hyperkzg.rs @@ -58,7 +58,7 @@ pub struct HyperKZGVerifierKey { pub kzg_vk: KZGVerifierKey

, } -#[derive(Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] pub struct HyperKZGCommitment(P::G1Affine); impl AppendToTranscript for HyperKZGCommitment

{ From 9579c317cefbe1d79ae40acda442ceb9880f97cd Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sat, 10 Aug 2024 19:57:25 -0700 Subject: [PATCH 23/44] WIP: HyperKZG gadget --- jolt-core/src/circuits/mod.rs | 1 + .../poly/commitment/commitment_scheme.rs | 6 +- .../src/circuits/poly/commitment/hyperkzg.rs | 163 ++++++++++++------ jolt-core/src/circuits/transcript/mock.rs | 103 +++++++++++ jolt-core/src/circuits/transcript/mod.rs | 1 + 5 files changed, 218 insertions(+), 56 deletions(-) create mode 100644 jolt-core/src/circuits/transcript/mock.rs create mode 100644 jolt-core/src/circuits/transcript/mod.rs diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 5be66eec6..8b2a3e53f 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -3,3 +3,4 @@ pub mod groups; pub mod offloaded; pub mod pairing; pub mod poly; +pub mod transcript; diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs index 1f0170b5b..2b716828f 100644 --- a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -7,20 +7,20 @@ use ark_relations::r1cs::SynthesisError; use crate::poly::commitment::commitment_scheme::CommitmentScheme; -pub trait CommitmentVerifierGadget +pub trait CommitmentVerifierGadget where ConstraintF: PrimeField, C: CommitmentScheme, + S: SpongeWithGadget, { type VerifyingKeyVar: AllocVar + Clone; type ProofVar: AllocVar + Clone; type CommitmentVar: AllocVar + Clone; - type TranscriptGadget: SpongeWithGadget + Clone; fn verify( proof: &Self::ProofVar, vk: &Self::VerifyingKeyVar, - transcript: &mut Self::TranscriptGadget, + transcript: &mut S::Var, opening_point: &[FpVar], opening: &FpVar, commitment: &Self::CommitmentVar, diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 159b72621..bbdea5d8c 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,11 +1,10 @@ -use std::borrow::Borrow; - use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; use crate::field::JoltField; use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, }; -use ark_crypto_primitives::sponge::poseidon::PoseidonSponge; +use ark_crypto_primitives::sponge::constraints::{CryptographicSpongeVar, SpongeWithGadget}; +use ark_crypto_primitives::sponge::CryptographicSponge; use ark_ec::pairing::Pairing; use ark_ff::{Field, PrimeField}; use ark_r1cs_std::boolean::Boolean; @@ -13,6 +12,7 @@ use ark_r1cs_std::fields::fp::FpVar; use ark_r1cs_std::pairing::PairingVar; use ark_r1cs_std::prelude::*; use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::borrow::Borrow; use ark_std::marker::PhantomData; #[derive(Clone)] @@ -35,11 +35,14 @@ where f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { - todo!() + // TODO implement + Ok(Self { + _params: PhantomData, + }) } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct HyperKZGCommitmentVar where E: Pairing, @@ -60,11 +63,14 @@ where f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { - todo!() + // TODO implement + Ok(Self { + _params: PhantomData, + }) } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct HyperKZGVerifierKeyVar where E: Pairing, @@ -85,67 +91,75 @@ where f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { - todo!() + // TODO implement + Ok(Self { + _params: PhantomData, + }) } } -pub struct HyperKZGVerifierGadget +pub struct HyperKZGVerifierGadget where E: Pairing, ConstraintF: PrimeField + JoltField, + S: SpongeWithGadget, { - _params: PhantomData<(E, ConstraintF)>, + _params: PhantomData<(E, ConstraintF, S)>, } -impl CommitmentVerifierGadget> - for HyperKZGVerifierGadget +impl CommitmentVerifierGadget, S> + for HyperKZGVerifierGadget where E: Pairing, ConstraintF: PrimeField + JoltField, + S: SpongeWithGadget, { type VerifyingKeyVar = HyperKZGVerifierKeyVar; type ProofVar = HyperKZGProofVar; type CommitmentVar = HyperKZGCommitmentVar; - type TranscriptGadget = PoseidonSponge; fn verify( proof: &Self::ProofVar, vk: &Self::VerifyingKeyVar, - transcript: &mut Self::TranscriptGadget, + transcript: &mut S::Var, opening_point: &[FpVar], opening: &FpVar, commitment: &Self::CommitmentVar, ) -> Result, SynthesisError> { - todo!() + Ok(Boolean::TRUE) } } #[cfg(test)] mod tests { + use super::*; + use crate::circuits::transcript::mock::{MockSponge, MockSpongeVar}; use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, }; use crate::poly::dense_mlpoly::DensePolynomial; use crate::utils::errors::ProofVerifyError; use crate::utils::transcript::ProofTranscript; - use ark_bn254::{Bn254, Fr}; + use ark_bn254::Bn254; use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; + use ark_crypto_primitives::sponge::constraints::CryptographicSpongeVar; + use ark_crypto_primitives::sponge::poseidon::constraints::PoseidonSpongeVar; + use ark_crypto_primitives::sponge::poseidon::{PoseidonConfig, PoseidonDefaultConfigField}; use ark_r1cs_std::ToConstraintFieldGadget; use ark_relations::ns; use ark_std::rand::Rng; use rand_core::{CryptoRng, RngCore, SeedableRng}; - use super::*; - struct HyperKZGVerifierCircuit where E: Pairing, { - pcs_vk: Option>, + pcs_pk_vk: Option<(HyperKZGProverKey, HyperKZGVerifierKey)>, commitment: Option>, point: Option>, eval: Option, pcs_proof: Option>, + expected_result: Option, } impl HyperKZGVerifierCircuit @@ -164,15 +178,55 @@ mod tests { impl ConstraintSynthesizer for HyperKZGVerifierCircuit where - E: Pairing, + E: Pairing, { fn generate_constraints( self, cs: ConstraintSystemRef, ) -> Result<(), SynthesisError> { - // TODO fill in - - let _ = Boolean::new_input(ns!(cs, "verification_result"), || Ok(true))?; + let vk_var = + HyperKZGVerifierKeyVar::::new_witness(ns!(cs, "vk"), || { + self.pcs_pk_vk.ok_or(SynthesisError::AssignmentMissing) + })?; + + let commitment_var = HyperKZGCommitmentVar::::new_witness( + ns!(cs, "commitment"), + || self.commitment.ok_or(SynthesisError::AssignmentMissing), + )?; + + let point_var = Vec::>::new_witness(ns!(cs, "point"), || { + if cs.is_in_setup_mode() { + return Ok(vec![]); // dummy value // TODO is there a better way? + } + self.point.ok_or(SynthesisError::AssignmentMissing) + })?; + + let eval_var = FpVar::::new_witness(ns!(cs, "eval"), || { + self.eval.ok_or(SynthesisError::AssignmentMissing) + })?; + + let proof_var = + HyperKZGProofVar::::new_witness(ns!(cs, "proof"), || { + self.pcs_proof.ok_or(SynthesisError::AssignmentMissing) + })?; + + let mut transcript_var = MockSpongeVar::new(ns!(cs, "transcript").cs(), &()); + + let r = + HyperKZGVerifierGadget::>::verify( + &proof_var, + &vk_var, + &mut transcript_var, + &point_var, + &eval_var, + &commitment_var, + )?; + + let r_input = Boolean::new_input(ns!(cs, "verification_result"), || { + self.expected_result + .ok_or(SynthesisError::AssignmentMissing) + })?; + r.enforce_equal(&r_input)?; Ok(()) } @@ -197,11 +251,12 @@ mod tests { let (cpk, cvk) = { let circuit = HyperKZGVerifierCircuit:: { - pcs_vk: None, + pcs_pk_vk: None, commitment: None, point: None, eval: None, pcs_proof: None, + expected_result: None, }; Groth16::setup(circuit, &mut rng).unwrap() @@ -210,36 +265,38 @@ mod tests { let C = HyperKZG::commit(&pcs_pk, &poly).unwrap(); - let test_inner = |point: Vec, eval: Fr| -> Result<(), ProofVerifyError> { - let mut tr = ProofTranscript::new(b"TestEval"); - let hkzg_proof = HyperKZG::open(&pcs_pk, &poly, &point, &eval, &mut tr).unwrap(); - let mut tr = ProofTranscript::new(b"TestEval"); - HyperKZG::verify(&pcs_vk, &C, &point, &eval, &hkzg_proof, &mut tr)?; - - // Create an instance of our circuit (with the - // witness) - let verifier_circuit = HyperKZGVerifierCircuit:: { - pcs_vk: Some(pcs_vk.clone()), - commitment: Some(C.clone()), - point: Some(point.clone()), - eval: Some(eval), - pcs_proof: Some(hkzg_proof), + let test_inner = + |point: Vec, eval: ark_bn254::Fr| -> Result<(), ProofVerifyError> { + let mut tr = ProofTranscript::new(b"TestEval"); + let hkzg_proof = HyperKZG::open(&pcs_pk, &poly, &point, &eval, &mut tr).unwrap(); + let mut tr = ProofTranscript::new(b"TestEval"); + HyperKZG::verify(&pcs_vk, &C, &point, &eval, &hkzg_proof, &mut tr)?; + + // Create an instance of our circuit (with the + // witness) + let verifier_circuit = HyperKZGVerifierCircuit:: { + pcs_pk_vk: Some((pcs_pk.clone(), pcs_vk.clone())), + commitment: Some(C.clone()), + point: Some(point.clone()), + eval: Some(eval), + pcs_proof: Some(hkzg_proof), + expected_result: Some(true), + }; + let instance = verifier_circuit.public_inputs(); + + let mut rng = + ark_std::rand::rngs::StdRng::seed_from_u64(ark_std::test_rng().next_u64()); + + // Create a groth16 proof with our parameters. + let proof = Groth16::prove(&cpk, verifier_circuit, &mut rng) + .map_err(|e| ProofVerifyError::InternalError)?; + let result = Groth16::verify_with_processed_vk(&pvk, &instance, &proof); + match result { + Ok(true) => Ok(()), + Ok(false) => Err(ProofVerifyError::InternalError), + Err(_) => Err(ProofVerifyError::InternalError), + } }; - let instance = verifier_circuit.public_inputs(); - - let mut rng = - ark_std::rand::rngs::StdRng::seed_from_u64(ark_std::test_rng().next_u64()); - - // Create a groth16 proof with our parameters. - let proof = Groth16::prove(&cpk, verifier_circuit, &mut rng) - .map_err(|e| ProofVerifyError::InternalError)?; - let result = Groth16::verify_with_processed_vk(&pvk, &instance, &proof); - match result { - Ok(true) => Ok(()), - Ok(false) => Err(ProofVerifyError::InternalError), - Err(_) => Err(ProofVerifyError::InternalError), - } - }; // Call the prover with a (point, eval) pair. // The prover does not recompute so it may produce a proof, but it should not verify diff --git a/jolt-core/src/circuits/transcript/mock.rs b/jolt-core/src/circuits/transcript/mock.rs new file mode 100644 index 000000000..a2df9a767 --- /dev/null +++ b/jolt-core/src/circuits/transcript/mock.rs @@ -0,0 +1,103 @@ +use ark_crypto_primitives::sponge::constraints::{ + AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget, +}; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; +use ark_ff::PrimeField; +use ark_r1cs_std::boolean::Boolean; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::prelude::UInt8; +use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; +use std::marker::PhantomData; + +#[derive(Clone)] +pub struct MockSponge +where + ConstraintF: PrimeField, +{ + _params: PhantomData, +} + +impl CryptographicSponge for MockSponge +where + ConstraintF: PrimeField, +{ + type Config = (); + + fn new(params: &Self::Config) -> Self { + Self { + _params: PhantomData, + } + } + + fn absorb(&mut self, input: &impl Absorb) { + todo!() + } + + fn squeeze_bytes(&mut self, num_bytes: usize) -> Vec { + todo!() + } + + fn squeeze_bits(&mut self, num_bits: usize) -> Vec { + todo!() + } +} + +impl SpongeWithGadget for MockSponge +where + ConstraintF: PrimeField, +{ + type Var = MockSpongeVar; +} + +#[derive(Clone)] +pub struct MockSpongeVar +where + ConstraintF: PrimeField, +{ + _params: PhantomData, + cs: ConstraintSystemRef, +} + +impl CryptographicSpongeVar> + for MockSpongeVar +where + ConstraintF: PrimeField, +{ + type Parameters = (); + + fn new(cs: ConstraintSystemRef, params: &Self::Parameters) -> Self { + Self { + _params: PhantomData, + cs, + } + } + + fn cs(&self) -> ConstraintSystemRef { + self.cs.clone() + } + + fn absorb(&mut self, input: &impl AbsorbGadget) -> Result<(), SynthesisError> { + todo!() + } + + fn squeeze_bytes( + &mut self, + num_bytes: usize, + ) -> Result>, SynthesisError> { + todo!() + } + + fn squeeze_bits( + &mut self, + num_bits: usize, + ) -> Result>, SynthesisError> { + todo!() + } + + fn squeeze_field_elements( + &mut self, + num_elements: usize, + ) -> Result>, SynthesisError> { + todo!() + } +} diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs new file mode 100644 index 000000000..9afc1d5e9 --- /dev/null +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -0,0 +1 @@ +pub mod mock; From b045294f234c9b10820474edeb2332bff6927534 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 11 Aug 2024 11:30:00 -0700 Subject: [PATCH 24/44] MockSponge --- .../src/circuits/poly/commitment/hyperkzg.rs | 3 +- jolt-core/src/circuits/transcript/mock.rs | 31 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index bbdea5d8c..53b3c8bfd 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -210,7 +210,8 @@ mod tests { self.pcs_proof.ok_or(SynthesisError::AssignmentMissing) })?; - let mut transcript_var = MockSpongeVar::new(ns!(cs, "transcript").cs(), &()); + let mut transcript_var = + MockSpongeVar::new(ns!(cs, "transcript").cs(), &(b"TestEval".as_slice())); let r = HyperKZGVerifierGadget::>::verify( diff --git a/jolt-core/src/circuits/transcript/mock.rs b/jolt-core/src/circuits/transcript/mock.rs index a2df9a767..2054fb5ef 100644 --- a/jolt-core/src/circuits/transcript/mock.rs +++ b/jolt-core/src/circuits/transcript/mock.rs @@ -1,3 +1,5 @@ +use crate::field::JoltField; +use crate::utils::transcript::ProofTranscript; use ark_crypto_primitives::sponge::constraints::{ AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget, }; @@ -5,21 +7,22 @@ use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_r1cs_std::boolean::Boolean; use ark_r1cs_std::fields::fp::FpVar; -use ark_r1cs_std::prelude::UInt8; +use ark_r1cs_std::prelude::*; +use ark_r1cs_std::R1CSVar; use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; use std::marker::PhantomData; #[derive(Clone)] pub struct MockSponge where - ConstraintF: PrimeField, + ConstraintF: PrimeField + JoltField, { _params: PhantomData, } impl CryptographicSponge for MockSponge where - ConstraintF: PrimeField, + ConstraintF: PrimeField + JoltField, { type Config = (); @@ -44,7 +47,7 @@ where impl SpongeWithGadget for MockSponge where - ConstraintF: PrimeField, + ConstraintF: PrimeField + JoltField, { type Var = MockSpongeVar; } @@ -56,19 +59,21 @@ where { _params: PhantomData, cs: ConstraintSystemRef, + transcript: ProofTranscript, } impl CryptographicSpongeVar> for MockSpongeVar where - ConstraintF: PrimeField, + ConstraintF: PrimeField + JoltField, { - type Parameters = (); + type Parameters = (&'static [u8]); fn new(cs: ConstraintSystemRef, params: &Self::Parameters) -> Self { Self { _params: PhantomData, cs, + transcript: ProofTranscript::new(params), } } @@ -77,7 +82,13 @@ where } fn absorb(&mut self, input: &impl AbsorbGadget) -> Result<(), SynthesisError> { - todo!() + let fs = input + .to_sponge_field_elements()? + .iter() + .map(|f| f.value()) + .collect::, _>>()?; + self.transcript.append_scalars(&fs); + Ok(()) } fn squeeze_bytes( @@ -98,6 +109,10 @@ where &mut self, num_elements: usize, ) -> Result>, SynthesisError> { - todo!() + self.transcript + .challenge_vector::(num_elements) + .iter() + .map(|&f| FpVar::new_witness(self.cs(), || Ok(f))) + .collect() } } From b64d33a9b65faeb4f60973b051d5eb5f732e2653 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 11 Aug 2024 11:51:43 -0700 Subject: [PATCH 25/44] Fix point assignment --- .../src/circuits/poly/commitment/hyperkzg.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 53b3c8bfd..76c592fbf 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -156,7 +156,7 @@ mod tests { { pcs_pk_vk: Option<(HyperKZGProverKey, HyperKZGVerifierKey)>, commitment: Option>, - point: Option>, + point: Vec>, eval: Option, pcs_proof: Option>, expected_result: Option, @@ -194,12 +194,13 @@ mod tests { || self.commitment.ok_or(SynthesisError::AssignmentMissing), )?; - let point_var = Vec::>::new_witness(ns!(cs, "point"), || { - if cs.is_in_setup_mode() { - return Ok(vec![]); // dummy value // TODO is there a better way? - } - self.point.ok_or(SynthesisError::AssignmentMissing) - })?; + let point_var = self + .point + .iter() + .map(|&x| { + FpVar::new_witness(ns!(cs, ""), || x.ok_or(SynthesisError::AssignmentMissing)) + }) + .collect::, _>>()?; let eval_var = FpVar::::new_witness(ns!(cs, "eval"), || { self.eval.ok_or(SynthesisError::AssignmentMissing) @@ -254,7 +255,7 @@ mod tests { let circuit = HyperKZGVerifierCircuit:: { pcs_pk_vk: None, commitment: None, - point: None, + point: vec![None, None], eval: None, pcs_proof: None, expected_result: None, @@ -278,7 +279,7 @@ mod tests { let verifier_circuit = HyperKZGVerifierCircuit:: { pcs_pk_vk: Some((pcs_pk.clone(), pcs_vk.clone())), commitment: Some(C.clone()), - point: Some(point.clone()), + point: point.into_iter().map(|x| Some(x)).collect(), eval: Some(eval), pcs_proof: Some(hkzg_proof), expected_result: Some(true), From 715349bee5c56472adcdb40c3e4be2439a32f5e0 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 11 Aug 2024 14:06:48 -0700 Subject: [PATCH 26/44] Return public input vector based on expected result --- jolt-core/src/circuits/poly/commitment/hyperkzg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 76c592fbf..8ec519774 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -167,7 +167,7 @@ mod tests { E: Pairing, { pub(crate) fn public_inputs(&self) -> Vec { - Boolean::::TRUE + Boolean::::constant(self.expected_result.unwrap()) // panics if None .to_constraint_field() .unwrap() .iter() From ee33022c3ccaf201de02b19781f9d3cb34fa07ae Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 13 Aug 2024 17:15:05 -0700 Subject: [PATCH 27/44] WIP: HyperKZG gadget --- .../src/circuits/poly/commitment/hyperkzg.rs | 204 ++++++++++++------ jolt-core/src/circuits/transcript/mock.rs | 25 ++- jolt-core/src/circuits/transcript/mod.rs | 72 +++++++ jolt-core/src/poly/commitment/hyperkzg.rs | 18 +- jolt-core/src/utils/transcript.rs | 2 +- 5 files changed, 247 insertions(+), 74 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 8ec519774..6e6361059 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,9 +1,12 @@ use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; +use crate::circuits::transcript::ImplAbsorb; use crate::field::JoltField; use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, }; -use ark_crypto_primitives::sponge::constraints::{CryptographicSpongeVar, SpongeWithGadget}; +use ark_crypto_primitives::sponge::constraints::{ + AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget, +}; use ark_crypto_primitives::sponge::CryptographicSponge; use ark_ec::pairing::Pairing; use ark_ff::{Field, PrimeField}; @@ -11,61 +14,81 @@ use ark_r1cs_std::boolean::Boolean; use ark_r1cs_std::fields::fp::FpVar; use ark_r1cs_std::pairing::PairingVar; use ark_r1cs_std::prelude::*; +use ark_r1cs_std::ToConstraintFieldGadget; +use ark_relations::ns; use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, SynthesisError}; use ark_std::borrow::Borrow; +use ark_std::iterable::Iterable; use ark_std::marker::PhantomData; #[derive(Clone)] -pub struct HyperKZGProofVar +pub struct HyperKZGProofVar where E: Pairing, - ConstraintF: PrimeField, + G1Var: CurveVar, { - _params: PhantomData<(E, ConstraintF)>, - // TODO fill in + pub com: Vec, + pub w: Vec, + pub v: Vec>>, } -impl AllocVar, ConstraintF> for HyperKZGProofVar +impl AllocVar, E::ScalarField> for HyperKZGProofVar where - E: Pairing, - ConstraintF: PrimeField, + E: Pairing, + G1Var: CurveVar, { fn new_variable>>( - cs: impl Into>, + cs: impl Into>, f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { - // TODO implement - Ok(Self { - _params: PhantomData, - }) + let ns = cs.into(); + let cs = ns.cs(); + + let proof_hold = f()?; + let proof = proof_hold.borrow(); + + let com = proof + .com + .iter() + .map(|&x| G1Var::new_variable(ns!(cs, "com").clone(), || Ok(x), mode)) + .collect::, _>>()?; + let w = proof + .w + .iter() + .map(|&x| G1Var::new_variable(ns!(cs, "w").clone(), || Ok(x), mode)) + .collect::, _>>()?; + let v = proof + .v + .iter() + .map(|v_i| { + v_i.iter() + .map(|&v_ij| FpVar::new_variable(ns!(cs, "v_ij"), || Ok(v_ij), mode)) + .collect::, _>>() + }) + .collect::, _>>()?; + + Ok(Self { com, w, v }) } } #[derive(Clone, Debug)] -pub struct HyperKZGCommitmentVar -where - E: Pairing, - ConstraintF: PrimeField, -{ - _params: PhantomData<(E, ConstraintF)>, - // TODO fill in +pub struct HyperKZGCommitmentVar { + pub c: G1Var, } -impl AllocVar, ConstraintF> - for HyperKZGCommitmentVar +impl AllocVar, E::ScalarField> for HyperKZGCommitmentVar where - E: Pairing, - ConstraintF: PrimeField, + E: Pairing, + G1Var: CurveVar, { fn new_variable>>( - cs: impl Into>, + cs: impl Into>, f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { - // TODO implement Ok(Self { - _params: PhantomData, + c: G1Var::new_variable(cs, || Ok(f()?.borrow().0), mode)?, }) } } @@ -98,25 +121,26 @@ where } } -pub struct HyperKZGVerifierGadget +pub struct HyperKZGVerifierGadget where - E: Pairing, - ConstraintF: PrimeField + JoltField, - S: SpongeWithGadget, + E: Pairing, + S: SpongeWithGadget, + G1Var: CurveVar + ToConstraintFieldGadget, { - _params: PhantomData<(E, ConstraintF, S)>, + _params: PhantomData<(E, S, G1Var)>, } -impl CommitmentVerifierGadget, S> - for HyperKZGVerifierGadget +impl CommitmentVerifierGadget, S> + for HyperKZGVerifierGadget where E: Pairing, ConstraintF: PrimeField + JoltField, S: SpongeWithGadget, + G1Var: CurveVar + ToConstraintFieldGadget, { type VerifyingKeyVar = HyperKZGVerifierKeyVar; - type ProofVar = HyperKZGProofVar; - type CommitmentVar = HyperKZGCommitmentVar; + type ProofVar = HyperKZGProofVar; + type CommitmentVar = HyperKZGCommitmentVar; fn verify( proof: &Self::ProofVar, @@ -126,6 +150,49 @@ where opening: &FpVar, commitment: &Self::CommitmentVar, ) -> Result, SynthesisError> { + let ell = opening_point.len(); + + transcript.absorb( + &proof + .com + .iter() + .map(|com| ImplAbsorb::wrap(com)) + .collect::>(), + )?; + + let r = transcript + .squeeze_field_elements(1)? + .into_iter() + .next() + .unwrap(); + + let u = vec![r.clone(), r.negate()?, r.clone() * &r]; + + let com = [vec![commitment.c.clone()], proof.com.clone()].concat(); + + let v = &proof.v; + if v.len() != 3 { + return Err(SynthesisError::Unsatisfiable); + } + if ell != v[0].len() || ell != v[1].len() || ell != v[2].len() { + return Err(SynthesisError::Unsatisfiable); + } + + let x = opening_point; + let y = [v[2].clone(), vec![opening.clone()]].concat(); + + let one = FpVar::Constant(E::ScalarField::one()); + let two = FpVar::Constant(E::ScalarField::from(2u128)); + for i in 0..ell { + (&two * &r * &y[i + 1]).enforce_equal( + &(&r * (&one - &x[ell - i - 1]) * (&v[0][i] + &v[1][i]) + + &x[ell - i - 1] * (&v[0][i] - &v[1][i])), + )?; + } + + dbg!(); + + // TODO implement Ok(Boolean::TRUE) } } @@ -133,6 +200,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::transcript::mock::{MockSponge, MockSpongeVar}; use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, @@ -148,23 +216,28 @@ mod tests { use ark_r1cs_std::ToConstraintFieldGadget; use ark_relations::ns; use ark_std::rand::Rng; + use ark_std::Zero; use rand_core::{CryptoRng, RngCore, SeedableRng}; - struct HyperKZGVerifierCircuit + #[derive(Debug)] + struct HyperKZGVerifierCircuit where E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, { + _params: PhantomData, pcs_pk_vk: Option<(HyperKZGProverKey, HyperKZGVerifierKey)>, commitment: Option>, point: Vec>, eval: Option, - pcs_proof: Option>, + pcs_proof: HyperKZGProof, expected_result: Option, } - impl HyperKZGVerifierCircuit + impl HyperKZGVerifierCircuit where E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, { pub(crate) fn public_inputs(&self) -> Vec { Boolean::::constant(self.expected_result.unwrap()) // panics if None @@ -176,9 +249,10 @@ mod tests { } } - impl ConstraintSynthesizer for HyperKZGVerifierCircuit + impl ConstraintSynthesizer for HyperKZGVerifierCircuit where - E: Pairing, + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, { fn generate_constraints( self, @@ -189,10 +263,10 @@ mod tests { self.pcs_pk_vk.ok_or(SynthesisError::AssignmentMissing) })?; - let commitment_var = HyperKZGCommitmentVar::::new_witness( - ns!(cs, "commitment"), - || self.commitment.ok_or(SynthesisError::AssignmentMissing), - )?; + let commitment_var = + HyperKZGCommitmentVar::::new_witness(ns!(cs, "commitment"), || { + self.commitment.ok_or(SynthesisError::AssignmentMissing) + })?; let point_var = self .point @@ -207,22 +281,19 @@ mod tests { })?; let proof_var = - HyperKZGProofVar::::new_witness(ns!(cs, "proof"), || { - self.pcs_proof.ok_or(SynthesisError::AssignmentMissing) - })?; + HyperKZGProofVar::::new_witness(ns!(cs, "proof"), || Ok(self.pcs_proof))?; let mut transcript_var = MockSpongeVar::new(ns!(cs, "transcript").cs(), &(b"TestEval".as_slice())); - let r = - HyperKZGVerifierGadget::>::verify( - &proof_var, - &vk_var, - &mut transcript_var, - &point_var, - &eval_var, - &commitment_var, - )?; + let r = HyperKZGVerifierGadget::, G1Var>::verify( + &proof_var, + &vk_var, + &mut transcript_var, + &point_var, + &eval_var, + &commitment_var, + )?; let r_input = Boolean::new_input(ns!(cs, "verification_result"), || { self.expected_result @@ -251,13 +322,15 @@ mod tests { ark_bn254::Fr::from(4), ]); + let size = 2usize; let (cpk, cvk) = { - let circuit = HyperKZGVerifierCircuit:: { + let circuit = HyperKZGVerifierCircuit:: { + _params: PhantomData, pcs_pk_vk: None, commitment: None, - point: vec![None, None], + point: vec![None; size], eval: None, - pcs_proof: None, + pcs_proof: HyperKZGProof::empty(size), expected_result: None, }; @@ -271,17 +344,21 @@ mod tests { |point: Vec, eval: ark_bn254::Fr| -> Result<(), ProofVerifyError> { let mut tr = ProofTranscript::new(b"TestEval"); let hkzg_proof = HyperKZG::open(&pcs_pk, &poly, &point, &eval, &mut tr).unwrap(); + + println!("Verifying natively..."); + let mut tr = ProofTranscript::new(b"TestEval"); HyperKZG::verify(&pcs_vk, &C, &point, &eval, &hkzg_proof, &mut tr)?; // Create an instance of our circuit (with the // witness) - let verifier_circuit = HyperKZGVerifierCircuit:: { + let verifier_circuit = HyperKZGVerifierCircuit:: { + _params: PhantomData, pcs_pk_vk: Some((pcs_pk.clone(), pcs_vk.clone())), commitment: Some(C.clone()), point: point.into_iter().map(|x| Some(x)).collect(), eval: Some(eval), - pcs_proof: Some(hkzg_proof), + pcs_proof: hkzg_proof, expected_result: Some(true), }; let instance = verifier_circuit.public_inputs(); @@ -289,9 +366,12 @@ mod tests { let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(ark_std::test_rng().next_u64()); + println!("Verifying in-circuit..."); + // Create a groth16 proof with our parameters. let proof = Groth16::prove(&cpk, verifier_circuit, &mut rng) .map_err(|e| ProofVerifyError::InternalError)?; + let result = Groth16::verify_with_processed_vk(&pvk, &instance, &proof); match result { Ok(true) => Ok(()), diff --git a/jolt-core/src/circuits/transcript/mock.rs b/jolt-core/src/circuits/transcript/mock.rs index 2054fb5ef..d78708f06 100644 --- a/jolt-core/src/circuits/transcript/mock.rs +++ b/jolt-core/src/circuits/transcript/mock.rs @@ -1,3 +1,4 @@ +use crate::circuits::transcript::IS_SLICE; use crate::field::JoltField; use crate::utils::transcript::ProofTranscript; use ark_crypto_primitives::sponge::constraints::{ @@ -10,6 +11,7 @@ use ark_r1cs_std::fields::fp::FpVar; use ark_r1cs_std::prelude::*; use ark_r1cs_std::R1CSVar; use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; +use ark_std::any::Any; use std::marker::PhantomData; #[derive(Clone)] @@ -57,9 +59,8 @@ pub struct MockSpongeVar where ConstraintF: PrimeField, { - _params: PhantomData, cs: ConstraintSystemRef, - transcript: ProofTranscript, + pub transcript: ProofTranscript, } impl CryptographicSpongeVar> @@ -71,7 +72,6 @@ where fn new(cs: ConstraintSystemRef, params: &Self::Parameters) -> Self { Self { - _params: PhantomData, cs, transcript: ProofTranscript::new(params), } @@ -82,12 +82,22 @@ where } fn absorb(&mut self, input: &impl AbsorbGadget) -> Result<(), SynthesisError> { - let fs = input - .to_sponge_field_elements()? + let bytes = input.to_sponge_bytes()?; + let is_slice = IS_SLICE.take(); + let fs = bytes .iter() - .map(|f| f.value()) + .map(|f| match self.cs.is_in_setup_mode() { + true => Ok(0u8), + false => f.value(), + }) .collect::, _>>()?; - self.transcript.append_scalars(&fs); + if is_slice { + self.transcript.append_message(b"begin_append_vector"); + } + self.transcript.append_bytes(&fs); + if is_slice { + self.transcript.append_message(b"end_append_vector"); + } Ok(()) } @@ -109,6 +119,7 @@ where &mut self, num_elements: usize, ) -> Result>, SynthesisError> { + dbg!(&self.transcript.n_rounds); self.transcript .challenge_vector::(num_elements) .iter() diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs index 9afc1d5e9..146414770 100644 --- a/jolt-core/src/circuits/transcript/mod.rs +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -1 +1,73 @@ +use crate::utils::transcript::ProofTranscript; +use ark_crypto_primitives::sponge::constraints::AbsorbGadget; +use ark_ff::PrimeField; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::prelude::*; +use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; +use ark_relations::ns; +use ark_relations::r1cs::SynthesisError; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_std::Zero; +use std::cell::RefCell; +use std::fmt::Debug; +use std::marker::PhantomData; + pub mod mock; + +pub struct ImplAbsorb<'a, T, F>(&'a T, PhantomData) +where + T: R1CSVar, + F: PrimeField; + +impl<'a, T, F> ImplAbsorb<'a, T, F> +where + T: R1CSVar, + F: PrimeField, +{ + pub fn wrap(t: &'a T) -> Self { + Self(t, PhantomData) + } +} + +thread_local! { + static IS_SLICE: RefCell = RefCell::new(false); +} + +impl<'a, T, F> AbsorbGadget for ImplAbsorb<'a, T, F> +where + T: R1CSVar + Debug, + F: PrimeField, +{ + fn to_sponge_bytes(&self) -> Result>, SynthesisError> { + let mut buf = vec![]; + + let t_value = match self.0.cs().is_in_setup_mode() { + true => T::Value::zero(), + false => self.0.value()?, + }; + + t_value + .serialize_compressed(&mut buf) + .map_err(|e| SynthesisError::Unsatisfiable)?; + + buf.into_iter() + .map(|b| UInt8::new_witness(ns!(self.0.cs(), "sponge_byte"), || Ok(b))) + .collect::, _>>() + } + + fn batch_to_sponge_bytes(batch: &[Self]) -> Result>, SynthesisError> + where + Self: Sized, + { + IS_SLICE.set(true); + let mut result = Vec::new(); + for item in batch { + result.append(&mut (item.to_sponge_bytes()?)) + } + Ok(result) + } + + fn to_sponge_field_elements(&self) -> Result>, SynthesisError> { + unimplemented!("should not be called") + } +} diff --git a/jolt-core/src/poly/commitment/hyperkzg.rs b/jolt-core/src/poly/commitment/hyperkzg.rs index ad6bd0905..d55240219 100644 --- a/jolt-core/src/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/poly/commitment/hyperkzg.rs @@ -59,7 +59,7 @@ pub struct HyperKZGVerifierKey { } #[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] -pub struct HyperKZGCommitment(P::G1Affine); +pub struct HyperKZGCommitment(pub P::G1Affine); impl AppendToTranscript for HyperKZGCommitment

{ fn append_to_transcript(&self, transcript: &mut ProofTranscript) { @@ -69,9 +69,19 @@ impl AppendToTranscript for HyperKZGCommitment

{ #[derive(Clone, CanonicalSerialize, CanonicalDeserialize, Debug)] pub struct HyperKZGProof { - com: Vec, - w: Vec, - v: Vec>, + pub com: Vec, + pub w: Vec, + pub v: Vec>, +} + +impl HyperKZGProof

{ + pub fn empty(size: usize) -> Self { + Self { + com: vec![P::G1Affine::zero(); size - 1], + w: vec![P::G1Affine::zero(); 3], + v: vec![vec![P::ScalarField::zero(); size]; 3], + } + } } // On input f(x) and u compute the witness polynomial used to prove diff --git a/jolt-core/src/utils/transcript.rs b/jolt-core/src/utils/transcript.rs index df7c6e865..faf498366 100644 --- a/jolt-core/src/utils/transcript.rs +++ b/jolt-core/src/utils/transcript.rs @@ -7,7 +7,7 @@ pub struct ProofTranscript { // Ethereum compatible 256 bit running state state: [u8; 32], // We append an ordinal to each invoke of the hash - n_rounds: u32, + pub n_rounds: u32, } impl ProofTranscript { From e6b81aed743aae099adc1024d726496983db4fd2 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 13 Aug 2024 19:59:28 -0700 Subject: [PATCH 28/44] WIP: HyperKZG gadget --- .../src/circuits/poly/commitment/hyperkzg.rs | 70 +++++++++++++++++-- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 6e6361059..2e2f1298a 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -20,6 +20,7 @@ use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, use ark_std::borrow::Borrow; use ark_std::iterable::Iterable; use ark_std::marker::PhantomData; +use ark_std::One; #[derive(Clone)] pub struct HyperKZGProofVar @@ -152,10 +153,11 @@ where ) -> Result, SynthesisError> { let ell = opening_point.len(); + let HyperKZGProofVar { com, w, v } = proof; + let HyperKZGCommitmentVar { c } = commitment; + transcript.absorb( - &proof - .com - .iter() + &com.iter() .map(|com| ImplAbsorb::wrap(com)) .collect::>(), )?; @@ -168,12 +170,14 @@ where let u = vec![r.clone(), r.negate()?, r.clone() * &r]; - let com = [vec![commitment.c.clone()], proof.com.clone()].concat(); + let com = [vec![c.clone()], com.clone()].concat(); - let v = &proof.v; if v.len() != 3 { return Err(SynthesisError::Unsatisfiable); } + if w.len() != 3 { + return Err(SynthesisError::Unsatisfiable); + } if ell != v[0].len() || ell != v[1].len() || ell != v[2].len() { return Err(SynthesisError::Unsatisfiable); } @@ -190,6 +194,37 @@ where )?; } + // kzg_verify_batch + + transcript.absorb(&v.iter().flatten().cloned().collect::>())?; + let q_powers = q_powers::(transcript, ell)?; + + transcript.absorb( + &proof + .w + .iter() + .map(|g| ImplAbsorb::wrap(g)) + .collect::>(), + )?; + let d = transcript + .squeeze_field_elements(1)? + .into_iter() + .next() + .unwrap(); + + let q_power_multiplier = one + &d + &d.square()?; + + let b_u_i = v + .iter() + .map(|v_i| { + let mut b_u_i = v_i[0].clone(); + for i in 1..ell { + b_u_i += &q_powers[i] * &v_i[i]; + } + b_u_i + }) + .collect::>(); + dbg!(); // TODO implement @@ -197,6 +232,29 @@ where } } +fn q_powers>( + transcript: &mut S::Var, + ell: usize, +) -> Result>, SynthesisError> { + let q = transcript + .squeeze_field_elements(1)? + .into_iter() + .next() + .unwrap(); + + let q_powers = [vec![FpVar::Constant(E::ScalarField::one()), q.clone()], { + let mut q_power = q.clone(); + (1..ell) + .map(|i| { + q_power *= &q; + q_power.clone() + }) + .collect() + }] + .concat(); + Ok(q_powers) +} + #[cfg(test)] mod tests { use super::*; @@ -301,6 +359,8 @@ mod tests { })?; r.enforce_equal(&r_input)?; + dbg!(cs.num_constraints()); + Ok(()) } } From f988dde376984569da3c11ed808ca5cf5ed49240 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Wed, 14 Aug 2024 13:51:24 -0700 Subject: [PATCH 29/44] WIP: HyperKZG gadget: adjust type parameters in OffloadedDataCircuit --- jolt-core/src/circuits/groups/curves/mod.rs | 6 +- jolt-core/src/circuits/offloaded/mod.rs | 26 ++++---- .../src/circuits/poly/commitment/hyperkzg.rs | 16 ++++- jolt-core/src/snark/mod.rs | 64 +++++++------------ 4 files changed, 55 insertions(+), 57 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 96c4e628f..7abd3fa46 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -49,7 +49,7 @@ mod tests { d: Option, // deferred fns to write offloaded data to public_input - deferred_fns_ref: DeferredFnsRef, + deferred_fns_ref: DeferredFnsRef, } impl ConstraintSynthesizer for DelayedOpsCircuit @@ -89,12 +89,12 @@ mod tests { } } - impl OffloadedDataCircuit for DelayedOpsCircuit + impl OffloadedDataCircuit for DelayedOpsCircuit where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { - fn deferred_fns_ref(&self) -> &DeferredFnsRef { + fn deferred_fns_ref(&self) -> &DeferredFnsRef { &self.deferred_fns_ref } } diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index 6ab7a81ae..90198a420 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -12,31 +12,31 @@ use ark_relations::ns; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use std::marker::PhantomData; -pub struct OffloadedMSMGadget +pub struct OffloadedMSMGadget where - E: Pairing, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, ConstraintF: PrimeField, FVar: FieldVar + ToConstraintFieldGadget, - G1Var: CurveVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, { - _params: PhantomData<(E, ConstraintF, FVar, G1Var, Circuit)>, + _params: PhantomData<(ConstraintF, FVar, G, GVar, Circuit)>, } -impl OffloadedMSMGadget +impl OffloadedMSMGadget where - E: Pairing, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, ConstraintF: PrimeField, FVar: FieldVar + ToConstraintFieldGadget, - G1Var: CurveVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, { pub fn msm( circuit: &Circuit, cs: impl Into>, - g1s: &[G1Var], + g1s: &[GVar], scalars: &[FVar], - ) -> Result { + ) -> Result { let ns = cs.into(); let cs = ns.cs(); @@ -53,7 +53,7 @@ where let (full_msm_value, msm_g1_value) = g1_values .zip(scalar_values) .map(|(g1s, scalars)| { - let r_g1 = E::G1::msm_unchecked(&g1s, &scalars); + let r_g1 = G::msm_unchecked(&g1s, &scalars); let minus_one = -ConstraintF::one(); ( ( @@ -65,7 +65,7 @@ where }) .unzip(); - let msm_g1_var = G1Var::new_witness(ns!(cs, "msm_g1"), || { + let msm_g1_var = GVar::new_witness(ns!(cs, "msm_g1"), || { msm_g1_value.ok_or(SynthesisError::AssignmentMissing) })?; diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 2e2f1298a..cdaa6663f 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -4,6 +4,7 @@ use crate::field::JoltField; use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, }; +use crate::snark::OffloadedDataCircuit; use ark_crypto_primitives::sponge::constraints::{ AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget, }; @@ -264,6 +265,7 @@ mod tests { HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, }; use crate::poly::dense_mlpoly::DensePolynomial; + use crate::snark::{DeferredFnsRef, OffloadedDataCircuit}; use crate::utils::errors::ProofVerifyError; use crate::utils::transcript::ProofTranscript; use ark_bn254::Bn254; @@ -277,13 +279,13 @@ mod tests { use ark_std::Zero; use rand_core::{CryptoRng, RngCore, SeedableRng}; - #[derive(Debug)] struct HyperKZGVerifierCircuit where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { _params: PhantomData, + deferred_fns_ref: DeferredFnsRef, pcs_pk_vk: Option<(HyperKZGProverKey, HyperKZGVerifierKey)>, commitment: Option>, point: Vec>, @@ -292,6 +294,16 @@ mod tests { expected_result: Option, } + impl OffloadedDataCircuit for HyperKZGVerifierCircuit + where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, + { + fn deferred_fns_ref(&self) -> &DeferredFnsRef { + &self.deferred_fns_ref + } + } + impl HyperKZGVerifierCircuit where E: Pairing, @@ -386,6 +398,7 @@ mod tests { let (cpk, cvk) = { let circuit = HyperKZGVerifierCircuit:: { _params: PhantomData, + deferred_fns_ref: Default::default(), pcs_pk_vk: None, commitment: None, point: vec![None; size], @@ -414,6 +427,7 @@ mod tests { // witness) let verifier_circuit = HyperKZGVerifierCircuit:: { _params: PhantomData, + deferred_fns_ref: Default::default(), pcs_pk_vk: Some((pcs_pk.clone(), pcs_vk.clone())), commitment: Some(C.clone()), point: point.into_iter().map(|x| Some(x)).collect(), diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index c25a17997..640132ac8 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -2,7 +2,7 @@ use ark_crypto_primitives::snark::SNARK; use ark_ec::{ pairing::Pairing, short_weierstrass::{Affine, SWCurveConfig}, - AffineRepr, VariableBaseMSM, + AffineRepr, CurveGroup, VariableBaseMSM, }; use ark_ff::{PrimeField, Zero}; use ark_r1cs_std::{ @@ -48,43 +48,32 @@ where S: SNARK, { pub snark_proof: S::Proof, - pub offloaded_data: OffloadedData, + pub offloaded_data: OffloadedData, } #[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] -pub struct OffloadedData { +pub struct OffloadedData { /// Blocks of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. /// It's the verifiers responsibility to ensure that the sum is zero. /// The scalar at index `length-1` is, by convention, always `-1`, so /// we save one public input element per MSM. - pub msms: Vec<(Vec, Vec)>, + pub msms: Vec<(Vec, Vec)>, } -pub type DeferredFn = - dyn FnOnce() -> Result, Vec)>, SynthesisError>; - -pub type DeferredFnsRef = Rc< - RefCell< - Vec< - Box< - dyn FnOnce() -> Result< - Option<(Vec, Vec)>, - SynthesisError, - >, - >, - >, - >, ->; +pub type DeferredFn = + dyn FnOnce() -> Result, Vec)>, SynthesisError>; + +pub type DeferredFnsRef = Rc>>>>; -pub trait OffloadedDataCircuit +pub trait OffloadedDataCircuit where - E: Pairing, + G: CurveGroup, { - fn deferred_fns_ref(&self) -> &DeferredFnsRef; + fn deferred_fns_ref(&self) -> &DeferredFnsRef; fn defer_msm( &self, - f: impl FnOnce() -> Result, Vec)>, SynthesisError> + f: impl FnOnce() -> Result, Vec)>, SynthesisError> + 'static, ) { self.deferred_fns_ref().borrow_mut().push(Box::new(f)); @@ -106,15 +95,13 @@ where SynthesisError(#[from] SynthesisError), } -struct WrappedCircuit +struct WrappedCircuit where - E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, - P: SWCurveConfig, - C: ConstraintSynthesizer + OffloadedDataCircuit, + E: Pairing, + C: ConstraintSynthesizer + OffloadedDataCircuit, { - _params: PhantomData<(E, P)>, circuit: C, - offloaded_data_ref: Rc>>, + offloaded_data_ref: Rc>>, } fn run_deferred( @@ -123,7 +110,7 @@ fn run_deferred( dyn FnOnce() -> Result, Vec)>, SynthesisError>, >, >, -) -> Result>, SynthesisError> { +) -> Result>, SynthesisError> { let msms = deferred_fns .into_iter() .map(|f| f()) @@ -132,11 +119,10 @@ fn run_deferred( Ok(msms.map(|msms| OffloadedData { msms })) } -impl ConstraintSynthesizer for WrappedCircuit +impl ConstraintSynthesizer for WrappedCircuit where - E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, - P: SWCurveConfig, - C: ConstraintSynthesizer + OffloadedDataCircuit, + E: Pairing, + C: ConstraintSynthesizer + OffloadedDataCircuit, { fn generate_constraints(self, cs: ConstraintSystemRef) -> r1cs::Result<()> { let deferred_fns_ref = self.circuit.deferred_fns_ref().clone(); @@ -160,15 +146,14 @@ where S: SNARK, G1Var: CurveVar + ToConstraintFieldGadget, { - type Circuit: ConstraintSynthesizer + OffloadedDataCircuit; + type Circuit: ConstraintSynthesizer + OffloadedDataCircuit; fn setup( circuit: Self::Circuit, rng: &mut R, ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> { - let circuit = WrappedCircuit { - _params: PhantomData, + let circuit: WrappedCircuit = WrappedCircuit { circuit, offloaded_data_ref: Default::default(), }; @@ -199,8 +184,7 @@ where circuit: Self::Circuit, rng: &mut R, ) -> Result, OffloadedSNARKError> { - let circuit = WrappedCircuit { - _params: PhantomData, + let circuit: WrappedCircuit = WrappedCircuit { circuit, offloaded_data_ref: Default::default(), }; @@ -354,7 +338,7 @@ where fn build_public_input( public_input: &[E::ScalarField], - data: &OffloadedData, + data: &OffloadedData, ) -> Vec where E: Pairing, From 23b474ffcd864f58dc5fbf3b7bca3006a20373cb Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Wed, 14 Aug 2024 14:35:40 -0700 Subject: [PATCH 30/44] WIP: HyperKZG gadget: factor out MSMGadget trait --- jolt-core/src/circuits/groups/curves/mod.rs | 9 ++- jolt-core/src/circuits/offloaded/mod.rs | 60 ++++++++++++++----- .../poly/commitment/commitment_scheme.rs | 13 ++-- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 7abd3fa46..1884fb55f 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -5,7 +5,7 @@ mod tests { use super::*; use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; - use crate::circuits::offloaded::OffloadedMSMGadget; + use crate::circuits::offloaded::{MSMGadget, OffloadedMSMGadget}; use crate::snark::{ DeferredFnsRef, OffloadedData, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, OffloadedSNARKVerifyingKey, @@ -81,8 +81,11 @@ mod tests { let d_k = [FpVar::one(), d, d_square]; dbg!(cs.num_constraints()); - let _ = - OffloadedMSMGadget::msm(&self, ns!(cs, "msm"), w_g1.as_slice(), d_k.as_slice())?; + let _ = OffloadedMSMGadget::new(self).msm( + ns!(cs, "msm"), + w_g1.as_slice(), + d_k.as_slice(), + )?; dbg!(cs.num_constraints()); Ok(()) diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index 90198a420..d4975f743 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -12,28 +12,42 @@ use ark_relations::ns; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use std::marker::PhantomData; -pub struct OffloadedMSMGadget +pub struct OffloadedMSMGadget where Circuit: OffloadedDataCircuit, - ConstraintF: PrimeField, - FVar: FieldVar + ToConstraintFieldGadget, - G: CurveGroup, - GVar: CurveVar + ToConstraintFieldGadget, + FVar: FieldVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, { - _params: PhantomData<(ConstraintF, FVar, G, GVar, Circuit)>, + _params: PhantomData<(FVar, G, GVar)>, + circuit: Circuit, } -impl OffloadedMSMGadget +impl OffloadedMSMGadget where Circuit: OffloadedDataCircuit, - ConstraintF: PrimeField, - FVar: FieldVar + ToConstraintFieldGadget, - G: CurveGroup, - GVar: CurveVar + ToConstraintFieldGadget, + FVar: FieldVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, { - pub fn msm( - circuit: &Circuit, - cs: impl Into>, + pub fn new(circuit: Circuit) -> Self { + Self { + _params: PhantomData, + circuit, + } + } +} + +impl MSMGadget for OffloadedMSMGadget +where + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, +{ + fn msm( + &self, + cs: impl Into>, g1s: &[GVar], scalars: &[FVar], ) -> Result { @@ -54,7 +68,7 @@ where .zip(scalar_values) .map(|(g1s, scalars)| { let r_g1 = G::msm_unchecked(&g1s, &scalars); - let minus_one = -ConstraintF::one(); + let minus_one = -G::ScalarField::one(); ( ( [g1s, vec![r_g1.into()]].concat(), @@ -76,7 +90,7 @@ where let ns = ns!(cs, "deferred_msm"); let cs = ns.cs(); - circuit.defer_msm(move || { + self.circuit.defer_msm(move || { // write scalars to public_input for x in scalars { let scalar_input = FVar::new_input(ns!(cs, "scalar"), || x.value())?; @@ -113,3 +127,17 @@ where Ok(msm_g1_var) } } + +pub trait MSMGadget +where + FVar: FieldVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, +{ + fn msm( + &self, + cs: impl Into>, + g1s: &[GVar], + scalars: &[FVar], + ) -> Result; +} diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs index 2b716828f..e59cd586a 100644 --- a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -1,3 +1,4 @@ +use crate::poly::commitment::commitment_scheme::CommitmentScheme; use ark_crypto_primitives::sponge::constraints::SpongeWithGadget; use ark_ec::pairing::Pairing; use ark_ff::PrimeField; @@ -5,17 +6,15 @@ use ark_r1cs_std::fields::fp::FpVar; use ark_r1cs_std::prelude::*; use ark_relations::r1cs::SynthesisError; -use crate::poly::commitment::commitment_scheme::CommitmentScheme; - -pub trait CommitmentVerifierGadget +pub trait CommitmentVerifierGadget where ConstraintF: PrimeField, - C: CommitmentScheme, + CS: CommitmentScheme, S: SpongeWithGadget, { - type VerifyingKeyVar: AllocVar + Clone; - type ProofVar: AllocVar + Clone; - type CommitmentVar: AllocVar + Clone; + type VerifyingKeyVar: AllocVar + Clone; + type ProofVar: AllocVar + Clone; + type CommitmentVar: AllocVar + Clone; fn verify( proof: &Self::ProofVar, From 395c891785e35c9abc34758afbabbd15597f6025 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Wed, 14 Aug 2024 16:13:17 -0700 Subject: [PATCH 31/44] WIP: HyperKZG gadget: prep to use OffloadedMSMGadget --- jolt-core/src/circuits/groups/curves/mod.rs | 2 +- jolt-core/src/circuits/offloaded/mod.rs | 11 ++-- .../poly/commitment/commitment_scheme.rs | 1 + .../src/circuits/poly/commitment/hyperkzg.rs | 64 ++++++++++++++----- 4 files changed, 57 insertions(+), 21 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 1884fb55f..eaec2f961 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -81,7 +81,7 @@ mod tests { let d_k = [FpVar::one(), d, d_square]; dbg!(cs.num_constraints()); - let _ = OffloadedMSMGadget::new(self).msm( + let _ = OffloadedMSMGadget::new(&self).msm( ns!(cs, "msm"), w_g1.as_slice(), d_k.as_slice(), diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index d4975f743..0cc2e8d95 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -12,7 +12,7 @@ use ark_relations::ns; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use std::marker::PhantomData; -pub struct OffloadedMSMGadget +pub struct OffloadedMSMGadget<'a, FVar, G, GVar, Circuit> where Circuit: OffloadedDataCircuit, FVar: FieldVar + ToConstraintFieldGadget, @@ -20,17 +20,17 @@ where GVar: CurveVar + ToConstraintFieldGadget, { _params: PhantomData<(FVar, G, GVar)>, - circuit: Circuit, + circuit: &'a Circuit, } -impl OffloadedMSMGadget +impl<'a, FVar, G, GVar, Circuit> OffloadedMSMGadget<'a, FVar, G, GVar, Circuit> where Circuit: OffloadedDataCircuit, FVar: FieldVar + ToConstraintFieldGadget, G: CurveGroup, GVar: CurveVar + ToConstraintFieldGadget, { - pub fn new(circuit: Circuit) -> Self { + pub fn new(circuit: &'a Circuit) -> Self { Self { _params: PhantomData, circuit, @@ -38,7 +38,8 @@ where } } -impl MSMGadget for OffloadedMSMGadget +impl<'a, FVar, G, GVar, Circuit> MSMGadget + for OffloadedMSMGadget<'a, FVar, G, GVar, Circuit> where Circuit: OffloadedDataCircuit, FVar: FieldVar + ToConstraintFieldGadget, diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs index e59cd586a..e5d9061d0 100644 --- a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -17,6 +17,7 @@ where type CommitmentVar: AllocVar + Clone; fn verify( + &self, proof: &Self::ProofVar, vk: &Self::VerifyingKeyVar, transcript: &mut S::Var, diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index cdaa6663f..bb01f1ead 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,3 +1,4 @@ +use crate::circuits::offloaded::OffloadedMSMGadget; use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; use crate::circuits::transcript::ImplAbsorb; use crate::field::JoltField; @@ -123,35 +124,53 @@ where } } -pub struct HyperKZGVerifierGadget +pub struct HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> where E: Pairing, S: SpongeWithGadget, G1Var: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, { _params: PhantomData<(E, S, G1Var)>, + circuit: &'a Circuit, } -impl CommitmentVerifierGadget, S> - for HyperKZGVerifierGadget +impl<'a, E, S, G1Var, Circuit> HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> where - E: Pairing, - ConstraintF: PrimeField + JoltField, - S: SpongeWithGadget, + E: Pairing, + S: SpongeWithGadget, + G1Var: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, +{ + pub fn new(circuit: &'a Circuit) -> Self { + Self { + _params: PhantomData, + circuit, + } + } +} + +impl<'a, E, S, G1Var, Circuit> CommitmentVerifierGadget, S> + for HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> +where + E: Pairing, + S: SpongeWithGadget, G1Var: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, { - type VerifyingKeyVar = HyperKZGVerifierKeyVar; + type VerifyingKeyVar = HyperKZGVerifierKeyVar; type ProofVar = HyperKZGProofVar; type CommitmentVar = HyperKZGCommitmentVar; fn verify( + &self, proof: &Self::ProofVar, vk: &Self::VerifyingKeyVar, transcript: &mut S::Var, - opening_point: &[FpVar], - opening: &FpVar, + opening_point: &[FpVar], + opening: &FpVar, commitment: &Self::CommitmentVar, - ) -> Result, SynthesisError> { + ) -> Result, SynthesisError> { let ell = opening_point.len(); let HyperKZGProofVar { com, w, v } = proof; @@ -226,6 +245,9 @@ where }) .collect::>(); + let msm_gadget = + OffloadedMSMGadget::, E::G1, G1Var, Circuit>::new(self.circuit); + dbg!(); // TODO implement @@ -330,12 +352,16 @@ mod tests { ) -> Result<(), SynthesisError> { let vk_var = HyperKZGVerifierKeyVar::::new_witness(ns!(cs, "vk"), || { - self.pcs_pk_vk.ok_or(SynthesisError::AssignmentMissing) + self.pcs_pk_vk + .clone() + .ok_or(SynthesisError::AssignmentMissing) })?; let commitment_var = HyperKZGCommitmentVar::::new_witness(ns!(cs, "commitment"), || { - self.commitment.ok_or(SynthesisError::AssignmentMissing) + self.commitment + .clone() + .ok_or(SynthesisError::AssignmentMissing) })?; let point_var = self @@ -350,13 +376,21 @@ mod tests { self.eval.ok_or(SynthesisError::AssignmentMissing) })?; - let proof_var = - HyperKZGProofVar::::new_witness(ns!(cs, "proof"), || Ok(self.pcs_proof))?; + let proof_var = HyperKZGProofVar::::new_witness(ns!(cs, "proof"), || { + Ok(self.pcs_proof.clone()) + })?; let mut transcript_var = MockSpongeVar::new(ns!(cs, "transcript").cs(), &(b"TestEval".as_slice())); - let r = HyperKZGVerifierGadget::, G1Var>::verify( + let h_kzg = HyperKZGVerifierGadget::< + E, + MockSponge, + G1Var, + HyperKZGVerifierCircuit, + >::new(&self); + + let r = h_kzg.verify( &proof_var, &vk_var, &mut transcript_var, From b475d8415fb849bd8f91b2af1a7b1b4032d75747 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Wed, 14 Aug 2024 16:32:06 -0700 Subject: [PATCH 32/44] WIP: HyperKZG verifier test: migrate to OffloadedSNARK --- .../src/circuits/poly/commitment/hyperkzg.rs | 51 +++++++++++++++++-- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index bb01f1ead..cc00d4677 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -287,7 +287,10 @@ mod tests { HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, }; use crate::poly::dense_mlpoly::DensePolynomial; - use crate::snark::{DeferredFnsRef, OffloadedDataCircuit}; + use crate::snark::{ + DeferredFnsRef, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, + OffloadedSNARKVerifyingKey, + }; use crate::utils::errors::ProofVerifyError; use crate::utils::transcript::ProofTranscript; use ark_bn254::Bn254; @@ -295,8 +298,10 @@ mod tests { use ark_crypto_primitives::sponge::constraints::CryptographicSpongeVar; use ark_crypto_primitives::sponge::poseidon::constraints::PoseidonSpongeVar; use ark_crypto_primitives::sponge::poseidon::{PoseidonConfig, PoseidonDefaultConfigField}; + use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; use ark_r1cs_std::ToConstraintFieldGadget; use ark_relations::ns; + use ark_serialize::SerializationError; use ark_std::rand::Rng; use ark_std::Zero; use rand_core::{CryptoRng, RngCore, SeedableRng}; @@ -411,9 +416,46 @@ mod tests { } } + struct HyperKZGVerifier + where + E: Pairing, + S: SNARK, + G1Var: CurveVar, + { + _params: PhantomData<(E, S, G1Var)>, + } + + impl OffloadedSNARK for HyperKZGVerifier + where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, + S: SNARK, + G1Var: CurveVar + ToConstraintFieldGadget, + { + type Circuit = HyperKZGVerifierCircuit; + + fn offloaded_setup( + snark_vk: S::ProcessedVerifyingKey, + ) -> Result, OffloadedSNARKError> { + Ok(OffloadedSNARKVerifyingKey { + snark_pvk: snark_vk, + delayed_pairings: vec![], + }) + } + + fn g2_elements( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &S::Proof, + ) -> Result>, SerializationError> { + Ok(vec![]) + } + } + #[test] fn test_hyperkzg_eval() { type Groth16 = ark_groth16::Groth16; + type VerifierSNARK = HyperKZGVerifier; // Test with poly(X1, X2) = 1 + X1 + X2 + X1*X2 let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); @@ -441,9 +483,8 @@ mod tests { expected_result: None, }; - Groth16::setup(circuit, &mut rng).unwrap() + VerifierSNARK::setup(circuit, &mut rng).unwrap() }; - let pvk = Groth16::process_vk(&cvk).unwrap(); let C = HyperKZG::commit(&pcs_pk, &poly).unwrap(); @@ -477,10 +518,10 @@ mod tests { println!("Verifying in-circuit..."); // Create a groth16 proof with our parameters. - let proof = Groth16::prove(&cpk, verifier_circuit, &mut rng) + let proof = VerifierSNARK::prove(&cpk, verifier_circuit, &mut rng) .map_err(|e| ProofVerifyError::InternalError)?; - let result = Groth16::verify_with_processed_vk(&pvk, &instance, &proof); + let result = VerifierSNARK::verify_with_processed_vk(&cvk, &instance, &proof); match result { Ok(true) => Ok(()), Ok(false) => Err(ProofVerifyError::InternalError), From 3433150437e33c474193f216df49ce77b733a162 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Thu, 15 Aug 2024 12:35:51 -0700 Subject: [PATCH 33/44] WIP: HyperKZG verifier: the only thing left is the pairing --- jolt-core/src/circuits/offloaded/mod.rs | 2 + .../src/circuits/poly/commitment/hyperkzg.rs | 95 ++++++++++++------- jolt-core/src/snark/mod.rs | 5 +- 3 files changed, 65 insertions(+), 37 deletions(-) diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index 0cc2e8d95..09df4b2ed 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -108,6 +108,7 @@ where f_input.enforce_equal(f)?; } } + dbg!(cs.num_constraints()); // write msm_g1 to public_input { @@ -119,6 +120,7 @@ where } } dbg!(cs.num_constraints()); + dbg!(cs.num_instance_variables()); Ok(full_msm_value) }) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index cc00d4677..ca41a5e37 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,4 +1,4 @@ -use crate::circuits::offloaded::OffloadedMSMGadget; +use crate::circuits::offloaded::{MSMGadget, OffloadedMSMGadget}; use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; use crate::circuits::transcript::ImplAbsorb; use crate::field::JoltField; @@ -28,7 +28,6 @@ use ark_std::One; pub struct HyperKZGProofVar where E: Pairing, - G1Var: CurveVar, { pub com: Vec, pub w: Vec, @@ -97,29 +96,26 @@ where } #[derive(Clone, Debug)] -pub struct HyperKZGVerifierKeyVar -where - E: Pairing, - ConstraintF: PrimeField, -{ - _params: PhantomData<(E, ConstraintF)>, - // TODO fill in +pub struct HyperKZGVerifierKeyVar { + pub g1: G1Var, + // pub g2: G2Var, + // pub beta_g2: G2Var, } -impl AllocVar<(HyperKZGProverKey, HyperKZGVerifierKey), ConstraintF> - for HyperKZGVerifierKeyVar +impl AllocVar<(HyperKZGProverKey, HyperKZGVerifierKey), E::ScalarField> + for HyperKZGVerifierKeyVar where E: Pairing, - ConstraintF: PrimeField, + G1Var: CurveVar, { fn new_variable, HyperKZGVerifierKey)>>( - cs: impl Into>, + cs: impl Into>, f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { // TODO implement Ok(Self { - _params: PhantomData, + g1: G1Var::new_variable(cs, || Ok(f()?.borrow().1.kzg_vk.g1), mode)?, }) } } @@ -150,15 +146,16 @@ where } } -impl<'a, E, S, G1Var, Circuit> CommitmentVerifierGadget, S> +impl<'a, E, S, F, G1Var, Circuit> CommitmentVerifierGadget, S> for HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> where - E: Pairing, - S: SpongeWithGadget, - G1Var: CurveVar + ToConstraintFieldGadget, + F: PrimeField + JoltField, + E: Pairing, + S: SpongeWithGadget, + G1Var: CurveVar + ToConstraintFieldGadget, Circuit: OffloadedDataCircuit, { - type VerifyingKeyVar = HyperKZGVerifierKeyVar; + type VerifyingKeyVar = HyperKZGVerifierKeyVar; type ProofVar = HyperKZGProofVar; type CommitmentVar = HyperKZGCommitmentVar; @@ -167,14 +164,16 @@ where proof: &Self::ProofVar, vk: &Self::VerifyingKeyVar, transcript: &mut S::Var, - opening_point: &[FpVar], - opening: &FpVar, + opening_point: &[FpVar], + opening: &FpVar, commitment: &Self::CommitmentVar, - ) -> Result, SynthesisError> { + ) -> Result, SynthesisError> { let ell = opening_point.len(); + assert!(ell >= 2); let HyperKZGProofVar { com, w, v } = proof; let HyperKZGCommitmentVar { c } = commitment; + let HyperKZGVerifierKeyVar { g1 } = vk; transcript.absorb( &com.iter() @@ -198,15 +197,15 @@ where if w.len() != 3 { return Err(SynthesisError::Unsatisfiable); } - if ell != v[0].len() || ell != v[1].len() || ell != v[2].len() { + if ell != v[0].len() || ell != v[1].len() || ell != v[2].len() || ell != com.len() { return Err(SynthesisError::Unsatisfiable); } let x = opening_point; let y = [v[2].clone(), vec![opening.clone()]].concat(); - let one = FpVar::Constant(E::ScalarField::one()); - let two = FpVar::Constant(E::ScalarField::from(2u128)); + let one = FpVar::one(); + let two = FpVar::Constant(F::from(2u128)); for i in 0..ell { (&two * &r * &y[i + 1]).enforce_equal( &(&r * (&one - &x[ell - i - 1]) * (&v[0][i] + &v[1][i]) @@ -232,9 +231,14 @@ where .next() .unwrap(); - let q_power_multiplier = one + &d + &d.square()?; + let d_square = d.square()?; + let q_power_multiplier = one + &d + &d_square; + let q_powers_multiplied = q_powers + .iter() + .map(|q_i| q_i * &q_power_multiplier) + .collect::>(); - let b_u_i = v + let b_u = v .iter() .map(|v_i| { let mut b_u_i = v_i[0].clone(); @@ -245,9 +249,29 @@ where }) .collect::>(); - let msm_gadget = - OffloadedMSMGadget::, E::G1, G1Var, Circuit>::new(self.circuit); + let msm_gadget = OffloadedMSMGadget::, E::G1, G1Var, Circuit>::new(self.circuit); + + let g1s = &[com.as_slice(), w.as_slice(), &[g1.clone()]].concat(); + let scalars = &[ + q_powers_multiplied.as_slice(), + &[ + u[0].clone(), + &u[1] * &d, + &u[2] * &d_square, + (&b_u[0] + &d * &b_u[1] + &d_square * &b_u[2]).negate()?, + ], + ] + .concat(); + debug_assert_eq!(g1s.len(), scalars.len()); + + let l_g1 = msm_gadget.msm(ns!(transcript.cs(), "l_g1"), g1s, scalars)?; + dbg!(); + let g1s = w.as_slice(); + let scalars = &[FpVar::one(), d, d_square]; + debug_assert_eq!(g1s.len(), scalars.len()); + + let r_g1 = msm_gadget.msm(ns!(transcript.cs(), "r_g1"), g1s, scalars)?; dbg!(); // TODO implement @@ -267,7 +291,7 @@ fn q_powers>( let q_powers = [vec![FpVar::Constant(E::ScalarField::one()), q.clone()], { let mut q_power = q.clone(); - (1..ell) + (2..ell) .map(|i| { q_power *= &q; q_power.clone() @@ -355,12 +379,11 @@ mod tests { self, cs: ConstraintSystemRef, ) -> Result<(), SynthesisError> { - let vk_var = - HyperKZGVerifierKeyVar::::new_witness(ns!(cs, "vk"), || { - self.pcs_pk_vk - .clone() - .ok_or(SynthesisError::AssignmentMissing) - })?; + let vk_var = HyperKZGVerifierKeyVar::::new_witness(ns!(cs, "vk"), || { + self.pcs_pk_vk + .clone() + .ok_or(SynthesisError::AssignmentMissing) + })?; let commitment_var = HyperKZGCommitmentVar::::new_witness(ns!(cs, "commitment"), || { diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index 640132ac8..d80626c8b 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -114,7 +114,10 @@ fn run_deferred( let msms = deferred_fns .into_iter() .map(|f| f()) - .collect::>, _>>()?; + .collect::, _>>()?; + + // can't collect into `Option>` above: it short-circuits on the first None + let msms = msms.into_iter().collect::>>(); Ok(msms.map(|msms| OffloadedData { msms })) } From f4212254cac4c1671d128b909a4be99cb62bacc8 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Thu, 15 Aug 2024 15:42:36 -0700 Subject: [PATCH 34/44] WIP: HyperKZG verifier: remove some debugging statements --- jolt-core/src/circuits/offloaded/mod.rs | 2 -- jolt-core/src/circuits/poly/commitment/hyperkzg.rs | 2 -- 2 files changed, 4 deletions(-) diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index 09df4b2ed..f44dc28f3 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -97,7 +97,6 @@ where let scalar_input = FVar::new_input(ns!(cs, "scalar"), || x.value())?; scalar_input.enforce_equal(&x)?; } - dbg!(cs.num_constraints()); // write g1s to public_input for g1 in g1s { @@ -108,7 +107,6 @@ where f_input.enforce_equal(f)?; } } - dbg!(cs.num_constraints()); // write msm_g1 to public_input { diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index ca41a5e37..7520f111e 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -262,10 +262,8 @@ where ], ] .concat(); - debug_assert_eq!(g1s.len(), scalars.len()); let l_g1 = msm_gadget.msm(ns!(transcript.cs(), "l_g1"), g1s, scalars)?; - dbg!(); let g1s = w.as_slice(); let scalars = &[FpVar::one(), d, d_square]; From 31e87a4af085a9264ce2ae8522411c110f4f3b5b Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Thu, 15 Aug 2024 17:24:46 -0700 Subject: [PATCH 35/44] WIP: HyperKZG verifier: add G2 elements to OffloadSNARKVerifyingKey --- jolt-core/src/circuits/groups/curves/mod.rs | 12 +++------- .../src/circuits/poly/commitment/hyperkzg.rs | 24 ++++++++----------- jolt-core/src/snark/mod.rs | 20 +++++++++++----- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index eaec2f961..41143cb89 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -37,6 +37,7 @@ mod tests { use itertools::Itertools; use rand_core::{CryptoRng, RngCore, SeedableRng}; + #[derive(Clone)] struct DelayedOpsCircuit where E: Pairing, @@ -121,22 +122,15 @@ mod tests { type Circuit = DelayedOpsCircuit; fn offloaded_setup( + circuit: Self::Circuit, snark_vk: S::ProcessedVerifyingKey, ) -> Result, OffloadedSNARKError> { Ok(OffloadedSNARKVerifyingKey { snark_pvk: snark_vk, delayed_pairings: vec![], // TODO none yet + g2_elements: vec![], }) } - - fn g2_elements( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &S::Proof, - ) -> Result>, SerializationError> { - // TODO get the G2 elements from the verifying key - Ok(vec![]) - } } #[test] diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 7520f111e..78eb80750 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -308,6 +308,7 @@ mod tests { use crate::poly::commitment::hyperkzg::{ HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, }; + use crate::poly::commitment::kzg::KZGVerifierKey; use crate::poly::dense_mlpoly::DensePolynomial; use crate::snark::{ DeferredFnsRef, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, @@ -328,6 +329,7 @@ mod tests { use ark_std::Zero; use rand_core::{CryptoRng, RngCore, SeedableRng}; + #[derive(Clone)] struct HyperKZGVerifierCircuit where E: Pairing, @@ -335,7 +337,7 @@ mod tests { { _params: PhantomData, deferred_fns_ref: DeferredFnsRef, - pcs_pk_vk: Option<(HyperKZGProverKey, HyperKZGVerifierKey)>, + pcs_pk_vk: (HyperKZGProverKey, HyperKZGVerifierKey), commitment: Option>, point: Vec>, eval: Option, @@ -378,9 +380,7 @@ mod tests { cs: ConstraintSystemRef, ) -> Result<(), SynthesisError> { let vk_var = HyperKZGVerifierKeyVar::::new_witness(ns!(cs, "vk"), || { - self.pcs_pk_vk - .clone() - .ok_or(SynthesisError::AssignmentMissing) + Ok(self.pcs_pk_vk.clone()) })?; let commitment_var = @@ -456,21 +456,17 @@ mod tests { type Circuit = HyperKZGVerifierCircuit; fn offloaded_setup( + circuit: Self::Circuit, snark_vk: S::ProcessedVerifyingKey, ) -> Result, OffloadedSNARKError> { + let KZGVerifierKey { g1, g2, beta_g2 } = circuit.pcs_pk_vk.1.kzg_vk; + Ok(OffloadedSNARKVerifyingKey { snark_pvk: snark_vk, delayed_pairings: vec![], + g2_elements: vec![vec![g2, beta_g2]], }) } - - fn g2_elements( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &S::Proof, - ) -> Result>, SerializationError> { - Ok(vec![]) - } } #[test] @@ -496,7 +492,7 @@ mod tests { let circuit = HyperKZGVerifierCircuit:: { _params: PhantomData, deferred_fns_ref: Default::default(), - pcs_pk_vk: None, + pcs_pk_vk: (pcs_pk.clone(), pcs_vk.clone()), commitment: None, point: vec![None; size], eval: None, @@ -524,7 +520,7 @@ mod tests { let verifier_circuit = HyperKZGVerifierCircuit:: { _params: PhantomData, deferred_fns_ref: Default::default(), - pcs_pk_vk: Some((pcs_pk.clone(), pcs_vk.clone())), + pcs_pk_vk: (pcs_pk.clone(), pcs_vk.clone()), commitment: Some(C.clone()), point: point.into_iter().map(|x| Some(x)).collect(), eval: Some(eval), diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index d80626c8b..2c07bf2ef 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -39,6 +39,7 @@ where { pub snark_pvk: S::ProcessedVerifyingKey, pub delayed_pairings: Vec, + pub g2_elements: Vec>, } #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] @@ -65,7 +66,7 @@ pub type DeferredFn = pub type DeferredFnsRef = Rc>>>>; -pub trait OffloadedDataCircuit +pub trait OffloadedDataCircuit: Clone where G: CurveGroup, { @@ -163,22 +164,24 @@ where Self::circuit_specific_setup(circuit, rng) } - fn circuit_specific_setup, R: RngCore + CryptoRng>( - circuit: C, + fn circuit_specific_setup( + circuit: WrappedCircuit, rng: &mut R, ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> { + let offloaded_circuit = circuit.circuit.clone(); let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) .map_err(|e| OffloadedSNARKError::SNARKError(e))?; let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; - let vk = Self::offloaded_setup(snark_pvk)?; + let vk = Self::offloaded_setup(offloaded_circuit, snark_pvk)?; Ok((pk, vk)) } fn offloaded_setup( + circuit: Self::Circuit, snark_vk: S::ProcessedVerifyingKey, ) -> Result, OffloadedSNARKError>; @@ -288,7 +291,7 @@ where .collect::>, SerializationError>>(); Ok(g1_vectors? .into_iter() - .zip(Self::g2_elements(vk, public_input, proof)?) + .zip(Self::g2_elements(vk, public_input, proof)) .collect()) } @@ -296,7 +299,12 @@ where vk: &OffloadedSNARKVerifyingKey, public_input: &[E::ScalarField], proof: &S::Proof, - ) -> Result>, SerializationError>; + ) -> Vec> { + vk.g2_elements + .iter() + .map(|g2s| g2s.iter().map(|g2| g2.into_group()).collect::>()) + .collect::>>() + } } fn g1_affine_size_in_scalar_field_elements() -> usize { From 4ab444d56edb52cb2fac6d4387f42f65f41aff38 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Sun, 18 Aug 2024 15:37:09 -0700 Subject: [PATCH 36/44] WIP: HyperKZG verifier: make pairings work --- jolt-core/src/circuits/groups/curves/mod.rs | 13 +- jolt-core/src/circuits/offloaded/mod.rs | 138 ++++++++++++++++-- .../src/circuits/poly/commitment/hyperkzg.rs | 92 +++++++----- jolt-core/src/circuits/transcript/mock.rs | 28 ++-- jolt-core/src/circuits/transcript/mod.rs | 5 +- jolt-core/src/snark/mod.rs | 85 +++++++---- 6 files changed, 266 insertions(+), 95 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 41143cb89..98fdbc3fe 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -121,16 +121,9 @@ mod tests { { type Circuit = DelayedOpsCircuit; - fn offloaded_setup( - circuit: Self::Circuit, - snark_vk: S::ProcessedVerifyingKey, - ) -> Result, OffloadedSNARKError> { - Ok(OffloadedSNARKVerifyingKey { - snark_pvk: snark_vk, - delayed_pairings: vec![], // TODO none yet - g2_elements: vec![], - }) - } + // fn pairing_setup(circuit: Self::Circuit) -> Vec> { + // vec![] + // } } #[test] diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index f44dc28f3..91f70614b 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -3,6 +3,7 @@ use ark_ec::pairing::Pairing; use ark_ec::{CurveGroup, VariableBaseMSM}; use ark_ff::{One, PrimeField}; use ark_r1cs_std::alloc::AllocVar; +use ark_r1cs_std::boolean::Boolean; use ark_r1cs_std::eq::EqGadget; use ark_r1cs_std::fields::fp::FpVar; use ark_r1cs_std::fields::FieldVar; @@ -10,8 +11,24 @@ use ark_r1cs_std::groups::CurveVar; use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; use ark_relations::ns; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_serialize::Valid; +use ark_std::Zero; use std::marker::PhantomData; +pub trait MSMGadget +where + FVar: FieldVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, +{ + fn msm( + &self, + cs: impl Into>, + g1s: &[GVar], + scalars: &[FVar], + ) -> Result; +} + pub struct OffloadedMSMGadget<'a, FVar, G, GVar, Circuit> where Circuit: OffloadedDataCircuit, @@ -98,10 +115,13 @@ where scalar_input.enforce_equal(&x)?; } + let mut offsets = vec![]; + // write g1s to public_input for g1 in g1s { let f_vec = g1.to_constraint_field()?; + offsets.push(cs.num_instance_variables() - 1); for f in f_vec.iter() { let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; f_input.enforce_equal(f)?; @@ -110,8 +130,10 @@ where // write msm_g1 to public_input { + dbg!(cs.num_instance_variables() - 1); let f_vec = msm_g1_var.to_constraint_field()?; + offsets.push(cs.num_instance_variables() - 1); for f in f_vec.iter() { let f_input = FpVar::new_input(ns!(cs, "msm_g1"), || f.value())?; f_input.enforce_equal(f)?; @@ -120,7 +142,7 @@ where dbg!(cs.num_constraints()); dbg!(cs.num_instance_variables()); - Ok(full_msm_value) + Ok((full_msm_value, offsets)) }) }; dbg!(cs.num_constraints()); @@ -129,16 +151,114 @@ where } } -pub trait MSMGadget +pub trait PairingGadget where - FVar: FieldVar + ToConstraintFieldGadget, - G: CurveGroup, - GVar: CurveVar + ToConstraintFieldGadget, + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, { - fn msm( + fn multi_pairing_is_zero( &self, - cs: impl Into>, + cs: impl Into>, + g1s: &[G1Var], + g2s: &[E::G2Affine], + ) -> Result<(), SynthesisError>; +} + +pub struct OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> +where + E: Pairing, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + GVar: CurveVar + ToConstraintFieldGadget, +{ + _params: PhantomData<(E, FVar, GVar)>, + circuit: &'a Circuit, +} + +impl<'a, E, FVar, GVar, Circuit> OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> +where + E: Pairing, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + GVar: CurveVar + ToConstraintFieldGadget, +{ + pub(crate) fn new(circuit: &'a Circuit) -> Self { + Self { + _params: PhantomData, + circuit, + } + } +} + +impl<'a, E, FVar, GVar, Circuit> PairingGadget + for OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> +where + E: Pairing, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + GVar: CurveVar + ToConstraintFieldGadget, +{ + fn multi_pairing_is_zero( + &self, + cs: impl Into>, g1s: &[GVar], - scalars: &[FVar], - ) -> Result; + g2s: &[E::G2Affine], + ) -> Result<(), SynthesisError> { + let ns = cs.into(); + let cs = ns.cs(); + + let g1_values_opt = g1s + .iter() + .map(|g1| g1.value().ok()) + .collect::>>(); + + let g2_values = g2s; + + let is_zero_opt = + g1_values_opt.map(|g1_values| E::multi_pairing(dbg!(&g1_values), g2_values).is_zero()); + if let Some(false) = is_zero_opt { + dbg!("multi_pairing_is_zero: false"); + return Err(SynthesisError::Unsatisfiable); + } + + // { + // let g1s = g1s.to_vec(); + // let ns = ns!(cs, "deferred_pairing"); + // let cs = ns.cs(); + // + // self.circuit.defer_msm(move || { + // let mut offsets = vec![]; + // + // // write g1s to public_input + // for g1 in g1s { + // let f_vec = g1.to_constraint_field()?; + // + // offsets.push(cs.num_instance_variables() - 1); + // for f in f_vec.iter() { + // let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; + // f_input.enforce_equal(f)?; + // } + // } + // + // // write g2s to public_input + // for g2 in g2s { + // let f_vec = g2.to_constraint_field()?; + // + // offsets.push(cs.num_instance_variables() - 1); + // for f in f_vec.iter() { + // let f_input = FpVar::new_input(ns!(cs, "g2s"), || f.value())?; + // f_input.enforce_equal(f)?; + // } + // } + // + // dbg!(cs.num_constraints()); + // dbg!(cs.num_instance_variables()); + // + // Ok(()) + // }) + // } + + dbg!("multi_pairing_is_zero: success"); + Ok(()) + } } diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 78eb80750..eacff1978 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,4 +1,6 @@ -use crate::circuits::offloaded::{MSMGadget, OffloadedMSMGadget}; +use crate::circuits::offloaded::{ + MSMGadget, OffloadedMSMGadget, OffloadedPairingGadget, PairingGadget, +}; use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; use crate::circuits::transcript::ImplAbsorb; use crate::field::JoltField; @@ -129,6 +131,7 @@ where { _params: PhantomData<(E, S, G1Var)>, circuit: &'a Circuit, + g2_elements: Vec, } impl<'a, E, S, G1Var, Circuit> HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> @@ -138,10 +141,11 @@ where G1Var: CurveVar + ToConstraintFieldGadget, Circuit: OffloadedDataCircuit, { - pub fn new(circuit: &'a Circuit) -> Self { + pub fn new(circuit: &'a Circuit, g2_elements: Vec) -> Self { Self { _params: PhantomData, circuit, + g2_elements, } } } @@ -187,6 +191,8 @@ where .next() .unwrap(); + dbg!(r.value()); + let u = vec![r.clone(), r.negate()?, r.clone() * &r]; let com = [vec![c.clone()], com.clone()].concat(); @@ -215,22 +221,25 @@ where // kzg_verify_batch - transcript.absorb(&v.iter().flatten().cloned().collect::>())?; - let q_powers = q_powers::(transcript, ell)?; - transcript.absorb( - &proof - .w - .iter() - .map(|g| ImplAbsorb::wrap(g)) + &v.iter() + .flatten() + .map(|v_ij| ImplAbsorb::wrap(v_ij)) .collect::>(), )?; + let q_powers = q_powers::(transcript, ell)?; + + dbg!(q_powers.value()); + + transcript.absorb(&w.iter().map(|g| ImplAbsorb::wrap(g)).collect::>())?; let d = transcript .squeeze_field_elements(1)? .into_iter() .next() .unwrap(); + dbg!(d.value()); + let d_square = d.square()?; let q_power_multiplier = one + &d + &d_square; let q_powers_multiplied = q_powers @@ -250,9 +259,11 @@ where .collect::>(); let msm_gadget = OffloadedMSMGadget::, E::G1, G1Var, Circuit>::new(self.circuit); + let pairing_gadget = + OffloadedPairingGadget::, G1Var, Circuit>::new(self.circuit); - let g1s = &[com.as_slice(), w.as_slice(), &[g1.clone()]].concat(); - let scalars = &[ + let l_g1s = &[com.as_slice(), w.as_slice(), &[g1.clone()]].concat(); + let l_scalars = &[ q_powers_multiplied.as_slice(), &[ u[0].clone(), @@ -262,17 +273,29 @@ where ], ] .concat(); + debug_assert_eq!(l_g1s.len(), l_scalars.len()); - let l_g1 = msm_gadget.msm(ns!(transcript.cs(), "l_g1"), g1s, scalars)?; + dbg!(transcript.cs().num_instance_variables() - 1); + let l_g1 = msm_gadget.msm(ns!(transcript.cs(), "l_g1"), l_g1s, l_scalars)?; - let g1s = w.as_slice(); - let scalars = &[FpVar::one(), d, d_square]; - debug_assert_eq!(g1s.len(), scalars.len()); + dbg!(w.as_slice().value()); - let r_g1 = msm_gadget.msm(ns!(transcript.cs(), "r_g1"), g1s, scalars)?; + let r_g1s = w.as_slice(); + let r_scalars = &[FpVar::one().negate()?, d.negate()?, d_square.negate()?]; + debug_assert_eq!(r_g1s.len(), r_scalars.len()); + + dbg!(transcript.cs().num_instance_variables() - 1); + let r_g1 = msm_gadget.msm(ns!(transcript.cs(), "r_g1"), r_g1s, r_scalars)?; + + // (dbg!(l_g1.value()), dbg!(r_g1.value())); + + pairing_gadget.multi_pairing_is_zero( + ns!(transcript.cs(), "multi_pairing"), + &[l_g1, r_g1], + self.g2_elements.as_slice(), + )?; dbg!(); - // TODO implement Ok(Boolean::TRUE) } } @@ -311,8 +334,8 @@ mod tests { use crate::poly::commitment::kzg::KZGVerifierKey; use crate::poly::dense_mlpoly::DensePolynomial; use crate::snark::{ - DeferredFnsRef, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, - OffloadedSNARKVerifyingKey, + DeferredFnsRef, OffloadedDataCircuit, OffloadedPairingDef, OffloadedSNARK, + OffloadedSNARKError, OffloadedSNARKVerifyingKey, }; use crate::utils::errors::ProofVerifyError; use crate::utils::transcript::ProofTranscript; @@ -409,14 +432,14 @@ mod tests { let mut transcript_var = MockSpongeVar::new(ns!(cs, "transcript").cs(), &(b"TestEval".as_slice())); - let h_kzg = HyperKZGVerifierGadget::< - E, - MockSponge, - G1Var, - HyperKZGVerifierCircuit, - >::new(&self); + let kzg_vk = self.pcs_pk_vk.1.kzg_vk; + let hyper_kzg = + HyperKZGVerifierGadget::, G1Var, Self>::new( + &self, + vec![kzg_vk.g2, kzg_vk.beta_g2], + ); - let r = h_kzg.verify( + let r = hyper_kzg.verify( &proof_var, &vk_var, &mut transcript_var, @@ -455,18 +478,11 @@ mod tests { { type Circuit = HyperKZGVerifierCircuit; - fn offloaded_setup( - circuit: Self::Circuit, - snark_vk: S::ProcessedVerifyingKey, - ) -> Result, OffloadedSNARKError> { - let KZGVerifierKey { g1, g2, beta_g2 } = circuit.pcs_pk_vk.1.kzg_vk; - - Ok(OffloadedSNARKVerifyingKey { - snark_pvk: snark_vk, - delayed_pairings: vec![], - g2_elements: vec![vec![g2, beta_g2]], - }) - } + // fn pairing_setup(circuit: Self::Circuit) -> Vec> { + // let KZGVerifierKey { g1, g2, beta_g2 } = circuit.pcs_pk_vk.1.kzg_vk; + // + // vec![vec![g2, beta_g2]] + // } } #[test] diff --git a/jolt-core/src/circuits/transcript/mock.rs b/jolt-core/src/circuits/transcript/mock.rs index d78708f06..6e4eda286 100644 --- a/jolt-core/src/circuits/transcript/mock.rs +++ b/jolt-core/src/circuits/transcript/mock.rs @@ -1,4 +1,4 @@ -use crate::circuits::transcript::IS_SLICE; +use crate::circuits::transcript::SLICE; use crate::field::JoltField; use crate::utils::transcript::ProofTranscript; use ark_crypto_primitives::sponge::constraints::{ @@ -83,21 +83,30 @@ where fn absorb(&mut self, input: &impl AbsorbGadget) -> Result<(), SynthesisError> { let bytes = input.to_sponge_bytes()?; - let is_slice = IS_SLICE.take(); - let fs = bytes + let bs = bytes .iter() .map(|f| match self.cs.is_in_setup_mode() { true => Ok(0u8), false => f.value(), }) .collect::, _>>()?; - if is_slice { - self.transcript.append_message(b"begin_append_vector"); - } - self.transcript.append_bytes(&fs); - if is_slice { - self.transcript.append_message(b"end_append_vector"); + + let slice_opt = SLICE.take(); + match slice_opt { + Some(slice_len) => { + self.transcript.append_message(b"begin_append_vector"); + if slice_len != 0 { + for chunk in bs.chunks(bs.len() / slice_len) { + self.transcript.append_bytes(chunk); + } + } + self.transcript.append_message(b"end_append_vector"); + } + None => { + self.transcript.append_bytes(&bs); + } } + Ok(()) } @@ -119,7 +128,6 @@ where &mut self, num_elements: usize, ) -> Result>, SynthesisError> { - dbg!(&self.transcript.n_rounds); self.transcript .challenge_vector::(num_elements) .iter() diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs index 146414770..12f66d886 100644 --- a/jolt-core/src/circuits/transcript/mod.rs +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -7,6 +7,7 @@ use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; use ark_relations::ns; use ark_relations::r1cs::SynthesisError; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_std::iterable::Iterable; use ark_std::Zero; use std::cell::RefCell; use std::fmt::Debug; @@ -30,7 +31,7 @@ where } thread_local! { - static IS_SLICE: RefCell = RefCell::new(false); + static SLICE: RefCell> = RefCell::new(None); } impl<'a, T, F> AbsorbGadget for ImplAbsorb<'a, T, F> @@ -59,7 +60,7 @@ where where Self: Sized, { - IS_SLICE.set(true); + SLICE.set(Some(batch.len())); let mut result = Vec::new(); for item in batch { result.append(&mut (item.to_sponge_bytes()?)) diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index 2c07bf2ef..e9ea1be0b 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -23,12 +23,16 @@ use rand_core::{CryptoRng, RngCore}; /// The verifier needs to use appropriate G2 elements from the verification key or the proof /// (depending on the protocol). #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct OffloadedPairingDef { +pub struct OffloadedPairingDef +where + E: Pairing, +{ /// Offsets of the G1 elements in the public input. The G1 elements are stored as sequences of scalar field elements /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). - /// The offsets are in the number of scalar field elements in the public input before the G1 elements block. + /// The offsets are in the number of scalar field elements in the public input before the G1 element. /// The last element, by convention, is always used in the multi-pairing computation with coefficient `-1`. pub g1_offsets: Vec, + pub g2_elements: Vec, } #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] @@ -38,8 +42,7 @@ where S: SNARK, { pub snark_pvk: S::ProcessedVerifyingKey, - pub delayed_pairings: Vec, - pub g2_elements: Vec>, + pub delayed_pairings: Vec>, } #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] @@ -61,8 +64,15 @@ pub struct OffloadedData { pub msms: Vec<(Vec, Vec)>, } -pub type DeferredFn = - dyn FnOnce() -> Result, Vec)>, SynthesisError>; +pub enum DeferredOp { + MSM(Option<(Vec, Vec)>), + Pairing(OffloadedPairingDef), +} + +pub type DeferredFn = dyn FnOnce() -> Result< + (Option<(Vec, Vec)>, Vec), + SynthesisError, +>; pub type DeferredFnsRef = Rc>>>>; @@ -74,8 +84,10 @@ where fn defer_msm( &self, - f: impl FnOnce() -> Result, Vec)>, SynthesisError> - + 'static, + f: impl FnOnce() -> Result< + (Option<(Vec, Vec)>, Vec), + SynthesisError, + > + 'static, ) { self.deferred_fns_ref().borrow_mut().push(Box::new(f)); } @@ -103,24 +115,21 @@ where { circuit: C, offloaded_data_ref: Rc>>, + g1_offsets_ref: Rc>>>, } fn run_deferred( - deferred_fns: Vec< - Box< - dyn FnOnce() -> Result, Vec)>, SynthesisError>, - >, - >, -) -> Result>, SynthesisError> { - let msms = deferred_fns + deferred_fns: Vec>>, +) -> Result<(Option>, Vec>), SynthesisError> { + let (msms, g1_offsets) = deferred_fns .into_iter() .map(|f| f()) - .collect::, _>>()?; + .collect::, Vec<_>), _>>()?; // can't collect into `Option>` above: it short-circuits on the first None let msms = msms.into_iter().collect::>>(); - Ok(msms.map(|msms| OffloadedData { msms })) + Ok((msms.map(|msms| OffloadedData { msms }), g1_offsets)) } impl ConstraintSynthesizer for WrappedCircuit @@ -132,12 +141,16 @@ where let deferred_fns_ref = self.circuit.deferred_fns_ref().clone(); let offloaded_data_ref = self.offloaded_data_ref.clone(); + let g1_offsets_ref = self.g1_offsets_ref.clone(); self.circuit.generate_constraints(cs)?; - if let Some(offloaded_data) = run_deferred::(deferred_fns_ref.take())? { + let (offloaded_data, g1_offsets) = run_deferred::(deferred_fns_ref.take())?; + + if let Some(offloaded_data) = offloaded_data { offloaded_data_ref.set(offloaded_data).unwrap(); - }; + } + g1_offsets_ref.set(g1_offsets).unwrap(); Ok(()) } @@ -160,6 +173,7 @@ where let circuit: WrappedCircuit = WrappedCircuit { circuit, offloaded_data_ref: Default::default(), + g1_offsets_ref: Default::default(), }; Self::circuit_specific_setup(circuit, rng) } @@ -169,21 +183,33 @@ where rng: &mut R, ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> { + let g1_offsets_ref = circuit.g1_offsets_ref.clone(); + let offloaded_circuit = circuit.circuit.clone(); + let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) .map_err(|e| OffloadedSNARKError::SNARKError(e))?; let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; - let vk = Self::offloaded_setup(offloaded_circuit, snark_pvk)?; + // let g2_vecs = Self::pairing_setup(offloaded_circuit); + // let delayed_pairings = g2_vecs + // .into_iter() + // .zip(g1_offsets_ref.get().unwrap().clone().into_iter()) + // .map(|(g2_elements, g1_offsets)| OffloadedPairingDef { + // g1_offsets, + // g2_elements, + // }) + // .collect(); + let vk = OffloadedSNARKVerifyingKey { + snark_pvk, + delayed_pairings: vec![], + }; Ok((pk, vk)) } - fn offloaded_setup( - circuit: Self::Circuit, - snark_vk: S::ProcessedVerifyingKey, - ) -> Result, OffloadedSNARKError>; + // fn pairing_setup(circuit: Self::Circuit) -> Vec>; fn prove( circuit_pk: &S::ProvingKey, @@ -193,6 +219,7 @@ where let circuit: WrappedCircuit = WrappedCircuit { circuit, offloaded_data_ref: Default::default(), + g1_offsets_ref: Default::default(), }; let offloaded_data_ref = circuit.offloaded_data_ref.clone(); @@ -300,9 +327,15 @@ where public_input: &[E::ScalarField], proof: &S::Proof, ) -> Vec> { - vk.g2_elements + vk.delayed_pairings .iter() - .map(|g2s| g2s.iter().map(|g2| g2.into_group()).collect::>()) + .map(|pairing_def| { + pairing_def + .g2_elements + .iter() + .map(|g2| g2.into_group()) + .collect::>() + }) .collect::>>() } } From b7e983ab8f4ace928cef171ed58e70c747b50ab1 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Mon, 19 Aug 2024 14:33:59 -0700 Subject: [PATCH 37/44] WIP: HyperKZG verifier: offload pairings data --- jolt-core/src/circuits/offloaded/mod.rs | 14 ++----- .../src/circuits/poly/commitment/hyperkzg.rs | 13 ------ jolt-core/src/snark/mod.rs | 42 +++++-------------- 3 files changed, 15 insertions(+), 54 deletions(-) diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index 91f70614b..7a6a61b0e 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -115,13 +115,10 @@ where scalar_input.enforce_equal(&x)?; } - let mut offsets = vec![]; - // write g1s to public_input for g1 in g1s { let f_vec = g1.to_constraint_field()?; - offsets.push(cs.num_instance_variables() - 1); for f in f_vec.iter() { let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; f_input.enforce_equal(f)?; @@ -133,7 +130,6 @@ where dbg!(cs.num_instance_variables() - 1); let f_vec = msm_g1_var.to_constraint_field()?; - offsets.push(cs.num_instance_variables() - 1); for f in f_vec.iter() { let f_input = FpVar::new_input(ns!(cs, "msm_g1"), || f.value())?; f_input.enforce_equal(f)?; @@ -142,7 +138,7 @@ where dbg!(cs.num_constraints()); dbg!(cs.num_instance_variables()); - Ok((full_msm_value, offsets)) + Ok(full_msm_value) }) }; dbg!(cs.num_constraints()); @@ -214,10 +210,9 @@ where let g2_values = g2s; - let is_zero_opt = - g1_values_opt.map(|g1_values| E::multi_pairing(dbg!(&g1_values), g2_values).is_zero()); - if let Some(false) = is_zero_opt { - dbg!("multi_pairing_is_zero: false"); + if let Some(false) = + g1_values_opt.map(|g1_values| E::multi_pairing(&g1_values, g2_values).is_zero()) + { return Err(SynthesisError::Unsatisfiable); } @@ -258,7 +253,6 @@ where // }) // } - dbg!("multi_pairing_is_zero: success"); Ok(()) } } diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index eacff1978..210913bac 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -191,8 +191,6 @@ where .next() .unwrap(); - dbg!(r.value()); - let u = vec![r.clone(), r.negate()?, r.clone() * &r]; let com = [vec![c.clone()], com.clone()].concat(); @@ -229,8 +227,6 @@ where )?; let q_powers = q_powers::(transcript, ell)?; - dbg!(q_powers.value()); - transcript.absorb(&w.iter().map(|g| ImplAbsorb::wrap(g)).collect::>())?; let d = transcript .squeeze_field_elements(1)? @@ -238,8 +234,6 @@ where .next() .unwrap(); - dbg!(d.value()); - let d_square = d.square()?; let q_power_multiplier = one + &d + &d_square; let q_powers_multiplied = q_powers @@ -275,26 +269,19 @@ where .concat(); debug_assert_eq!(l_g1s.len(), l_scalars.len()); - dbg!(transcript.cs().num_instance_variables() - 1); let l_g1 = msm_gadget.msm(ns!(transcript.cs(), "l_g1"), l_g1s, l_scalars)?; - dbg!(w.as_slice().value()); - let r_g1s = w.as_slice(); let r_scalars = &[FpVar::one().negate()?, d.negate()?, d_square.negate()?]; debug_assert_eq!(r_g1s.len(), r_scalars.len()); - dbg!(transcript.cs().num_instance_variables() - 1); let r_g1 = msm_gadget.msm(ns!(transcript.cs(), "r_g1"), r_g1s, r_scalars)?; - // (dbg!(l_g1.value()), dbg!(r_g1.value())); - pairing_gadget.multi_pairing_is_zero( ns!(transcript.cs(), "multi_pairing"), &[l_g1, r_g1], self.g2_elements.as_slice(), )?; - dbg!(); Ok(Boolean::TRUE) } diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index e9ea1be0b..9a06a8c4e 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -69,10 +69,8 @@ pub enum DeferredOp { Pairing(OffloadedPairingDef), } -pub type DeferredFn = dyn FnOnce() -> Result< - (Option<(Vec, Vec)>, Vec), - SynthesisError, ->; +pub type DeferredFn = + dyn FnOnce() -> Result, Vec)>, SynthesisError>; pub type DeferredFnsRef = Rc>>>>; @@ -84,10 +82,8 @@ where fn defer_msm( &self, - f: impl FnOnce() -> Result< - (Option<(Vec, Vec)>, Vec), - SynthesisError, - > + 'static, + f: impl FnOnce() -> Result, Vec)>, SynthesisError> + + 'static, ) { self.deferred_fns_ref().borrow_mut().push(Box::new(f)); } @@ -115,21 +111,20 @@ where { circuit: C, offloaded_data_ref: Rc>>, - g1_offsets_ref: Rc>>>, } fn run_deferred( deferred_fns: Vec>>, -) -> Result<(Option>, Vec>), SynthesisError> { - let (msms, g1_offsets) = deferred_fns +) -> Result>, SynthesisError> { + let msms = deferred_fns .into_iter() .map(|f| f()) - .collect::, Vec<_>), _>>()?; + .collect::, _>>()?; // can't collect into `Option>` above: it short-circuits on the first None let msms = msms.into_iter().collect::>>(); - Ok((msms.map(|msms| OffloadedData { msms }), g1_offsets)) + Ok(msms.map(|msms| OffloadedData { msms })) } impl ConstraintSynthesizer for WrappedCircuit @@ -141,16 +136,14 @@ where let deferred_fns_ref = self.circuit.deferred_fns_ref().clone(); let offloaded_data_ref = self.offloaded_data_ref.clone(); - let g1_offsets_ref = self.g1_offsets_ref.clone(); self.circuit.generate_constraints(cs)?; - let (offloaded_data, g1_offsets) = run_deferred::(deferred_fns_ref.take())?; + let offloaded_data = run_deferred::(deferred_fns_ref.take())?; if let Some(offloaded_data) = offloaded_data { offloaded_data_ref.set(offloaded_data).unwrap(); } - g1_offsets_ref.set(g1_offsets).unwrap(); Ok(()) } @@ -173,7 +166,6 @@ where let circuit: WrappedCircuit = WrappedCircuit { circuit, offloaded_data_ref: Default::default(), - g1_offsets_ref: Default::default(), }; Self::circuit_specific_setup(circuit, rng) } @@ -183,8 +175,6 @@ where rng: &mut R, ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> { - let g1_offsets_ref = circuit.g1_offsets_ref.clone(); - let offloaded_circuit = circuit.circuit.clone(); let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) @@ -209,8 +199,6 @@ where Ok((pk, vk)) } - // fn pairing_setup(circuit: Self::Circuit) -> Vec>; - fn prove( circuit_pk: &S::ProvingKey, circuit: Self::Circuit, @@ -219,7 +207,6 @@ where let circuit: WrappedCircuit = WrappedCircuit { circuit, offloaded_data_ref: Default::default(), - g1_offsets_ref: Default::default(), }; let offloaded_data_ref = circuit.offloaded_data_ref.clone(); @@ -316,17 +303,10 @@ where g1s }) .collect::>, SerializationError>>(); - Ok(g1_vectors? - .into_iter() - .zip(Self::g2_elements(vk, public_input, proof)) - .collect()) + Ok(g1_vectors?.into_iter().zip(Self::g2_elements(vk)).collect()) } - fn g2_elements( - vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &S::Proof, - ) -> Vec> { + fn g2_elements(vk: &OffloadedSNARKVerifyingKey) -> Vec> { vk.delayed_pairings .iter() .map(|pairing_def| { From 70a65c354e9a263a4b057890f3ebf340c2d8409d Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 20 Aug 2024 16:39:23 -0700 Subject: [PATCH 38/44] HyperKZG verifier works with offloaded data --- jolt-core/src/circuits/groups/curves/mod.rs | 171 -------------- jolt-core/src/circuits/offloaded/mod.rs | 122 +++++----- .../src/circuits/poly/commitment/hyperkzg.rs | 14 +- jolt-core/src/snark/mod.rs | 208 +++++++++++------- 4 files changed, 188 insertions(+), 327 deletions(-) diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs index 98fdbc3fe..dbd7e00e0 100644 --- a/jolt-core/src/circuits/groups/curves/mod.rs +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -1,172 +1 @@ pub mod short_weierstrass; - -#[cfg(test)] -mod tests { - use super::*; - use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; - use crate::circuits::groups::curves::short_weierstrass::{AffineVar, ProjectiveVar}; - use crate::circuits::offloaded::{MSMGadget, OffloadedMSMGadget}; - use crate::snark::{ - DeferredFnsRef, OffloadedData, OffloadedDataCircuit, OffloadedSNARK, OffloadedSNARKError, - OffloadedSNARKVerifyingKey, - }; - use ark_bn254::{Bn254, Fq, Fr}; - use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; - use ark_crypto_primitives::sponge::Absorb; - use ark_ec::bn::G1Projective; - use ark_ec::pairing::Pairing; - use ark_ec::short_weierstrass::{Affine, Projective, SWCurveConfig}; - use ark_ec::{CurveGroup, Group, VariableBaseMSM}; - use ark_ff::{PrimeField, ToConstraintField}; - use ark_groth16::Groth16; - use ark_r1cs_std::fields::fp::FpVar; - use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; - use ark_r1cs_std::prelude::*; - use ark_r1cs_std::ToConstraintFieldGadget; - use ark_relations::ns; - use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; - use ark_serialize::{CanonicalSerialize, SerializationError}; - use ark_std::cell::OnceCell; - use ark_std::cell::{Cell, RefCell}; - use ark_std::marker::PhantomData; - use ark_std::ops::Deref; - use ark_std::rand::Rng; - use ark_std::rc::Rc; - use ark_std::sync::RwLock; - use ark_std::{end_timer, start_timer, test_rng, One, UniformRand}; - use itertools::Itertools; - use rand_core::{CryptoRng, RngCore, SeedableRng}; - - #[derive(Clone)] - struct DelayedOpsCircuit - where - E: Pairing, - G1Var: CurveVar, - { - _params: PhantomData, - - // witness values - w_g1: [Option; 3], - d: Option, - - // deferred fns to write offloaded data to public_input - deferred_fns_ref: DeferredFnsRef, - } - - impl ConstraintSynthesizer for DelayedOpsCircuit - where - E: Pairing, - G1Var: CurveVar + ToConstraintFieldGadget, - { - fn generate_constraints( - self, - cs: ConstraintSystemRef, - ) -> Result<(), SynthesisError> { - dbg!(cs.num_constraints()); - - let d = FpVar::new_witness(ns!(cs, "d"), || { - self.d.ok_or(SynthesisError::AssignmentMissing) - })?; - dbg!(cs.num_constraints()); - - let w_g1 = (0..3) - .map(|i| { - G1Var::new_witness(ns!(cs, "w_g1"), || { - self.w_g1[i].ok_or(SynthesisError::AssignmentMissing) - }) - }) - .collect::, _>>()?; - dbg!(cs.num_constraints()); - - let d_square = d.square()?; - let d_k = [FpVar::one(), d, d_square]; - dbg!(cs.num_constraints()); - - let _ = OffloadedMSMGadget::new(&self).msm( - ns!(cs, "msm"), - w_g1.as_slice(), - d_k.as_slice(), - )?; - dbg!(cs.num_constraints()); - - Ok(()) - } - } - - impl OffloadedDataCircuit for DelayedOpsCircuit - where - E: Pairing, - G1Var: CurveVar + ToConstraintFieldGadget, - { - fn deferred_fns_ref(&self) -> &DeferredFnsRef { - &self.deferred_fns_ref - } - } - - struct DelayedOpsCircuitSNARK - where - E: Pairing, - S: SNARK, - G1Var: CurveVar, - { - _params: PhantomData<(E, S, G1Var)>, - } - - impl OffloadedSNARK for DelayedOpsCircuitSNARK - where - E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, - P: SWCurveConfig, - S: SNARK, - G1Var: CurveVar + ToConstraintFieldGadget, - { - type Circuit = DelayedOpsCircuit; - - // fn pairing_setup(circuit: Self::Circuit) -> Vec> { - // vec![] - // } - } - - #[test] - fn test_delayed_pairing_circuit() { - type DemoCircuit = DelayedOpsCircuit; - - type DemoSNARK = DelayedOpsCircuitSNARK, G1Var>; - - let circuit = DemoCircuit { - _params: PhantomData, - w_g1: [None; 3], - d: None, - deferred_fns_ref: Default::default(), - }; - - // This is not cryptographically safe, use - // `OsRng` (for example) in production software. - let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(test_rng().next_u64()); - - let setup_timer = start_timer!(|| "Groth16::setup"); - let (pk, vk) = DemoSNARK::setup(circuit, &mut rng).unwrap(); - end_timer!(setup_timer); - - let process_vk_timer = start_timer!(|| "Groth16::process_vk"); - // let pvk = DemoSNARK::process_vk(&vk).unwrap(); - let pvk = vk; - end_timer!(process_vk_timer); - - let c_init_values = DemoCircuit { - _params: PhantomData, - w_g1: [Some(rng.gen()); 3], - d: Some(rng.gen()), - deferred_fns_ref: Default::default(), - }; - - let prove_timer = start_timer!(|| "Groth16::prove"); - let proof = DemoSNARK::prove(&pk, c_init_values, &mut rng).unwrap(); - end_timer!(prove_timer); - - let verify_timer = start_timer!(|| "Groth16::verify"); - let verify_result = DemoSNARK::verify_with_processed_vk(&pvk, &[], &proof); - end_timer!(verify_timer); - - assert!(verify_result.unwrap()); - } -} diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index 7a6a61b0e..4499a4240 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -1,4 +1,4 @@ -use crate::snark::OffloadedDataCircuit; +use crate::snark::{DeferredOpData, OffloadedData, OffloadedDataCircuit}; use ark_ec::pairing::Pairing; use ark_ec::{CurveGroup, VariableBaseMSM}; use ark_ff::{One, PrimeField}; @@ -29,23 +29,23 @@ where ) -> Result; } -pub struct OffloadedMSMGadget<'a, FVar, G, GVar, Circuit> +pub struct OffloadedMSMGadget<'a, FVar, E, GVar, Circuit> where - Circuit: OffloadedDataCircuit, - FVar: FieldVar + ToConstraintFieldGadget, - G: CurveGroup, - GVar: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + E: Pairing, + GVar: CurveVar + ToConstraintFieldGadget, { - _params: PhantomData<(FVar, G, GVar)>, + _params: PhantomData<(FVar, E, GVar)>, circuit: &'a Circuit, } -impl<'a, FVar, G, GVar, Circuit> OffloadedMSMGadget<'a, FVar, G, GVar, Circuit> +impl<'a, FVar, E, GVar, Circuit> OffloadedMSMGadget<'a, FVar, E, GVar, Circuit> where - Circuit: OffloadedDataCircuit, - FVar: FieldVar + ToConstraintFieldGadget, - G: CurveGroup, - GVar: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + E: Pairing, + GVar: CurveVar + ToConstraintFieldGadget, { pub fn new(circuit: &'a Circuit) -> Self { Self { @@ -55,17 +55,17 @@ where } } -impl<'a, FVar, G, GVar, Circuit> MSMGadget - for OffloadedMSMGadget<'a, FVar, G, GVar, Circuit> +impl<'a, FVar, E, GVar, Circuit> MSMGadget + for OffloadedMSMGadget<'a, FVar, E, GVar, Circuit> where - Circuit: OffloadedDataCircuit, - FVar: FieldVar + ToConstraintFieldGadget, - G: CurveGroup, - GVar: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + E: Pairing, + GVar: CurveVar + ToConstraintFieldGadget, { fn msm( &self, - cs: impl Into>, + cs: impl Into>, g1s: &[GVar], scalars: &[FVar], ) -> Result { @@ -85,8 +85,8 @@ where let (full_msm_value, msm_g1_value) = g1_values .zip(scalar_values) .map(|(g1s, scalars)| { - let r_g1 = G::msm_unchecked(&g1s, &scalars); - let minus_one = -G::ScalarField::one(); + let r_g1 = E::G1::msm_unchecked(&g1s, &scalars); + let minus_one = -E::ScalarField::one(); ( ( [g1s, vec![r_g1.into()]].concat(), @@ -108,7 +108,7 @@ where let ns = ns!(cs, "deferred_msm"); let cs = ns.cs(); - self.circuit.defer_msm(move || { + self.circuit.defer_op(move || { // write scalars to public_input for x in scalars { let scalar_input = FVar::new_input(ns!(cs, "scalar"), || x.value())?; @@ -138,7 +138,7 @@ where dbg!(cs.num_constraints()); dbg!(cs.num_instance_variables()); - Ok(full_msm_value) + Ok(DeferredOpData::MSM(full_msm_value)) }) }; dbg!(cs.num_constraints()); @@ -163,7 +163,7 @@ where pub struct OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> where E: Pairing, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, FVar: FieldVar + ToConstraintFieldGadget, GVar: CurveVar + ToConstraintFieldGadget, { @@ -174,7 +174,7 @@ where impl<'a, E, FVar, GVar, Circuit> OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> where E: Pairing, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, FVar: FieldVar + ToConstraintFieldGadget, GVar: CurveVar + ToConstraintFieldGadget, { @@ -190,7 +190,7 @@ impl<'a, E, FVar, GVar, Circuit> PairingGadget for OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> where E: Pairing, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, FVar: FieldVar + ToConstraintFieldGadget, GVar: CurveVar + ToConstraintFieldGadget, { @@ -205,53 +205,41 @@ where let g1_values_opt = g1s .iter() - .map(|g1| g1.value().ok()) + .map(|g1| g1.value().ok().map(|g1| g1.into_affine())) .collect::>>(); let g2_values = g2s; - if let Some(false) = - g1_values_opt.map(|g1_values| E::multi_pairing(&g1_values, g2_values).is_zero()) - { - return Err(SynthesisError::Unsatisfiable); + for g1_values in g1_values_opt.iter() { + if !E::multi_pairing(g1_values, g2_values).is_zero() { + return Err(SynthesisError::Unsatisfiable); + } } - // { - // let g1s = g1s.to_vec(); - // let ns = ns!(cs, "deferred_pairing"); - // let cs = ns.cs(); - // - // self.circuit.defer_msm(move || { - // let mut offsets = vec![]; - // - // // write g1s to public_input - // for g1 in g1s { - // let f_vec = g1.to_constraint_field()?; - // - // offsets.push(cs.num_instance_variables() - 1); - // for f in f_vec.iter() { - // let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; - // f_input.enforce_equal(f)?; - // } - // } - // - // // write g2s to public_input - // for g2 in g2s { - // let f_vec = g2.to_constraint_field()?; - // - // offsets.push(cs.num_instance_variables() - 1); - // for f in f_vec.iter() { - // let f_input = FpVar::new_input(ns!(cs, "g2s"), || f.value())?; - // f_input.enforce_equal(f)?; - // } - // } - // - // dbg!(cs.num_constraints()); - // dbg!(cs.num_instance_variables()); - // - // Ok(()) - // }) - // } + { + let g1_values_opt = g1_values_opt; + let g2_values = g2_values.to_vec(); + let g1s = g1s.to_vec(); + let ns = ns!(cs, "deferred_pairing"); + let cs = ns.cs(); + + self.circuit.defer_op(move || { + // write g1s to public_input + for g1 in g1s { + let f_vec = g1.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; + f_input.enforce_equal(f)?; + } + } + + dbg!(cs.num_constraints()); + dbg!(cs.num_instance_variables()); + + Ok(DeferredOpData::Pairing(g1_values_opt, g2_values)) + }) + } Ok(()) } diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 210913bac..67b84c88c 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -127,7 +127,7 @@ where E: Pairing, S: SpongeWithGadget, G1Var: CurveVar + ToConstraintFieldGadget, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, { _params: PhantomData<(E, S, G1Var)>, circuit: &'a Circuit, @@ -139,7 +139,7 @@ where E: Pairing, S: SpongeWithGadget, G1Var: CurveVar + ToConstraintFieldGadget, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, { pub fn new(circuit: &'a Circuit, g2_elements: Vec) -> Self { Self { @@ -157,7 +157,7 @@ where E: Pairing, S: SpongeWithGadget, G1Var: CurveVar + ToConstraintFieldGadget, - Circuit: OffloadedDataCircuit, + Circuit: OffloadedDataCircuit, { type VerifyingKeyVar = HyperKZGVerifierKeyVar; type ProofVar = HyperKZGProofVar; @@ -252,7 +252,7 @@ where }) .collect::>(); - let msm_gadget = OffloadedMSMGadget::, E::G1, G1Var, Circuit>::new(self.circuit); + let msm_gadget = OffloadedMSMGadget::, E, G1Var, Circuit>::new(self.circuit); let pairing_gadget = OffloadedPairingGadget::, G1Var, Circuit>::new(self.circuit); @@ -346,7 +346,7 @@ mod tests { G1Var: CurveVar + ToConstraintFieldGadget, { _params: PhantomData, - deferred_fns_ref: DeferredFnsRef, + deferred_fns_ref: DeferredFnsRef, pcs_pk_vk: (HyperKZGProverKey, HyperKZGVerifierKey), commitment: Option>, point: Vec>, @@ -355,12 +355,12 @@ mod tests { expected_result: Option, } - impl OffloadedDataCircuit for HyperKZGVerifierCircuit + impl OffloadedDataCircuit for HyperKZGVerifierCircuit where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { - fn deferred_fns_ref(&self) -> &DeferredFnsRef { + fn deferred_fns_ref(&self) -> &DeferredFnsRef { &self.deferred_fns_ref } } diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index 9a06a8c4e..931e2880b 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -17,6 +17,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError use ark_std::{cell::OnceCell, cell::RefCell, marker::PhantomData, rc::Rc}; use itertools::Itertools; use rand_core::{CryptoRng, RngCore}; +use std::any::Any; /// Describes G1 elements to be used in a multi-pairing. /// The verifier is responsible for ensuring that the sum of the pairings is zero. @@ -52,39 +53,46 @@ where S: SNARK, { pub snark_proof: S::Proof, - pub offloaded_data: OffloadedData, + pub offloaded_data: ProofData, } #[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] -pub struct OffloadedData { +pub struct ProofData { /// Blocks of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. /// It's the verifiers responsibility to ensure that the sum is zero. /// The scalar at index `length-1` is, by convention, always `-1`, so /// we save one public input element per MSM. - pub msms: Vec<(Vec, Vec)>, + msms: Vec<(Vec, Vec)>, + /// Blocks of G1 elements `Gᵢ` in the public input, used in multi-pairings with + /// the corresponding G2 elements in the offloaded SNARK verification key. + /// It's the verifiers responsibility to ensure that the sum is zero. + /// The scalar at index `length-1` is, by convention, always `-1`, so + /// we save one public input element per MSM. + pairing_g1s: Vec>, +} + +#[derive(Clone, Debug)] +pub struct OffloadedData { + proof_data: Option>, + setup_data: Vec>, } -pub enum DeferredOp { +pub enum DeferredOpData { MSM(Option<(Vec, Vec)>), - Pairing(OffloadedPairingDef), + Pairing(Option>, Vec), } -pub type DeferredFn = - dyn FnOnce() -> Result, Vec)>, SynthesisError>; +pub type DeferredFn = dyn FnOnce() -> Result, SynthesisError>; -pub type DeferredFnsRef = Rc>>>>; +pub type DeferredFnsRef = Rc>>>>; -pub trait OffloadedDataCircuit: Clone +pub trait OffloadedDataCircuit: Clone where - G: CurveGroup, + E: Pairing, { - fn deferred_fns_ref(&self) -> &DeferredFnsRef; + fn deferred_fns_ref(&self) -> &DeferredFnsRef; - fn defer_msm( - &self, - f: impl FnOnce() -> Result, Vec)>, SynthesisError> - + 'static, - ) { + fn defer_op(&self, f: impl FnOnce() -> Result, SynthesisError> + 'static) { self.deferred_fns_ref().borrow_mut().push(Box::new(f)); } } @@ -107,43 +115,76 @@ where struct WrappedCircuit where E: Pairing, - C: ConstraintSynthesizer + OffloadedDataCircuit, + C: ConstraintSynthesizer + OffloadedDataCircuit, { circuit: C, - offloaded_data_ref: Rc>>, + offloaded_data_ref: Rc>>, } +/// This is run both at setup and at proving time. +/// At setup time we only need to get G2 elements: we need them to form the verifying key. +/// At proving time we need to get G1 elements as well. fn run_deferred( - deferred_fns: Vec>>, -) -> Result>, SynthesisError> { - let msms = deferred_fns + deferred_fns: Vec>>, +) -> Result, SynthesisError> { + let op_data = deferred_fns .into_iter() .map(|f| f()) .collect::, _>>()?; - // can't collect into `Option>` above: it short-circuits on the first None - let msms = msms.into_iter().collect::>>(); + let op_data_by_type = op_data + .into_iter() + .into_grouping_map_by(|d| match d { + DeferredOpData::MSM(..) => 0, + DeferredOpData::Pairing(..) => 1, + }) + .collect::>(); + + let msms = op_data_by_type + .get(&0) + .into_iter() + .flatten() + .map(|d| match d { + DeferredOpData::MSM(msm_opt) => msm_opt.clone(), + _ => unreachable!(), + }) + .collect::>>(); - Ok(msms.map(|msms| OffloadedData { msms })) + let (p_g1s, p_g2s): (Vec<_>, Vec<_>) = op_data_by_type + .get(&1) + .into_iter() + .flatten() + .map(|d| match d { + DeferredOpData::Pairing(g1s_opt, g2s) => (g1s_opt.clone(), g2s.clone()), + _ => unreachable!(), + }) + .unzip(); + let pairing_g1s = p_g1s.into_iter().collect::>>(); + + Ok(OffloadedData { + proof_data: msms + .zip(pairing_g1s) + .map(|(msms, pairing_g1s)| ProofData { msms, pairing_g1s }), + setup_data: p_g2s, + }) } impl ConstraintSynthesizer for WrappedCircuit where E: Pairing, - C: ConstraintSynthesizer + OffloadedDataCircuit, + C: ConstraintSynthesizer + OffloadedDataCircuit, { fn generate_constraints(self, cs: ConstraintSystemRef) -> r1cs::Result<()> { + // `self.circuit` will be consumed by `self.circuit.generate_constraints(cs)` + // so we need to clone the reference to the deferred functions let deferred_fns_ref = self.circuit.deferred_fns_ref().clone(); let offloaded_data_ref = self.offloaded_data_ref.clone(); self.circuit.generate_constraints(cs)?; - let offloaded_data = run_deferred::(deferred_fns_ref.take())?; - if let Some(offloaded_data) = offloaded_data { - offloaded_data_ref.set(offloaded_data).unwrap(); - } + offloaded_data_ref.set(offloaded_data).unwrap(); Ok(()) } @@ -156,7 +197,7 @@ where S: SNARK, G1Var: CurveVar + ToConstraintFieldGadget, { - type Circuit: ConstraintSynthesizer + OffloadedDataCircuit; + type Circuit: ConstraintSynthesizer + OffloadedDataCircuit; fn setup( circuit: Self::Circuit, @@ -175,25 +216,26 @@ where rng: &mut R, ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> { - let offloaded_circuit = circuit.circuit.clone(); + let offloaded_data_ref = circuit.offloaded_data_ref.clone(); let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) .map_err(|e| OffloadedSNARKError::SNARKError(e))?; let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; - // let g2_vecs = Self::pairing_setup(offloaded_circuit); - // let delayed_pairings = g2_vecs - // .into_iter() - // .zip(g1_offsets_ref.get().unwrap().clone().into_iter()) - // .map(|(g2_elements, g1_offsets)| OffloadedPairingDef { - // g1_offsets, - // g2_elements, - // }) - // .collect(); + let setup_data = offloaded_data_ref.get().unwrap().clone().setup_data; + + let delayed_pairings = setup_data + .into_iter() + .map(|g2| OffloadedPairingDef { + g1_offsets: vec![], + g2_elements: g2, + }) + .collect(); + let vk = OffloadedSNARKVerifyingKey { snark_pvk, - delayed_pairings: vec![], + delayed_pairings, }; Ok((pk, vk)) @@ -214,9 +256,14 @@ where let proof = S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let proof_data = match offloaded_data_ref.get().unwrap().clone().proof_data { + Some(proof_data) => proof_data, + _ => unreachable!(), + }; + Ok(OffloadedSNARKProof { snark_proof: proof, - offloaded_data: offloaded_data_ref.get().unwrap().clone(), + offloaded_data: proof_data, }) } @@ -249,7 +296,7 @@ where } } - let pairings = Self::pairing_inputs(vk, &public_input, &proof.snark_proof)?; + let pairings = Self::pairing_inputs(vk, &proof.offloaded_data.pairing_g1s)?; for (g1s, g2s) in pairings { assert_eq!(g1s.len(), g2s.len()); let r = E::multi_pairing(&g1s, &g2s); @@ -279,31 +326,13 @@ where fn pairing_inputs( vk: &OffloadedSNARKVerifyingKey, - public_input: &[E::ScalarField], - proof: &S::Proof, + g1_vectors: &Vec>, ) -> Result, Vec)>, SerializationError> { - let g1_vectors = vk - .delayed_pairings - .iter() - .map(|pairing_def| { - let last_index = pairing_def.g1_offsets.len() - 1; - let g1s = pairing_def - .g1_offsets - .iter() - .enumerate() - .map(|(i, &offset)| { - let g1 = Self::g1_elements(public_input, offset, 1)?[0]; - if i == last_index { - Ok((-g1).into()) - } else { - Ok(g1.into()) - } - }) - .collect::, _>>(); - g1s - }) - .collect::>, SerializationError>>(); - Ok(g1_vectors?.into_iter().zip(Self::g2_elements(vk)).collect()) + Ok(g1_vectors + .into_iter() + .map(|g1_vec| g1_vec.into_iter().map(|&g1| g1.into()).collect()) + .zip(Self::g2_elements(vk)) + .collect()) } fn g2_elements(vk: &OffloadedSNARKVerifyingKey) -> Vec> { @@ -362,35 +391,50 @@ where fn build_public_input( public_input: &[E::ScalarField], - data: &OffloadedData, + data: &ProofData, ) -> Vec where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, { - let appended_data = data + let msm_data = data .msms .iter() .map(|msm| { let scalars = &msm.1; let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) - let msm_g1_vec = msm - .0 - .iter() - .map(|&g1| { - G1Var::constant(g1.into()) - .to_constraint_field() - .unwrap() - .iter() - .map(|x| x.value().unwrap()) - .collect::>() - }) - .concat(); + let g1s = &msm.0; + let msm_g1_vec = to_scalars::(g1s); [scalar_vec, msm_g1_vec].concat() }) .concat(); - [public_input.to_vec(), appended_data].concat() + let pairing_data = data + .pairing_g1s + .iter() + .map(|g1s| to_scalars::(g1s)) + .concat(); + + [public_input.to_vec(), msm_data, pairing_data].concat() +} + +fn to_scalars(g1s: &Vec) -> Vec +where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + let msm_g1_vec = g1s + .iter() + .map(|&g1| { + G1Var::constant(g1.into()) + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + }) + .concat(); + msm_g1_vec } From 678b04d705ec0530fe2384f55a858a9d570fa200 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 20 Aug 2024 16:54:14 -0700 Subject: [PATCH 39/44] Remove unneeded code --- .../curves/short_weierstrass/bls12/mod.rs | 314 ------------------ .../curves/short_weierstrass/bls12_381.rs | 7 + .../groups/curves/short_weierstrass/bn254.rs | 4 +- .../groups/curves/short_weierstrass/mod.rs | 5 +- jolt-core/src/circuits/mod.rs | 1 - jolt-core/src/circuits/pairing/bls12/mod.rs | 178 ---------- jolt-core/src/circuits/pairing/mod.rs | 206 ------------ 7 files changed, 10 insertions(+), 705 deletions(-) delete mode 100644 jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs create mode 100644 jolt-core/src/circuits/groups/curves/short_weierstrass/bls12_381.rs delete mode 100644 jolt-core/src/circuits/pairing/bls12/mod.rs delete mode 100644 jolt-core/src/circuits/pairing/mod.rs diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs deleted file mode 100644 index 84b4fc25f..000000000 --- a/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12/mod.rs +++ /dev/null @@ -1,314 +0,0 @@ -use ark_ec::{ - bls12::{Bls12Config, G1Prepared, G2Prepared, TwistType}, - short_weierstrass::Affine as GroupAffine, -}; -use ark_ff::{BitIteratorBE, Field, One}; -use ark_relations::r1cs::{Namespace, SynthesisError}; - -use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; -use ark_r1cs_std::prelude::*; -use ark_r1cs_std::{ - fields::{fp::FpVar, FieldVar}, - R1CSVar, -}; -use core::fmt::Debug; -use derivative::Derivative; - -use crate::circuits::fields::fp2::Fp2Var; -use crate::circuits::groups::curves::short_weierstrass::*; -use ark_std::vec::Vec; - -/// Represents a projective point in G1. -pub type G1Var = ProjectiveVar< -

::G1Config, - ConstraintF, - NonNativeFieldVar<

::Fp, ConstraintF>, ->; - -/// Represents an affine point on G1. Should be used only for comparison and -/// when a canonical representation of a point is required, and not for -/// arithmetic. -pub type G1AffineVar = AffineVar< -

::G1Config, - ConstraintF, - NonNativeFieldVar<

::Fp, ConstraintF>, ->; - -/// Represents a projective point in G2. -pub type G2Var = - ProjectiveVar<

::G2Config, ConstraintF, Fp2G>; -/// Represents an affine point on G2. Should be used only for comparison and -/// when a canonical representation of a point is required, and not for -/// arithmetic. -pub type G2AffineVar = - AffineVar<

::G2Config, ConstraintF, Fp2G>; - -/// Represents the cached precomputation that can be performed on a G1 element -/// which enables speeding up pairing computation. -#[derive(Derivative)] -#[derivative( - Clone(bound = "G1Var: Clone"), - Debug(bound = "G1Var: Debug") -)] -pub struct G1PreparedVar( - pub AffineVar>, -); - -impl G1PreparedVar { - /// Returns the value assigned to `self` in the underlying constraint - /// system. - pub fn value(&self) -> Result, SynthesisError> { - let x = self.0.x.value()?; - let y = self.0.y.value()?; - let infinity = self.0.infinity.value()?; - let g = infinity - .then_some(GroupAffine::identity()) - .unwrap_or(GroupAffine::new(x, y)) - .into(); - Ok(g) - } - - /// Constructs `Self` from a `G1Var`. - pub fn from_group_var(q: &G1Var) -> Result { - let g = q.to_affine()?; - Ok(Self(g)) - } -} - -impl AllocVar, ConstraintF> - for G1PreparedVar -{ - fn new_variable>>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - let g1_prep = f().map(|b| b.borrow().0); - - let x = NonNativeFieldVar::new_variable( - ark_relations::ns!(cs, "x"), - || g1_prep.map(|g| g.x), - mode, - )?; - let y = NonNativeFieldVar::new_variable( - ark_relations::ns!(cs, "y"), - || g1_prep.map(|g| g.y), - mode, - )?; - let infinity = Boolean::new_variable( - ark_relations::ns!(cs, "inf"), - || g1_prep.map(|g| g.infinity), - mode, - )?; - let g = AffineVar::new(x, y, infinity); - Ok(Self(g)) - } -} - -impl ToBytesGadget - for G1PreparedVar -{ - #[inline] - #[tracing::instrument(target = "r1cs")] - fn to_bytes(&self) -> Result>, SynthesisError> { - let mut bytes = self.0.x.to_bytes()?; - let y_bytes = self.0.y.to_bytes()?; - let inf_bytes = self.0.infinity.to_bytes()?; - bytes.extend_from_slice(&y_bytes); - bytes.extend_from_slice(&inf_bytes); - Ok(bytes) - } - - #[tracing::instrument(target = "r1cs")] - fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { - let mut bytes = self.0.x.to_non_unique_bytes()?; - let y_bytes = self.0.y.to_non_unique_bytes()?; - let inf_bytes = self.0.infinity.to_non_unique_bytes()?; - bytes.extend_from_slice(&y_bytes); - bytes.extend_from_slice(&inf_bytes); - Ok(bytes) - } -} - -type Fp2G = Fp2Var<

::Fp2Config, ConstraintF>; -type LCoeff = (Fp2G, Fp2G); -/// Represents the cached precomputation that can be performed on a G2 element -/// which enables speeding up pairing computation. -#[derive(Derivative)] -#[derivative( - Clone(bound = "Fp2Var: Clone"), - Debug(bound = "Fp2Var: Debug") -)] -pub struct G2PreparedVar { - #[doc(hidden)] - pub ell_coeffs: Vec>, -} - -impl AllocVar, ConstraintF> for G2PreparedVar -where - P: Bls12Config, - ConstraintF: PrimeField, -{ - #[tracing::instrument(target = "r1cs", skip(cs, f, mode))] - fn new_variable>>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - let g2_prep = f().map(|b| { - let projective_coeffs = &b.borrow().ell_coeffs; - match P::TWIST_TYPE { - TwistType::M => { - let mut z_s = projective_coeffs - .iter() - .map(|(_, _, z)| *z) - .collect::>(); - ark_ff::fields::batch_inversion(&mut z_s); - projective_coeffs - .iter() - .zip(z_s) - .map(|((x, y, _), z_inv)| (*x * &z_inv, *y * &z_inv)) - .collect::>() - } - TwistType::D => { - let mut z_s = projective_coeffs - .iter() - .map(|(z, ..)| *z) - .collect::>(); - ark_ff::fields::batch_inversion(&mut z_s); - projective_coeffs - .iter() - .zip(z_s) - .map(|((_, x, y), z_inv)| (*x * &z_inv, *y * &z_inv)) - .collect::>() - } - } - }); - - let l = Vec::new_variable( - ark_relations::ns!(cs, "l"), - || { - g2_prep - .clone() - .map(|c| c.iter().map(|(l, _)| *l).collect::>()) - }, - mode, - )?; - let r = Vec::new_variable( - ark_relations::ns!(cs, "r"), - || g2_prep.map(|c| c.iter().map(|(_, r)| *r).collect::>()), - mode, - )?; - let ell_coeffs = l.into_iter().zip(r).collect(); - Ok(Self { ell_coeffs }) - } -} - -impl ToBytesGadget for G2PreparedVar -where - P: Bls12Config, - ConstraintF: PrimeField, -{ - #[inline] - #[tracing::instrument(target = "r1cs")] - fn to_bytes(&self) -> Result>, SynthesisError> { - let mut bytes = Vec::new(); - for coeffs in &self.ell_coeffs { - bytes.extend_from_slice(&coeffs.0.to_bytes()?); - bytes.extend_from_slice(&coeffs.1.to_bytes()?); - } - Ok(bytes) - } - - #[tracing::instrument(target = "r1cs")] - fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { - let mut bytes = Vec::new(); - for coeffs in &self.ell_coeffs { - bytes.extend_from_slice(&coeffs.0.to_non_unique_bytes()?); - bytes.extend_from_slice(&coeffs.1.to_non_unique_bytes()?); - } - Ok(bytes) - } -} - -impl G2PreparedVar -where - P: Bls12Config, - ConstraintF: PrimeField, -{ - /// Constructs `Self` from a `G2Var`. - #[tracing::instrument(target = "r1cs")] - pub fn from_group_var(q: &G2Var) -> Result { - let q = q.to_affine()?; - let two_inv = P::Fp::one().double().inverse().unwrap(); - // Enforce that `q` is not the point at infinity. - q.infinity.enforce_not_equal(&Boolean::Constant(true))?; - let mut ell_coeffs = vec![]; - let mut r = q.clone(); - - for i in BitIteratorBE::new(P::X).skip(1) { - ell_coeffs.push(Self::double(&mut r, &two_inv)?); - - if i { - ell_coeffs.push(Self::add(&mut r, &q)?); - } - } - - Ok(Self { ell_coeffs }) - } - - #[tracing::instrument(target = "r1cs")] - fn double( - r: &mut G2AffineVar, - two_inv: &P::Fp, - ) -> Result, SynthesisError> { - let a = r.y.inverse()?; - let mut b = r.x.square()?; - let b_tmp = b.clone(); - b.mul_assign_by_base_field_constant(*two_inv); - b += &b_tmp; - - let c = &a * &b; - let d = r.x.double()?; - let x3 = c.square()? - &d; - let e = &c * &r.x - &r.y; - let c_x3 = &c * &x3; - let y3 = &e - &c_x3; - let mut f = c; - f.negate_in_place()?; - r.x = x3; - r.y = y3; - match P::TWIST_TYPE { - TwistType::M => Ok((e, f)), - TwistType::D => Ok((f, e)), - } - } - - #[tracing::instrument(target = "r1cs")] - fn add( - r: &mut G2AffineVar, - q: &G2AffineVar, - ) -> Result, SynthesisError> { - let a = (&q.x - &r.x).inverse()?; - let b = &q.y - &r.y; - let c = &a * &b; - let d = &r.x + &q.x; - let x3 = c.square()? - &d; - - let e = (&r.x - &x3) * &c; - let y3 = e - &r.y; - let g = &c * &r.x - &r.y; - let mut f = c; - f.negate_in_place()?; - r.x = x3; - r.y = y3; - match P::TWIST_TYPE { - TwistType::M => Ok((g, f)), - TwistType::D => Ok((f, g)), - } - } -} diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12_381.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12_381.rs new file mode 100644 index 000000000..ad77f2e22 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12_381.rs @@ -0,0 +1,7 @@ +use crate::circuits::groups::curves::short_weierstrass::ProjectiveVar; +use ark_bls12_381::{g1, Fq, Fr}; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; + +pub type FBaseVar = NonNativeFieldVar; + +pub type G1Var = ProjectiveVar; diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs index 3ca4a9838..0a28be59b 100644 --- a/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs @@ -1,7 +1,7 @@ use crate::circuits::groups::curves::short_weierstrass::ProjectiveVar; -use ark_bn254::{Bn254, Fq, Fr}; +use ark_bn254::{g1, Fq, Fr}; use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; pub type FBaseVar = NonNativeFieldVar; -pub type G1Var = ProjectiveVar; +pub type G1Var = ProjectiveVar; diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs index 735e044a5..b26f4e3ad 100644 --- a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs @@ -16,10 +16,7 @@ use ark_r1cs_std::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; use ark_std::vec::Vec; use binius_field::PackedField; -/// This module provides a generic implementation of G1 and G2 for -/// the [\[BLS12]\]() family of bilinear groups. -pub mod bls12; - +pub mod bls12_381; pub mod bn254; /// This module provides a generic implementation of elliptic curve operations diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs index 8b2a3e53f..fd74b94db 100644 --- a/jolt-core/src/circuits/mod.rs +++ b/jolt-core/src/circuits/mod.rs @@ -1,6 +1,5 @@ pub mod fields; pub mod groups; pub mod offloaded; -pub mod pairing; pub mod poly; pub mod transcript; diff --git a/jolt-core/src/circuits/pairing/bls12/mod.rs b/jolt-core/src/circuits/pairing/bls12/mod.rs deleted file mode 100644 index 819590835..000000000 --- a/jolt-core/src/circuits/pairing/bls12/mod.rs +++ /dev/null @@ -1,178 +0,0 @@ -use crate::circuits::fields::fp12::Fp12Var; -use crate::circuits::fields::fp2::Fp2Var; -use crate::circuits::groups::curves::short_weierstrass::bls12::{ - G1AffineVar, G1PreparedVar, G1Var, G2PreparedVar, G2Var, -}; -use ark_ec::bls12::{Bls12, Bls12Config, TwistType}; -use ark_ff::{BitIteratorBE, PrimeField}; -use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; -use ark_r1cs_std::prelude::FieldVar; -use ark_relations::r1cs::SynthesisError; -use ark_std::marker::PhantomData; - -/// Specifies the constraints for computing a pairing in a BLS12 bilinear group. -pub struct PairingGadget(PhantomData<(P, ConstraintF)>) -where - P: Bls12Config, - ConstraintF: PrimeField; - -type Fp2V = Fp2Var<

::Fp2Config, ConstraintF>; - -impl PairingGadget -where - P: Bls12Config, - ConstraintF: PrimeField, -{ - // Evaluate the line function at point p. - #[tracing::instrument(target = "r1cs")] - fn ell( - f: &mut Fp12Var, - coeffs: &(Fp2V, Fp2V), - p: &G1AffineVar, - ) -> Result<(), SynthesisError> { - let zero = NonNativeFieldVar::::zero(); - - match P::TWIST_TYPE { - TwistType::M => { - let c0 = coeffs.0.clone(); - let mut c1 = coeffs.1.clone(); - let c2 = Fp2V::::new(p.y.clone(), zero); - - c1.c0 *= &p.x; - c1.c1 *= &p.x; - *f = f.mul_by_014(&c0, &c1, &c2)?; - Ok(()) - } - TwistType::D => { - let c0 = Fp2V::::new(p.y.clone(), zero); - let mut c1 = coeffs.0.clone(); - let c2 = coeffs.1.clone(); - - c1.c0 *= &p.x; - c1.c1 *= &p.x; - *f = f.mul_by_034(&c0, &c1, &c2)?; - Ok(()) - } - } - } - - #[tracing::instrument(target = "r1cs")] - fn exp_by_x( - f: &Fp12Var, - ) -> Result, SynthesisError> { - let mut result = f.optimized_cyclotomic_exp(P::X)?; - if P::X_IS_NEGATIVE { - result = result.unitary_inverse()?; - } - Ok(result) - } -} - -impl super::PairingGadget, ConstraintF> - for PairingGadget -{ - type G1Var = G1Var; - type G2Var = G2Var; - type GTVar = Fp12Var; - type G1PreparedVar = G1PreparedVar; - type G2PreparedVar = G2PreparedVar; - - #[tracing::instrument(target = "r1cs")] - fn miller_loop( - ps: &[Self::G1PreparedVar], - qs: &[Self::G2PreparedVar], - ) -> Result { - let mut pairs = vec![]; - for (p, q) in ps.iter().zip(qs.iter()) { - pairs.push((p, q.ell_coeffs.iter())); - } - let mut f = Self::GTVar::one(); - - for i in BitIteratorBE::new(P::X).skip(1) { - f.square_in_place()?; - - for &mut (p, ref mut coeffs) in pairs.iter_mut() { - Self::ell(&mut f, coeffs.next().unwrap(), &p.0)?; - } - - if i { - for &mut (p, ref mut coeffs) in pairs.iter_mut() { - Self::ell(&mut f, &coeffs.next().unwrap(), &p.0)?; - } - } - } - - if P::X_IS_NEGATIVE { - f = f.unitary_inverse()?; - } - - Ok(f) - } - - #[tracing::instrument(target = "r1cs")] - fn final_exponentiation(f: &Self::GTVar) -> Result { - // Computing the final exponentation following - // https://eprint.iacr.org/2016/130.pdf. - // We don't use their "faster" formula because it is difficult to make - // it work for curves with odd `P::X`. - // Hence we implement the slower algorithm from Table 1 below. - - let f1 = f.unitary_inverse()?; - - f.inverse().and_then(|mut f2| { - // f2 = f^(-1); - // r = f^(p^6 - 1) - let mut r = f1; - r *= &f2; - - // f2 = f^(p^6 - 1) - f2 = r.clone(); - // r = f^((p^6 - 1)(p^2)) - r.frobenius_map_in_place(2)?; - - // r = f^((p^6 - 1)(p^2) + (p^6 - 1)) - // r = f^((p^6 - 1)(p^2 + 1)) - r *= &f2; - - // Hard part of the final exponentation is below: - // From https://eprint.iacr.org/2016/130.pdf, Table 1 - let mut y0 = r.cyclotomic_square()?; - y0 = y0.unitary_inverse()?; - - let mut y5 = Self::exp_by_x(&r)?; - - let mut y1 = y5.cyclotomic_square()?; - let mut y3 = y0 * &y5; - y0 = Self::exp_by_x(&y3)?; - let y2 = Self::exp_by_x(&y0)?; - let mut y4 = Self::exp_by_x(&y2)?; - y4 *= &y1; - y1 = Self::exp_by_x(&y4)?; - y3 = y3.unitary_inverse()?; - y1 *= &y3; - y1 *= &r; - y3 = r.clone(); - y3 = y3.unitary_inverse()?; - y0 *= &r; - y0.frobenius_map_in_place(3)?; - y4 *= &y3; - y4.frobenius_map_in_place(1)?; - y5 *= &y2; - y5.frobenius_map_in_place(2)?; - y5 *= &y0; - y5 *= &y4; - y5 *= &y1; - Ok(y5) - }) - } - - #[tracing::instrument(target = "r1cs")] - fn prepare_g1(p: &Self::G1Var) -> Result { - Self::G1PreparedVar::from_group_var(p) - } - - #[tracing::instrument(target = "r1cs")] - fn prepare_g2(q: &Self::G2Var) -> Result { - Self::G2PreparedVar::from_group_var(q) - } -} diff --git a/jolt-core/src/circuits/pairing/mod.rs b/jolt-core/src/circuits/pairing/mod.rs deleted file mode 100644 index 36bec9d44..000000000 --- a/jolt-core/src/circuits/pairing/mod.rs +++ /dev/null @@ -1,206 +0,0 @@ -pub mod bls12; - -use ark_ec::pairing::Pairing; -use ark_ff::PrimeField; -use ark_r1cs_std::prelude::*; -use ark_relations::r1cs::SynthesisError; -use ark_std::fmt::Debug; - -/// Specifies the constraints for computing a pairing in the bilinear group -/// `E`. -pub trait PairingGadget { - /// A variable representing an element of `G1`. - /// This is the R1CS equivalent of `E::G1Projective`. - type G1Var: CurveVar - + AllocVar - + AllocVar; - - /// A variable representing an element of `G2`. - /// This is the R1CS equivalent of `E::G2Projective`. - type G2Var: CurveVar - + AllocVar - + AllocVar; - - /// A variable representing an element of `GT`. - /// This is the R1CS equivalent of `E::GT`. - type GTVar: FieldVar; - - /// A variable representing cached precomputation that can speed up - /// pairings computations. This is the R1CS equivalent of - /// `E::G1Prepared`. - type G1PreparedVar: ToBytesGadget - + AllocVar - + Clone - + Debug; - /// A variable representing cached precomputation that can speed up - /// pairings computations. This is the R1CS equivalent of - /// `E::G2Prepared`. - type G2PreparedVar: ToBytesGadget - + AllocVar - + Clone - + Debug; - - /// Computes a multi-miller loop between elements - /// of `p` and `q`. - fn miller_loop( - p: &[Self::G1PreparedVar], - q: &[Self::G2PreparedVar], - ) -> Result; - - /// Computes a final exponentiation over `p`. - fn final_exponentiation(p: &Self::GTVar) -> Result; - - /// Computes a pairing over `p` and `q`. - #[tracing::instrument(target = "r1cs")] - fn pairing( - p: Self::G1PreparedVar, - q: Self::G2PreparedVar, - ) -> Result { - let tmp = Self::miller_loop(&[p], &[q])?; - Self::final_exponentiation(&tmp) - } - - /// Computes a product of pairings over the elements in `p` and `q`. - #[must_use] - #[tracing::instrument(target = "r1cs")] - fn multi_pairing( - p: &[Self::G1PreparedVar], - q: &[Self::G2PreparedVar], - ) -> Result { - let miller_result = Self::miller_loop(p, q)?; - Self::final_exponentiation(&miller_result) - } - - /// Performs the precomputation to generate `Self::G1PreparedVar`. - fn prepare_g1(q: &Self::G1Var) -> Result; - - /// Performs the precomputation to generate `Self::G2PreparedVar`. - fn prepare_g2(q: &Self::G2Var) -> Result; -} - -#[cfg(test)] -mod tests { - use super::*; - use ark_bls12_381::Bls12_381; - use ark_bn254::Bn254; - use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; - use ark_ec::pairing::Pairing; - use ark_ec::Group; - use ark_ff::PrimeField; - use ark_groth16::Groth16; - use ark_r1cs_std::prelude::*; - use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; - use ark_std::marker::PhantomData; - use ark_std::rand::Rng; - use ark_std::{end_timer, start_timer, test_rng}; - use rand_core::{RngCore, SeedableRng}; - - struct PairingCheckCircuit - where - E: Pairing, - ConstraintF: PrimeField, - P: PairingGadget, - { - r: Option, - r_g2: Option, - _params: PhantomData<(ConstraintF, P)>, - } - - impl ConstraintSynthesizer - for PairingCheckCircuit - where - E: Pairing, - ConstraintF: PrimeField, - P: PairingGadget, - { - fn generate_constraints( - self, - cs: ConstraintSystemRef, - ) -> Result<(), SynthesisError> { - dbg!(cs.num_constraints()); - - let r_g1 = P::G1Var::new_witness(cs.clone(), || { - Ok(E::G1::generator() * self.r.ok_or(SynthesisError::AssignmentMissing)?) - })?; - dbg!(cs.num_constraints()); - - let r_g1_prepared = P::prepare_g1(&r_g1)?; - dbg!(cs.num_constraints()); - - let minus_one_g1_prepared = P::G1PreparedVar::new_constant( - cs.clone(), - &E::G1Prepared::from(-E::G1::generator()), - )?; - dbg!(cs.num_constraints()); - - let r_g2 = P::G2Var::new_witness(cs.clone(), || { - Ok(self.r_g2.ok_or(SynthesisError::AssignmentMissing)?) - })?; - dbg!(cs.num_constraints()); - - let r_g2_prepared = P::prepare_g2(&r_g2)?; - dbg!(cs.num_constraints()); - - let one_g2_prepared = P::G2PreparedVar::new_constant( - cs.clone(), - &E::G2Prepared::from(E::G2::generator()), - )?; - dbg!(cs.num_constraints()); - - let result = P::multi_pairing( - &[r_g1_prepared, minus_one_g1_prepared], - &[one_g2_prepared, r_g2_prepared], - )?; - dbg!(cs.num_constraints()); - - result.enforce_equal(&P::GTVar::one()) - } - } - - #[test] - #[ignore] - fn test_pairing_check_circuit() { - type DemoCircuit = PairingCheckCircuit< - Bls12_381, - ark_bn254::Fr, - bls12::PairingGadget, - >; - - let c = DemoCircuit { - r: None, - r_g2: None, - _params: PhantomData, - }; - - // This is not cryptographically safe, use - // `OsRng` (for example) in production software. - let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(test_rng().next_u64()); - - let setup_timer = start_timer!(|| "Groth16::setup"); - let (pk, vk) = Groth16::::setup(c, &mut rng).unwrap(); - end_timer!(setup_timer); - - let process_vk_timer = start_timer!(|| "Groth16::process_vk"); - let pvk = Groth16::::process_vk(&vk).unwrap(); - end_timer!(process_vk_timer); - - let r = rng.gen(); - let r_g2 = ::G2::generator() * &r; - - let c = DemoCircuit { - r: Some(r), - r_g2: Some(r_g2.into()), - _params: PhantomData, - }; - - let prove_timer = start_timer!(|| "Groth16::prove"); - let proof = Groth16::::prove(&pk, c, &mut rng).unwrap(); - end_timer!(prove_timer); - - let verify_timer = start_timer!(|| "Groth16::verify"); - let verify_result = Groth16::::verify_with_processed_vk(&pvk, &[], &proof); - end_timer!(verify_timer); - - assert!(verify_result.unwrap()); - } -} From ae0e9cf5194ccef49f6f38a753e0f8c9104ce592 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 20 Aug 2024 17:17:31 -0700 Subject: [PATCH 40/44] Cleanup warnings --- jolt-core/src/circuits/fields/fp2.rs | 7 +- .../groups/curves/short_weierstrass/mod.rs | 11 +-- jolt-core/src/circuits/offloaded/mod.rs | 26 +++--- .../poly/commitment/commitment_scheme.rs | 1 - .../src/circuits/poly/commitment/hyperkzg.rs | 81 ++++++++----------- jolt-core/src/circuits/transcript/mock.rs | 28 +++---- jolt-core/src/circuits/transcript/mod.rs | 18 ++--- jolt-core/src/snark/mod.rs | 7 +- 8 files changed, 72 insertions(+), 107 deletions(-) diff --git a/jolt-core/src/circuits/fields/fp2.rs b/jolt-core/src/circuits/fields/fp2.rs index 87050befe..619c6e9cf 100644 --- a/jolt-core/src/circuits/fields/fp2.rs +++ b/jolt-core/src/circuits/fields/fp2.rs @@ -1,7 +1,8 @@ use crate::circuits::fields::quadratic_extension::*; -use ark_ff::fields::{Fp2Config, Fp2ConfigWrapper, QuadExtConfig}; -use ark_ff::PrimeField; -use ark_r1cs_std::fields::fp::FpVar; +use ark_ff::{ + fields::{Fp2Config, Fp2ConfigWrapper, QuadExtConfig}, + PrimeField, +}; use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; /// A quadratic extension field constructed over a prime field. diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs index b26f4e3ad..b3bfd59ef 100644 --- a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs @@ -5,17 +5,12 @@ use ark_ec::{ AffineRepr, CurveGroup, }; use ark_ff::{BigInteger, BitIteratorBE, Field, One, PrimeField, Zero}; -use ark_r1cs_std::impl_bounded_ops; +use ark_r1cs_std::{fields::fp::FpVar, impl_bounded_ops, prelude::*, ToConstraintFieldGadget}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; -use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul}; +use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul, vec::Vec}; use derivative::Derivative; use non_zero_affine::NonZeroAffineVar; -use ark_r1cs_std::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; - -use ark_std::vec::Vec; -use binius_field::PackedField; - pub mod bls12_381; pub mod bn254; @@ -852,7 +847,7 @@ where let (mut ge, iter) = if cofactor_weight < modulus_minus_1_weight { let ge = Self::new_variable_omit_prime_order_check( ark_relations::ns!(cs, "Witness without subgroup check with cofactor mul"), - || f().map(|g| g.borrow().into_affine().mul_by_cofactor_inv().into()), + || f().map(|g| g.into_affine().mul_by_cofactor_inv().into()), mode, )?; ( diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index 4499a4240..a3c9770d4 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -1,18 +1,14 @@ -use crate::snark::{DeferredOpData, OffloadedData, OffloadedDataCircuit}; -use ark_ec::pairing::Pairing; -use ark_ec::{CurveGroup, VariableBaseMSM}; -use ark_ff::{One, PrimeField}; -use ark_r1cs_std::alloc::AllocVar; -use ark_r1cs_std::boolean::Boolean; -use ark_r1cs_std::eq::EqGadget; -use ark_r1cs_std::fields::fp::FpVar; -use ark_r1cs_std::fields::FieldVar; -use ark_r1cs_std::groups::CurveVar; -use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; -use ark_relations::ns; -use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; -use ark_serialize::Valid; -use ark_std::Zero; +use crate::snark::{DeferredOpData, OffloadedDataCircuit}; +use ark_ec::{pairing::Pairing, CurveGroup, VariableBaseMSM}; +use ark_r1cs_std::{ + alloc::AllocVar, eq::EqGadget, fields::fp::FpVar, fields::FieldVar, groups::CurveVar, R1CSVar, + ToConstraintFieldGadget, +}; +use ark_relations::{ + ns, + r1cs::{Namespace, SynthesisError}, +}; +use ark_std::{One, Zero}; use std::marker::PhantomData; pub trait MSMGadget diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs index e5d9061d0..e0a3df7c9 100644 --- a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -1,6 +1,5 @@ use crate::poly::commitment::commitment_scheme::CommitmentScheme; use ark_crypto_primitives::sponge::constraints::SpongeWithGadget; -use ark_ec::pairing::Pairing; use ark_ff::PrimeField; use ark_r1cs_std::fields::fp::FpVar; use ark_r1cs_std::prelude::*; diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 67b84c88c..7ecfdf1bd 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,30 +1,24 @@ -use crate::circuits::offloaded::{ - MSMGadget, OffloadedMSMGadget, OffloadedPairingGadget, PairingGadget, +use crate::{ + circuits::{ + offloaded::{MSMGadget, OffloadedMSMGadget, OffloadedPairingGadget, PairingGadget}, + poly::commitment::commitment_scheme::CommitmentVerifierGadget, + transcript::ImplAbsorb, + }, + field::JoltField, + poly::commitment::hyperkzg::{ + HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, + }, + snark::OffloadedDataCircuit, }; -use crate::circuits::poly::commitment::commitment_scheme::CommitmentVerifierGadget; -use crate::circuits::transcript::ImplAbsorb; -use crate::field::JoltField; -use crate::poly::commitment::hyperkzg::{ - HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, -}; -use crate::snark::OffloadedDataCircuit; -use ark_crypto_primitives::sponge::constraints::{ - AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget, -}; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::constraints::{CryptographicSpongeVar, SpongeWithGadget}; use ark_ec::pairing::Pairing; -use ark_ff::{Field, PrimeField}; -use ark_r1cs_std::boolean::Boolean; -use ark_r1cs_std::fields::fp::FpVar; -use ark_r1cs_std::pairing::PairingVar; -use ark_r1cs_std::prelude::*; -use ark_r1cs_std::ToConstraintFieldGadget; -use ark_relations::ns; -use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, SynthesisError}; -use ark_std::borrow::Borrow; -use ark_std::iterable::Iterable; -use ark_std::marker::PhantomData; -use ark_std::One; +use ark_ff::PrimeField; +use ark_r1cs_std::{boolean::Boolean, fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; +use ark_relations::{ + ns, + r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, SynthesisError}, +}; +use ark_std::{borrow::Borrow, iterable::Iterable, marker::PhantomData, One}; #[derive(Clone)] pub struct HyperKZGProofVar @@ -300,7 +294,7 @@ fn q_powers>( let q_powers = [vec![FpVar::Constant(E::ScalarField::one()), q.clone()], { let mut q_power = q.clone(); (2..ell) - .map(|i| { + .map(|_i| { q_power *= &q; q_power.clone() }) @@ -313,31 +307,24 @@ fn q_powers>( #[cfg(test)] mod tests { use super::*; - use crate::circuits::groups::curves::short_weierstrass::bn254::G1Var; - use crate::circuits::transcript::mock::{MockSponge, MockSpongeVar}; - use crate::poly::commitment::hyperkzg::{ - HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey, - }; - use crate::poly::commitment::kzg::KZGVerifierKey; - use crate::poly::dense_mlpoly::DensePolynomial; - use crate::snark::{ - DeferredFnsRef, OffloadedDataCircuit, OffloadedPairingDef, OffloadedSNARK, - OffloadedSNARKError, OffloadedSNARKVerifyingKey, + use crate::{ + circuits::{ + groups::curves::short_weierstrass::bn254::G1Var, + transcript::mock::{MockSponge, MockSpongeVar}, + }, + poly::{ + commitment::hyperkzg::{HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey}, + dense_mlpoly::DensePolynomial, + }, + snark::{DeferredFnsRef, OffloadedDataCircuit, OffloadedSNARK}, + utils::{errors::ProofVerifyError, transcript::ProofTranscript}, }; - use crate::utils::errors::ProofVerifyError; - use crate::utils::transcript::ProofTranscript; use ark_bn254::Bn254; - use ark_crypto_primitives::snark::{CircuitSpecificSetupSNARK, SNARK}; - use ark_crypto_primitives::sponge::constraints::CryptographicSpongeVar; - use ark_crypto_primitives::sponge::poseidon::constraints::PoseidonSpongeVar; - use ark_crypto_primitives::sponge::poseidon::{PoseidonConfig, PoseidonDefaultConfigField}; + use ark_crypto_primitives::{snark::SNARK, sponge::constraints::CryptographicSpongeVar}; use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; use ark_r1cs_std::ToConstraintFieldGadget; use ark_relations::ns; - use ark_serialize::SerializationError; - use ark_std::rand::Rng; - use ark_std::Zero; - use rand_core::{CryptoRng, RngCore, SeedableRng}; + use rand_core::{RngCore, SeedableRng}; #[derive(Clone)] struct HyperKZGVerifierCircuit @@ -539,7 +526,7 @@ mod tests { // Create a groth16 proof with our parameters. let proof = VerifierSNARK::prove(&cpk, verifier_circuit, &mut rng) - .map_err(|e| ProofVerifyError::InternalError)?; + .map_err(|_e| ProofVerifyError::InternalError)?; let result = VerifierSNARK::verify_with_processed_vk(&cvk, &instance, &proof); match result { diff --git a/jolt-core/src/circuits/transcript/mock.rs b/jolt-core/src/circuits/transcript/mock.rs index 6e4eda286..f63d4f802 100644 --- a/jolt-core/src/circuits/transcript/mock.rs +++ b/jolt-core/src/circuits/transcript/mock.rs @@ -1,18 +1,14 @@ use crate::circuits::transcript::SLICE; use crate::field::JoltField; use crate::utils::transcript::ProofTranscript; -use ark_crypto_primitives::sponge::constraints::{ - AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget, +use ark_crypto_primitives::sponge::{ + constraints::{AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget}, + Absorb, CryptographicSponge, }; -use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_r1cs_std::boolean::Boolean; -use ark_r1cs_std::fields::fp::FpVar; -use ark_r1cs_std::prelude::*; -use ark_r1cs_std::R1CSVar; +use ark_r1cs_std::{boolean::Boolean, fields::fp::FpVar, prelude::*, R1CSVar}; use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; -use ark_std::any::Any; -use std::marker::PhantomData; +use ark_std::marker::PhantomData; #[derive(Clone)] pub struct MockSponge @@ -28,21 +24,21 @@ where { type Config = (); - fn new(params: &Self::Config) -> Self { + fn new(_params: &Self::Config) -> Self { Self { _params: PhantomData, } } - fn absorb(&mut self, input: &impl Absorb) { + fn absorb(&mut self, _input: &impl Absorb) { todo!() } - fn squeeze_bytes(&mut self, num_bytes: usize) -> Vec { + fn squeeze_bytes(&mut self, _num_bytes: usize) -> Vec { todo!() } - fn squeeze_bits(&mut self, num_bits: usize) -> Vec { + fn squeeze_bits(&mut self, _num_bits: usize) -> Vec { todo!() } } @@ -68,7 +64,7 @@ impl CryptographicSpongeVar> where ConstraintF: PrimeField + JoltField, { - type Parameters = (&'static [u8]); + type Parameters = &'static [u8]; fn new(cs: ConstraintSystemRef, params: &Self::Parameters) -> Self { Self { @@ -112,14 +108,14 @@ where fn squeeze_bytes( &mut self, - num_bytes: usize, + _num_bytes: usize, ) -> Result>, SynthesisError> { todo!() } fn squeeze_bits( &mut self, - num_bits: usize, + _num_bits: usize, ) -> Result>, SynthesisError> { todo!() } diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs index 12f66d886..a744b9b27 100644 --- a/jolt-core/src/circuits/transcript/mod.rs +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -1,17 +1,9 @@ -use crate::utils::transcript::ProofTranscript; use ark_crypto_primitives::sponge::constraints::AbsorbGadget; use ark_ff::PrimeField; -use ark_r1cs_std::fields::fp::FpVar; -use ark_r1cs_std::prelude::*; -use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; -use ark_relations::ns; -use ark_relations::r1cs::SynthesisError; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use ark_std::iterable::Iterable; -use ark_std::Zero; -use std::cell::RefCell; -use std::fmt::Debug; -use std::marker::PhantomData; +use ark_r1cs_std::{fields::fp::FpVar, prelude::*, R1CSVar}; +use ark_relations::{ns, r1cs::SynthesisError}; +use ark_serialize::CanonicalSerialize; +use ark_std::{cell::RefCell, fmt::Debug, marker::PhantomData, Zero}; pub mod mock; @@ -49,7 +41,7 @@ where t_value .serialize_compressed(&mut buf) - .map_err(|e| SynthesisError::Unsatisfiable)?; + .map_err(|_e| SynthesisError::Unsatisfiable)?; buf.into_iter() .map(|b| UInt8::new_witness(ns!(self.0.cs(), "sponge_byte"), || Ok(b))) diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index 931e2880b..2ccab909f 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -2,7 +2,7 @@ use ark_crypto_primitives::snark::SNARK; use ark_ec::{ pairing::Pairing, short_weierstrass::{Affine, SWCurveConfig}, - AffineRepr, CurveGroup, VariableBaseMSM, + AffineRepr, VariableBaseMSM, }; use ark_ff::{PrimeField, Zero}; use ark_r1cs_std::{ @@ -14,10 +14,9 @@ use ark_r1cs_std::{ use ark_relations::r1cs; use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; -use ark_std::{cell::OnceCell, cell::RefCell, marker::PhantomData, rc::Rc}; +use ark_std::{cell::OnceCell, cell::RefCell, rc::Rc}; use itertools::Itertools; use rand_core::{CryptoRng, RngCore}; -use std::any::Any; /// Describes G1 elements to be used in a multi-pairing. /// The verifier is responsible for ensuring that the sum of the pairings is zero. @@ -112,7 +111,7 @@ where SynthesisError(#[from] SynthesisError), } -struct WrappedCircuit +pub struct WrappedCircuit where E: Pairing, C: ConstraintSynthesizer + OffloadedDataCircuit, From 6565391ab9aa14042ea5d76c40d94ca0de0d228c Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 20 Aug 2024 17:52:24 -0700 Subject: [PATCH 41/44] Cleanup lints --- .../src/circuits/fields/cubic_extension.rs | 2 +- jolt-core/src/circuits/fields/fp12.rs | 10 +++--- .../circuits/fields/quadratic_extension.rs | 2 +- .../groups/curves/short_weierstrass/mod.rs | 22 +++++------- jolt-core/src/circuits/offloaded/mod.rs | 1 - .../src/circuits/poly/commitment/hyperkzg.rs | 9 +++-- jolt-core/src/circuits/transcript/mod.rs | 2 +- jolt-core/src/snark/mod.rs | 34 +++++++++++-------- 8 files changed, 42 insertions(+), 40 deletions(-) diff --git a/jolt-core/src/circuits/fields/cubic_extension.rs b/jolt-core/src/circuits/fields/cubic_extension.rs index c66572c5e..8c1d86398 100644 --- a/jolt-core/src/circuits/fields/cubic_extension.rs +++ b/jolt-core/src/circuits/fields/cubic_extension.rs @@ -87,7 +87,7 @@ where /// Sets `self = self.mul_by_base_field_constant(fe)`. #[inline] pub fn mul_assign_by_base_field_constant(&mut self, fe: P::BaseField) { - *self = (&*self).mul_by_base_field_constant(fe); + *self = (*self).mul_by_base_field_constant(fe); } } diff --git a/jolt-core/src/circuits/fields/fp12.rs b/jolt-core/src/circuits/fields/fp12.rs index 017deefe1..f46397770 100644 --- a/jolt-core/src/circuits/fields/fp12.rs +++ b/jolt-core/src/circuits/fields/fp12.rs @@ -45,11 +45,11 @@ where c1: &Fp2Var, ConstraintF>, d1: &Fp2Var, ConstraintF>, ) -> Result { - let v0 = self.c0.mul_by_c0_c1_0(&c0, &c1)?; - let v1 = self.c1.mul_by_0_c1_0(&d1)?; + let v0 = self.c0.mul_by_c0_c1_0(c0, c1)?; + let v1 = self.c1.mul_by_0_c1_0(d1)?; let new_c0 = Self::mul_base_field_by_nonresidue(&v1)? + &v0; - let new_c1 = (&self.c0 + &self.c1).mul_by_c0_c1_0(&c0, &(c1 + d1))? - &v0 - &v1; + let new_c1 = (&self.c0 + &self.c1).mul_by_c0_c1_0(c0, &(c1 + d1))? - &v0 - &v1; Ok(Self::new(new_c0, new_c1)) } @@ -66,11 +66,11 @@ where let a1 = &self.c0.c1 * c0; let a2 = &self.c0.c2 * c0; let a = Fp6Var::new(a0, a1, a2); - let b = self.c1.mul_by_c0_c1_0(&d0, &d1)?; + let b = self.c1.mul_by_c0_c1_0(d0, d1)?; let c0 = c0 + d0; let c1 = d1; - let e = (&self.c0 + &self.c1).mul_by_c0_c1_0(&c0, &c1)?; + let e = (&self.c0 + &self.c1).mul_by_c0_c1_0(&c0, c1)?; let new_c1 = e - (&a + &b); let new_c0 = Self::mul_base_field_by_nonresidue(&b)? + &a; diff --git a/jolt-core/src/circuits/fields/quadratic_extension.rs b/jolt-core/src/circuits/fields/quadratic_extension.rs index 84c4028cb..bcb1e3749 100644 --- a/jolt-core/src/circuits/fields/quadratic_extension.rs +++ b/jolt-core/src/circuits/fields/quadratic_extension.rs @@ -82,7 +82,7 @@ where /// Sets `self = self.mul_by_base_field_constant(fe)`. #[inline] pub fn mul_assign_by_base_field_constant(&mut self, fe: P::BaseField) { - *self = (&*self).mul_by_base_field_constant(fe); + *self = (*self).mul_by_base_field_constant(fe); } /// This is only to be used when the element is *known* to be in the diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs index b3bfd59ef..e0ae386cf 100644 --- a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs @@ -339,10 +339,10 @@ where for bit in affine_bits.iter().skip(1) { if bit.is_constant() { if *bit == &Boolean::TRUE { - accumulator = accumulator.add_unchecked(&multiple_of_power_of_two)?; + accumulator = accumulator.add_unchecked(multiple_of_power_of_two)?; } } else { - let temp = accumulator.add_unchecked(&multiple_of_power_of_two)?; + let temp = accumulator.add_unchecked(multiple_of_power_of_two)?; accumulator = bit.select(&temp, &accumulator)?; } multiple_of_power_of_two.double_in_place()?; @@ -365,7 +365,7 @@ where } } else { let temp = &*mul_result + &multiple_of_power_of_two.into_projective(); - *mul_result = bit.select(&temp, &mul_result)?; + *mul_result = bit.select(&temp, mul_result)?; } multiple_of_power_of_two.double_in_place()?; } @@ -507,10 +507,8 @@ where &self, bits: impl Iterator>, ) -> Result { - if self.is_constant() { - if self.value().unwrap().is_zero() { - return Ok(self.clone()); - } + if self.is_constant() && self.value().unwrap().is_zero() { + return Ok(self.clone()); } let self_affine = self.to_affine()?; let (x, y, infinity) = (self_affine.x, self_affine.y, self_affine.infinity); @@ -519,7 +517,7 @@ where let non_zero_self = NonZeroAffineVar::new(x, y); let mut bits = bits.collect::>(); - if bits.len() == 0 { + if bits.is_empty() { return Ok(Self::zero()); } // Remove unnecessary constant zeros in the most-significant positions. @@ -528,7 +526,7 @@ where // We iterate from the MSB down. .rev() // Skip leading zeros, if they are constants. - .skip_while(|b| b.is_constant() && (b.value().unwrap() == false)) + .skip_while(|b| b.is_constant() && (!b.value().unwrap())) .collect(); // After collecting we are in big-endian form; we have to reverse to get back to // little-endian. @@ -791,11 +789,7 @@ where f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { - Self::new_variable( - cs, - || f().map(|b| SWProjective::from((*b.borrow()).clone())), - mode, - ) + Self::new_variable(cs, || f().map(|b| SWProjective::from(*b.borrow())), mode) } } diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs index a3c9770d4..c44647897 100644 --- a/jolt-core/src/circuits/offloaded/mod.rs +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -213,7 +213,6 @@ where } { - let g1_values_opt = g1_values_opt; let g2_values = g2_values.to_vec(); let g1s = g1s.to_vec(); let ns = ns!(cs, "deferred_pairing"); diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 7ecfdf1bd..5b710b847 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -16,7 +16,7 @@ use ark_ff::PrimeField; use ark_r1cs_std::{boolean::Boolean, fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; use ark_relations::{ ns, - r1cs::{ConstraintSynthesizer, ConstraintSystemRef, Namespace, SynthesisError}, + r1cs::{Namespace, SynthesisError}, }; use ark_std::{borrow::Borrow, iterable::Iterable, marker::PhantomData, One}; @@ -185,7 +185,7 @@ where .next() .unwrap(); - let u = vec![r.clone(), r.negate()?, r.clone() * &r]; + let u = [r.clone(), r.negate()?, r.clone() * &r]; let com = [vec![c.clone()], com.clone()].concat(); @@ -323,7 +323,10 @@ mod tests { use ark_crypto_primitives::{snark::SNARK, sponge::constraints::CryptographicSpongeVar}; use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; use ark_r1cs_std::ToConstraintFieldGadget; - use ark_relations::ns; + use ark_relations::{ + ns, + r1cs::{ConstraintSynthesizer, ConstraintSystemRef}, + }; use rand_core::{RngCore, SeedableRng}; #[derive(Clone)] diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs index a744b9b27..92f4e0a38 100644 --- a/jolt-core/src/circuits/transcript/mod.rs +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -23,7 +23,7 @@ where } thread_local! { - static SLICE: RefCell> = RefCell::new(None); + static SLICE: RefCell> = const { RefCell::new(None) }; } impl<'a, T, F> AbsorbGadget for ImplAbsorb<'a, T, F> diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index 2ccab909f..46621641b 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -61,7 +61,7 @@ pub struct ProofData { /// It's the verifiers responsibility to ensure that the sum is zero. /// The scalar at index `length-1` is, by convention, always `-1`, so /// we save one public input element per MSM. - msms: Vec<(Vec, Vec)>, + msms: Vec>, /// Blocks of G1 elements `Gᵢ` in the public input, used in multi-pairings with /// the corresponding G2 elements in the offloaded SNARK verification key. /// It's the verifiers responsibility to ensure that the sum is zero. @@ -77,10 +77,17 @@ pub struct OffloadedData { } pub enum DeferredOpData { - MSM(Option<(Vec, Vec)>), + MSM(Option>), Pairing(Option>, Vec), } +pub type MSMDef = ( + Vec<::G1Affine>, + Vec<::ScalarField>, +); + +pub type MultiPairingDef = (Vec<::G1>, Vec<::G2>); + pub type DeferredFn = dyn FnOnce() -> Result, SynthesisError>; pub type DeferredFnsRef = Rc>>>>; @@ -217,10 +224,10 @@ where { let offloaded_data_ref = circuit.offloaded_data_ref.clone(); - let (pk, snark_vk) = S::circuit_specific_setup(circuit, rng) - .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let (pk, snark_vk) = + S::circuit_specific_setup(circuit, rng).map_err(OffloadedSNARKError::SNARKError)?; - let snark_pvk = S::process_vk(&snark_vk).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let snark_pvk = S::process_vk(&snark_vk).map_err(OffloadedSNARKError::SNARKError)?; let setup_data = offloaded_data_ref.get().unwrap().clone().setup_data; @@ -252,8 +259,7 @@ where let offloaded_data_ref = circuit.offloaded_data_ref.clone(); - let proof = - S::prove(circuit_pk, circuit, rng).map_err(|e| OffloadedSNARKError::SNARKError(e))?; + let proof = S::prove(circuit_pk, circuit, rng).map_err(OffloadedSNARKError::SNARKError)?; let proof_data = match offloaded_data_ref.get().unwrap().clone().proof_data { Some(proof_data) => proof_data, @@ -282,14 +288,14 @@ where let public_input = build_public_input::(public_input, &proof.offloaded_data); let r = S::verify_with_processed_vk(&vk.snark_pvk, &public_input, &proof.snark_proof) - .map_err(|e| OffloadedSNARKError::SNARKError(e))?; + .map_err(OffloadedSNARKError::SNARKError)?; if !r { return Ok(false); } for (g1s, scalars) in &proof.offloaded_data.msms { assert_eq!(g1s.len(), scalars.len()); - let r = E::G1::msm_unchecked(&g1s, &scalars); + let r = E::G1::msm_unchecked(g1s, scalars); if !r.is_zero() { return Ok(false); } @@ -325,11 +331,11 @@ where fn pairing_inputs( vk: &OffloadedSNARKVerifyingKey, - g1_vectors: &Vec>, - ) -> Result, Vec)>, SerializationError> { + g1_vectors: &[Vec], + ) -> Result>, SerializationError> { Ok(g1_vectors - .into_iter() - .map(|g1_vec| g1_vec.into_iter().map(|&g1| g1.into()).collect()) + .iter() + .map(|g1_vec| g1_vec.iter().map(|&g1| g1.into()).collect()) .zip(Self::g2_elements(vk)) .collect()) } @@ -419,7 +425,7 @@ where [public_input.to_vec(), msm_data, pairing_data].concat() } -fn to_scalars(g1s: &Vec) -> Vec +fn to_scalars(g1s: &[E::G1Affine]) -> Vec where E: Pairing, G1Var: CurveVar + ToConstraintFieldGadget, From 2904552aa0a584203911bdbb2d80bdd07efb2ccb Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Tue, 20 Aug 2024 20:27:47 -0700 Subject: [PATCH 42/44] Remove commented out code --- jolt-core/src/circuits/poly/commitment/hyperkzg.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 5b710b847..3b4696fd2 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -454,12 +454,6 @@ mod tests { G1Var: CurveVar + ToConstraintFieldGadget, { type Circuit = HyperKZGVerifierCircuit; - - // fn pairing_setup(circuit: Self::Circuit) -> Vec> { - // let KZGVerifierKey { g1, g2, beta_g2 } = circuit.pcs_pk_vk.1.kzg_vk; - // - // vec![vec![g2, beta_g2]] - // } } #[test] From c9aa0e98f951dbb023cbf96d4f53ca4529562f15 Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Wed, 21 Aug 2024 16:11:18 -0700 Subject: [PATCH 43/44] Port transcript changes --- .../src/circuits/poly/commitment/hyperkzg.rs | 14 +++- jolt-core/src/circuits/transcript/mod.rs | 83 +++++++++++++++++-- jolt-core/src/snark/mod.rs | 14 +--- 3 files changed, 89 insertions(+), 22 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index 3b4696fd2..aa336e535 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -1,8 +1,9 @@ +use crate::circuits::transcript::ImplAbsorbGVar; use crate::{ circuits::{ offloaded::{MSMGadget, OffloadedMSMGadget, OffloadedPairingGadget, PairingGadget}, poly::commitment::commitment_scheme::CommitmentVerifierGadget, - transcript::ImplAbsorb, + transcript::ImplAbsorbFVar, }, field::JoltField, poly::commitment::hyperkzg::{ @@ -175,7 +176,7 @@ where transcript.absorb( &com.iter() - .map(|com| ImplAbsorb::wrap(com)) + .map(|com| ImplAbsorbGVar::wrap(com)) .collect::>(), )?; @@ -216,12 +217,16 @@ where transcript.absorb( &v.iter() .flatten() - .map(|v_ij| ImplAbsorb::wrap(v_ij)) + .map(|v_ij| ImplAbsorbFVar::wrap(v_ij)) .collect::>(), )?; let q_powers = q_powers::(transcript, ell)?; - transcript.absorb(&w.iter().map(|g| ImplAbsorb::wrap(g)).collect::>())?; + transcript.absorb( + &w.iter() + .map(|g| ImplAbsorbGVar::wrap(g)) + .collect::>(), + )?; let d = transcript .squeeze_field_elements(1)? .into_iter() @@ -276,6 +281,7 @@ where &[l_g1, r_g1], self.g2_elements.as_slice(), )?; + dbg!(); Ok(Boolean::TRUE) } diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs index 92f4e0a38..7f4922bef 100644 --- a/jolt-core/src/circuits/transcript/mod.rs +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -1,5 +1,6 @@ use ark_crypto_primitives::sponge::constraints::AbsorbGadget; -use ark_ff::PrimeField; +use ark_ec::{AffineRepr, CurveGroup}; +use ark_ff::{Field, PrimeField}; use ark_r1cs_std::{fields::fp::FpVar, prelude::*, R1CSVar}; use ark_relations::{ns, r1cs::SynthesisError}; use ark_serialize::CanonicalSerialize; @@ -7,12 +8,12 @@ use ark_std::{cell::RefCell, fmt::Debug, marker::PhantomData, Zero}; pub mod mock; -pub struct ImplAbsorb<'a, T, F>(&'a T, PhantomData) +pub struct ImplAbsorbFVar<'a, T, F>(&'a T, PhantomData) where T: R1CSVar, F: PrimeField; -impl<'a, T, F> ImplAbsorb<'a, T, F> +impl<'a, T, F> ImplAbsorbFVar<'a, T, F> where T: R1CSVar, F: PrimeField, @@ -26,7 +27,7 @@ thread_local! { static SLICE: RefCell> = const { RefCell::new(None) }; } -impl<'a, T, F> AbsorbGadget for ImplAbsorb<'a, T, F> +impl<'a, T, F> AbsorbGadget for ImplAbsorbFVar<'a, T, F> where T: R1CSVar + Debug, F: PrimeField, @@ -40,10 +41,82 @@ where }; t_value - .serialize_compressed(&mut buf) + .serialize_uncompressed(&mut buf) .map_err(|_e| SynthesisError::Unsatisfiable)?; buf.into_iter() + .rev() + .map(|b| UInt8::new_witness(ns!(self.0.cs(), "sponge_byte"), || Ok(b))) + .collect::, _>>() + } + + fn batch_to_sponge_bytes(batch: &[Self]) -> Result>, SynthesisError> + where + Self: Sized, + { + SLICE.set(Some(batch.len())); + let mut result = Vec::new(); + for item in batch { + result.append(&mut (item.to_sponge_bytes()?)) + } + Ok(result) + } + + fn to_sponge_field_elements(&self) -> Result>, SynthesisError> { + unimplemented!("should not be called") + } +} + +pub struct ImplAbsorbGVar<'a, T, F, G>(&'a T, PhantomData<(F, G)>) +where + T: CurveVar, + F: PrimeField, + G: CurveGroup; + +impl<'a, T, F, G> ImplAbsorbGVar<'a, T, F, G> +where + T: CurveVar, + F: PrimeField, + G: CurveGroup, +{ + pub fn wrap(t: &'a T) -> Self { + Self(t, PhantomData) + } +} + +impl<'a, T, F, G> AbsorbGadget for ImplAbsorbGVar<'a, T, F, G> +where + T: CurveVar + Debug, + F: PrimeField, + G: CurveGroup, +{ + fn to_sponge_bytes(&self) -> Result>, SynthesisError> { + let g = match self.0.cs().is_in_setup_mode() { + true => T::Value::zero(), + false => self.0.value()?, + } + .into_affine(); + + if g.is_zero() { + return Ok(vec![UInt8::constant(0); 64]); + } + + let (x, y) = g.xy().unwrap(); + + fn serialize(x: &F) -> Vec { + let mut buf = vec![]; + x.serialize_compressed(&mut buf) + .expect("failed to serialize uncompressed"); + buf.reverse(); + buf + } + + let x_buf = serialize(x); + let y_buf = serialize(y); + + [x_buf, y_buf] + .iter() + .flatten() .map(|b| UInt8::new_witness(ns!(self.0.cs(), "sponge_byte"), || Ok(b))) .collect::, _>>() } diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs index 46621641b..6f248bc0c 100644 --- a/jolt-core/src/snark/mod.rs +++ b/jolt-core/src/snark/mod.rs @@ -18,20 +18,11 @@ use ark_std::{cell::OnceCell, cell::RefCell, rc::Rc}; use itertools::Itertools; use rand_core::{CryptoRng, RngCore}; -/// Describes G1 elements to be used in a multi-pairing. -/// The verifier is responsible for ensuring that the sum of the pairings is zero. -/// The verifier needs to use appropriate G2 elements from the verification key or the proof -/// (depending on the protocol). #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] pub struct OffloadedPairingDef where E: Pairing, { - /// Offsets of the G1 elements in the public input. The G1 elements are stored as sequences of scalar field elements - /// encoding the compressed coordinates of the G1 points (which would natively be numbers in the base field). - /// The offsets are in the number of scalar field elements in the public input before the G1 element. - /// The last element, by convention, is always used in the multi-pairing computation with coefficient `-1`. - pub g1_offsets: Vec, pub g2_elements: Vec, } @@ -233,10 +224,7 @@ where let delayed_pairings = setup_data .into_iter() - .map(|g2| OffloadedPairingDef { - g1_offsets: vec![], - g2_elements: g2, - }) + .map(|g2| OffloadedPairingDef { g2_elements: g2 }) .collect(); let vk = OffloadedSNARKVerifyingKey { From 8bb1a29dddd5236b9c3846812f5c14e2b0c7d21c Mon Sep 17 00:00:00 2001 From: Ivan Mikushin Date: Wed, 21 Aug 2024 16:57:44 -0700 Subject: [PATCH 44/44] Port transcript changes: make it work --- .../src/circuits/poly/commitment/hyperkzg.rs | 18 ++++++++++++++---- jolt-core/src/circuits/transcript/mod.rs | 19 ++++++++----------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs index aa336e535..924c0ca4e 100644 --- a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -15,6 +15,7 @@ use ark_crypto_primitives::sponge::constraints::{CryptographicSpongeVar, SpongeW use ark_ec::pairing::Pairing; use ark_ff::PrimeField; use ark_r1cs_std::{boolean::Boolean, fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; +use ark_relations::r1cs::ConstraintSystemRef; use ark_relations::{ ns, r1cs::{Namespace, SynthesisError}, @@ -126,6 +127,7 @@ where { _params: PhantomData<(E, S, G1Var)>, circuit: &'a Circuit, + cs: ConstraintSystemRef, g2_elements: Vec, } @@ -136,10 +138,17 @@ where G1Var: CurveVar + ToConstraintFieldGadget, Circuit: OffloadedDataCircuit, { - pub fn new(circuit: &'a Circuit, g2_elements: Vec) -> Self { + pub fn new( + circuit: &'a Circuit, + cs: impl Into>, + g2_elements: Vec, + ) -> Self { + let ns = cs.into(); + let cs: ConstraintSystemRef = ns.cs(); Self { _params: PhantomData, circuit, + cs, g2_elements, } } @@ -268,16 +277,16 @@ where .concat(); debug_assert_eq!(l_g1s.len(), l_scalars.len()); - let l_g1 = msm_gadget.msm(ns!(transcript.cs(), "l_g1"), l_g1s, l_scalars)?; + let l_g1 = msm_gadget.msm(ns!(self.cs, "l_g1"), l_g1s, l_scalars)?; let r_g1s = w.as_slice(); let r_scalars = &[FpVar::one().negate()?, d.negate()?, d_square.negate()?]; debug_assert_eq!(r_g1s.len(), r_scalars.len()); - let r_g1 = msm_gadget.msm(ns!(transcript.cs(), "r_g1"), r_g1s, r_scalars)?; + let r_g1 = msm_gadget.msm(ns!(self.cs, "r_g1"), r_g1s, r_scalars)?; pairing_gadget.multi_pairing_is_zero( - ns!(transcript.cs(), "multi_pairing"), + ns!(self.cs, "multi_pairing"), &[l_g1, r_g1], self.g2_elements.as_slice(), )?; @@ -419,6 +428,7 @@ mod tests { let hyper_kzg = HyperKZGVerifierGadget::, G1Var, Self>::new( &self, + ns!(cs, "hyperkzg"), vec![kzg_vk.g2, kzg_vk.beta_g2], ); diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs index 7f4922bef..831420c36 100644 --- a/jolt-core/src/circuits/transcript/mod.rs +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -97,12 +97,6 @@ where } .into_affine(); - if g.is_zero() { - return Ok(vec![UInt8::constant(0); 64]); - } - - let (x, y) = g.xy().unwrap(); - fn serialize(x: &F) -> Vec { let mut buf = vec![]; x.serialize_compressed(&mut buf) @@ -111,12 +105,15 @@ where buf } - let x_buf = serialize(x); - let y_buf = serialize(y); + let buf = match g.is_zero() { + true => vec![0u8; 64], + false => { + let (x, y) = g.xy().unwrap(); + [serialize(x), serialize(y)].concat() + } + }; - [x_buf, y_buf] - .iter() - .flatten() + buf.iter() .map(|b| UInt8::new_witness(ns!(self.0.cs(), "sponge_byte"), || Ok(b))) .collect::, _>>() }