diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index d4aa2f3ce..1d2a4c9b1 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -20,15 +20,20 @@ license-file = "LICENSE" keywords = ["SNARK", "cryptography", "proofs"] [dependencies] +ark-bls12-381 = "0.4.0" ark-bn254 = "0.4.0" +ark-crypto-primitives = { version = "0.4.0", default-features = false, features = ["snark", "sponge", "r1cs"] } ark-ec = { version = "0.4.2", default-features = false } ark-ff = { version = "0.4.2", default-features = false } +ark-r1cs-std = { version = "0.4.0" } +ark-relations = { version = "0.4.0", default-features = false } ark-serialize = { version = "0.4.2", default-features = false, features = [ "derive", ] } ark-std = { version = "0.4.0" } binius-field = { git = "https://gitlab.com/UlvetannaOSS/binius", package = "binius_field"} clap = { version = "4.3.10", features = ["derive"] } +derivative = { version = "2" } enum_dispatch = "0.3.12" fixedbitset = "0.5.0" itertools = "0.10.0" @@ -75,8 +80,11 @@ alloy-sol-macro = "0.7.6" alloy-sol-types = "0.7.6" [dev-dependencies] +ark-groth16 = { version = "0.4.0" } +ark-test-curves = { version = "0.4.0", default-features = false, features = ["bls12_381_curve", "mnt6_753"] } criterion = { version = "0.5.1", features = ["html_reports"] } iai-callgrind = "0.10.2" +#sigma0-polymath = { git = "https://github.com/sigma0-xyz/polymath" } [build-dependencies] common = { path = "../common" } @@ -99,6 +107,7 @@ default = [ "rayon", ] host = ["dep:reqwest", "dep:tokio"] +print-trace = [ "ark-std/print-trace" ] [target.'cfg(not(target_arch = "wasm32"))'.dependencies] memory-stats = "1.0.0" diff --git a/jolt-core/src/circuits/fields/cubic_extension.rs b/jolt-core/src/circuits/fields/cubic_extension.rs new file mode 100644 index 000000000..8c1d86398 --- /dev/null +++ b/jolt-core/src/circuits/fields/cubic_extension.rs @@ -0,0 +1,604 @@ +use ark_ff::{ + fields::{CubicExtField, Field}, + CubicExtConfig, PrimeField, Zero, +}; +use ark_r1cs_std::{ + fields::{fp::FpVar, FieldOpsBounds, FieldVar}, + impl_bounded_ops, + prelude::*, + ToConstraintFieldGadget, +}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::vec::Vec; +use core::{borrow::Borrow, marker::PhantomData}; +use derivative::Derivative; + +/// This struct is the `R1CS` equivalent of the cubic extension field type +/// in `ark-ff`, i.e. `ark_ff::CubicExtField`. +#[derive(Derivative)] +#[derivative(Debug(bound = "BF: core::fmt::Debug"), Clone(bound = "BF: Clone"))] +#[must_use] +pub struct CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// The zero-th coefficient of this field element. + pub c0: BF, + /// The first coefficient of this field element. + pub c1: BF, + /// The second coefficient of this field element. + pub c2: BF, + #[derivative(Debug = "ignore")] + _params: PhantomData<(P, ConstraintF)>, +} + +/// This trait describes parameters that are used to implement arithmetic for +/// `CubicExtVar`. +pub trait CubicExtVarConfig: CubicExtConfig +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'a> &'a BF: FieldOpsBounds<'a, Self::BaseField, BF>, +{ + /// Multiply the base field of the `CubicExtVar` by the appropriate + /// Frobenius coefficient. This is equivalent to + /// `Self::mul_base_field_by_frob_coeff(c1, c2, power)`. + fn mul_base_field_vars_by_frob_coeff(c1: &mut BF, c2: &mut BF, power: usize); +} + +impl CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// Constructs a `CubicExtVar` from the underlying coefficients. + #[inline] + pub fn new(c0: BF, c1: BF, c2: BF) -> Self { + let _params = PhantomData; + Self { + c0, + c1, + c2, + _params, + } + } + + /// Multiplies a variable of the base field by the cubic nonresidue + /// `P::NONRESIDUE` that is used to construct the extension field. + #[inline] + pub fn mul_base_field_by_nonresidue(fe: &BF) -> Result { + Ok(fe * P::NONRESIDUE) + } + + /// Multiplies `self` by a constant from the base field. + #[inline] + pub fn mul_by_base_field_constant(&self, fe: P::BaseField) -> Self { + let c0 = &self.c0 * fe; + let c1 = &self.c1 * fe; + let c2 = &self.c2 * fe; + Self::new(c0, c1, c2) + } + + /// Sets `self = self.mul_by_base_field_constant(fe)`. + #[inline] + pub fn mul_assign_by_base_field_constant(&mut self, fe: P::BaseField) { + *self = (*self).mul_by_base_field_constant(fe); + } +} + +impl R1CSVar for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, + P: CubicExtVarConfig, +{ + type Value = CubicExtField

; + + fn cs(&self) -> ConstraintSystemRef { + [&self.c0, &self.c1, &self.c2].cs() + } + + #[inline] + fn value(&self) -> Result { + match (self.c0.value(), self.c1.value(), self.c2.value()) { + (Ok(c0), Ok(c1), Ok(c2)) => Ok(CubicExtField::new(c0, c1, c2)), + (..) => Err(SynthesisError::AssignmentMissing), + } + } +} + +impl From> for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn from(other: Boolean) -> Self { + let c0 = BF::from(other); + let c1 = BF::zero(); + let c2 = BF::zero(); + Self::new(c0, c1, c2) + } +} + +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, CubicExtField

, CubicExtVar> + for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ +} +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, CubicExtField

, CubicExtVar> + for &'a CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ +} + +impl FieldVar, ConstraintF> for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn zero() -> Self { + let c0 = BF::zero(); + let c1 = BF::zero(); + let c2 = BF::zero(); + Self::new(c0, c1, c2) + } + + fn one() -> Self { + let c0 = BF::one(); + let c1 = BF::zero(); + let c2 = BF::zero(); + Self::new(c0, c1, c2) + } + + fn constant(other: CubicExtField

) -> Self { + let c0 = BF::constant(other.c0); + let c1 = BF::constant(other.c1); + let c2 = BF::constant(other.c2); + Self::new(c0, c1, c2) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn double(&self) -> Result { + let c0 = self.c0.double()?; + let c1 = self.c1.double()?; + let c2 = self.c2.double()?; + Ok(Self::new(c0, c1, c2)) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn negate(&self) -> Result { + let mut result = self.clone(); + result.c0.negate_in_place()?; + result.c1.negate_in_place()?; + result.c2.negate_in_place()?; + Ok(result) + } + + /// Use the Chung-Hasan asymmetric squaring formula. + /// + /// (Devegili OhEig Scott Dahab --- Multiplication and Squaring on + /// Abstract Pairing-Friendly + /// Fields.pdf; Section 4 (CH-SQR2)) + #[inline] + #[tracing::instrument(target = "r1cs")] + fn square(&self) -> Result { + let a = self.c0.clone(); + let b = self.c1.clone(); + let c = self.c2.clone(); + + let s0 = a.square()?; + let ab = &a * &b; + let s1 = ab.double()?; + let s2 = (&a - &b + &c).square()?; + let s3 = (&b * &c).double()?; + let s4 = c.square()?; + + let c0 = Self::mul_base_field_by_nonresidue(&s3)? + &s0; + let c1 = Self::mul_base_field_by_nonresidue(&s4)? + &s1; + let c2 = s1 + &s2 + &s3 - &s0 - &s4; + + Ok(Self::new(c0, c1, c2)) + } + + #[tracing::instrument(target = "r1cs")] + fn mul_equals(&self, other: &Self, result: &Self) -> Result<(), SynthesisError> { + // Karatsuba multiplication for cubic extensions: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // v2 = A.c2 * B.c2 + // result.c0 = v0 + β((a1 + a2)(b1 + b2) − v1 − v2) + // result.c1 = (a0 + a1)(b0 + b1) − v0 − v1 + βv2 + // result.c2 = (a0 + a2)(b0 + b2) − v0 + v1 − v2, + // We enforce this with six constraints: + // + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // v2 = A.c2 * B.c2 + // + // result.c0 - v0 + \beta*(v1 + v2) = β(a1 + a2)(b1 + b2)) + // result.c1 + v0 + v1 - βv2 = (a0 + a1)(b0 + b1) + // result.c2 + v0 - v1 + v2 = (a0 + a2)(b0 + b2) + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + // + // This implementation adapted from + // https://github.com/ZencashOfficial/ginger-lib/blob/development/r1cs/gadgets/std/src/fields/fp3.rs + let v0 = &self.c0 * &other.c0; + let v1 = &self.c1 * &other.c1; + let v2 = &self.c2 * &other.c2; + + // Check c0 + let nr_a1_plus_a2 = (&self.c1 + &self.c2) * P::NONRESIDUE; + let b1_plus_b2 = &other.c1 + &other.c2; + let nr_v1 = &v1 * P::NONRESIDUE; + let nr_v2 = &v2 * P::NONRESIDUE; + let to_check = &result.c0 - &v0 + &nr_v1 + &nr_v2; + nr_a1_plus_a2.mul_equals(&b1_plus_b2, &to_check)?; + + // Check c1 + let a0_plus_a1 = &self.c0 + &self.c1; + let b0_plus_b1 = &other.c0 + &other.c1; + let to_check = &result.c1 - &nr_v2 + &v0 + &v1; + a0_plus_a1.mul_equals(&b0_plus_b1, &to_check)?; + + // Check c2 + let a0_plus_a2 = &self.c0 + &self.c2; + let b0_plus_b2 = &other.c0 + &other.c2; + let to_check = &result.c2 + &v0 - &v1 + &v2; + a0_plus_a2.mul_equals(&b0_plus_b2, &to_check)?; + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn inverse(&self) -> Result { + let mode = if self.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let inverse = Self::new_variable( + self.cs(), + || { + self.value() + .map(|f| f.inverse().unwrap_or_else(CubicExtField::zero)) + }, + mode, + )?; + self.mul_equals(&inverse, &Self::one())?; + Ok(inverse) + } + + #[tracing::instrument(target = "r1cs")] + fn frobenius_map(&self, power: usize) -> Result { + let mut result = self.clone(); + result.c0.frobenius_map_in_place(power)?; + result.c1.frobenius_map_in_place(power)?; + result.c2.frobenius_map_in_place(power)?; + + P::mul_base_field_vars_by_frob_coeff(&mut result.c1, &mut result.c2, power); + Ok(result) + } +} + +impl_bounded_ops!( + CubicExtVar, + CubicExtField

, + Add, + add, + AddAssign, + add_assign, + |this: &'a CubicExtVar, other: &'a CubicExtVar| { + let c0 = &this.c0 + &other.c0; + let c1 = &this.c1 + &other.c1; + let c2 = &this.c2 + &other.c2; + CubicExtVar::new(c0, c1, c2) + }, + |this: &'a CubicExtVar, other: CubicExtField

| { + this + CubicExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: CubicExtVarConfig), + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +); +impl_bounded_ops!( + CubicExtVar, + CubicExtField

, + Sub, + sub, + SubAssign, + sub_assign, + |this: &'a CubicExtVar, other: &'a CubicExtVar| { + let c0 = &this.c0 - &other.c0; + let c1 = &this.c1 - &other.c1; + let c2 = &this.c2 - &other.c2; + CubicExtVar::new(c0, c1, c2) + }, + |this: &'a CubicExtVar, other: CubicExtField

| { + this - CubicExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: CubicExtVarConfig), + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +); +impl_bounded_ops!( + CubicExtVar, + CubicExtField

, + Mul, + mul, + MulAssign, + mul_assign, + |this: &'a CubicExtVar, other: &'a CubicExtVar| { + // Karatsuba multiplication for cubic extensions: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // v2 = A.c2 * B.c2 + // result.c0 = v0 + β((a1 + a2)(b1 + b2) − v1 − v2) + // result.c1 = (a0 + a1)(b0 + b1) − v0 − v1 + βv2 + // result.c2 = (a0 + a2)(b0 + b2) − v0 + v1 − v2, + // + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + let v0 = &this.c0 * &other.c0; + let v1 = &this.c1 * &other.c1; + let v2 = &this.c2 * &other.c2; + let c0 = + (((&this.c1 + &this.c2) * (&other.c1 + &other.c2) - &v1 - &v2) * P::NONRESIDUE) + &v0 ; + let c1 = + (&this.c0 + &this.c1) * (&other.c0 + &other.c1) - &v0 - &v1 + (&v2 * P::NONRESIDUE); + let c2 = + (&this.c0 + &this.c2) * (&other.c0 + &other.c2) - &v0 + &v1 - &v2; + + CubicExtVar::new(c0, c1, c2) + }, + |this: &'a CubicExtVar, other: CubicExtField

| { + this * CubicExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: CubicExtVarConfig), + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +); + +impl EqGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let b0 = self.c0.is_eq(&other.c0)?; + let b1 = self.c1.is_eq(&other.c1)?; + let b2 = self.c2.is_eq(&other.c2)?; + b0.and(&b1)?.and(&b2) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.c0.conditional_enforce_equal(&other.c0, condition)?; + self.c1.conditional_enforce_equal(&other.c1, condition)?; + self.c2.conditional_enforce_equal(&other.c2, condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +impl ToBitsGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bits_le()?; + let mut c1 = self.c1.to_bits_le()?; + let mut c2 = self.c2.to_bits_le()?; + c0.append(&mut c1); + c0.append(&mut c2); + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bits_le()?; + let mut c1 = self.c1.to_non_unique_bits_le()?; + let mut c2 = self.c2.to_non_unique_bits_le()?; + c0.append(&mut c1); + c0.append(&mut c2); + Ok(c0) + } +} + +impl ToBytesGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bytes()?; + let mut c1 = self.c1.to_bytes()?; + let mut c2 = self.c2.to_bytes()?; + c0.append(&mut c1); + c0.append(&mut c2); + + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bytes()?; + let mut c1 = self.c1.to_non_unique_bytes()?; + let mut c2 = self.c2.to_non_unique_bytes()?; + + c0.append(&mut c1); + c0.append(&mut c2); + + Ok(c0) + } +} + +impl ToConstraintFieldGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + BF: ToConstraintFieldGadget, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let mut res = Vec::new(); + + res.extend_from_slice(&self.c0.to_constraint_field()?); + res.extend_from_slice(&self.c1.to_constraint_field()?); + res.extend_from_slice(&self.c2.to_constraint_field()?); + + Ok(res) + } +} + +impl CondSelectGadget for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let c0 = BF::conditionally_select(cond, &true_value.c0, &false_value.c0)?; + let c1 = BF::conditionally_select(cond, &true_value.c1, &false_value.c1)?; + let c2 = BF::conditionally_select(cond, &true_value.c2, &false_value.c2)?; + Ok(Self::new(c0, c1, c2)) + } +} + +impl TwoBitLookupGadget for CubicExtVar +where + BF: FieldVar + + TwoBitLookupGadget, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + type TableConstant = CubicExtField

; + + #[tracing::instrument(target = "r1cs")] + fn two_bit_lookup( + b: &[Boolean], + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c2s = c.iter().map(|f| f.c2).collect::>(); + let c0 = BF::two_bit_lookup(b, &c0s)?; + let c1 = BF::two_bit_lookup(b, &c1s)?; + let c2 = BF::two_bit_lookup(b, &c2s)?; + Ok(Self::new(c0, c1, c2)) + } +} + +impl ThreeBitCondNegLookupGadget + for CubicExtVar +where + BF: FieldVar + + ThreeBitCondNegLookupGadget, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + type TableConstant = CubicExtField

; + + #[tracing::instrument(target = "r1cs")] + fn three_bit_cond_neg_lookup( + b: &[Boolean], + b0b1: &Boolean, + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c2s = c.iter().map(|f| f.c2).collect::>(); + let c0 = BF::three_bit_cond_neg_lookup(b, b0b1, &c0s)?; + let c1 = BF::three_bit_cond_neg_lookup(b, b0b1, &c1s)?; + let c2 = BF::three_bit_cond_neg_lookup(b, b0b1, &c2s)?; + Ok(Self::new(c0, c1, c2)) + } +} + +impl AllocVar, ConstraintF> for CubicExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: CubicExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + + use SynthesisError::*; + let (c0, c1, c2) = match f() { + Ok(fe) => (Ok(fe.borrow().c0), Ok(fe.borrow().c1), Ok(fe.borrow().c2)), + Err(_) => ( + Err(AssignmentMissing), + Err(AssignmentMissing), + Err(AssignmentMissing), + ), + }; + + let c0 = BF::new_variable(ark_relations::ns!(cs, "c0"), || c0, mode)?; + let c1 = BF::new_variable(ark_relations::ns!(cs, "c1"), || c1, mode)?; + let c2 = BF::new_variable(ark_relations::ns!(cs, "c2"), || c2, mode)?; + Ok(Self::new(c0, c1, c2)) + } +} diff --git a/jolt-core/src/circuits/fields/fp12.rs b/jolt-core/src/circuits/fields/fp12.rs new file mode 100644 index 000000000..f46397770 --- /dev/null +++ b/jolt-core/src/circuits/fields/fp12.rs @@ -0,0 +1,189 @@ +use crate::circuits::fields::{fp2::Fp2Var, fp6_3over2::Fp6Var, quadratic_extension::*}; +use ark_ff::{ + fields::{fp12_2over3over2::*, Field}, + fp6_3over2::Fp6Config, + PrimeField, QuadExtConfig, +}; +use ark_r1cs_std::fields::FieldVar; +use ark_relations::r1cs::SynthesisError; + +/// A degree-12 extension field constructed as the tower of a +/// quadratic extension over a cubic extension over a quadratic extension field. +/// This is the R1CS equivalent of `ark_ff::fp12_2over3over2::Fp12

`. +pub type Fp12Var = QuadExtVar< + Fp6Var<

::Fp6Config, ConstraintF>, + ConstraintF, + Fp12ConfigWrapper

, +>; + +type Fp2Config

= <

::Fp6Config as Fp6Config>::Fp2Config; + +impl QuadExtVarConfig, ConstraintF> + for Fp12ConfigWrapper

+where + P: Fp12Config, + ConstraintF: PrimeField, +{ + fn mul_base_field_var_by_frob_coeff(fe: &mut Fp6Var, power: usize) { + fe.c0 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + fe.c1 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + fe.c2 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + } +} + +impl Fp12Var +where + P: Fp12Config, + ConstraintF: PrimeField, +{ + /// Multiplies by a sparse element of the form `(c0 = (c0, c1, 0), c1 = (0, + /// d1, 0))`. + #[inline] + pub fn mul_by_014( + &self, + c0: &Fp2Var, ConstraintF>, + c1: &Fp2Var, ConstraintF>, + d1: &Fp2Var, ConstraintF>, + ) -> Result { + let v0 = self.c0.mul_by_c0_c1_0(c0, c1)?; + let v1 = self.c1.mul_by_0_c1_0(d1)?; + let new_c0 = Self::mul_base_field_by_nonresidue(&v1)? + &v0; + + let new_c1 = (&self.c0 + &self.c1).mul_by_c0_c1_0(c0, &(c1 + d1))? - &v0 - &v1; + Ok(Self::new(new_c0, new_c1)) + } + + /// Multiplies by a sparse element of the form `(c0 = (c0, 0, 0), c1 = (d0, + /// d1, 0))`. + #[inline] + pub fn mul_by_034( + &self, + c0: &Fp2Var, ConstraintF>, + d0: &Fp2Var, ConstraintF>, + d1: &Fp2Var, ConstraintF>, + ) -> Result { + let a0 = &self.c0.c0 * c0; + let a1 = &self.c0.c1 * c0; + let a2 = &self.c0.c2 * c0; + let a = Fp6Var::new(a0, a1, a2); + let b = self.c1.mul_by_c0_c1_0(d0, d1)?; + + let c0 = c0 + d0; + let c1 = d1; + let e = (&self.c0 + &self.c1).mul_by_c0_c1_0(&c0, c1)?; + let new_c1 = e - (&a + &b); + let new_c0 = Self::mul_base_field_by_nonresidue(&b)? + &a; + + Ok(Self::new(new_c0, new_c1)) + } + + /// Squares `self` when `self` is in the cyclotomic subgroup. + pub fn cyclotomic_square(&self) -> Result { + if characteristic_square_mod_6_is_one(Fp12::

::characteristic()) { + let fp2_nr = ::NONRESIDUE; + + let z0 = &self.c0.c0; + let z4 = &self.c0.c1; + let z3 = &self.c0.c2; + let z2 = &self.c1.c0; + let z1 = &self.c1.c1; + let z5 = &self.c1.c2; + + // t0 + t1*y = (z0 + z1*y)^2 = a^2 + let tmp = z0 * z1; + let t0 = { + let tmp1 = z0 + z1; + let tmp2 = z1 * fp2_nr + z0; + let tmp4 = &tmp * fp2_nr + &tmp; + tmp1 * tmp2 - tmp4 + }; + let t1 = tmp.double()?; + + // t2 + t3*y = (z2 + z3*y)^2 = b^2 + let tmp = z2 * z3; + let t2 = { + // (z2 + &z3) * &(z2 + &(fp2_nr * &z3)) - &tmp - &(tmp * &fp2_nr); + let tmp1 = z2 + z3; + let tmp2 = z3 * fp2_nr + z2; + let tmp4 = &tmp * fp2_nr + &tmp; + tmp1 * tmp2 - tmp4 + }; + let t3 = tmp.double()?; + + // t4 + t5*y = (z4 + z5*y)^2 = c^2 + let tmp = z4 * z5; + let t4 = { + // (z4 + &z5) * &(z4 + &(fp2_nr * &z5)) - &tmp - &(tmp * &fp2_nr); + let tmp1 = z4 + z5; + let tmp2 = (z5 * fp2_nr) + z4; + let tmp4 = (&tmp * fp2_nr) + &tmp; + (tmp1 * tmp2) - tmp4 + }; + let t5 = tmp.double()?; + + // for A + + // z0 = 3 * t0 - 2 * z0 + let c0_c0 = (&t0 - z0).double()? + &t0; + + // z1 = 3 * t1 + 2 * z1 + let c1_c1 = (&t1 + z1).double()? + &t1; + + // for B + + // z2 = 3 * (xi * t5) + 2 * z2 + let c1_c0 = { + let tmp = &t5 * fp2_nr; + (z2 + &tmp).double()? + &tmp + }; + + // z3 = 3 * t4 - 2 * z3 + let c0_c2 = (&t4 - z3).double()? + &t4; + + // for C + + // z4 = 3 * t2 - 2 * z4 + let c0_c1 = (&t2 - z4).double()? + &t2; + + // z5 = 3 * t3 + 2 * z5 + let c1_c2 = (&t3 + z5).double()? + &t3; + let c0 = Fp6Var::new(c0_c0, c0_c1, c0_c2); + let c1 = Fp6Var::new(c1_c0, c1_c1, c1_c2); + + Ok(Self::new(c0, c1)) + } else { + self.square() + } + } + + /// Like `Self::cyclotomic_exp`, but additionally uses cyclotomic squaring. + pub fn optimized_cyclotomic_exp( + &self, + exponent: impl AsRef<[u64]>, + ) -> Result { + use ark_ff::biginteger::arithmetic::find_naf; + let mut res = Self::one(); + let self_inverse = self.unitary_inverse()?; + + let mut found_nonzero = false; + let naf = find_naf(exponent.as_ref()); + + for &value in naf.iter().rev() { + if found_nonzero { + res = res.cyclotomic_square()?; + } + + if value != 0 { + found_nonzero = true; + + if value > 0 { + res *= self; + } else { + res *= &self_inverse; + } + } + } + + Ok(res) + } +} diff --git a/jolt-core/src/circuits/fields/fp2.rs b/jolt-core/src/circuits/fields/fp2.rs new file mode 100644 index 000000000..619c6e9cf --- /dev/null +++ b/jolt-core/src/circuits/fields/fp2.rs @@ -0,0 +1,28 @@ +use crate::circuits::fields::quadratic_extension::*; +use ark_ff::{ + fields::{Fp2Config, Fp2ConfigWrapper, QuadExtConfig}, + PrimeField, +}; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; + +/// A quadratic extension field constructed over a prime field. +/// This is the R1CS equivalent of `ark_ff::Fp2

`. +pub type Fp2Var = QuadExtVar< + NonNativeFieldVar<

::Fp, ConstraintF>, + ConstraintF, + Fp2ConfigWrapper

, +>; + +impl QuadExtVarConfig, ConstraintF> + for Fp2ConfigWrapper

+where + P: Fp2Config, + ConstraintF: PrimeField, +{ + fn mul_base_field_var_by_frob_coeff( + fe: &mut NonNativeFieldVar, + power: usize, + ) { + *fe *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + } +} diff --git a/jolt-core/src/circuits/fields/fp6_3over2.rs b/jolt-core/src/circuits/fields/fp6_3over2.rs new file mode 100644 index 000000000..5d600c93e --- /dev/null +++ b/jolt-core/src/circuits/fields/fp6_3over2.rs @@ -0,0 +1,105 @@ +use crate::circuits::fields::{cubic_extension::*, fp2::*}; +use ark_ff::{ + fields::{fp6_3over2::*, Fp2}, + CubicExtConfig, PrimeField, +}; +use ark_relations::r1cs::SynthesisError; +use ark_std::ops::MulAssign; + +/// A sextic extension field constructed as the tower of a +/// cubic extension over a quadratic extension field. +/// This is the R1CS equivalent of `ark_ff::fp6_3over3::Fp6

`. +pub type Fp6Var = + CubicExtVar::Fp2Config, ConstraintF>, ConstraintF, Fp6ConfigWrapper

>; + +impl CubicExtVarConfig, ConstraintF> + for Fp6ConfigWrapper

+where + P: Fp6Config, + ConstraintF: PrimeField, +{ + fn mul_base_field_vars_by_frob_coeff( + c1: &mut Fp2Var, + c2: &mut Fp2Var, + power: usize, + ) { + *c1 *= Self::FROBENIUS_COEFF_C1[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + *c2 *= Self::FROBENIUS_COEFF_C2[power % Self::DEGREE_OVER_BASE_PRIME_FIELD]; + } +} + +impl Fp6Var +where + P: Fp6Config, + ConstraintF: PrimeField, +{ + /// Multiplies `self` by a sparse element which has `c0 == c2 == zero`. + pub fn mul_by_0_c1_0( + &self, + c1: &Fp2Var, + ) -> Result { + // Karatsuba multiplication + // v0 = a0 * b0 = 0 + + // v1 = a1 * b1 + let v1 = &self.c1 * c1; + + // v2 = a2 * b2 = 0 + + let a1_plus_a2 = &self.c1 + &self.c2; + let b1_plus_b2 = c1.clone(); + + let a0_plus_a1 = &self.c0 + &self.c1; + + // c0 = (NONRESIDUE * ((a1 + a2)*(b1 + b2) - v1 - v2)) + v0 + // = NONRESIDUE * ((a1 + a2) * b1 - v1) + let c0 = &(a1_plus_a2 * &b1_plus_b2 - &v1) * P::NONRESIDUE; + + // c1 = (a0 + a1) * (b0 + b1) - v0 - v1 + NONRESIDUE * v2 + // = (a0 + a1) * b1 - v1 + let c1 = a0_plus_a1 * c1 - &v1; + // c2 = (a0 + a2) * (b0 + b2) - v0 - v2 + v1 + // = v1 + let c2 = v1; + Ok(Self::new(c0, c1, c2)) + } + + /// Multiplies `self` by a sparse element which has `c2 == zero`. + pub fn mul_by_c0_c1_0( + &self, + c0: &Fp2Var, + c1: &Fp2Var, + ) -> Result { + let v0 = &self.c0 * c0; + let v1 = &self.c1 * c1; + // v2 = 0. + + let a1_plus_a2 = &self.c1 + &self.c2; + let a0_plus_a1 = &self.c0 + &self.c1; + let a0_plus_a2 = &self.c0 + &self.c2; + + let b1_plus_b2 = c1.clone(); + let b0_plus_b1 = c0 + c1; + let b0_plus_b2 = c0.clone(); + + let c0 = (&a1_plus_a2 * &b1_plus_b2 - &v1) * P::NONRESIDUE + &v0; + + let c1 = a0_plus_a1 * &b0_plus_b1 - &v0 - &v1; + + let c2 = a0_plus_a2 * &b0_plus_b2 - &v0 + &v1; + + Ok(Self::new(c0, c1, c2)) + } +} + +impl MulAssign> for Fp6Var +where + P: Fp6Config, + ConstraintF: PrimeField, +{ + fn mul_assign(&mut self, other: Fp2) { + self.c0 *= other; + self.c1 *= other; + self.c2 *= other; + } +} diff --git a/jolt-core/src/circuits/fields/mod.rs b/jolt-core/src/circuits/fields/mod.rs new file mode 100644 index 000000000..2e88f53b8 --- /dev/null +++ b/jolt-core/src/circuits/fields/mod.rs @@ -0,0 +1,5 @@ +mod cubic_extension; +pub mod fp12; +pub mod fp2; +mod fp6_3over2; +pub mod quadratic_extension; diff --git a/jolt-core/src/circuits/fields/quadratic_extension.rs b/jolt-core/src/circuits/fields/quadratic_extension.rs new file mode 100644 index 000000000..bcb1e3749 --- /dev/null +++ b/jolt-core/src/circuits/fields/quadratic_extension.rs @@ -0,0 +1,586 @@ +use ark_ff::{ + fields::{Field, QuadExtConfig, QuadExtField}, + PrimeField, Zero, +}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::vec::Vec; +use core::{borrow::Borrow, marker::PhantomData}; + +use ark_r1cs_std::{ + fields::{fp::FpVar, FieldOpsBounds, FieldVar}, + impl_bounded_ops, + prelude::*, + ToConstraintFieldGadget, +}; +use derivative::Derivative; + +/// This struct is the `R1CS` equivalent of the quadratic extension field type +/// in `ark-ff`, i.e. `ark_ff::QuadExtField`. +#[derive(Derivative)] +#[derivative(Debug(bound = "BF: core::fmt::Debug"), Clone(bound = "BF: Clone"))] +#[must_use] +pub struct QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// The zero-th coefficient of this field element. + pub c0: BF, + /// The first coefficient of this field element. + pub c1: BF, + #[derivative(Debug = "ignore")] + _params: PhantomData<(P, ConstraintF)>, +} + +/// This trait describes parameters that are used to implement arithmetic for +/// `QuadExtVar`. +pub trait QuadExtVarConfig: QuadExtConfig +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'a> &'a BF: FieldOpsBounds<'a, Self::BaseField, BF>, +{ + /// Multiply the base field of the `QuadExtVar` by the appropriate Frobenius + /// coefficient. This is equivalent to + /// `Self::mul_base_field_by_frob_coeff(power)`. + fn mul_base_field_var_by_frob_coeff(fe: &mut BF, power: usize); +} + +impl QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + /// Constructs a `QuadExtVar` from the underlying coefficients. + pub fn new(c0: BF, c1: BF) -> Self { + Self { + c0, + c1, + _params: PhantomData, + } + } + + /// Multiplies a variable of the base field by the quadratic nonresidue + /// `P::NONRESIDUE` that is used to construct the extension field. + #[inline] + pub fn mul_base_field_by_nonresidue(fe: &BF) -> Result { + Ok(fe * P::NONRESIDUE) + } + + /// Multiplies `self` by a constant from the base field. + #[inline] + pub fn mul_by_base_field_constant(&self, fe: P::BaseField) -> Self { + let c0 = self.c0.clone() * fe; + let c1 = self.c1.clone() * fe; + QuadExtVar::new(c0, c1) + } + + /// Sets `self = self.mul_by_base_field_constant(fe)`. + #[inline] + pub fn mul_assign_by_base_field_constant(&mut self, fe: P::BaseField) { + *self = (*self).mul_by_base_field_constant(fe); + } + + /// This is only to be used when the element is *known* to be in the + /// cyclotomic subgroup. + #[inline] + pub fn unitary_inverse(&self) -> Result { + Ok(Self::new(self.c0.clone(), self.c1.negate()?)) + } + + /// This is only to be used when the element is *known* to be in the + /// cyclotomic subgroup. + #[inline] + #[tracing::instrument(target = "r1cs", skip(exponent))] + pub fn cyclotomic_exp(&self, exponent: impl AsRef<[u64]>) -> Result + where + Self: FieldVar, ConstraintF>, + { + let mut res = Self::one(); + let self_inverse = self.unitary_inverse()?; + + let mut found_nonzero = false; + let naf = ark_ff::biginteger::arithmetic::find_naf(exponent.as_ref()); + + for &value in naf.iter().rev() { + if found_nonzero { + res.square_in_place()?; + } + + if value != 0 { + found_nonzero = true; + + if value > 0 { + res *= self; + } else { + res *= &self_inverse; + } + } + } + + Ok(res) + } +} + +impl R1CSVar for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + type Value = QuadExtField

; + + fn cs(&self) -> ConstraintSystemRef { + [&self.c0, &self.c1].cs() + } + + #[inline] + fn value(&self) -> Result { + match (self.c0.value(), self.c1.value()) { + (Ok(c0), Ok(c1)) => Ok(QuadExtField::new(c0, c1)), + (..) => Err(SynthesisError::AssignmentMissing), + } + } +} + +impl From> for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn from(other: Boolean) -> Self { + let c0 = BF::from(other); + let c1 = BF::zero(); + Self::new(c0, c1) + } +} + +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, QuadExtField

, QuadExtVar> + for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ +} +impl<'a, BF, ConstraintF, P> FieldOpsBounds<'a, QuadExtField

, QuadExtVar> + for &'a QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, + P: QuadExtVarConfig, +{ +} + +impl FieldVar, ConstraintF> for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + fn zero() -> Self { + let c0 = BF::zero(); + let c1 = BF::zero(); + Self::new(c0, c1) + } + + fn one() -> Self { + let c0 = BF::one(); + let c1 = BF::zero(); + Self::new(c0, c1) + } + + fn constant(other: QuadExtField

) -> Self { + let c0 = BF::constant(other.c0); + let c1 = BF::constant(other.c1); + Self::new(c0, c1) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn double(&self) -> Result { + let c0 = self.c0.double()?; + let c1 = self.c1.double()?; + Ok(Self::new(c0, c1)) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn negate(&self) -> Result { + let mut result = self.clone(); + result.c0.negate_in_place()?; + result.c1.negate_in_place()?; + Ok(result) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn square(&self) -> Result { + // From Libsnark/fp2_gadget.tcc + // Complex multiplication for Fp2: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + + // v0 = c0 - c1 + let mut v0 = &self.c0 - &self.c1; + // v3 = c0 - beta * c1 + let v3 = &self.c0 - &Self::mul_base_field_by_nonresidue(&self.c1)?; + // v2 = c0 * c1 + let v2 = &self.c0 * &self.c1; + + // v0 = (v0 * v3) + v2 + v0 *= &v3; + v0 += &v2; + + let c0 = &v0 + &Self::mul_base_field_by_nonresidue(&v2)?; + let c1 = v2.double()?; + + Ok(Self::new(c0, c1)) + } + + #[tracing::instrument(target = "r1cs")] + fn mul_equals(&self, other: &Self, result: &Self) -> Result<(), SynthesisError> { + // Karatsuba multiplication for Fp2: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // result.c0 = v0 + non_residue * v1 + // result.c1 = (A.c0 + A.c1) * (B.c0 + B.c1) - v0 - v1 + // Enforced with 3 constraints: + // A.c1 * B.c1 = v1 + // A.c0 * B.c0 = result.c0 - non_residue * v1 + // (A.c0+A.c1)*(B.c0+B.c1) = result.c1 + result.c0 + (1 - non_residue) * v1 + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + // Compute v1 + let v1 = &self.c1 * &other.c1; + + // Perform second check + let non_residue_times_v1 = Self::mul_base_field_by_nonresidue(&v1)?; + let rhs = &result.c0 - &non_residue_times_v1; + self.c0.mul_equals(&other.c0, &rhs)?; + + // Last check + let a0_plus_a1 = &self.c0 + &self.c1; + let b0_plus_b1 = &other.c0 + &other.c1; + let one_minus_non_residue_v1 = &v1 - &non_residue_times_v1; + + let tmp = &(&result.c1 + &result.c0) + &one_minus_non_residue_v1; + a0_plus_a1.mul_equals(&b0_plus_b1, &tmp)?; + + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn inverse(&self) -> Result { + let mode = if self.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let inverse = Self::new_variable( + self.cs(), + || { + self.value() + .map(|f| f.inverse().unwrap_or_else(QuadExtField::zero)) + }, + mode, + )?; + self.mul_equals(&inverse, &Self::one())?; + Ok(inverse) + } + + #[tracing::instrument(target = "r1cs")] + fn frobenius_map(&self, power: usize) -> Result { + let mut result = self.clone(); + result.c0.frobenius_map_in_place(power)?; + result.c1.frobenius_map_in_place(power)?; + P::mul_base_field_var_by_frob_coeff(&mut result.c1, power); + Ok(result) + } +} + +impl_bounded_ops!( + QuadExtVar, + QuadExtField

, + Add, + add, + AddAssign, + add_assign, + |this: &'a QuadExtVar, other: &'a QuadExtVar| { + let c0 = &this.c0 + &other.c0; + let c1 = &this.c1 + &other.c1; + QuadExtVar::new(c0, c1) + }, + |this: &'a QuadExtVar, other: QuadExtField

| { + this + QuadExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: QuadExtVarConfig), + for <'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF> +); +impl_bounded_ops!( + QuadExtVar, + QuadExtField

, + Sub, + sub, + SubAssign, + sub_assign, + |this: &'a QuadExtVar, other: &'a QuadExtVar| { + let c0 = &this.c0 - &other.c0; + let c1 = &this.c1 - &other.c1; + QuadExtVar::new(c0, c1) + }, + |this: &'a QuadExtVar, other: QuadExtField

| { + this - QuadExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: QuadExtVarConfig), + for <'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF> +); +impl_bounded_ops!( + QuadExtVar, + QuadExtField

, + Mul, + mul, + MulAssign, + mul_assign, + |this: &'a QuadExtVar, other: &'a QuadExtVar| { + // Karatsuba multiplication for Fp2: + // v0 = A.c0 * B.c0 + // v1 = A.c1 * B.c1 + // result.c0 = v0 + non_residue * v1 + // result.c1 = (A.c0 + A.c1) * (B.c0 + B.c1) - v0 - v1 + // Enforced with 3 constraints: + // A.c1 * B.c1 = v1 + // A.c0 * B.c0 = result.c0 - non_residue * v1 + // (A.c0+A.c1)*(B.c0+B.c1) = result.c1 + result.c0 + (1 - non_residue) * v1 + // Reference: + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Devegili, OhEigeartaigh, Scott, Dahab + let mut result = this.clone(); + let v0 = &this.c0 * &other.c0; + let v1 = &this.c1 * &other.c1; + + result.c1 += &this.c0; + result.c1 *= &other.c0 + &other.c1; + result.c1 -= &v0; + result.c1 -= &v1; + result.c0 = v0 + &QuadExtVar::::mul_base_field_by_nonresidue(&v1).unwrap(); + result + }, + |this: &'a QuadExtVar, other: QuadExtField

| { + this * QuadExtVar::constant(other) + }, + (BF: FieldVar, ConstraintF: PrimeField, P: QuadExtVarConfig), + for <'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF> +); + +impl EqGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let b0 = self.c0.is_eq(&other.c0)?; + let b1 = self.c1.is_eq(&other.c1)?; + b0.and(&b1) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.c0.conditional_enforce_equal(&other.c0, condition)?; + self.c1.conditional_enforce_equal(&other.c1, condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +impl ToBitsGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bits_le()?; + let mut c1 = self.c1.to_bits_le()?; + c0.append(&mut c1); + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bits_le()?; + let mut c1 = self.c1.to_non_unique_bits_le()?; + c0.append(&mut c1); + Ok(c0) + } +} + +impl ToBytesGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_bytes()?; + let mut c1 = self.c1.to_bytes()?; + c0.append(&mut c1); + Ok(c0) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let mut c0 = self.c0.to_non_unique_bytes()?; + let mut c1 = self.c1.to_non_unique_bytes()?; + c0.append(&mut c1); + Ok(c0) + } +} + +impl ToConstraintFieldGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + BF: ToConstraintFieldGadget, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let mut res = Vec::new(); + + res.extend_from_slice(&self.c0.to_constraint_field()?); + res.extend_from_slice(&self.c1.to_constraint_field()?); + + Ok(res) + } +} + +impl CondSelectGadget for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, + P: QuadExtVarConfig, +{ + #[inline] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let c0 = BF::conditionally_select(cond, &true_value.c0, &false_value.c0)?; + let c1 = BF::conditionally_select(cond, &true_value.c1, &false_value.c1)?; + Ok(Self::new(c0, c1)) + } +} + +impl TwoBitLookupGadget for QuadExtVar +where + BF: FieldVar + + TwoBitLookupGadget, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + type TableConstant = QuadExtField

; + + #[tracing::instrument(target = "r1cs")] + fn two_bit_lookup( + b: &[Boolean], + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c0 = BF::two_bit_lookup(b, &c0s)?; + let c1 = BF::two_bit_lookup(b, &c1s)?; + Ok(Self::new(c0, c1)) + } +} + +impl ThreeBitCondNegLookupGadget for QuadExtVar +where + BF: FieldVar + + ThreeBitCondNegLookupGadget, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + type TableConstant = QuadExtField

; + + #[tracing::instrument(target = "r1cs")] + fn three_bit_cond_neg_lookup( + b: &[Boolean], + b0b1: &Boolean, + c: &[Self::TableConstant], + ) -> Result { + let c0s = c.iter().map(|f| f.c0).collect::>(); + let c1s = c.iter().map(|f| f.c1).collect::>(); + let c0 = BF::three_bit_cond_neg_lookup(b, b0b1, &c0s)?; + let c1 = BF::three_bit_cond_neg_lookup(b, b0b1, &c1s)?; + Ok(Self::new(c0, c1)) + } +} + +impl AllocVar, ConstraintF> for QuadExtVar +where + BF: FieldVar, + ConstraintF: PrimeField, + P: QuadExtVarConfig, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let (c0, c1) = match f() { + Ok(fe) => (Ok(fe.borrow().c0), Ok(fe.borrow().c1)), + Err(_) => ( + Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), + ), + }; + + let c0 = BF::new_variable(ark_relations::ns!(cs, "c0"), || c0, mode)?; + let c1 = BF::new_variable(ark_relations::ns!(cs, "c1"), || c1, mode)?; + Ok(Self::new(c0, c1)) + } +} diff --git a/jolt-core/src/circuits/groups/curves/mod.rs b/jolt-core/src/circuits/groups/curves/mod.rs new file mode 100644 index 000000000..dbd7e00e0 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/mod.rs @@ -0,0 +1 @@ +pub mod short_weierstrass; diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12_381.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12_381.rs new file mode 100644 index 000000000..ad77f2e22 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/bls12_381.rs @@ -0,0 +1,7 @@ +use crate::circuits::groups::curves::short_weierstrass::ProjectiveVar; +use ark_bls12_381::{g1, Fq, Fr}; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; + +pub type FBaseVar = NonNativeFieldVar; + +pub type G1Var = ProjectiveVar; diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs new file mode 100644 index 000000000..0a28be59b --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/bn254.rs @@ -0,0 +1,7 @@ +use crate::circuits::groups::curves::short_weierstrass::ProjectiveVar; +use ark_bn254::{g1, Fq, Fr}; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; + +pub type FBaseVar = NonNativeFieldVar; + +pub type G1Var = ProjectiveVar; diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs new file mode 100644 index 000000000..e0ae386cf --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/mod.rs @@ -0,0 +1,964 @@ +use ark_ec::{ + short_weierstrass::{ + Affine as SWAffine, Projective as SWProjective, SWCurveConfig as SWModelParameters, + }, + AffineRepr, CurveGroup, +}; +use ark_ff::{BigInteger, BitIteratorBE, Field, One, PrimeField, Zero}; +use ark_r1cs_std::{fields::fp::FpVar, impl_bounded_ops, prelude::*, ToConstraintFieldGadget}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul, vec::Vec}; +use derivative::Derivative; +use non_zero_affine::NonZeroAffineVar; + +pub mod bls12_381; +pub mod bn254; + +/// This module provides a generic implementation of elliptic curve operations +/// for points on short-weierstrass curves in affine coordinates that **are +/// not** equal to zero. +/// +/// Note: this module is **unsafe** in general: it can synthesize unsatisfiable +/// or underconstrained constraint systems when a represented point _is_ equal +/// to zero. The [ProjectiveVar] gadget is the recommended way of working with +/// elliptic curve points. +pub mod non_zero_affine; + +/// An implementation of arithmetic for Short Weierstrass curves that relies on +/// the complete formulae derived in the paper of +/// [[Renes, Costello, Batina 2015]](). +#[derive(Derivative)] +#[derivative(Debug, Clone)] +#[must_use] +pub struct ProjectiveVar< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +> where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + /// The x-coordinate. + pub x: F, + /// The y-coordinate. + pub y: F, + /// The z-coordinate. + pub z: F, + #[derivative(Debug = "ignore")] + _params: PhantomData

, + #[derivative(Debug = "ignore")] + _constraint_f: PhantomData, +} + +/// An affine representation of a curve point. +#[derive(Derivative)] +#[derivative(Debug, Clone)] +#[must_use] +pub struct AffineVar< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +> where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + /// The x-coordinate. + pub x: F, + /// The y-coordinate. + pub y: F, + /// Is `self` the point at infinity. + pub infinity: Boolean, + #[derivative(Debug = "ignore")] + _params: PhantomData

, + #[derivative(Debug = "ignore")] + _constraint_f: PhantomData, +} + +impl AffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn new(x: F, y: F, infinity: Boolean) -> Self { + Self { + x, + y, + infinity, + _params: PhantomData, + _constraint_f: PhantomData, + } + } + + /// Returns the value assigned to `self` in the underlying + /// constraint system. + pub fn value(&self) -> Result, SynthesisError> { + Ok(match self.infinity.value()? { + true => SWAffine::identity(), + false => SWAffine::new(self.x.value()?, self.y.value()?), + }) + } +} + +impl ToConstraintFieldGadget for AffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, + F: ToConstraintFieldGadget, +{ + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let mut res = Vec::>::new(); + + res.extend_from_slice(&self.x.to_constraint_field()?); + res.extend_from_slice(&self.y.to_constraint_field()?); + res.extend_from_slice(&self.infinity.to_constraint_field()?); + + Ok(res) + } +} + +impl R1CSVar for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + type Value = SWProjective

; + + fn cs(&self) -> ConstraintSystemRef { + self.x.cs().or(self.y.cs()).or(self.z.cs()) + } + + fn value(&self) -> Result { + let (x, y, z) = (self.x.value()?, self.y.value()?, self.z.value()?); + let result = if let Some(z_inv) = z.inverse() { + SWAffine::new(x * &z_inv, y * &z_inv) + } else { + SWAffine::identity() + }; + Ok(result.into()) + } +} + +impl ProjectiveVar +where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +{ + /// Constructs `Self` from an `(x, y, z)` coordinate triple. + pub fn new(x: F, y: F, z: F) -> Self { + Self { + x, + y, + z, + _params: PhantomData, + _constraint_f: PhantomData, + } + } + + /// Convert this point into affine form. + #[tracing::instrument(target = "r1cs")] + pub fn to_affine(&self) -> Result, SynthesisError> { + if self.is_constant() { + let point = self.value()?.into_affine(); + let x = F::new_constant(ConstraintSystemRef::None, point.x)?; + let y = F::new_constant(ConstraintSystemRef::None, point.y)?; + let infinity = Boolean::constant(point.infinity); + Ok(AffineVar::new(x, y, infinity)) + } else { + let cs = self.cs(); + let infinity = self.is_zero()?; + let zero_x = F::zero(); + let zero_y = F::one(); + // Allocate a variable whose value is either `self.z.inverse()` if the inverse + // exists, and is zero otherwise. + let z_inv = F::new_witness(ark_relations::ns!(cs, "z_inverse"), || { + Ok(self.z.value()?.inverse().unwrap_or_else(P::BaseField::zero)) + })?; + // The inverse exists if `!self.is_zero()`. + // This means that `z_inv * self.z = 1` if `self.is_not_zero()`, and + // `z_inv * self.z = 0` if `self.is_zero()`. + // + // Thus, `z_inv * self.z = !self.is_zero()`. + z_inv.mul_equals(&self.z, &F::from(infinity.not()))?; + + let non_zero_x = &self.x * &z_inv; + let non_zero_y = &self.y * &z_inv; + + let x = infinity.select(&zero_x, &non_zero_x)?; + let y = infinity.select(&zero_y, &non_zero_y)?; + + Ok(AffineVar::new(x, y, infinity)) + } + } + + /// Allocates a new variable without performing an on-curve check, which is + /// useful if the variable is known to be on the curve (eg., if the point + /// is a constant or is a public input). + #[tracing::instrument(target = "r1cs", skip(cs, f))] + pub fn new_variable_omit_on_curve_check( + cs: impl Into>, + f: impl FnOnce() -> Result, SynthesisError>, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + + let (x, y, z) = match f() { + Ok(ge) => { + let ge = ge.into_affine(); + if ge.is_zero() { + ( + Ok(P::BaseField::zero()), + Ok(P::BaseField::one()), + Ok(P::BaseField::zero()), + ) + } else { + (Ok(ge.x), Ok(ge.y), Ok(P::BaseField::one())) + } + } + _ => ( + Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), + ), + }; + + let x = F::new_variable(ark_relations::ns!(cs, "x"), || x, mode)?; + let y = F::new_variable(ark_relations::ns!(cs, "y"), || y, mode)?; + let z = F::new_variable(ark_relations::ns!(cs, "z"), || z, mode)?; + + Ok(Self::new(x, y, z)) + } + + /// Mixed addition, which is useful when `other = (x2, y2)` is known to have + /// z = 1. + #[tracing::instrument(target = "r1cs", skip(self, other))] + pub(crate) fn add_mixed( + &self, + other: &NonZeroAffineVar, + ) -> Result { + // Complete mixed addition formula from Renes-Costello-Batina 2015 + // Algorithm 2 + // (https://eprint.iacr.org/2015/1060). + // Below, comments at the end of a line denote the corresponding + // step(s) of the algorithm + // + // Adapted from code in + // https://github.com/RustCrypto/elliptic-curves/blob/master/p256/src/arithmetic/projective.rs + let three_b = P::COEFF_B.double() + &P::COEFF_B; + let (x1, y1, z1) = (&self.x, &self.y, &self.z); + let (x2, y2) = (&other.x, &other.y); + + let xx = x1 * x2; // 1 + let yy = y1 * y2; // 2 + let xy_pairs = ((x1 + y1) * &(x2 + y2)) - (&xx + &yy); // 4, 5, 6, 7, 8 + let xz_pairs = (x2 * z1) + x1; // 8, 9 + let yz_pairs = (y2 * z1) + y1; // 10, 11 + + let axz = mul_by_coeff_a::(&xz_pairs); // 12 + + let bz3_part = &axz + z1 * three_b; // 13, 14 + + let yy_m_bz3 = &yy - &bz3_part; // 15 + let yy_p_bz3 = &yy + &bz3_part; // 16 + + let azz = mul_by_coeff_a::(z1); // 20 + let xx3_p_azz = xx.double().unwrap() + &xx + &azz; // 18, 19, 22 + + let bxz3 = &xz_pairs * three_b; // 21 + let b3_xz_pairs = mul_by_coeff_a::(&(&xx - &azz)) + &bxz3; // 23, 24, 25 + + let x = (&yy_m_bz3 * &xy_pairs) - &yz_pairs * &b3_xz_pairs; // 28,29, 30 + let y = (&yy_p_bz3 * &yy_m_bz3) + &xx3_p_azz * b3_xz_pairs; // 17, 26, 27 + let z = (&yy_p_bz3 * &yz_pairs) + xy_pairs * xx3_p_azz; // 31, 32, 33 + + Ok(ProjectiveVar::new(x, y, z)) + } + + /// Computes a scalar multiplication with a little-endian scalar of size + /// `P::ScalarField::MODULUS_BITS`. + #[tracing::instrument( + target = "r1cs", + skip(self, mul_result, multiple_of_power_of_two, bits) + )] + fn fixed_scalar_mul_le( + &self, + mul_result: &mut Self, + multiple_of_power_of_two: &mut NonZeroAffineVar, + bits: &[&Boolean], + ) -> Result<(), SynthesisError> { + let scalar_modulus_bits = ::MODULUS_BIT_SIZE as usize; + + assert!(scalar_modulus_bits >= bits.len()); + let split_len = ark_std::cmp::min(scalar_modulus_bits - 2, bits.len()); + let (affine_bits, proj_bits) = bits.split_at(split_len); + // Computes the standard little-endian double-and-add algorithm + // (Algorithm 3.26, Guide to Elliptic Curve Cryptography) + // + // We rely on *incomplete* affine formulae for partially computing this. + // However, we avoid exceptional edge cases because we partition the scalar + // into two chunks: one guaranteed to be less than p - 2, and the rest. + // We only use incomplete formulae for the first chunk, which means we avoid + // exceptions: + // + // `add_unchecked(a, b)` is incomplete when either `b.is_zero()`, or when + // `b = ±a`. During scalar multiplication, we don't hit either case: + // * `b = ±a`: `b = accumulator = k * a`, where `2 <= k < p - 1`. This implies + // that `k != p ± 1`, and so `b != (p ± 1) * a`. Because the group is finite, + // this in turn means that `b != ±a`, as required. + // * `a` or `b` is zero: for `a`, we handle the zero case after the loop; for + // `b`, notice that it is monotonically increasing, and furthermore, equals `k + // * a`, where `k != p = 0 mod p`. + + // Unlike normal double-and-add, here we start off with a non-zero + // `accumulator`, because `NonZeroAffineVar::add_unchecked` doesn't + // support addition with `zero`. In more detail, we initialize + // `accumulator` to be the initial value of `multiple_of_power_of_two`. + // This ensures that all unchecked additions of `accumulator` with later + // values of `multiple_of_power_of_two` are safe. However, to do this + // correctly, we need to perform two steps: + // * We must skip the LSB, and instead proceed assuming that it was 1. Later, we + // will conditionally subtract the initial value of `accumulator`: if LSB == + // 0: subtract initial_acc_value; else, subtract 0. + // * Because we are assuming the first bit, we must double + // `multiple_of_power_of_two`. + + let mut accumulator = multiple_of_power_of_two.clone(); + let initial_acc_value = accumulator.into_projective(); + + // The powers start at 2 (instead of 1) because we're skipping the first bit. + multiple_of_power_of_two.double_in_place()?; + + // As mentioned, we will skip the LSB, and will later handle it via a + // conditional subtraction. + for bit in affine_bits.iter().skip(1) { + if bit.is_constant() { + if *bit == &Boolean::TRUE { + accumulator = accumulator.add_unchecked(multiple_of_power_of_two)?; + } + } else { + let temp = accumulator.add_unchecked(multiple_of_power_of_two)?; + accumulator = bit.select(&temp, &accumulator)?; + } + multiple_of_power_of_two.double_in_place()?; + } + // Perform conditional subtraction: + + // We can convert to projective safely because the result is guaranteed to be + // non-zero by the condition on `affine_bits.len()`, and by the fact + // that `accumulator` is non-zero + let result = accumulator.into_projective(); + // If bits[0] is 0, then we have to subtract `self`; else, we subtract zero. + let subtrahend = bits[0].select(&Self::zero(), &initial_acc_value)?; + *mul_result += result - subtrahend; + + // Now, let's finish off the rest of the bits using our complete formulae + for bit in proj_bits { + if bit.is_constant() { + if *bit == &Boolean::TRUE { + *mul_result += &multiple_of_power_of_two.into_projective(); + } + } else { + let temp = &*mul_result + &multiple_of_power_of_two.into_projective(); + *mul_result = bit.select(&temp, mul_result)?; + } + multiple_of_power_of_two.double_in_place()?; + } + Ok(()) + } +} + +impl CurveVar, ConstraintF> for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn zero() -> Self { + Self::new(F::zero(), F::one(), F::zero()) + } + + fn is_zero(&self) -> Result, SynthesisError> { + self.z.is_zero() + } + + fn constant(g: SWProjective

) -> Self { + let cs = ConstraintSystemRef::None; + Self::new_variable_omit_on_curve_check(cs, || Ok(g), AllocationMode::Constant).unwrap() + } + + #[tracing::instrument(target = "r1cs", skip(cs, f))] + fn new_variable_omit_prime_order_check( + cs: impl Into>, + f: impl FnOnce() -> Result, SynthesisError>, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + // Curve equation in projective form: + // E: Y² * Z = X³ + aX * Z² + bZ³ + // + // This can be re-written as + // E: Y² * Z - bZ³ = X³ + aX * Z² + // E: Z * (Y² - bZ²) = X * (X² + aZ²) + // so, compute X², Y², Z², + // compute temp = X * (X² + aZ²) + // check Z.mul_equals((Y² - bZ²), temp) + // + // A total of 5 multiplications + + let g = Self::new_variable_omit_on_curve_check(cs, f, mode)?; + + if mode != AllocationMode::Constant { + // Perform on-curve check. + let b = P::COEFF_B; + let a = P::COEFF_A; + + let x2 = g.x.square()?; + let y2 = g.y.square()?; + let z2 = g.z.square()?; + let t = &g.x * (x2 + &z2 * a); + + g.z.mul_equals(&(y2 - z2 * b), &t)?; + } + Ok(g) + } + + /// Enforce that `self` is in the prime-order subgroup. + /// + /// Does so by multiplying by the prime order, and checking that the result + /// is unchanged. + // TODO: at the moment this doesn't work, because the addition and doubling + // formulae are incomplete for even-order points. + #[tracing::instrument(target = "r1cs")] + fn enforce_prime_order(&self) -> Result<(), SynthesisError> { + unimplemented!("cannot enforce prime order"); + // let r_minus_1 = (-P::ScalarField::one()).into_bigint(); + + // let mut result = Self::zero(); + // for b in BitIteratorBE::without_leading_zeros(r_minus_1) { + // result.double_in_place()?; + + // if b { + // result += self; + // } + // } + // self.negate()?.enforce_equal(&result)?; + // Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn double_in_place(&mut self) -> Result<(), SynthesisError> { + // Complete doubling formula from Renes-Costello-Batina 2015 + // Algorithm 3 + // (https://eprint.iacr.org/2015/1060). + // Below, comments at the end of a line denote the corresponding + // step(s) of the algorithm + // + // Adapted from code in + // https://github.com/RustCrypto/elliptic-curves/blob/master/p256/src/arithmetic/projective.rs + let three_b = P::COEFF_B.double() + &P::COEFF_B; + + let xx = self.x.square()?; // 1 + let yy = self.y.square()?; // 2 + let zz = self.z.square()?; // 3 + let xy2 = (&self.x * &self.y).double()?; // 4, 5 + let xz2 = (&self.x * &self.z).double()?; // 6, 7 + + let axz2 = mul_by_coeff_a::(&xz2); // 8 + + let bzz3_part = &axz2 + &zz * three_b; // 9, 10 + let yy_m_bzz3 = &yy - &bzz3_part; // 11 + let yy_p_bzz3 = &yy + &bzz3_part; // 12 + let y_frag = yy_p_bzz3 * &yy_m_bzz3; // 13 + let x_frag = yy_m_bzz3 * &xy2; // 14 + + let bxz3 = xz2 * three_b; // 15 + let azz = mul_by_coeff_a::(&zz); // 16 + let b3_xz_pairs = mul_by_coeff_a::(&(&xx - &azz)) + &bxz3; // 15, 16, 17, 18, 19 + let xx3_p_azz = (xx.double()? + &xx + &azz) * &b3_xz_pairs; // 23, 24, 25 + + let y = y_frag + &xx3_p_azz; // 26, 27 + let yz2 = (&self.y * &self.z).double()?; // 28, 29 + let x = x_frag - &(b3_xz_pairs * &yz2); // 30, 31 + let z = (yz2 * &yy).double()?.double()?; // 32, 33, 34 + self.x = x; + self.y = y; + self.z = z; + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn negate(&self) -> Result { + Ok(Self::new(self.x.clone(), self.y.negate()?, self.z.clone())) + } + + /// Computes `bits * self`, where `bits` is a little-endian + /// `Boolean` representation of a scalar. + #[tracing::instrument(target = "r1cs", skip(bits))] + fn scalar_mul_le<'a>( + &self, + bits: impl Iterator>, + ) -> Result { + if self.is_constant() && self.value().unwrap().is_zero() { + return Ok(self.clone()); + } + let self_affine = self.to_affine()?; + let (x, y, infinity) = (self_affine.x, self_affine.y, self_affine.infinity); + // We first handle the non-zero case, and then later + // will conditionally select zero if `self` was zero. + let non_zero_self = NonZeroAffineVar::new(x, y); + + let mut bits = bits.collect::>(); + if bits.is_empty() { + return Ok(Self::zero()); + } + // Remove unnecessary constant zeros in the most-significant positions. + bits = bits + .into_iter() + // We iterate from the MSB down. + .rev() + // Skip leading zeros, if they are constants. + .skip_while(|b| b.is_constant() && (!b.value().unwrap())) + .collect(); + // After collecting we are in big-endian form; we have to reverse to get back to + // little-endian. + bits.reverse(); + + let scalar_modulus_bits = ::MODULUS_BIT_SIZE; + let mut mul_result = Self::zero(); + let mut power_of_two_times_self = non_zero_self; + // We chunk up `bits` into `p`-sized chunks. + for bits in bits.chunks(scalar_modulus_bits as usize) { + self.fixed_scalar_mul_le(&mut mul_result, &mut power_of_two_times_self, bits)?; + } + + // The foregoing algorithm relies on incomplete addition, and so does not + // work when the input (`self`) is zero. We hence have to perform + // a check to ensure that if the input is zero, then so is the output. + // The cost of this check should be less than the benefit of using + // mixed addition in almost all cases. + infinity.select(&Self::zero(), &mul_result) + } + + #[tracing::instrument(target = "r1cs", skip(scalar_bits_with_bases))] + fn precomputed_base_scalar_mul_le<'a, I, B>( + &mut self, + scalar_bits_with_bases: I, + ) -> Result<(), SynthesisError> + where + I: Iterator)>, + B: Borrow>, + { + // We just ignore the provided bases and use the faster scalar multiplication. + let (bits, bases): (Vec<_>, Vec<_>) = scalar_bits_with_bases + .map(|(b, c)| (b.borrow().clone(), *c)) + .unzip(); + let base = bases[0]; + *self = Self::constant(base).scalar_mul_le(bits.iter())?; + Ok(()) + } +} + +impl ToConstraintFieldGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, + F: ToConstraintFieldGadget, +{ + fn to_constraint_field(&self) -> Result>, SynthesisError> { + self.to_affine()?.to_constraint_field() + } +} + +fn mul_by_coeff_a< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +>( + f: &F, +) -> F +where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + if !P::COEFF_A.is_zero() { + f * P::COEFF_A + } else { + F::zero() + } +} + +impl_bounded_ops!( + ProjectiveVar, + SWProjective

, + Add, + add, + AddAssign, + add_assign, + |mut this: &'a ProjectiveVar, mut other: &'a ProjectiveVar| { + // Implement complete addition for Short Weierstrass curves, following + // the complete addition formula from Renes-Costello-Batina 2015 + // (https://eprint.iacr.org/2015/1060). + // + // We special case handling of constants to get better constraint weight. + if this.is_constant() { + // we'll just act like `other` is constant. + core::mem::swap(&mut this, &mut other); + } + + if other.is_constant() { + // The value should exist because `other` is a constant. + let other = other.value().unwrap(); + if other.is_zero() { + // this + 0 = this + this.clone() + } else { + // We'll use mixed addition to add non-zero constants. + let x = F::constant(other.x); + let y = F::constant(other.y); + this.add_mixed(&NonZeroAffineVar::new(x, y)).unwrap() + } + } else { + // Complete addition formula from Renes-Costello-Batina 2015 + // Algorithm 1 + // (https://eprint.iacr.org/2015/1060). + // Below, comments at the end of a line denote the corresponding + // step(s) of the algorithm + // + // Adapted from code in + // https://github.com/RustCrypto/elliptic-curves/blob/master/p256/src/arithmetic/projective.rs + let three_b = P::COEFF_B.double() + &P::COEFF_B; + let (x1, y1, z1) = (&this.x, &this.y, &this.z); + let (x2, y2, z2) = (&other.x, &other.y, &other.z); + + let xx = x1 * x2; // 1 + let yy = y1 * y2; // 2 + let zz = z1 * z2; // 3 + let xy_pairs = ((x1 + y1) * &(x2 + y2)) - (&xx + &yy); // 4, 5, 6, 7, 8 + let xz_pairs = ((x1 + z1) * &(x2 + z2)) - (&xx + &zz); // 9, 10, 11, 12, 13 + let yz_pairs = ((y1 + z1) * &(y2 + z2)) - (&yy + &zz); // 14, 15, 16, 17, 18 + + let axz = mul_by_coeff_a::(&xz_pairs); // 19 + + let bzz3_part = &axz + &zz * three_b; // 20, 21 + + let yy_m_bzz3 = &yy - &bzz3_part; // 22 + let yy_p_bzz3 = &yy + &bzz3_part; // 23 + + let azz = mul_by_coeff_a::(&zz); + let xx3_p_azz = xx.double().unwrap() + &xx + &azz; // 25, 26, 27, 29 + + let bxz3 = &xz_pairs * three_b; // 28 + let b3_xz_pairs = mul_by_coeff_a::(&(&xx - &azz)) + &bxz3; // 30, 31, 32 + + let x = (&yy_m_bzz3 * &xy_pairs) - &yz_pairs * &b3_xz_pairs; // 35, 39, 40 + let y = (&yy_p_bzz3 * &yy_m_bzz3) + &xx3_p_azz * b3_xz_pairs; // 24, 36, 37, 38 + let z = (&yy_p_bzz3 * &yz_pairs) + xy_pairs * xx3_p_azz; // 41, 42, 43 + + ProjectiveVar::new(x, y, z) + } + + }, + |this: &'a ProjectiveVar, other: SWProjective

| { + this + ProjectiveVar::constant(other) + }, + (ConstraintF: PrimeField, F: FieldVar, P: SWModelParameters), + for <'b> &'b F: FieldOpsBounds<'b, P::BaseField, F>, +); + +impl_bounded_ops!( + ProjectiveVar, + SWProjective

, + Sub, + sub, + SubAssign, + sub_assign, + |this: &'a ProjectiveVar, other: &'a ProjectiveVar| this + other.negate().unwrap(), + |this: &'a ProjectiveVar, other: SWProjective

| this - ProjectiveVar::constant(other), + (ConstraintF: PrimeField, F: FieldVar, P: SWModelParameters), + for <'b> &'b F: FieldOpsBounds<'b, P::BaseField, F> +); + +impl<'a, P, ConstraintF, F> GroupOpsBounds<'a, SWProjective

, ProjectiveVar> + for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'b> &'b F: FieldOpsBounds<'b, P::BaseField, F>, +{ +} + +impl<'a, P, ConstraintF, F> GroupOpsBounds<'a, SWProjective

, ProjectiveVar> + for &'a ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'b> &'b F: FieldOpsBounds<'b, P::BaseField, F>, +{ +} + +impl CondSelectGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let x = cond.select(&true_value.x, &false_value.x)?; + let y = cond.select(&true_value.y, &false_value.y)?; + let z = cond.select(&true_value.z, &false_value.z)?; + + Ok(Self::new(x, y, z)) + } +} + +impl EqGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; + let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; + let coordinates_equal = x_equal.and(&y_equal)?; + let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; + both_are_zero.or(&coordinates_equal) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; + let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; + let coordinates_equal = x_equal.and(&y_equal)?; + let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; + both_are_zero + .or(&coordinates_equal)? + .conditional_enforce_equal(&Boolean::Constant(true), condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +impl AllocVar, ConstraintF> for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + Self::new_variable(cs, || f().map(|b| SWProjective::from(*b.borrow())), mode) + } +} + +impl AllocVar, ConstraintF> for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let f = || Ok(*f()?.borrow()); + match mode { + AllocationMode::Constant => Self::new_variable_omit_prime_order_check(cs, f, mode), + AllocationMode::Input => Self::new_variable_omit_prime_order_check(cs, f, mode), + AllocationMode::Witness => { + // if cofactor.is_even(): + // divide until you've removed all even factors + // else: + // just directly use double and add. + let mut power_of_2: u32 = 0; + let mut cofactor = P::COFACTOR.to_vec(); + while cofactor[0] % 2 == 0 { + div2(&mut cofactor); + power_of_2 += 1; + } + + let cofactor_weight = BitIteratorBE::new(cofactor.as_slice()) + .filter(|b| *b) + .count(); + let modulus_minus_1 = (-P::ScalarField::one()).into_bigint(); // r - 1 + let modulus_minus_1_weight = + BitIteratorBE::new(modulus_minus_1).filter(|b| *b).count(); + + // We pick the most efficient method of performing the prime order check: + // If the cofactor has lower hamming weight than the scalar field's modulus, + // we first multiply by the inverse of the cofactor, and then, after allocating, + // multiply by the cofactor. This ensures the resulting point has no cofactors + // + // Else, we multiply by the scalar field's modulus and ensure that the result + // equals the identity. + + let (mut ge, iter) = if cofactor_weight < modulus_minus_1_weight { + let ge = Self::new_variable_omit_prime_order_check( + ark_relations::ns!(cs, "Witness without subgroup check with cofactor mul"), + || f().map(|g| g.into_affine().mul_by_cofactor_inv().into()), + mode, + )?; + ( + ge, + BitIteratorBE::without_leading_zeros(cofactor.as_slice()), + ) + } else { + let ge = Self::new_variable_omit_prime_order_check( + ark_relations::ns!(cs, "Witness without subgroup check with `r` check"), + || { + f().map(|g| { + let g = g.into_affine(); + let mut power_of_two = P::ScalarField::one().into_bigint(); + power_of_two.muln(power_of_2); + let power_of_two_inv = P::ScalarField::from_bigint(power_of_two) + .and_then(|n| n.inverse()) + .unwrap(); + g.mul(power_of_two_inv) + }) + }, + mode, + )?; + + ( + ge, + BitIteratorBE::without_leading_zeros(modulus_minus_1.as_ref()), + ) + }; + // Remove the even part of the cofactor + for _ in 0..power_of_2 { + ge.double_in_place()?; + } + + let mut result = Self::zero(); + for b in iter { + result.double_in_place()?; + + if b { + result += &ge + } + } + if cofactor_weight < modulus_minus_1_weight { + Ok(result) + } else { + ge.enforce_equal(&ge)?; + Ok(ge) + } + } + } + } +} + +#[inline] +fn div2(limbs: &mut [u64]) { + let mut t = 0; + for i in limbs.iter_mut().rev() { + let t2 = *i << 63; + *i >>= 1; + *i |= t; + t = t2; + } +} + +impl ToBitsGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bits = g.x.to_bits_le()?; + let y_bits = g.y.to_bits_le()?; + bits.extend_from_slice(&y_bits); + bits.push(g.infinity); + Ok(bits) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bits = g.x.to_non_unique_bits_le()?; + let y_bits = g.y.to_non_unique_bits_le()?; + bits.extend_from_slice(&y_bits); + bits.push(g.infinity); + Ok(bits) + } +} + +impl ToBytesGadget for ProjectiveVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bytes = g.x.to_bytes()?; + let y_bytes = g.y.to_bytes()?; + let inf_bytes = g.infinity.to_bytes()?; + bytes.extend_from_slice(&y_bytes); + bytes.extend_from_slice(&inf_bytes); + Ok(bytes) + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + let g = self.to_affine()?; + let mut bytes = g.x.to_non_unique_bytes()?; + let y_bytes = g.y.to_non_unique_bytes()?; + let inf_bytes = g.infinity.to_non_unique_bytes()?; + bytes.extend_from_slice(&y_bytes); + bytes.extend_from_slice(&inf_bytes); + Ok(bytes) + } +} diff --git a/jolt-core/src/circuits/groups/curves/short_weierstrass/non_zero_affine.rs b/jolt-core/src/circuits/groups/curves/short_weierstrass/non_zero_affine.rs new file mode 100644 index 000000000..120760fb2 --- /dev/null +++ b/jolt-core/src/circuits/groups/curves/short_weierstrass/non_zero_affine.rs @@ -0,0 +1,399 @@ +use super::*; +use ark_ec::Group; +use ark_std::ops::Add; +use derivative::Derivative; + +/// An affine representation of a prime order curve point that is guaranteed +/// to *not* be the point at infinity. +#[derive(Derivative)] +#[derivative(Debug, Clone)] +#[must_use] +pub struct NonZeroAffineVar< + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, +> where + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + /// The x-coordinate. + pub x: F, + /// The y-coordinate. + pub y: F, + #[derivative(Debug = "ignore")] + _params: PhantomData

, + #[derivative(Debug = "ignore")] + _constraint_f: PhantomData, +} + +impl NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + pub fn new(x: F, y: F) -> Self { + Self { + x, + y, + _params: PhantomData, + _constraint_f: PhantomData, + } + } + + /// Converts self into a non-zero projective point. + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn into_projective(&self) -> ProjectiveVar { + ProjectiveVar::new(self.x.clone(), self.y.clone(), F::one()) + } + + /// Performs an addition without checking that other != ±self. + #[tracing::instrument(target = "r1cs", skip(self, other))] + pub fn add_unchecked(&self, other: &Self) -> Result { + if [self, other].is_constant() { + let result = self.value()?.add(other.value()?).into_affine(); + Ok(Self::new(F::constant(result.x), F::constant(result.y))) + } else { + let (x1, y1) = (&self.x, &self.y); + let (x2, y2) = (&other.x, &other.y); + // Then, + // slope lambda := (y2 - y1)/(x2 - x1); + // x3 = lambda^2 - x1 - x2; + // y3 = lambda * (x1 - x3) - y1 + let numerator = y2 - y1; + let denominator = x2 - x1; + // It's okay to use `unchecked` here, because the precondition of + // `add_unchecked` is that self != ±other, which means that + // `numerator` and `denominator` are both non-zero. + let lambda = numerator.mul_by_inverse_unchecked(&denominator)?; + let x3 = lambda.square()? - x1 - x2; + let y3 = lambda * &(x1 - &x3) - y1; + Ok(Self::new(x3, y3)) + } + } + + /// Doubles `self`. As this is a prime order curve point, + /// the output is guaranteed to not be the point at infinity. + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn double(&self) -> Result { + if [self].is_constant() { + let result = SWProjective::

::from(self.value()?) + .double() + .into_affine(); + // Panic if the result is zero. + assert!(!result.is_zero()); + Ok(Self::new(F::constant(result.x), F::constant(result.y))) + } else { + let (x1, y1) = (&self.x, &self.y); + let x1_sqr = x1.square()?; + // Then, + // tangent lambda := (3 * x1^2 + a) / (2 * y1); + // x3 = lambda^2 - 2x1 + // y3 = lambda * (x1 - x3) - y1 + let numerator = x1_sqr.double()? + &x1_sqr + P::COEFF_A; + let denominator = y1.double()?; + // It's okay to use `unchecked` here, because the precondition of `double` is + // that self != zero. + let lambda = numerator.mul_by_inverse_unchecked(&denominator)?; + let x3 = lambda.square()? - x1.double()?; + let y3 = lambda * &(x1 - &x3) - y1; + Ok(Self::new(x3, y3)) + } + } + + /// Computes `(self + other) + self`. This method requires only 5 + /// constraints, less than the 7 required when computing via + /// `self.double() + other`. + /// + /// This follows the formulae from [\[ELM03\]](https://arxiv.org/abs/math/0208038). + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn double_and_add_unchecked(&self, other: &Self) -> Result { + if [self].is_constant() || other.is_constant() { + self.double()?.add_unchecked(other) + } else { + // It's okay to use `unchecked` the precondition is that `self != ±other` (i.e. + // same logic as in `add_unchecked`) + let (x1, y1) = (&self.x, &self.y); + let (x2, y2) = (&other.x, &other.y); + + // Calculate self + other: + // slope lambda := (y2 - y1)/(x2 - x1); + // x3 = lambda^2 - x1 - x2; + // y3 = lambda * (x1 - x3) - y1 + let numerator = y2 - y1; + let denominator = x2 - x1; + let lambda_1 = numerator.mul_by_inverse_unchecked(&denominator)?; + + let x3 = lambda_1.square()? - x1 - x2; + + // Calculate final addition slope: + let lambda_2 = + (lambda_1 + y1.double()?.mul_by_inverse_unchecked(&(&x3 - x1))?).negate()?; + + let x4 = lambda_2.square()? - x1 - x3; + let y4 = lambda_2 * &(x1 - &x4) - y1; + Ok(Self::new(x4, y4)) + } + } + + /// Doubles `self` in place. + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn double_in_place(&mut self) -> Result<(), SynthesisError> { + *self = self.double()?; + Ok(()) + } +} + +impl R1CSVar for NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + type Value = SWAffine

; + + fn cs(&self) -> ConstraintSystemRef { + self.x.cs().or(self.y.cs()) + } + + fn value(&self) -> Result, SynthesisError> { + Ok(SWAffine::new(self.x.value()?, self.y.value()?)) + } +} + +impl CondSelectGadget for NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let x = cond.select(&true_value.x, &false_value.x)?; + let y = cond.select(&true_value.y, &false_value.y)?; + + Ok(Self::new(x, y)) + } +} + +impl EqGadget for NonZeroAffineVar +where + P: SWModelParameters, + ConstraintF: PrimeField, + F: FieldVar, + for<'a> &'a F: FieldOpsBounds<'a, P::BaseField, F>, +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let x_equal = self.x.is_eq(&other.x)?; + let y_equal = self.y.is_eq(&other.y)?; + x_equal.and(&y_equal) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let x_equal = self.x.is_eq(&other.x)?; + let y_equal = self.y.is_eq(&other.y)?; + let coordinates_equal = x_equal.and(&y_equal)?; + coordinates_equal.conditional_enforce_equal(&Boolean::Constant(true), condition)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> { + self.x.enforce_equal(&other.x)?; + self.y.enforce_equal(&other.y)?; + Ok(()) + } + + #[inline] + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let is_equal = self.is_eq(other)?; + is_equal + .and(condition)? + .enforce_equal(&Boolean::Constant(false)) + } +} + +#[cfg(test)] +mod test_non_zero_affine { + use ark_ec::{models::short_weierstrass::SWCurveConfig, CurveGroup}; + use ark_r1cs_std::{ + alloc::AllocVar, + eq::EqGadget, + fields::fp::{AllocatedFp, FpVar}, + groups::{ + curves::short_weierstrass::{non_zero_affine::NonZeroAffineVar, ProjectiveVar}, + CurveVar, + }, + R1CSVar, + }; + use ark_relations::r1cs::ConstraintSystem; + use ark_std::{vec::Vec, One}; + use ark_test_curves::bls12_381::{g1::Config as G1Config, Fq}; + + #[test] + fn correctness_test_1() { + let cs = ConstraintSystem::::new_ref(); + + let x = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.x)).unwrap(), + ); + let y = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.y)).unwrap(), + ); + + // The following code uses `double` and `add` (`add_unchecked`) to compute + // (1 + 2 + ... + 2^9) G + + let sum_a = { + let mut a = ProjectiveVar::>::new( + x.clone(), + y.clone(), + FpVar::Constant(Fq::one()), + ); + + let mut double_sequence = Vec::new(); + double_sequence.push(a.clone()); + + for _ in 1..10 { + a = a.double().unwrap(); + double_sequence.push(a.clone()); + } + + let mut sum = double_sequence[0].clone(); + for elem in double_sequence.iter().skip(1) { + sum = sum + elem; + } + + let sum = sum.value().unwrap().into_affine(); + (sum.x, sum.y) + }; + + let sum_b = { + let mut a = NonZeroAffineVar::>::new(x, y); + + let mut double_sequence = Vec::new(); + double_sequence.push(a.clone()); + + for _ in 1..10 { + a = a.double().unwrap(); + double_sequence.push(a.clone()); + } + + let mut sum = double_sequence[0].clone(); + for elem in double_sequence.iter().skip(1) { + sum = sum.add_unchecked(&elem).unwrap(); + } + + (sum.x.value().unwrap(), sum.y.value().unwrap()) + }; + + assert_eq!(sum_a.0, sum_b.0); + assert_eq!(sum_a.1, sum_b.1); + } + + #[test] + fn correctness_test_2() { + let cs = ConstraintSystem::::new_ref(); + + let x = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.x)).unwrap(), + ); + let y = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.y)).unwrap(), + ); + + // The following code tests `double_and_add`. + let sum_a = { + let a = ProjectiveVar::>::new( + x.clone(), + y.clone(), + FpVar::Constant(Fq::one()), + ); + + let mut cur = a.clone(); + cur.double_in_place().unwrap(); + for _ in 1..10 { + cur.double_in_place().unwrap(); + cur = cur + &a; + } + + let sum = cur.value().unwrap().into_affine(); + (sum.x, sum.y) + }; + + let sum_b = { + let a = NonZeroAffineVar::>::new(x, y); + + let mut cur = a.double().unwrap(); + for _ in 1..10 { + cur = cur.double_and_add_unchecked(&a).unwrap(); + } + + (cur.x.value().unwrap(), cur.y.value().unwrap()) + }; + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(sum_a.0, sum_b.0); + assert_eq!(sum_a.1, sum_b.1); + } + + #[test] + fn correctness_test_eq() { + let cs = ConstraintSystem::::new_ref(); + + let x = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.x)).unwrap(), + ); + let y = FpVar::Var( + AllocatedFp::::new_witness(cs.clone(), || Ok(G1Config::GENERATOR.y)).unwrap(), + ); + + let a = NonZeroAffineVar::>::new(x, y); + + let n = 10; + + let a_multiples: Vec>> = + std::iter::successors(Some(a.clone()), |acc| Some(acc.add_unchecked(&a).unwrap())) + .take(n) + .collect(); + + let all_equal: Vec>> = (0..n / 2) + .map(|i| { + a_multiples[i] + .add_unchecked(&a_multiples[n - i - 1]) + .unwrap() + }) + .collect(); + + for i in 0..n - 1 { + a_multiples[i] + .enforce_not_equal(&a_multiples[i + 1]) + .unwrap(); + } + for i in 0..all_equal.len() - 1 { + all_equal[i].enforce_equal(&all_equal[i + 1]).unwrap(); + } + + assert!(cs.is_satisfied().unwrap()); + } +} diff --git a/jolt-core/src/circuits/groups/mod.rs b/jolt-core/src/circuits/groups/mod.rs new file mode 100644 index 000000000..26b097205 --- /dev/null +++ b/jolt-core/src/circuits/groups/mod.rs @@ -0,0 +1 @@ +pub mod curves; diff --git a/jolt-core/src/circuits/mod.rs b/jolt-core/src/circuits/mod.rs new file mode 100644 index 000000000..fd74b94db --- /dev/null +++ b/jolt-core/src/circuits/mod.rs @@ -0,0 +1,5 @@ +pub mod fields; +pub mod groups; +pub mod offloaded; +pub mod poly; +pub mod transcript; diff --git a/jolt-core/src/circuits/offloaded/mod.rs b/jolt-core/src/circuits/offloaded/mod.rs new file mode 100644 index 000000000..c44647897 --- /dev/null +++ b/jolt-core/src/circuits/offloaded/mod.rs @@ -0,0 +1,241 @@ +use crate::snark::{DeferredOpData, OffloadedDataCircuit}; +use ark_ec::{pairing::Pairing, CurveGroup, VariableBaseMSM}; +use ark_r1cs_std::{ + alloc::AllocVar, eq::EqGadget, fields::fp::FpVar, fields::FieldVar, groups::CurveVar, R1CSVar, + ToConstraintFieldGadget, +}; +use ark_relations::{ + ns, + r1cs::{Namespace, SynthesisError}, +}; +use ark_std::{One, Zero}; +use std::marker::PhantomData; + +pub trait MSMGadget +where + FVar: FieldVar + ToConstraintFieldGadget, + G: CurveGroup, + GVar: CurveVar + ToConstraintFieldGadget, +{ + fn msm( + &self, + cs: impl Into>, + g1s: &[GVar], + scalars: &[FVar], + ) -> Result; +} + +pub struct OffloadedMSMGadget<'a, FVar, E, GVar, Circuit> +where + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + E: Pairing, + GVar: CurveVar + ToConstraintFieldGadget, +{ + _params: PhantomData<(FVar, E, GVar)>, + circuit: &'a Circuit, +} + +impl<'a, FVar, E, GVar, Circuit> OffloadedMSMGadget<'a, FVar, E, GVar, Circuit> +where + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + E: Pairing, + GVar: CurveVar + ToConstraintFieldGadget, +{ + pub fn new(circuit: &'a Circuit) -> Self { + Self { + _params: PhantomData, + circuit, + } + } +} + +impl<'a, FVar, E, GVar, Circuit> MSMGadget + for OffloadedMSMGadget<'a, FVar, E, GVar, Circuit> +where + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + E: Pairing, + GVar: CurveVar + ToConstraintFieldGadget, +{ + fn msm( + &self, + cs: impl Into>, + g1s: &[GVar], + scalars: &[FVar], + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + + let g1_values = g1s + .iter() + .map(|g1| g1.value().ok().map(|g1| g1.into_affine())) + .collect::>>(); + + let scalar_values = scalars + .iter() + .map(|s| s.value().ok()) + .collect::>>(); + + let (full_msm_value, msm_g1_value) = g1_values + .zip(scalar_values) + .map(|(g1s, scalars)| { + let r_g1 = E::G1::msm_unchecked(&g1s, &scalars); + let minus_one = -E::ScalarField::one(); + ( + ( + [g1s, vec![r_g1.into()]].concat(), + [scalars, vec![minus_one]].concat(), + ), + r_g1, + ) + }) + .unzip(); + + let msm_g1_var = GVar::new_witness(ns!(cs, "msm_g1"), || { + msm_g1_value.ok_or(SynthesisError::AssignmentMissing) + })?; + + { + let g1s = g1s.to_vec(); + let scalars = scalars.to_vec(); + let msm_g1_var = msm_g1_var.clone(); + let ns = ns!(cs, "deferred_msm"); + let cs = ns.cs(); + + self.circuit.defer_op(move || { + // write scalars to public_input + for x in scalars { + let scalar_input = FVar::new_input(ns!(cs, "scalar"), || x.value())?; + scalar_input.enforce_equal(&x)?; + } + + // write g1s to public_input + for g1 in g1s { + let f_vec = g1.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; + f_input.enforce_equal(f)?; + } + } + + // write msm_g1 to public_input + { + dbg!(cs.num_instance_variables() - 1); + let f_vec = msm_g1_var.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "msm_g1"), || f.value())?; + f_input.enforce_equal(f)?; + } + } + dbg!(cs.num_constraints()); + dbg!(cs.num_instance_variables()); + + Ok(DeferredOpData::MSM(full_msm_value)) + }) + }; + dbg!(cs.num_constraints()); + + Ok(msm_g1_var) + } +} + +pub trait PairingGadget +where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + fn multi_pairing_is_zero( + &self, + cs: impl Into>, + g1s: &[G1Var], + g2s: &[E::G2Affine], + ) -> Result<(), SynthesisError>; +} + +pub struct OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> +where + E: Pairing, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + GVar: CurveVar + ToConstraintFieldGadget, +{ + _params: PhantomData<(E, FVar, GVar)>, + circuit: &'a Circuit, +} + +impl<'a, E, FVar, GVar, Circuit> OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> +where + E: Pairing, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + GVar: CurveVar + ToConstraintFieldGadget, +{ + pub(crate) fn new(circuit: &'a Circuit) -> Self { + Self { + _params: PhantomData, + circuit, + } + } +} + +impl<'a, E, FVar, GVar, Circuit> PairingGadget + for OffloadedPairingGadget<'a, E, FVar, GVar, Circuit> +where + E: Pairing, + Circuit: OffloadedDataCircuit, + FVar: FieldVar + ToConstraintFieldGadget, + GVar: CurveVar + ToConstraintFieldGadget, +{ + fn multi_pairing_is_zero( + &self, + cs: impl Into>, + g1s: &[GVar], + g2s: &[E::G2Affine], + ) -> Result<(), SynthesisError> { + let ns = cs.into(); + let cs = ns.cs(); + + let g1_values_opt = g1s + .iter() + .map(|g1| g1.value().ok().map(|g1| g1.into_affine())) + .collect::>>(); + + let g2_values = g2s; + + for g1_values in g1_values_opt.iter() { + if !E::multi_pairing(g1_values, g2_values).is_zero() { + return Err(SynthesisError::Unsatisfiable); + } + } + + { + let g2_values = g2_values.to_vec(); + let g1s = g1s.to_vec(); + let ns = ns!(cs, "deferred_pairing"); + let cs = ns.cs(); + + self.circuit.defer_op(move || { + // write g1s to public_input + for g1 in g1s { + let f_vec = g1.to_constraint_field()?; + + for f in f_vec.iter() { + let f_input = FpVar::new_input(ns!(cs, "g1s"), || f.value())?; + f_input.enforce_equal(f)?; + } + } + + dbg!(cs.num_constraints()); + dbg!(cs.num_instance_variables()); + + Ok(DeferredOpData::Pairing(g1_values_opt, g2_values)) + }) + } + + Ok(()) + } +} diff --git a/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs new file mode 100644 index 000000000..e0a3df7c9 --- /dev/null +++ b/jolt-core/src/circuits/poly/commitment/commitment_scheme.rs @@ -0,0 +1,27 @@ +use crate::poly::commitment::commitment_scheme::CommitmentScheme; +use ark_crypto_primitives::sponge::constraints::SpongeWithGadget; +use ark_ff::PrimeField; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::prelude::*; +use ark_relations::r1cs::SynthesisError; + +pub trait CommitmentVerifierGadget +where + ConstraintF: PrimeField, + CS: CommitmentScheme, + S: SpongeWithGadget, +{ + type VerifyingKeyVar: AllocVar + Clone; + type ProofVar: AllocVar + Clone; + type CommitmentVar: AllocVar + Clone; + + fn verify( + &self, + proof: &Self::ProofVar, + vk: &Self::VerifyingKeyVar, + transcript: &mut S::Var, + opening_point: &[FpVar], + opening: &FpVar, + commitment: &Self::CommitmentVar, + ) -> Result, SynthesisError>; +} diff --git a/jolt-core/src/circuits/poly/commitment/hyperkzg.rs b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs new file mode 100644 index 000000000..924c0ca4e --- /dev/null +++ b/jolt-core/src/circuits/poly/commitment/hyperkzg.rs @@ -0,0 +1,583 @@ +use crate::circuits::transcript::ImplAbsorbGVar; +use crate::{ + circuits::{ + offloaded::{MSMGadget, OffloadedMSMGadget, OffloadedPairingGadget, PairingGadget}, + poly::commitment::commitment_scheme::CommitmentVerifierGadget, + transcript::ImplAbsorbFVar, + }, + field::JoltField, + poly::commitment::hyperkzg::{ + HyperKZG, HyperKZGCommitment, HyperKZGProof, HyperKZGProverKey, HyperKZGVerifierKey, + }, + snark::OffloadedDataCircuit, +}; +use ark_crypto_primitives::sponge::constraints::{CryptographicSpongeVar, SpongeWithGadget}; +use ark_ec::pairing::Pairing; +use ark_ff::PrimeField; +use ark_r1cs_std::{boolean::Boolean, fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; +use ark_relations::r1cs::ConstraintSystemRef; +use ark_relations::{ + ns, + r1cs::{Namespace, SynthesisError}, +}; +use ark_std::{borrow::Borrow, iterable::Iterable, marker::PhantomData, One}; + +#[derive(Clone)] +pub struct HyperKZGProofVar +where + E: Pairing, +{ + pub com: Vec, + pub w: Vec, + pub v: Vec>>, +} + +impl AllocVar, E::ScalarField> for HyperKZGProofVar +where + E: Pairing, + G1Var: CurveVar, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + + let proof_hold = f()?; + let proof = proof_hold.borrow(); + + let com = proof + .com + .iter() + .map(|&x| G1Var::new_variable(ns!(cs, "com").clone(), || Ok(x), mode)) + .collect::, _>>()?; + let w = proof + .w + .iter() + .map(|&x| G1Var::new_variable(ns!(cs, "w").clone(), || Ok(x), mode)) + .collect::, _>>()?; + let v = proof + .v + .iter() + .map(|v_i| { + v_i.iter() + .map(|&v_ij| FpVar::new_variable(ns!(cs, "v_ij"), || Ok(v_ij), mode)) + .collect::, _>>() + }) + .collect::, _>>()?; + + Ok(Self { com, w, v }) + } +} + +#[derive(Clone, Debug)] +pub struct HyperKZGCommitmentVar { + pub c: G1Var, +} + +impl AllocVar, E::ScalarField> for HyperKZGCommitmentVar +where + E: Pairing, + G1Var: CurveVar, +{ + fn new_variable>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + Ok(Self { + c: G1Var::new_variable(cs, || Ok(f()?.borrow().0), mode)?, + }) + } +} + +#[derive(Clone, Debug)] +pub struct HyperKZGVerifierKeyVar { + pub g1: G1Var, + // pub g2: G2Var, + // pub beta_g2: G2Var, +} + +impl AllocVar<(HyperKZGProverKey, HyperKZGVerifierKey), E::ScalarField> + for HyperKZGVerifierKeyVar +where + E: Pairing, + G1Var: CurveVar, +{ + fn new_variable, HyperKZGVerifierKey)>>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + // TODO implement + Ok(Self { + g1: G1Var::new_variable(cs, || Ok(f()?.borrow().1.kzg_vk.g1), mode)?, + }) + } +} + +pub struct HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> +where + E: Pairing, + S: SpongeWithGadget, + G1Var: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, +{ + _params: PhantomData<(E, S, G1Var)>, + circuit: &'a Circuit, + cs: ConstraintSystemRef, + g2_elements: Vec, +} + +impl<'a, E, S, G1Var, Circuit> HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> +where + E: Pairing, + S: SpongeWithGadget, + G1Var: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, +{ + pub fn new( + circuit: &'a Circuit, + cs: impl Into>, + g2_elements: Vec, + ) -> Self { + let ns = cs.into(); + let cs: ConstraintSystemRef = ns.cs(); + Self { + _params: PhantomData, + circuit, + cs, + g2_elements, + } + } +} + +impl<'a, E, S, F, G1Var, Circuit> CommitmentVerifierGadget, S> + for HyperKZGVerifierGadget<'a, E, S, G1Var, Circuit> +where + F: PrimeField + JoltField, + E: Pairing, + S: SpongeWithGadget, + G1Var: CurveVar + ToConstraintFieldGadget, + Circuit: OffloadedDataCircuit, +{ + type VerifyingKeyVar = HyperKZGVerifierKeyVar; + type ProofVar = HyperKZGProofVar; + type CommitmentVar = HyperKZGCommitmentVar; + + fn verify( + &self, + proof: &Self::ProofVar, + vk: &Self::VerifyingKeyVar, + transcript: &mut S::Var, + opening_point: &[FpVar], + opening: &FpVar, + commitment: &Self::CommitmentVar, + ) -> Result, SynthesisError> { + let ell = opening_point.len(); + assert!(ell >= 2); + + let HyperKZGProofVar { com, w, v } = proof; + let HyperKZGCommitmentVar { c } = commitment; + let HyperKZGVerifierKeyVar { g1 } = vk; + + transcript.absorb( + &com.iter() + .map(|com| ImplAbsorbGVar::wrap(com)) + .collect::>(), + )?; + + let r = transcript + .squeeze_field_elements(1)? + .into_iter() + .next() + .unwrap(); + + let u = [r.clone(), r.negate()?, r.clone() * &r]; + + let com = [vec![c.clone()], com.clone()].concat(); + + if v.len() != 3 { + return Err(SynthesisError::Unsatisfiable); + } + if w.len() != 3 { + return Err(SynthesisError::Unsatisfiable); + } + if ell != v[0].len() || ell != v[1].len() || ell != v[2].len() || ell != com.len() { + return Err(SynthesisError::Unsatisfiable); + } + + let x = opening_point; + let y = [v[2].clone(), vec![opening.clone()]].concat(); + + let one = FpVar::one(); + let two = FpVar::Constant(F::from(2u128)); + for i in 0..ell { + (&two * &r * &y[i + 1]).enforce_equal( + &(&r * (&one - &x[ell - i - 1]) * (&v[0][i] + &v[1][i]) + + &x[ell - i - 1] * (&v[0][i] - &v[1][i])), + )?; + } + + // kzg_verify_batch + + transcript.absorb( + &v.iter() + .flatten() + .map(|v_ij| ImplAbsorbFVar::wrap(v_ij)) + .collect::>(), + )?; + let q_powers = q_powers::(transcript, ell)?; + + transcript.absorb( + &w.iter() + .map(|g| ImplAbsorbGVar::wrap(g)) + .collect::>(), + )?; + let d = transcript + .squeeze_field_elements(1)? + .into_iter() + .next() + .unwrap(); + + let d_square = d.square()?; + let q_power_multiplier = one + &d + &d_square; + let q_powers_multiplied = q_powers + .iter() + .map(|q_i| q_i * &q_power_multiplier) + .collect::>(); + + let b_u = v + .iter() + .map(|v_i| { + let mut b_u_i = v_i[0].clone(); + for i in 1..ell { + b_u_i += &q_powers[i] * &v_i[i]; + } + b_u_i + }) + .collect::>(); + + let msm_gadget = OffloadedMSMGadget::, E, G1Var, Circuit>::new(self.circuit); + let pairing_gadget = + OffloadedPairingGadget::, G1Var, Circuit>::new(self.circuit); + + let l_g1s = &[com.as_slice(), w.as_slice(), &[g1.clone()]].concat(); + let l_scalars = &[ + q_powers_multiplied.as_slice(), + &[ + u[0].clone(), + &u[1] * &d, + &u[2] * &d_square, + (&b_u[0] + &d * &b_u[1] + &d_square * &b_u[2]).negate()?, + ], + ] + .concat(); + debug_assert_eq!(l_g1s.len(), l_scalars.len()); + + let l_g1 = msm_gadget.msm(ns!(self.cs, "l_g1"), l_g1s, l_scalars)?; + + let r_g1s = w.as_slice(); + let r_scalars = &[FpVar::one().negate()?, d.negate()?, d_square.negate()?]; + debug_assert_eq!(r_g1s.len(), r_scalars.len()); + + let r_g1 = msm_gadget.msm(ns!(self.cs, "r_g1"), r_g1s, r_scalars)?; + + pairing_gadget.multi_pairing_is_zero( + ns!(self.cs, "multi_pairing"), + &[l_g1, r_g1], + self.g2_elements.as_slice(), + )?; + dbg!(); + + Ok(Boolean::TRUE) + } +} + +fn q_powers>( + transcript: &mut S::Var, + ell: usize, +) -> Result>, SynthesisError> { + let q = transcript + .squeeze_field_elements(1)? + .into_iter() + .next() + .unwrap(); + + let q_powers = [vec![FpVar::Constant(E::ScalarField::one()), q.clone()], { + let mut q_power = q.clone(); + (2..ell) + .map(|_i| { + q_power *= &q; + q_power.clone() + }) + .collect() + }] + .concat(); + Ok(q_powers) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuits::{ + groups::curves::short_weierstrass::bn254::G1Var, + transcript::mock::{MockSponge, MockSpongeVar}, + }, + poly::{ + commitment::hyperkzg::{HyperKZG, HyperKZGProverKey, HyperKZGSRS, HyperKZGVerifierKey}, + dense_mlpoly::DensePolynomial, + }, + snark::{DeferredFnsRef, OffloadedDataCircuit, OffloadedSNARK}, + utils::{errors::ProofVerifyError, transcript::ProofTranscript}, + }; + use ark_bn254::Bn254; + use ark_crypto_primitives::{snark::SNARK, sponge::constraints::CryptographicSpongeVar}; + use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; + use ark_r1cs_std::ToConstraintFieldGadget; + use ark_relations::{ + ns, + r1cs::{ConstraintSynthesizer, ConstraintSystemRef}, + }; + use rand_core::{RngCore, SeedableRng}; + + #[derive(Clone)] + struct HyperKZGVerifierCircuit + where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, + { + _params: PhantomData, + deferred_fns_ref: DeferredFnsRef, + pcs_pk_vk: (HyperKZGProverKey, HyperKZGVerifierKey), + commitment: Option>, + point: Vec>, + eval: Option, + pcs_proof: HyperKZGProof, + expected_result: Option, + } + + impl OffloadedDataCircuit for HyperKZGVerifierCircuit + where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, + { + fn deferred_fns_ref(&self) -> &DeferredFnsRef { + &self.deferred_fns_ref + } + } + + impl HyperKZGVerifierCircuit + where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, + { + pub(crate) fn public_inputs(&self) -> Vec { + Boolean::::constant(self.expected_result.unwrap()) // panics if None + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + } + } + + impl ConstraintSynthesizer for HyperKZGVerifierCircuit + where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, + { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + let vk_var = HyperKZGVerifierKeyVar::::new_witness(ns!(cs, "vk"), || { + Ok(self.pcs_pk_vk.clone()) + })?; + + let commitment_var = + HyperKZGCommitmentVar::::new_witness(ns!(cs, "commitment"), || { + self.commitment + .clone() + .ok_or(SynthesisError::AssignmentMissing) + })?; + + let point_var = self + .point + .iter() + .map(|&x| { + FpVar::new_witness(ns!(cs, ""), || x.ok_or(SynthesisError::AssignmentMissing)) + }) + .collect::, _>>()?; + + let eval_var = FpVar::::new_witness(ns!(cs, "eval"), || { + self.eval.ok_or(SynthesisError::AssignmentMissing) + })?; + + let proof_var = HyperKZGProofVar::::new_witness(ns!(cs, "proof"), || { + Ok(self.pcs_proof.clone()) + })?; + + let mut transcript_var = + MockSpongeVar::new(ns!(cs, "transcript").cs(), &(b"TestEval".as_slice())); + + let kzg_vk = self.pcs_pk_vk.1.kzg_vk; + let hyper_kzg = + HyperKZGVerifierGadget::, G1Var, Self>::new( + &self, + ns!(cs, "hyperkzg"), + vec![kzg_vk.g2, kzg_vk.beta_g2], + ); + + let r = hyper_kzg.verify( + &proof_var, + &vk_var, + &mut transcript_var, + &point_var, + &eval_var, + &commitment_var, + )?; + + let r_input = Boolean::new_input(ns!(cs, "verification_result"), || { + self.expected_result + .ok_or(SynthesisError::AssignmentMissing) + })?; + r.enforce_equal(&r_input)?; + + dbg!(cs.num_constraints()); + + Ok(()) + } + } + + struct HyperKZGVerifier + where + E: Pairing, + S: SNARK, + G1Var: CurveVar, + { + _params: PhantomData<(E, S, G1Var)>, + } + + impl OffloadedSNARK for HyperKZGVerifier + where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, + S: SNARK, + G1Var: CurveVar + ToConstraintFieldGadget, + { + type Circuit = HyperKZGVerifierCircuit; + } + + #[test] + fn test_hyperkzg_eval() { + type Groth16 = ark_groth16::Groth16; + type VerifierSNARK = HyperKZGVerifier; + + // Test with poly(X1, X2) = 1 + X1 + X2 + X1*X2 + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); + let srs = HyperKZGSRS::setup(&mut rng, 3); + let (pcs_pk, pcs_vk): (HyperKZGProverKey, HyperKZGVerifierKey) = srs.trim(3); + + // poly is in eval. representation; evaluated at [(0,0), (0,1), (1,0), (1,1)] + let poly = DensePolynomial::new(vec![ + ark_bn254::Fr::from(1), + ark_bn254::Fr::from(2), + ark_bn254::Fr::from(2), + ark_bn254::Fr::from(4), + ]); + + let size = 2usize; + let (cpk, cvk) = { + let circuit = HyperKZGVerifierCircuit:: { + _params: PhantomData, + deferred_fns_ref: Default::default(), + pcs_pk_vk: (pcs_pk.clone(), pcs_vk.clone()), + commitment: None, + point: vec![None; size], + eval: None, + pcs_proof: HyperKZGProof::empty(size), + expected_result: None, + }; + + VerifierSNARK::setup(circuit, &mut rng).unwrap() + }; + + let C = HyperKZG::commit(&pcs_pk, &poly).unwrap(); + + let test_inner = + |point: Vec, eval: ark_bn254::Fr| -> Result<(), ProofVerifyError> { + let mut tr = ProofTranscript::new(b"TestEval"); + let hkzg_proof = HyperKZG::open(&pcs_pk, &poly, &point, &eval, &mut tr).unwrap(); + + println!("Verifying natively..."); + + let mut tr = ProofTranscript::new(b"TestEval"); + HyperKZG::verify(&pcs_vk, &C, &point, &eval, &hkzg_proof, &mut tr)?; + + // Create an instance of our circuit (with the + // witness) + let verifier_circuit = HyperKZGVerifierCircuit:: { + _params: PhantomData, + deferred_fns_ref: Default::default(), + pcs_pk_vk: (pcs_pk.clone(), pcs_vk.clone()), + commitment: Some(C.clone()), + point: point.into_iter().map(|x| Some(x)).collect(), + eval: Some(eval), + pcs_proof: hkzg_proof, + expected_result: Some(true), + }; + let instance = verifier_circuit.public_inputs(); + + let mut rng = + ark_std::rand::rngs::StdRng::seed_from_u64(ark_std::test_rng().next_u64()); + + println!("Verifying in-circuit..."); + + // Create a groth16 proof with our parameters. + let proof = VerifierSNARK::prove(&cpk, verifier_circuit, &mut rng) + .map_err(|_e| ProofVerifyError::InternalError)?; + + let result = VerifierSNARK::verify_with_processed_vk(&cvk, &instance, &proof); + match result { + Ok(true) => Ok(()), + Ok(false) => Err(ProofVerifyError::InternalError), + Err(_) => Err(ProofVerifyError::InternalError), + } + }; + + // Call the prover with a (point, eval) pair. + // The prover does not recompute so it may produce a proof, but it should not verify + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(0)]; + let eval = ark_bn254::Fr::from(1); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(1)]; + let eval = ark_bn254::Fr::from(2); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![ark_bn254::Fr::from(1), ark_bn254::Fr::from(1)]; + let eval = ark_bn254::Fr::from(4); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(3); + assert!(test_inner(point, eval).is_ok()); + + let point = vec![ark_bn254::Fr::from(2), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(9); + assert!(test_inner(point, eval).is_ok()); + + // Try a couple incorrect evaluations and expect failure + let point = vec![ark_bn254::Fr::from(2), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(50); + assert!(test_inner(point, eval).is_err()); + + let point = vec![ark_bn254::Fr::from(0), ark_bn254::Fr::from(2)]; + let eval = ark_bn254::Fr::from(4); + assert!(test_inner(point, eval).is_err()); + } +} diff --git a/jolt-core/src/circuits/poly/commitment/mod.rs b/jolt-core/src/circuits/poly/commitment/mod.rs new file mode 100644 index 000000000..0b1c33650 --- /dev/null +++ b/jolt-core/src/circuits/poly/commitment/mod.rs @@ -0,0 +1,2 @@ +mod commitment_scheme; +pub mod hyperkzg; diff --git a/jolt-core/src/circuits/poly/mod.rs b/jolt-core/src/circuits/poly/mod.rs new file mode 100644 index 000000000..1f9123814 --- /dev/null +++ b/jolt-core/src/circuits/poly/mod.rs @@ -0,0 +1 @@ +pub mod commitment; diff --git a/jolt-core/src/circuits/transcript/mock.rs b/jolt-core/src/circuits/transcript/mock.rs new file mode 100644 index 000000000..f63d4f802 --- /dev/null +++ b/jolt-core/src/circuits/transcript/mock.rs @@ -0,0 +1,133 @@ +use crate::circuits::transcript::SLICE; +use crate::field::JoltField; +use crate::utils::transcript::ProofTranscript; +use ark_crypto_primitives::sponge::{ + constraints::{AbsorbGadget, CryptographicSpongeVar, SpongeWithGadget}, + Absorb, CryptographicSponge, +}; +use ark_ff::PrimeField; +use ark_r1cs_std::{boolean::Boolean, fields::fp::FpVar, prelude::*, R1CSVar}; +use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; +use ark_std::marker::PhantomData; + +#[derive(Clone)] +pub struct MockSponge +where + ConstraintF: PrimeField + JoltField, +{ + _params: PhantomData, +} + +impl CryptographicSponge for MockSponge +where + ConstraintF: PrimeField + JoltField, +{ + type Config = (); + + fn new(_params: &Self::Config) -> Self { + Self { + _params: PhantomData, + } + } + + fn absorb(&mut self, _input: &impl Absorb) { + todo!() + } + + fn squeeze_bytes(&mut self, _num_bytes: usize) -> Vec { + todo!() + } + + fn squeeze_bits(&mut self, _num_bits: usize) -> Vec { + todo!() + } +} + +impl SpongeWithGadget for MockSponge +where + ConstraintF: PrimeField + JoltField, +{ + type Var = MockSpongeVar; +} + +#[derive(Clone)] +pub struct MockSpongeVar +where + ConstraintF: PrimeField, +{ + cs: ConstraintSystemRef, + pub transcript: ProofTranscript, +} + +impl CryptographicSpongeVar> + for MockSpongeVar +where + ConstraintF: PrimeField + JoltField, +{ + type Parameters = &'static [u8]; + + fn new(cs: ConstraintSystemRef, params: &Self::Parameters) -> Self { + Self { + cs, + transcript: ProofTranscript::new(params), + } + } + + fn cs(&self) -> ConstraintSystemRef { + self.cs.clone() + } + + fn absorb(&mut self, input: &impl AbsorbGadget) -> Result<(), SynthesisError> { + let bytes = input.to_sponge_bytes()?; + let bs = bytes + .iter() + .map(|f| match self.cs.is_in_setup_mode() { + true => Ok(0u8), + false => f.value(), + }) + .collect::, _>>()?; + + let slice_opt = SLICE.take(); + match slice_opt { + Some(slice_len) => { + self.transcript.append_message(b"begin_append_vector"); + if slice_len != 0 { + for chunk in bs.chunks(bs.len() / slice_len) { + self.transcript.append_bytes(chunk); + } + } + self.transcript.append_message(b"end_append_vector"); + } + None => { + self.transcript.append_bytes(&bs); + } + } + + Ok(()) + } + + fn squeeze_bytes( + &mut self, + _num_bytes: usize, + ) -> Result>, SynthesisError> { + todo!() + } + + fn squeeze_bits( + &mut self, + _num_bits: usize, + ) -> Result>, SynthesisError> { + todo!() + } + + fn squeeze_field_elements( + &mut self, + num_elements: usize, + ) -> Result>, SynthesisError> { + self.transcript + .challenge_vector::(num_elements) + .iter() + .map(|&f| FpVar::new_witness(self.cs(), || Ok(f))) + .collect() + } +} diff --git a/jolt-core/src/circuits/transcript/mod.rs b/jolt-core/src/circuits/transcript/mod.rs new file mode 100644 index 000000000..831420c36 --- /dev/null +++ b/jolt-core/src/circuits/transcript/mod.rs @@ -0,0 +1,136 @@ +use ark_crypto_primitives::sponge::constraints::AbsorbGadget; +use ark_ec::{AffineRepr, CurveGroup}; +use ark_ff::{Field, PrimeField}; +use ark_r1cs_std::{fields::fp::FpVar, prelude::*, R1CSVar}; +use ark_relations::{ns, r1cs::SynthesisError}; +use ark_serialize::CanonicalSerialize; +use ark_std::{cell::RefCell, fmt::Debug, marker::PhantomData, Zero}; + +pub mod mock; + +pub struct ImplAbsorbFVar<'a, T, F>(&'a T, PhantomData) +where + T: R1CSVar, + F: PrimeField; + +impl<'a, T, F> ImplAbsorbFVar<'a, T, F> +where + T: R1CSVar, + F: PrimeField, +{ + pub fn wrap(t: &'a T) -> Self { + Self(t, PhantomData) + } +} + +thread_local! { + static SLICE: RefCell> = const { RefCell::new(None) }; +} + +impl<'a, T, F> AbsorbGadget for ImplAbsorbFVar<'a, T, F> +where + T: R1CSVar + Debug, + F: PrimeField, +{ + fn to_sponge_bytes(&self) -> Result>, SynthesisError> { + let mut buf = vec![]; + + let t_value = match self.0.cs().is_in_setup_mode() { + true => T::Value::zero(), + false => self.0.value()?, + }; + + t_value + .serialize_uncompressed(&mut buf) + .map_err(|_e| SynthesisError::Unsatisfiable)?; + + buf.into_iter() + .rev() + .map(|b| UInt8::new_witness(ns!(self.0.cs(), "sponge_byte"), || Ok(b))) + .collect::, _>>() + } + + fn batch_to_sponge_bytes(batch: &[Self]) -> Result>, SynthesisError> + where + Self: Sized, + { + SLICE.set(Some(batch.len())); + let mut result = Vec::new(); + for item in batch { + result.append(&mut (item.to_sponge_bytes()?)) + } + Ok(result) + } + + fn to_sponge_field_elements(&self) -> Result>, SynthesisError> { + unimplemented!("should not be called") + } +} + +pub struct ImplAbsorbGVar<'a, T, F, G>(&'a T, PhantomData<(F, G)>) +where + T: CurveVar, + F: PrimeField, + G: CurveGroup; + +impl<'a, T, F, G> ImplAbsorbGVar<'a, T, F, G> +where + T: CurveVar, + F: PrimeField, + G: CurveGroup, +{ + pub fn wrap(t: &'a T) -> Self { + Self(t, PhantomData) + } +} + +impl<'a, T, F, G> AbsorbGadget for ImplAbsorbGVar<'a, T, F, G> +where + T: CurveVar + Debug, + F: PrimeField, + G: CurveGroup, +{ + fn to_sponge_bytes(&self) -> Result>, SynthesisError> { + let g = match self.0.cs().is_in_setup_mode() { + true => T::Value::zero(), + false => self.0.value()?, + } + .into_affine(); + + fn serialize(x: &F) -> Vec { + let mut buf = vec![]; + x.serialize_compressed(&mut buf) + .expect("failed to serialize uncompressed"); + buf.reverse(); + buf + } + + let buf = match g.is_zero() { + true => vec![0u8; 64], + false => { + let (x, y) = g.xy().unwrap(); + [serialize(x), serialize(y)].concat() + } + }; + + buf.iter() + .map(|b| UInt8::new_witness(ns!(self.0.cs(), "sponge_byte"), || Ok(b))) + .collect::, _>>() + } + + fn batch_to_sponge_bytes(batch: &[Self]) -> Result>, SynthesisError> + where + Self: Sized, + { + SLICE.set(Some(batch.len())); + let mut result = Vec::new(); + for item in batch { + result.append(&mut (item.to_sponge_bytes()?)) + } + Ok(result) + } + + fn to_sponge_field_elements(&self) -> Result>, SynthesisError> { + unimplemented!("should not be called") + } +} diff --git a/jolt-core/src/lib.rs b/jolt-core/src/lib.rs index 965f9f090..21c7a36cf 100644 --- a/jolt-core/src/lib.rs +++ b/jolt-core/src/lib.rs @@ -16,11 +16,13 @@ pub mod benches; #[cfg(feature = "host")] pub mod host; +pub mod circuits; pub mod field; pub mod jolt; pub mod lasso; pub mod msm; pub mod poly; pub mod r1cs; +pub mod snark; pub mod subprotocols; pub mod utils; diff --git a/jolt-core/src/poly/commitment/hyperkzg.rs b/jolt-core/src/poly/commitment/hyperkzg.rs index dea0d9906..d55240219 100644 --- a/jolt-core/src/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/poly/commitment/hyperkzg.rs @@ -58,7 +58,7 @@ pub struct HyperKZGVerifierKey { pub kzg_vk: KZGVerifierKey

, } -#[derive(Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] pub struct HyperKZGCommitment(pub P::G1Affine); impl AppendToTranscript for HyperKZGCommitment

{ @@ -74,6 +74,16 @@ pub struct HyperKZGProof { pub v: Vec>, } +impl HyperKZGProof

{ + pub fn empty(size: usize) -> Self { + Self { + com: vec![P::G1Affine::zero(); size - 1], + w: vec![P::G1Affine::zero(); 3], + v: vec![vec![P::ScalarField::zero(); size]; 3], + } + } +} + // On input f(x) and u compute the witness polynomial used to prove // that f(u) = v. The main part of this is to compute the // division (f(x) - f(u)) / (x - u), but we don't use a general diff --git a/jolt-core/src/snark/mod.rs b/jolt-core/src/snark/mod.rs new file mode 100644 index 000000000..6f248bc0c --- /dev/null +++ b/jolt-core/src/snark/mod.rs @@ -0,0 +1,433 @@ +use ark_crypto_primitives::snark::SNARK; +use ark_ec::{ + pairing::Pairing, + short_weierstrass::{Affine, SWCurveConfig}, + AffineRepr, VariableBaseMSM, +}; +use ark_ff::{PrimeField, Zero}; +use ark_r1cs_std::{ + fields::nonnative::params::{get_params, OptimizationType}, + fields::nonnative::AllocatedNonNativeFieldVar, + groups::CurveVar, + R1CSVar, ToConstraintFieldGadget, +}; +use ark_relations::r1cs; +use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisError}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; +use ark_std::{cell::OnceCell, cell::RefCell, rc::Rc}; +use itertools::Itertools; +use rand_core::{CryptoRng, RngCore}; + +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct OffloadedPairingDef +where + E: Pairing, +{ + pub g2_elements: Vec, +} + +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct OffloadedSNARKVerifyingKey +where + E: Pairing, + S: SNARK, +{ + pub snark_pvk: S::ProcessedVerifyingKey, + pub delayed_pairings: Vec>, +} + +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct OffloadedSNARKProof +where + E: Pairing, + S: SNARK, +{ + pub snark_proof: S::Proof, + pub offloaded_data: ProofData, +} + +#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +pub struct ProofData { + /// Blocks of G1 elements `Gᵢ` and scalars in `sᵢ` the public input, such that `∑ᵢ sᵢ·Gᵢ == 0`. + /// It's the verifiers responsibility to ensure that the sum is zero. + /// The scalar at index `length-1` is, by convention, always `-1`, so + /// we save one public input element per MSM. + msms: Vec>, + /// Blocks of G1 elements `Gᵢ` in the public input, used in multi-pairings with + /// the corresponding G2 elements in the offloaded SNARK verification key. + /// It's the verifiers responsibility to ensure that the sum is zero. + /// The scalar at index `length-1` is, by convention, always `-1`, so + /// we save one public input element per MSM. + pairing_g1s: Vec>, +} + +#[derive(Clone, Debug)] +pub struct OffloadedData { + proof_data: Option>, + setup_data: Vec>, +} + +pub enum DeferredOpData { + MSM(Option>), + Pairing(Option>, Vec), +} + +pub type MSMDef = ( + Vec<::G1Affine>, + Vec<::ScalarField>, +); + +pub type MultiPairingDef = (Vec<::G1>, Vec<::G2>); + +pub type DeferredFn = dyn FnOnce() -> Result, SynthesisError>; + +pub type DeferredFnsRef = Rc>>>>; + +pub trait OffloadedDataCircuit: Clone +where + E: Pairing, +{ + fn deferred_fns_ref(&self) -> &DeferredFnsRef; + + fn defer_op(&self, f: impl FnOnce() -> Result, SynthesisError> + 'static) { + self.deferred_fns_ref().borrow_mut().push(Box::new(f)); + } +} + +#[derive(thiserror::Error, Debug)] +pub enum OffloadedSNARKError +where + Err: 'static + ark_std::error::Error, +{ + /// Wraps `Err`. + #[error(transparent)] + SNARKError(Err), + /// Wraps `SerializationError`. + #[error(transparent)] + SerializationError(#[from] SerializationError), + #[error(transparent)] + SynthesisError(#[from] SynthesisError), +} + +pub struct WrappedCircuit +where + E: Pairing, + C: ConstraintSynthesizer + OffloadedDataCircuit, +{ + circuit: C, + offloaded_data_ref: Rc>>, +} + +/// This is run both at setup and at proving time. +/// At setup time we only need to get G2 elements: we need them to form the verifying key. +/// At proving time we need to get G1 elements as well. +fn run_deferred( + deferred_fns: Vec>>, +) -> Result, SynthesisError> { + let op_data = deferred_fns + .into_iter() + .map(|f| f()) + .collect::, _>>()?; + + let op_data_by_type = op_data + .into_iter() + .into_grouping_map_by(|d| match d { + DeferredOpData::MSM(..) => 0, + DeferredOpData::Pairing(..) => 1, + }) + .collect::>(); + + let msms = op_data_by_type + .get(&0) + .into_iter() + .flatten() + .map(|d| match d { + DeferredOpData::MSM(msm_opt) => msm_opt.clone(), + _ => unreachable!(), + }) + .collect::>>(); + + let (p_g1s, p_g2s): (Vec<_>, Vec<_>) = op_data_by_type + .get(&1) + .into_iter() + .flatten() + .map(|d| match d { + DeferredOpData::Pairing(g1s_opt, g2s) => (g1s_opt.clone(), g2s.clone()), + _ => unreachable!(), + }) + .unzip(); + let pairing_g1s = p_g1s.into_iter().collect::>>(); + + Ok(OffloadedData { + proof_data: msms + .zip(pairing_g1s) + .map(|(msms, pairing_g1s)| ProofData { msms, pairing_g1s }), + setup_data: p_g2s, + }) +} + +impl ConstraintSynthesizer for WrappedCircuit +where + E: Pairing, + C: ConstraintSynthesizer + OffloadedDataCircuit, +{ + fn generate_constraints(self, cs: ConstraintSystemRef) -> r1cs::Result<()> { + // `self.circuit` will be consumed by `self.circuit.generate_constraints(cs)` + // so we need to clone the reference to the deferred functions + let deferred_fns_ref = self.circuit.deferred_fns_ref().clone(); + + let offloaded_data_ref = self.offloaded_data_ref.clone(); + + self.circuit.generate_constraints(cs)?; + let offloaded_data = run_deferred::(deferred_fns_ref.take())?; + + offloaded_data_ref.set(offloaded_data).unwrap(); + + Ok(()) + } +} + +pub trait OffloadedSNARK +where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, + S: SNARK, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + type Circuit: ConstraintSynthesizer + OffloadedDataCircuit; + + fn setup( + circuit: Self::Circuit, + rng: &mut R, + ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> + { + let circuit: WrappedCircuit = WrappedCircuit { + circuit, + offloaded_data_ref: Default::default(), + }; + Self::circuit_specific_setup(circuit, rng) + } + + fn circuit_specific_setup( + circuit: WrappedCircuit, + rng: &mut R, + ) -> Result<(S::ProvingKey, OffloadedSNARKVerifyingKey), OffloadedSNARKError> + { + let offloaded_data_ref = circuit.offloaded_data_ref.clone(); + + let (pk, snark_vk) = + S::circuit_specific_setup(circuit, rng).map_err(OffloadedSNARKError::SNARKError)?; + + let snark_pvk = S::process_vk(&snark_vk).map_err(OffloadedSNARKError::SNARKError)?; + + let setup_data = offloaded_data_ref.get().unwrap().clone().setup_data; + + let delayed_pairings = setup_data + .into_iter() + .map(|g2| OffloadedPairingDef { g2_elements: g2 }) + .collect(); + + let vk = OffloadedSNARKVerifyingKey { + snark_pvk, + delayed_pairings, + }; + + Ok((pk, vk)) + } + + fn prove( + circuit_pk: &S::ProvingKey, + circuit: Self::Circuit, + rng: &mut R, + ) -> Result, OffloadedSNARKError> { + let circuit: WrappedCircuit = WrappedCircuit { + circuit, + offloaded_data_ref: Default::default(), + }; + + let offloaded_data_ref = circuit.offloaded_data_ref.clone(); + + let proof = S::prove(circuit_pk, circuit, rng).map_err(OffloadedSNARKError::SNARKError)?; + + let proof_data = match offloaded_data_ref.get().unwrap().clone().proof_data { + Some(proof_data) => proof_data, + _ => unreachable!(), + }; + + Ok(OffloadedSNARKProof { + snark_proof: proof, + offloaded_data: proof_data, + }) + } + + fn verify( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &OffloadedSNARKProof, + ) -> Result> { + Self::verify_with_processed_vk(vk, public_input, proof) + } + + fn verify_with_processed_vk( + vk: &OffloadedSNARKVerifyingKey, + public_input: &[E::ScalarField], + proof: &OffloadedSNARKProof, + ) -> Result> { + let public_input = build_public_input::(public_input, &proof.offloaded_data); + + let r = S::verify_with_processed_vk(&vk.snark_pvk, &public_input, &proof.snark_proof) + .map_err(OffloadedSNARKError::SNARKError)?; + if !r { + return Ok(false); + } + + for (g1s, scalars) in &proof.offloaded_data.msms { + assert_eq!(g1s.len(), scalars.len()); + let r = E::G1::msm_unchecked(g1s, scalars); + if !r.is_zero() { + return Ok(false); + } + } + + let pairings = Self::pairing_inputs(vk, &proof.offloaded_data.pairing_g1s)?; + for (g1s, g2s) in pairings { + assert_eq!(g1s.len(), g2s.len()); + let r = E::multi_pairing(&g1s, &g2s); + if !r.is_zero() { + return Ok(false); + } + } + + Ok(true) + } + + fn g1_elements( + public_input: &[E::ScalarField], + g1_offset: usize, + length: usize, + ) -> Result, SerializationError> { + let g1_element_size = g1_affine_size_in_scalar_field_elements::(); + if public_input.len() < g1_offset + length * g1_element_size { + return Err(SerializationError::InvalidData); + }; + + public_input[g1_offset..g1_offset + length * g1_element_size] + .chunks(g1_element_size) + .map(|chunk| g1_affine_from_scalar_field::(chunk)) + .collect() + } + + fn pairing_inputs( + vk: &OffloadedSNARKVerifyingKey, + g1_vectors: &[Vec], + ) -> Result>, SerializationError> { + Ok(g1_vectors + .iter() + .map(|g1_vec| g1_vec.iter().map(|&g1| g1.into()).collect()) + .zip(Self::g2_elements(vk)) + .collect()) + } + + fn g2_elements(vk: &OffloadedSNARKVerifyingKey) -> Vec> { + vk.delayed_pairings + .iter() + .map(|pairing_def| { + pairing_def + .g2_elements + .iter() + .map(|g2| g2.into_group()) + .collect::>() + }) + .collect::>>() + } +} + +fn g1_affine_size_in_scalar_field_elements() -> usize { + let params = get_params( + E::BaseField::MODULUS_BIT_SIZE as usize, + E::ScalarField::MODULUS_BIT_SIZE as usize, + OptimizationType::Weight, + ); + params.num_limbs * 2 + 1 +} + +fn g1_affine_from_scalar_field( + s: &[E::ScalarField], +) -> Result +where + E: Pairing, BaseField = P::BaseField, ScalarField = P::ScalarField>, + P: SWCurveConfig, +{ + let infinity = !s[s.len() - 1].is_zero(); + if infinity { + return Ok(E::G1Affine::zero()); + } + + let base_field_size_in_limbs = (s.len() - 1) / 2; + let x = AllocatedNonNativeFieldVar::::limbs_to_value( + s[..base_field_size_in_limbs].to_vec(), + OptimizationType::Weight, + ); + let y = AllocatedNonNativeFieldVar::::limbs_to_value( + s[base_field_size_in_limbs..s.len() - 1].to_vec(), + OptimizationType::Weight, + ); + + let affine = Affine { + x, + y, + infinity: false, + }; + affine.check()?; + Ok(affine) +} + +fn build_public_input( + public_input: &[E::ScalarField], + data: &ProofData, +) -> Vec +where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + let msm_data = data + .msms + .iter() + .map(|msm| { + let scalars = &msm.1; + let scalar_vec = scalars[..scalars.len() - 1].to_vec(); // remove the last element (always `-1`) + + let g1s = &msm.0; + let msm_g1_vec = to_scalars::(g1s); + + [scalar_vec, msm_g1_vec].concat() + }) + .concat(); + + let pairing_data = data + .pairing_g1s + .iter() + .map(|g1s| to_scalars::(g1s)) + .concat(); + + [public_input.to_vec(), msm_data, pairing_data].concat() +} + +fn to_scalars(g1s: &[E::G1Affine]) -> Vec +where + E: Pairing, + G1Var: CurveVar + ToConstraintFieldGadget, +{ + let msm_g1_vec = g1s + .iter() + .map(|&g1| { + G1Var::constant(g1.into()) + .to_constraint_field() + .unwrap() + .iter() + .map(|x| x.value().unwrap()) + .collect::>() + }) + .concat(); + msm_g1_vec +} diff --git a/jolt-core/src/utils/transcript.rs b/jolt-core/src/utils/transcript.rs index e0d2589a0..8f88e21d0 100644 --- a/jolt-core/src/utils/transcript.rs +++ b/jolt-core/src/utils/transcript.rs @@ -8,7 +8,7 @@ pub struct ProofTranscript { // Ethereum compatible 256 bit running state pub state: [u8; 32], // We append an ordinal to each invoke of the hash - n_rounds: u32, + pub n_rounds: u32, } impl ProofTranscript {