Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation to subtable implementation #445

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jolt-core/src/jolt/instruction/sll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ impl<const WORD_SIZE: usize> JoltInstruction for SLLInstruction<WORD_SIZE> {
C: usize,
_: usize,
) -> Vec<(Box<dyn LassoSubtable<F>>, SubtableIndices)> {
// We have to pre-define subtables in this way because `CHUNK_INDEX` needs to be a constant,
// i.e. known at compile time (so we cannot do a `map` over the range of `C`,
// which only happens at runtime).
let mut subtables: Vec<Box<dyn LassoSubtable<F>>> = vec![
Box::new(SllSubtable::<F, 0, WORD_SIZE>::new()),
Box::new(SllSubtable::<F, 1, WORD_SIZE>::new()),
Expand Down
3 changes: 2 additions & 1 deletion jolt-core/src/jolt/subtable/and.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<F: JoltField> AndSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for AndSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = x & y
let mut entries: Vec<F> = Vec::with_capacity(M);
let bits_per_operand = (log2(M) / 2) as usize;

Expand All @@ -33,7 +34,7 @@ impl<F: JoltField> LassoSubtable<F> for AndSubtable<F> {
}

fn evaluate_mle(&self, point: &[F]) -> F {
// x * y
// \sum_i 2^i * x_{b - i - 1} * y_{b - i - 1}
debug_assert!(point.len() % 2 == 0);
let b = point.len() / 2;
let (x, y) = point.split_at(b);
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/jolt/subtable/div_by_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ impl<F: JoltField> DivByZeroSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for DivByZeroSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (x == 0) && (y == 2^b - 1)
let mut entries: Vec<F> = vec![F::zero(); M];
let bits_per_operand = (log2(M) / 2) as usize;

Expand Down
6 changes: 3 additions & 3 deletions jolt-core/src/jolt/subtable/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ impl<F: JoltField> EqSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for EqSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// Materialize table entries in order where (x | y) ranges 0..M
// Below is the optimized loop for the condition:
// table[x | y] = (x == y)
let mut entries: Vec<F> = vec![F::zero(); M];
let bits_per_operand = (log2(M) / 2) as usize;

// Materialize table entries in order where (x | y) ranges 0..M
// Below is the optimized loop for the condition:
// table[x | y] = x == y
for idx in 0..(1 << bits_per_operand) {
let concat_idx = idx | (idx << bits_per_operand);
entries[concat_idx] = F::one();
Expand Down
8 changes: 4 additions & 4 deletions jolt-core/src/jolt/subtable/eq_abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ impl<F: JoltField> EqAbsSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for EqAbsSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
let mut entries: Vec<F> = vec![F::zero(); M];
let bits_per_operand = (log2(M) / 2) as usize;

// Materialize table entries in order where (x | y) ranges 0..M
// Below is the optimized loop for the condition:
// lower_bits_mask = 0b01111...11
// table[x | y] == (x & lower_bits_mask) == (y & lower_bits_mask)
let mut entries: Vec<F> = vec![F::zero(); M];
let bits_per_operand = (log2(M) / 2) as usize;

for idx in 0..(1 << (bits_per_operand)) {
// we set the bit in the table where x == y
// e.g. 01010011 | 01010011 = 1
Expand All @@ -41,7 +41,7 @@ impl<F: JoltField> LassoSubtable<F> for EqAbsSubtable<F> {
}

fn evaluate_mle(&self, point: &[F]) -> F {
// \prod_i x_i * y_i + (1 - x_i) * (1 - y_i)
// \prod_i x_i * y_i + (1 - x_i) * (1 - y_i) for i > 0
debug_assert!(point.len() % 2 == 0);
let b = point.len() / 2;
let (x, y) = point.split_at(b);
Expand Down
2 changes: 2 additions & 0 deletions jolt-core/src/jolt/subtable/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ impl<F: JoltField> IdentitySubtable<F> {

impl<F: JoltField> LassoSubtable<F> for IdentitySubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x] = x
(0..M).map(|i| F::from_u64(i as u64).unwrap()).collect()
}

fn evaluate_mle(&self, point: &[F]) -> F {
// \sum_i 2^i * x_{b - i - 1}
let mut result = F::zero();
for i in 0..point.len() {
result += F::from_u64(1u64 << i).unwrap() * point[point.len() - 1 - i];
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/jolt/subtable/left_is_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ impl<F: JoltField> LeftIsZeroSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for LeftIsZeroSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (x == 0)
let mut entries: Vec<F> = vec![F::zero(); M];

for idx in 0..(1 << (log2(M) / 2)) {
Expand Down
2 changes: 2 additions & 0 deletions jolt-core/src/jolt/subtable/left_msb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<F: JoltField> LeftMSBSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for LeftMSBSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (x & 0b100..0) = msb(x)
let mut entries: Vec<F> = Vec::with_capacity(M);
let bits_per_operand = (log2(M) / 2) as usize;
let high_bit = 1usize << (bits_per_operand - 1);
Expand All @@ -37,6 +38,7 @@ impl<F: JoltField> LassoSubtable<F> for LeftMSBSubtable<F> {
}

fn evaluate_mle(&self, point: &[F]) -> F {
// x_0
debug_assert!(point.len() % 2 == 0);
let b = point.len() / 2;
let (x, _) = point.split_at(b);
Expand Down
3 changes: 2 additions & 1 deletion jolt-core/src/jolt/subtable/lt_abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<F: JoltField> LtAbsSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for LtAbsSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (x & 0b011..11) < (y & 0b011..11)
let mut entries: Vec<F> = Vec::with_capacity(M);
let bits_per_operand = (log2(M) / 2) as usize;
// 0b01111...11
Expand All @@ -39,7 +40,7 @@ impl<F: JoltField> LassoSubtable<F> for LtAbsSubtable<F> {
}

fn evaluate_mle(&self, point: &[F]) -> F {
// \prod_i x_i * y_i + (1 - x_i) * (1 - y_i)
// \sum_{i > 0} (1 - x_i) * y_i * \prod_{j < i} ((1 - x_j) * (1 - y_j) + x_j * y_j)
debug_assert!(point.len() % 2 == 0);
let b = point.len() / 2;
let (x, y) = point.split_at(b);
Expand Down
3 changes: 2 additions & 1 deletion jolt-core/src/jolt/subtable/ltu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<F: JoltField> LtuSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for LtuSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (x < y)
let mut entries: Vec<F> = Vec::with_capacity(M);
let bits_per_operand = (log2(M) / 2) as usize;

Expand All @@ -33,7 +34,7 @@ impl<F: JoltField> LassoSubtable<F> for LtuSubtable<F> {
}

fn evaluate_mle(&self, point: &[F]) -> F {
// \prod_i x_i * y_i + (1 - x_i) * (1 - y_i)
// \sum_i (1 - x_i) * y_i * \prod_{j < i} ((1 - x_j) * (1 - y_j) + x_j * y_j)
debug_assert!(point.len() % 2 == 0);
let b = point.len() / 2;
let (x, y) = point.split_at(b);
Expand Down
3 changes: 2 additions & 1 deletion jolt-core/src/jolt/subtable/or.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<F: JoltField> OrSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for OrSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = x | y (bit-wise OR)
let mut entries: Vec<F> = Vec::with_capacity(M);
let bits_per_operand = (log2(M) / 2) as usize;

Expand All @@ -33,7 +34,7 @@ impl<F: JoltField> LassoSubtable<F> for OrSubtable<F> {
}

fn evaluate_mle(&self, point: &[F]) -> F {
// x + y - x * y
// \sum_i 2^i * (x_{b - i - 1} + y_{b - i - 1} - x_{b - i - 1} * y_{b - i - 1})
debug_assert!(point.len() % 2 == 0);
let b = point.len() / 2;
let (x, y) = point.split_at(b);
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/jolt/subtable/right_is_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ impl<F: JoltField> RightIsZeroSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for RightIsZeroSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (y == 0)
let mut entries: Vec<F> = vec![F::zero(); M];
let right_operand_bits = (1 << (log2(M) / 2)) - 1;

Expand Down
2 changes: 2 additions & 0 deletions jolt-core/src/jolt/subtable/right_msb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<F: JoltField> RightMSBSubtable<F> {

impl<F: JoltField> LassoSubtable<F> for RightMSBSubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (y & 0b100..0) = msb(y)
let mut entries: Vec<F> = Vec::with_capacity(M);
let bits_per_operand = (log2(M) / 2) as usize;
let high_bit = 1usize << (bits_per_operand - 1);
Expand All @@ -37,6 +38,7 @@ impl<F: JoltField> LassoSubtable<F> for RightMSBSubtable<F> {
}

fn evaluate_mle(&self, point: &[F]) -> F {
// y_0
debug_assert!(point.len() % 2 == 0);
let b = point.len() / 2;
let (_, y) = point.split_at(b);
Expand Down
10 changes: 7 additions & 3 deletions jolt-core/src/jolt/subtable/sign_extend.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::field::JoltField;
use ark_std::log2;
use std::marker::PhantomData;

use super::LassoSubtable;
Expand All @@ -18,8 +19,10 @@ impl<F: JoltField, const WIDTH: usize> SignExtendSubtable<F, WIDTH> {

impl<F: JoltField, const WIDTH: usize> LassoSubtable<F> for SignExtendSubtable<F, WIDTH> {
fn materialize(&self, M: usize) -> Vec<F> {
// TODO(moodlezoup): This subtable currently only works for M = 2^16
assert_eq!(M, 1 << 16);
// table[x] = x[b - WIDTH] * (2^{WIDTH} - 1)
// Take the WIDTH-th bit of the input (counting from the LSB), then multiply by (2^{WIDTH} - 1)
// Requires `log2(M) >= WIDTH`
debug_assert!(WIDTH <= log2(M) as usize);
let mut entries: Vec<F> = Vec::with_capacity(M);

// The sign-extension will be the same width as the value being extended
Expand All @@ -36,7 +39,8 @@ impl<F: JoltField, const WIDTH: usize> LassoSubtable<F> for SignExtendSubtable<F
}

fn evaluate_mle(&self, point: &[F]) -> F {
assert_eq!(point.len(), 16);
// 2 ^ {WIDTH - 1} * x_{b - WIDTH}
debug_assert!(point.len() >= WIDTH);

let sign_bit = point[point.len() - WIDTH];
let ones: u64 = (1 << WIDTH) - 1;
Expand Down
26 changes: 16 additions & 10 deletions jolt-core/src/jolt/subtable/sll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,32 @@ impl<F: JoltField, const CHUNK_INDEX: usize, const WORD_SIZE: usize> LassoSubtab
for SllSubtable<F, CHUNK_INDEX, WORD_SIZE>
{
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (x << (y % WORD_SIZE)) % (1 << (WORD_SIZE - suffix_length))
// where `suffix_length = operand_chunk_width * CHUNK_INDEX`
let mut entries: Vec<F> = Vec::with_capacity(M);

let operand_chunk_width: usize = (log2(M) / 2) as usize;
let suffix_length = operand_chunk_width * CHUNK_INDEX;

for idx in 0..M {
let (x, y) = split_bits(idx, operand_chunk_width);
let x = x as u64;

let row = x
.checked_shl((y % WORD_SIZE + suffix_length) as u32)
let row = (x as u64)
.checked_shl((y % WORD_SIZE) as u32)
.unwrap_or(0)
.rem_euclid(1 << WORD_SIZE)
.checked_shr(suffix_length as u32)
.unwrap_or(0);
.rem_euclid(1 << (WORD_SIZE - suffix_length));

entries.push(F::from_u64(row as u64).unwrap());
entries.push(F::from_u64(row).unwrap());
}
entries
}

fn evaluate_mle(&self, point: &[F]) -> F {
// first half is chunk X_i
// and second half is always chunk Y_0
// \sum_{k = 0}^{2^b - 1} eq(y, bin(k)) * (\sum_{j = 0}^{m'-1} 2^{k + j} * x_{b - j - 1}),
// where m = min(b, max( 0, (k + b * (CHUNK_INDEX + 1)) - WORD_SIZE))
// and m' = b - m

// We assume the first half is chunk(X_i) and the second half is always chunk(Y_0)
debug_assert!(point.len() % 2 == 0);

let log_WORD_SIZE = log2(WORD_SIZE) as usize;
Expand All @@ -61,12 +63,14 @@ impl<F: JoltField, const CHUNK_INDEX: usize, const WORD_SIZE: usize> LassoSubtab

// min with 1 << b is included for test cases with subtables of bit-length smaller than 6
for k in 0..min(WORD_SIZE, 1 << b) {
// bit-decompose k
let k_bits = k
.get_bits(log_WORD_SIZE)
.iter()
.map(|bit| F::from_u64(*bit as u64).unwrap())
.collect::<Vec<F>>(); // big-endian

// Compute eq(y, bin(k))
let mut eq_term = F::one();
// again, min with b is included when subtables of bit-length less than 6 are used
for i in 0..min(log_WORD_SIZE, b) {
Expand All @@ -79,8 +83,9 @@ impl<F: JoltField, const CHUNK_INDEX: usize, const WORD_SIZE: usize> LassoSubtab
} else {
0
};

let m_prime = b - m;

// Compute \sum_{j = 0}^{m'-1} 2^{k + j} * x_{b - j - 1}
let shift_x_by_k = (0..m_prime)
.enumerate()
.map(|(j, _)| F::from_u64(1_u64 << (j + k)).unwrap() * x[b - 1 - j])
Expand Down Expand Up @@ -108,6 +113,7 @@ mod test {
subtable_materialize_mle_parity_test!(sll_materialize_mle_parity2, SllSubtable<Fr, 2, 32>, Fr, 1 << 10);
subtable_materialize_mle_parity_test!(sll_materialize_mle_parity3, SllSubtable<Fr, 3, 32>, Fr, 1 << 10);

// This test is noticeably slow
subtable_materialize_mle_parity_test!(
sll_binius_materialize_mle_parity3,
SllSubtable<BiniusField<BinaryField128b>, 3, 32>,
Expand Down
23 changes: 16 additions & 7 deletions jolt-core/src/jolt/subtable/sra_sign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ impl<F: JoltField, const WORD_SIZE: usize> SraSignSubtable<F, WORD_SIZE> {

impl<F: JoltField, const WORD_SIZE: usize> LassoSubtable<F> for SraSignSubtable<F, WORD_SIZE> {
fn materialize(&self, M: usize) -> Vec<F> {
// table[x | y] = (x_sign == 0) ? 0 : 0b11..100..0,
// where x_sign = (x >> ((WORD_SIZE - 1) & (log2(M) / 2))) & 1,
// `0b11..100..0` has `WORD_SIZE` bits and `y % WORD_SIZE` ones
let mut entries: Vec<F> = Vec::with_capacity(M);

let operand_chunk_width: usize = (log2(M) / 2) as usize;
Expand All @@ -31,19 +34,23 @@ impl<F: JoltField, const WORD_SIZE: usize> LassoSubtable<F> for SraSignSubtable<
for idx in 0..M {
let (x, y) = split_bits(idx, operand_chunk_width);

let x_sign = F::from_u64(((x >> sign_bit_index) & 1) as u64).unwrap();
let x_sign = (x >> sign_bit_index) & 1;

let row = (0..(y % WORD_SIZE) as u32).fold(F::zero(), |acc, i: u32| {
acc + F::from_u64(1_u64 << (WORD_SIZE as u32 - 1 - i)).unwrap() * x_sign
});

entries.push(row);
if x_sign == 0 {
entries.push(F::zero());
} else {
let row = (0..(y % WORD_SIZE)).fold(0, |acc, i| acc + (1 << (WORD_SIZE - 1 - i)));
entries.push(F::from_u64(row).unwrap());
}
}
entries
}

fn evaluate_mle(&self, point: &[F]) -> F {
// first half is chunk X_i
// \sum_{k = 0}^{WORD_SIZE - 1} eq(y, bin(k)) * x_sign * \prod_{i = 0}^{k-1} 2^{WORD_SIZE - 1 - k},
// where x_sign = x_{b - 1 - (WORD_SIZE - 1) % b}

// first half is chunk X_last
// and second half is always chunk Y_0
debug_assert!(point.len() % 2 == 0);

Expand All @@ -59,12 +66,14 @@ impl<F: JoltField, const WORD_SIZE: usize> LassoSubtable<F> for SraSignSubtable<

// min with 1 << b is included for test cases with subtables of bit-length smaller than 6
for k in 0..std::cmp::min(WORD_SIZE, 1 << b) {
// bit-decompose k
let k_bits = k
.get_bits(log_WORD_SIZE)
.iter()
.map(|bit| if *bit { F::one() } else { F::zero() })
.collect::<Vec<F>>(); // big-endian

// Compute eq(y, bin(k))
let mut eq_term = F::one();
// again, min with b is included when subtables of bit-length less than 6 are used
for i in 0..std::cmp::min(log_WORD_SIZE, b) {
Expand Down
Loading
Loading