From cafd7038ee7e7f1aa6b7b01164dc02b01e32d8d2 Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Thu, 6 Jul 2023 11:19:35 -0400 Subject: [PATCH] Reduce number of lookup accesses --- plonky2/src/gates/lookup.rs | 30 +++++++++++++++-------------- plonky2/src/gates/lookup_table.rs | 11 +++-------- plonky2/src/plonk/prover.rs | 7 +++---- plonky2/src/plonk/vanishing_poly.rs | 20 +++++++------------ 4 files changed, 29 insertions(+), 39 deletions(-) diff --git a/plonky2/src/gates/lookup.rs b/plonky2/src/gates/lookup.rs index 36dab8bf91..5c7519897e 100644 --- a/plonky2/src/gates/lookup.rs +++ b/plonky2/src/gates/lookup.rs @@ -170,23 +170,25 @@ impl, const D: usize> SimpleGenerator for Loo let get_wire = |wire: usize| -> F { witness.get_target(Target::wire(self.row, wire)) }; let input_val = get_wire(LookupGate::wire_ith_looking_inp(self.slot_nb)); - let output_val = if input_val - == F::from_canonical_u16(self.lut[input_val.to_canonical_u64() as usize].0) - { - F::from_canonical_u16(self.lut[input_val.to_canonical_u64() as usize].1) + let (input, output) = self.lut[input_val.to_canonical_u64() as usize]; + if input_val == F::from_canonical_u16(input) { + let output_val = F::from_canonical_u16(output); + + let out_wire = Target::wire(self.row, LookupGate::wire_ith_looking_out(self.slot_nb)); + out_buffer.set_target(out_wire, output_val); } else { - let mut cur_idx = 0; - while input_val != F::from_canonical_u16(self.lut[cur_idx].0) - && cur_idx < self.lut.len() - { - cur_idx += 1; + for (input, output) in self.lut.iter() { + if input_val == F::from_canonical_u16(*input) { + let output_val = F::from_canonical_u16(*output); + + let out_wire = + Target::wire(self.row, LookupGate::wire_ith_looking_out(self.slot_nb)); + out_buffer.set_target(out_wire, output_val); + return; + } } - assert!(cur_idx < self.lut.len(), "Incorrect input value provided"); - F::from_canonical_u16(self.lut[cur_idx].1) + panic!("Incorrect input value provided"); }; - - let out_wire = Target::wire(self.row, LookupGate::wire_ith_looking_out(self.slot_nb)); - out_buffer.set_target(out_wire, output_val); } fn serialize(&self, dst: &mut Vec, cd: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/lookup_table.rs b/plonky2/src/gates/lookup_table.rs index 99b8f8e137..99109f0434 100644 --- a/plonky2/src/gates/lookup_table.rs +++ b/plonky2/src/gates/lookup_table.rs @@ -193,14 +193,9 @@ impl, const D: usize> SimpleGenerator for Loo Target::wire(self.row, LookupTableGate::wire_ith_looked_out(self.slot_nb)); if slot < self.lut.len() { - out_buffer.set_target( - slot_input_target, - F::from_canonical_usize(self.lut[slot].0 as usize), - ); - out_buffer.set_target( - slot_output_target, - F::from_canonical_usize(self.lut[slot].1.into()), - ); + let (input, output) = self.lut[slot]; + out_buffer.set_target(slot_input_target, F::from_canonical_usize(input as usize)); + out_buffer.set_target(slot_output_target, F::from_canonical_usize(output as usize)); } else { // Pad with zeros. out_buffer.set_target(slot_input_target, F::ZERO); diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index df2249dfcd..8c70a19a61 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -71,15 +71,14 @@ pub fn set_lookup_wires< let remaining_slots = (num_entries - (prover_data.lut_to_lookups[lut_index].len() % num_entries)) % num_entries; - let first_inp_value = F::from_canonical_u16(common_data.luts[lut_index][0].0); - let first_out_value = F::from_canonical_u16(common_data.luts[lut_index][0].1); + let (first_inp_value, first_out_value) = common_data.luts[lut_index][0]; for slot in (num_entries - remaining_slots)..num_entries { let inp_target = Target::wire(last_lut_gate - 1, LookupGate::wire_ith_looking_inp(slot)); let out_target = Target::wire(last_lut_gate - 1, LookupGate::wire_ith_looking_out(slot)); - pw.set_target(inp_target, first_inp_value); - pw.set_target(out_target, first_out_value); + pw.set_target(inp_target, F::from_canonical_u16(first_inp_value)); + pw.set_target(out_target, F::from_canonical_u16(first_out_value)); multiplicities[0] += 1; } diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index adc76654e2..a56b4b7ef7 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -37,11 +37,8 @@ pub(crate) fn get_lut_poly, const D: usize>( let b = deltas[LookupChallenges::ChallengeB as usize]; let mut coeffs = Vec::new(); let n = common_data.luts[lut_index].len(); - for i in 0..n { - coeffs.push( - F::from_canonical_u16(common_data.luts[lut_index][i].0) - + b * F::from_canonical_u16(common_data.luts[lut_index][i].1), - ); + for (input, output) in common_data.luts[lut_index].iter() { + coeffs.push(F::from_canonical_u16(*input) + b * F::from_canonical_u16(*output)); } coeffs.append(&mut vec![F::ZERO; degree - n]); coeffs.reverse(); @@ -767,14 +764,11 @@ pub(crate) fn get_lut_poly_circuit, const D: usize> let b = deltas[LookupChallenges::ChallengeB as usize]; let delta = deltas[LookupChallenges::ChallengeDelta as usize]; let n = common_data.luts[lut_index].len(); - let mut coeffs: Vec = (0..n) - .map(|i| { - let temp = - builder.mul_const(F::from_canonical_u16(common_data.luts[lut_index][i].1), b); - builder.add_const( - temp, - F::from_canonical_u16(common_data.luts[lut_index][i].0), - ) + let mut coeffs: Vec = common_data.luts[lut_index] + .iter() + .map(|(input, output)| { + let temp = builder.mul_const(F::from_canonical_u16(*output), b); + builder.add_const(temp, F::from_canonical_u16(*input)) }) .collect(); for _ in n..degree {