From e13325713abaf663f75ae5ddccff9ec3c11a9a37 Mon Sep 17 00:00:00 2001 From: Noah Citron Date: Mon, 24 Jun 2024 16:08:14 -0400 Subject: [PATCH 01/17] trace new instructions --- tracer/src/emulator/cpu.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tracer/src/emulator/cpu.rs b/tracer/src/emulator/cpu.rs index 2a950900a..c42dfae17 100644 --- a/tracer/src/emulator/cpu.rs +++ b/tracer/src/emulator/cpu.rs @@ -2459,7 +2459,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, @@ -2477,7 +2477,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, @@ -3130,7 +3130,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, @@ -3145,7 +3145,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, @@ -3165,7 +3165,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, @@ -3184,7 +3184,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, @@ -3273,7 +3273,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, @@ -3290,7 +3290,7 @@ pub const INSTRUCTIONS: [Instruction; INSTRUCTION_NUM] = [ Ok(()) }, disassemble: dump_format_r, - trace: None, + trace: Some(trace_r), }, Instruction { mask: 0xfe00707f, From 7aed02df00180bc076a6ff6f92de80a49e5e0fad Mon Sep 17 00:00:00 2001 From: Noah Citron Date: Mon, 24 Jun 2024 16:13:21 -0400 Subject: [PATCH 02/17] rename toolchain tag file --- .jolt.rust.toolchain-tag => guest-toolchain-tag | 0 jolt-core/src/host/toolchain.rs | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename .jolt.rust.toolchain-tag => guest-toolchain-tag (100%) diff --git a/.jolt.rust.toolchain-tag b/guest-toolchain-tag similarity index 100% rename from .jolt.rust.toolchain-tag rename to guest-toolchain-tag diff --git a/jolt-core/src/host/toolchain.rs b/jolt-core/src/host/toolchain.rs index 0c3b0163a..45ca7becc 100644 --- a/jolt-core/src/host/toolchain.rs +++ b/jolt-core/src/host/toolchain.rs @@ -11,7 +11,7 @@ use indicatif::{ProgressBar, ProgressStyle}; use reqwest::Client; use tokio::runtime::Runtime; -const TOOLCHAIN_TAG: &str = include_str!("../../../.jolt.rust.toolchain-tag"); +const TOOLCHAIN_TAG: &str = include_str!("../../../guest-toolchain-tag"); const DOWNLOAD_RETRIES: usize = 5; const DELAY_BASE_MS: u64 = 500; From fd39280f20195ffdf426fded0723d7516ac740eb Mon Sep 17 00:00:00 2001 From: Noah Citron Date: Mon, 24 Jun 2024 16:14:52 -0400 Subject: [PATCH 03/17] fix ci --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8b2f485bd..6fd2acb07 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -49,7 +49,7 @@ jobs: - name: Cache Jolt RISC-V Rust toolchain uses: actions/cache@v4 with: - key: jolt-rust-toolchain-${{hashFiles('.jolt.rust.toolchain-tag')}} + key: jolt-rust-toolchain-${{hashFiles('guest-toolchain-tag')}} path: ~/.jolt - name: Install Jolt RISC-V Rust toolchain run: cargo run install-toolchain From faaa91bbbfea70fad6a99174905ed7653d1ed472 Mon Sep 17 00:00:00 2001 From: Noah Citron Date: Tue, 16 Jul 2024 20:49:44 -0400 Subject: [PATCH 04/17] add m extension support to tracer --- Cargo.toml | 2 ++ examples/muldiv/Cargo.toml | 9 +++++++++ examples/muldiv/guest/Cargo.toml | 14 ++++++++++++++ examples/muldiv/guest/src/lib.rs | 7 +++++++ examples/muldiv/src/main.rs | 9 +++++++++ guest-toolchain-tag | 2 +- jolt-core/src/host/mod.rs | 2 +- jolt-core/src/host/toolchain.rs | 2 +- 8 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 examples/muldiv/Cargo.toml create mode 100644 examples/muldiv/guest/Cargo.toml create mode 100644 examples/muldiv/guest/src/lib.rs create mode 100644 examples/muldiv/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index a4cacf3ad..ab2270a50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,8 @@ members = [ "examples/alloc/guest", "examples/stdlib", "examples/stdlib/guest", + "examples/muldiv", + "examples/muldiv/guest", ] [features] diff --git a/examples/muldiv/Cargo.toml b/examples/muldiv/Cargo.toml new file mode 100644 index 000000000..6910f3f60 --- /dev/null +++ b/examples/muldiv/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "muldiv" +version = "0.1.0" +edition = "2021" + +[dependencies] +jolt-sdk = { path = "../../jolt-sdk", features = ["host"] } +guest = { package = "muldiv-guest", path = "./guest" } + diff --git a/examples/muldiv/guest/Cargo.toml b/examples/muldiv/guest/Cargo.toml new file mode 100644 index 000000000..81ee8a31e --- /dev/null +++ b/examples/muldiv/guest/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "muldiv-guest" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "guest" +path = "./src/lib.rs" + +[features] +guest = [] + +[dependencies] +jolt = { package = "jolt-sdk", path = "../../../jolt-sdk" } diff --git a/examples/muldiv/guest/src/lib.rs b/examples/muldiv/guest/src/lib.rs new file mode 100644 index 000000000..a08b3ef90 --- /dev/null +++ b/examples/muldiv/guest/src/lib.rs @@ -0,0 +1,7 @@ +#![cfg_attr(feature = "guest", no_std)] +#![no_main] + +#[jolt::provable] +fn muldiv(a: u32, b: u32, c: u32) -> u32 { + a * b / c +} diff --git a/examples/muldiv/src/main.rs b/examples/muldiv/src/main.rs new file mode 100644 index 000000000..488b02211 --- /dev/null +++ b/examples/muldiv/src/main.rs @@ -0,0 +1,9 @@ +pub fn main() { + let (prove, verify) = guest::build_muldiv(); + + let (output, proof) = prove(12031293, 17, 92); + let is_valid = verify(proof); + + println!("output: {}", output); + println!("valid: {}", is_valid); +} diff --git a/guest-toolchain-tag b/guest-toolchain-tag index ed254e28a..5f35dcf5b 100644 --- a/guest-toolchain-tag +++ b/guest-toolchain-tag @@ -1 +1 @@ -nightly-3c5f0ec3f4f98a2d211061a83bade8d62c6a6135 \ No newline at end of file +nightly-3cce1fd56f8c0f705f27f5dfb8f777583bc62e20 diff --git a/jolt-core/src/host/mod.rs b/jolt-core/src/host/mod.rs index 6bcd8a1e6..292b8c75c 100644 --- a/jolt-core/src/host/mod.rs +++ b/jolt-core/src/host/mod.rs @@ -109,7 +109,7 @@ impl Program { "panic=abort", ]; - let toolchain = "riscv32i-jolt-zkvm-elf"; + let toolchain = "riscv32im-jolt-zkvm-elf"; let mut envs = vec![ ("CARGO_ENCODED_RUSTFLAGS", rust_flags.join("\x1f")), ("RUSTUP_TOOLCHAIN", toolchain.to_string()), diff --git a/jolt-core/src/host/toolchain.rs b/jolt-core/src/host/toolchain.rs index 45ca7becc..94a9945d0 100644 --- a/jolt-core/src/host/toolchain.rs +++ b/jolt-core/src/host/toolchain.rs @@ -68,7 +68,7 @@ fn link_toolchain() -> Result<()> { .args([ "toolchain", "link", - "riscv32i-jolt-zkvm-elf", + "riscv32im-jolt-zkvm-elf", link_path.to_str().unwrap(), ]) .output()?; From 60ac0cfe6b35ea0866e6fe3179645a7fecbb04cd Mon Sep 17 00:00:00 2001 From: Noah Citron Date: Tue, 16 Jul 2024 21:08:57 -0400 Subject: [PATCH 05/17] surface error in toolchain download --- jolt-core/src/host/toolchain.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jolt-core/src/host/toolchain.rs b/jolt-core/src/host/toolchain.rs index 94a9945d0..23adae68b 100644 --- a/jolt-core/src/host/toolchain.rs +++ b/jolt-core/src/host/toolchain.rs @@ -125,7 +125,8 @@ async fn download_toolchain(client: &Client, url: &str) -> Result<()> { Ok(()) } else { - Err(eyre!("failed to download toolchain")) + let err = response.error_for_status().err().unwrap(); + Err(eyre!("failed to download toolchain: {}", err)) } } From 84ce384eb79795ac8a9f58457fa532372b923cc0 Mon Sep 17 00:00:00 2001 From: Noah Citron Date: Tue, 16 Jul 2024 21:14:00 -0400 Subject: [PATCH 06/17] remove unwrap in error handling --- jolt-core/src/host/toolchain.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jolt-core/src/host/toolchain.rs b/jolt-core/src/host/toolchain.rs index 23adae68b..be9937a8f 100644 --- a/jolt-core/src/host/toolchain.rs +++ b/jolt-core/src/host/toolchain.rs @@ -125,8 +125,10 @@ async fn download_toolchain(client: &Client, url: &str) -> Result<()> { Ok(()) } else { - let err = response.error_for_status().err().unwrap(); - Err(eyre!("failed to download toolchain: {}", err)) + Err(match response.error_for_status() { + Ok(_) => eyre!("failed to download toolchain"), + Err(err) => eyre!("failed to download toolchain: {}", err), + }) } } From 64b2b2f3714b5c5480da6d7114d972d3248ea647 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Tue, 23 Jul 2024 11:26:41 -0400 Subject: [PATCH 07/17] Enable div and rem virtual sequences --- jolt-core/src/host/mod.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/jolt-core/src/host/mod.rs b/jolt-core/src/host/mod.rs index df78f0e4e..bfe808b63 100644 --- a/jolt-core/src/host/mod.rs +++ b/jolt-core/src/host/mod.rs @@ -25,7 +25,9 @@ use crate::{ field::JoltField, jolt::{ instruction::{ - mulh::MULHInstruction, mulhsu::MULHSUInstruction, VirtualInstructionSequence, + div::DIVInstruction, divu::DIVUInstruction, mulh::MULHInstruction, + mulhsu::MULHSUInstruction, rem::REMInstruction, remu::REMUInstruction, + VirtualInstructionSequence, }, vm::{bytecode::BytecodeRow, rv32i_vm::RV32I, JoltTraceStep}, }, @@ -178,10 +180,10 @@ impl Program { .flat_map(|row| match row.instruction.opcode { tracer::RV32IM::MULH => MULHInstruction::<32>::virtual_sequence(row), tracer::RV32IM::MULHSU => MULHSUInstruction::<32>::virtual_sequence(row), - tracer::RV32IM::DIV => todo!(), - tracer::RV32IM::DIVU => todo!(), - tracer::RV32IM::REM => todo!(), - tracer::RV32IM::REMU => todo!(), + tracer::RV32IM::DIV => DIVInstruction::<32>::virtual_sequence(row), + tracer::RV32IM::DIVU => DIVUInstruction::<32>::virtual_sequence(row), + tracer::RV32IM::REM => REMInstruction::<32>::virtual_sequence(row), + tracer::RV32IM::REMU => REMUInstruction::<32>::virtual_sequence(row), _ => vec![row], }) .map(|row| { From 2394c54bbd50c55f242d04f0ad0199a2bb69fdf0 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Tue, 23 Jul 2024 16:57:40 -0400 Subject: [PATCH 08/17] Fix bytecode preprocessing for virtual instructions --- jolt-core/src/field/ark.rs | 17 ++ jolt-core/src/field/mod.rs | 3 + jolt-core/src/host/mod.rs | 12 +- jolt-core/src/jolt/instruction/div.rs | 174 +++++++++++--------- jolt-core/src/jolt/instruction/divu.rs | 195 ++++++++++++----------- jolt-core/src/jolt/instruction/mod.rs | 3 +- jolt-core/src/jolt/instruction/mulh.rs | 174 +++++++++++--------- jolt-core/src/jolt/instruction/mulhsu.rs | 110 +++++++------ jolt-core/src/jolt/instruction/rem.rs | 153 ++++++++++-------- jolt-core/src/jolt/instruction/remu.rs | 175 +++++++++++--------- jolt-core/src/jolt/vm/bytecode.rs | 107 ++++++++++++- jolt-core/src/jolt/vm/mod.rs | 20 ++- jolt-core/src/jolt/vm/rv32i_vm.rs | 27 ++++ 13 files changed, 709 insertions(+), 461 deletions(-) diff --git a/jolt-core/src/field/ark.rs b/jolt-core/src/field/ark.rs index ab6eaced7..eecb9d7b9 100644 --- a/jolt-core/src/field/ark.rs +++ b/jolt-core/src/field/ark.rs @@ -26,6 +26,23 @@ impl JoltField for ark_bn254::Fr { } } + fn to_u64(&self) -> Option { + let bigint = self.into_bigint(); + let limbs: &[u64] = bigint.as_ref(); + let result = limbs[0]; + + match ::from_u64(result) { + None => None, + Some(x) => { + if x == *self { + Some(result) + } else { + None + } + } + } + } + fn square(&self) -> Self { ::square(self) } diff --git a/jolt-core/src/field/mod.rs b/jolt-core/src/field/mod.rs index 6a3b4823d..02ee610c9 100644 --- a/jolt-core/src/field/mod.rs +++ b/jolt-core/src/field/mod.rs @@ -44,6 +44,9 @@ pub trait JoltField: fn square(&self) -> Self; fn from_bytes(bytes: &[u8]) -> Self; fn inverse(&self) -> Option; + fn to_u64(&self) -> Option { + unimplemented!("conversion to u64 not implemented"); + } } pub trait OptimizedMul: Sized + Mul { diff --git a/jolt-core/src/host/mod.rs b/jolt-core/src/host/mod.rs index bfe808b63..17095f120 100644 --- a/jolt-core/src/host/mod.rs +++ b/jolt-core/src/host/mod.rs @@ -178,12 +178,12 @@ impl Program { let trace: Vec<_> = raw_trace .into_par_iter() .flat_map(|row| match row.instruction.opcode { - tracer::RV32IM::MULH => MULHInstruction::<32>::virtual_sequence(row), - tracer::RV32IM::MULHSU => MULHSUInstruction::<32>::virtual_sequence(row), - tracer::RV32IM::DIV => DIVInstruction::<32>::virtual_sequence(row), - tracer::RV32IM::DIVU => DIVUInstruction::<32>::virtual_sequence(row), - tracer::RV32IM::REM => REMInstruction::<32>::virtual_sequence(row), - tracer::RV32IM::REMU => REMUInstruction::<32>::virtual_sequence(row), + tracer::RV32IM::MULH => MULHInstruction::<32>::virtual_trace(row), + tracer::RV32IM::MULHSU => MULHSUInstruction::<32>::virtual_trace(row), + tracer::RV32IM::DIV => DIVInstruction::<32>::virtual_trace(row), + tracer::RV32IM::DIVU => DIVUInstruction::<32>::virtual_trace(row), + tracer::RV32IM::REM => REMInstruction::<32>::virtual_trace(row), + tracer::RV32IM::REMU => REMUInstruction::<32>::virtual_trace(row), _ => vec![row], }) .map(|row| { diff --git a/jolt-core/src/jolt/instruction/div.rs b/jolt-core/src/jolt/instruction/div.rs index 2b18328f1..1aeef59d4 100644 --- a/jolt-core/src/jolt/instruction/div.rs +++ b/jolt-core/src/jolt/instruction/div.rs @@ -11,14 +11,11 @@ use crate::jolt::instruction::{ pub struct DIVInstruction; impl VirtualInstructionSequence for DIVInstruction { - fn virtual_sequence(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::DIV); - // DIV operands - let x = trace_row.register_state.rs1_val.unwrap(); - let y = trace_row.register_state.rs2_val.unwrap(); + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + assert_eq!(instruction.opcode, RV32IM::DIV); // DIV source registers - let r_x = trace_row.instruction.rs1; - let r_y = trace_row.instruction.rs2; + let r_x = instruction.rs1; + let r_y = instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_r: Option = Some(virtual_register_index(1)); @@ -26,6 +23,81 @@ impl VirtualInstructionSequence for DIVInstruction Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::DIV); + // DIV operands + let x = trace_row.register_state.rs1_val.unwrap(); + let y = trace_row.register_state.rs2_val.unwrap(); + + let virtual_instructions = Self::virtual_sequence(trace_row.instruction); + let mut virtual_trace = vec![]; + let (quotient, remainder) = match WORD_SIZE { 32 => { let mut quotient = x as i32 / y as i32; @@ -49,16 +121,8 @@ impl VirtualInstructionSequence for DIVInstruction(quotient).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: trace_row.instruction.rd, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -69,16 +133,8 @@ impl VirtualInstructionSequence for DIVInstruction(remainder).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_r, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -90,16 +146,8 @@ impl VirtualInstructionSequence for DIVInstruction(r, y).lookup_entry(); assert_eq!(is_valid, 1); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER, - rs1: v_r, - rs2: r_y, - rd: None, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(r), rs2_val: Some(y), @@ -111,16 +159,8 @@ impl VirtualInstructionSequence for DIVInstruction(y, q).lookup_entry(); assert_eq!(is_valid, 1); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_DIV0, - rs1: r_y, - rs2: trace_row.instruction.rd, - rd: None, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(y), rs2_val: Some(q), @@ -131,16 +171,8 @@ impl VirtualInstructionSequence for DIVInstruction(q, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MUL, - rs1: trace_row.instruction.rd, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -151,16 +183,8 @@ impl VirtualInstructionSequence for DIVInstruction(q_y, r).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: v_r, - rd: v_0, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -171,16 +195,8 @@ impl VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction::virtual_sequence(div_trace_row); + let virtual_sequence = DIVInstruction::<32>::virtual_trace(div_trace_row); let mut registers = vec![0u64; REGISTER_COUNT as usize]; registers[r_x as usize] = x; registers[r_y as usize] = y; diff --git a/jolt-core/src/jolt/instruction/divu.rs b/jolt-core/src/jolt/instruction/divu.rs index 481e27f84..394556f9b 100644 --- a/jolt-core/src/jolt/instruction/divu.rs +++ b/jolt-core/src/jolt/instruction/divu.rs @@ -13,35 +13,108 @@ use crate::jolt::instruction::{ pub struct DIVUInstruction; impl VirtualInstructionSequence for DIVUInstruction { - fn virtual_sequence(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::DIVU); - // DIVU operands - let x = trace_row.register_state.rs1_val.unwrap(); - let y = trace_row.register_state.rs2_val.unwrap(); + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + assert_eq!(instruction.opcode, RV32IM::DIVU); // DIVU source registers - let r_x = trace_row.instruction.rs1; - let r_y = trace_row.instruction.rs2; + let r_x = instruction.rs1; + let r_y = instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_r: Option = Some(virtual_register_index(1)); let v_qy = Some(virtual_register_index(2)); let mut virtual_sequence = vec![]; + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: instruction.rd, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_r, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::MULU, + rs1: instruction.rd, + rs2: r_y, + rd: v_qy, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER, + rs1: v_r, + rs2: r_y, + rd: None, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_LTE, + rs1: v_qy, + rs2: r_x, + rd: None, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_VALID_DIV0, + rs1: r_y, + rs2: instruction.rd, + rd: None, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::ADD, + rs1: v_qy, + rs2: v_r, + rd: v_0, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_EQ, + rs1: v_0, + rs2: r_x, + rd: None, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + + virtual_sequence + } + + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::DIVU); + // DIVU operands + let x = trace_row.register_state.rs1_val.unwrap(); + let y = trace_row.register_state.rs2_val.unwrap(); + + let virtual_instructions = Self::virtual_sequence(trace_row.instruction); + let mut virtual_trace = vec![]; let quotient = x / y; let remainder = x - quotient * y; let q = ADVICEInstruction::(quotient).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: trace_row.instruction.rd, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -52,16 +125,8 @@ impl VirtualInstructionSequence for DIVUInstruction(remainder).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_r, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -72,16 +137,8 @@ impl VirtualInstructionSequence for DIVUInstruction(q, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MULU, - rs1: trace_row.instruction.rd, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -93,16 +150,8 @@ impl VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction(y, q).lookup_entry(); assert_eq!(is_valid, 1); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_DIV0, - rs1: r_y, - rs2: trace_row.instruction.rd, - rd: None, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(y), rs2_val: Some(q), @@ -155,16 +188,8 @@ impl VirtualInstructionSequence for DIVUInstruction(q_y, r).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: v_r, - rd: v_0, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -175,16 +200,8 @@ impl VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction::virtual_sequence(divu_trace_row); + let virtual_sequence = DIVUInstruction::<32>::virtual_trace(divu_trace_row); let mut registers = vec![0u64; REGISTER_COUNT as usize]; registers[r_x as usize] = x; registers[r_y as usize] = y; diff --git a/jolt-core/src/jolt/instruction/mod.rs b/jolt-core/src/jolt/instruction/mod.rs index c38034b0e..c9e622eca 100644 --- a/jolt-core/src/jolt/instruction/mod.rs +++ b/jolt-core/src/jolt/instruction/mod.rs @@ -126,7 +126,8 @@ impl From> for SubtableIndices { } pub trait VirtualInstructionSequence { - fn virtual_sequence(trace_row: RVTraceRow) -> Vec; + fn virtual_sequence(instruction: ELFInstruction) -> Vec; + fn virtual_trace(trace_row: RVTraceRow) -> Vec; } pub mod add; diff --git a/jolt-core/src/jolt/instruction/mulh.rs b/jolt-core/src/jolt/instruction/mulh.rs index 9c96ae6e2..6e7ea6396 100644 --- a/jolt-core/src/jolt/instruction/mulh.rs +++ b/jolt-core/src/jolt/instruction/mulh.rs @@ -10,14 +10,11 @@ use crate::jolt::instruction::{ pub struct MULHInstruction; impl VirtualInstructionSequence for MULHInstruction { - fn virtual_sequence(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::MULH); - // MULH operands - let x = trace_row.register_state.rs1_val.unwrap(); - let y = trace_row.register_state.rs2_val.unwrap(); + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + assert_eq!(instruction.opcode, RV32IM::MULH); // MULH source registers - let r_x = trace_row.instruction.rs1; - let r_y = trace_row.instruction.rs2; + let r_x = instruction.rs1; + let r_y = instruction.rs2; // Virtual registers used in sequence let v_sx = Some(virtual_register_index(0)); let v_sy = Some(virtual_register_index(1)); @@ -28,17 +25,84 @@ impl VirtualInstructionSequence for MULHInstruction Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::MULH); + // MULH operands + let x = trace_row.register_state.rs1_val.unwrap(); + let y = trace_row.register_state.rs2_val.unwrap(); + + let virtual_instructions = Self::virtual_sequence(trace_row.instruction); + let mut virtual_trace = vec![]; + let s_x = MOVSIGNInstruction::(x).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_MOVSIGN, - rs1: r_x, - rs2: None, - rd: v_sx, - imm: None, - virtual_sequence_index: Some(0), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(x), rs2_val: None, @@ -49,16 +113,8 @@ impl VirtualInstructionSequence for MULHInstruction(y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_MOVSIGN, - rs1: r_y, - rs2: None, - rd: v_sy, - imm: None, - virtual_sequence_index: Some(1), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(y), rs2_val: None, @@ -69,16 +125,8 @@ impl VirtualInstructionSequence for MULHInstruction(x, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MULHU, - rs1: r_x, - rs2: r_y, - rd: v_0, - imm: None, - virtual_sequence_index: Some(2), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(x), rs2_val: Some(y), @@ -89,16 +137,8 @@ impl VirtualInstructionSequence for MULHInstruction(s_x, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MULU, - rs1: v_sx, - rs2: r_y, - rd: v_1, - imm: None, - virtual_sequence_index: Some(3), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(s_x), rs2_val: Some(y), @@ -109,16 +149,8 @@ impl VirtualInstructionSequence for MULHInstruction(s_y, x).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MULU, - rs1: v_sy, - rs2: r_x, - rd: v_2, - imm: None, - virtual_sequence_index: Some(4), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(s_y), rs2_val: Some(x), @@ -129,16 +161,8 @@ impl VirtualInstructionSequence for MULHInstruction(xy_high_bits, sx_y_low_bits).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::ADD, - rs1: v_0, - rs2: v_1, - rd: v_3, - imm: None, - virtual_sequence_index: Some(5), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(xy_high_bits), rs2_val: Some(sx_y_low_bits), @@ -149,16 +173,8 @@ impl VirtualInstructionSequence for MULHInstruction(partial_sum, sy_x_low_bits).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::ADD, - rs1: v_3, - rs2: v_2, - rd: trace_row.instruction.rd, - imm: None, - virtual_sequence_index: Some(6), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(partial_sum), rs2_val: Some(sy_x_low_bits), @@ -167,7 +183,7 @@ impl VirtualInstructionSequence for MULHInstruction::virtual_sequence(mulh_trace_row); + let virtual_sequence = MULHInstruction::<32>::virtual_trace(mulh_trace_row); let mut registers = vec![0u64; REGISTER_COUNT as usize]; registers[r_x as usize] = x; registers[r_y as usize] = y; diff --git a/jolt-core/src/jolt/instruction/mulhsu.rs b/jolt-core/src/jolt/instruction/mulhsu.rs index 9a69250e2..64ade88c4 100644 --- a/jolt-core/src/jolt/instruction/mulhsu.rs +++ b/jolt-core/src/jolt/instruction/mulhsu.rs @@ -11,32 +11,68 @@ use crate::jolt::instruction::{ pub struct MULHSUInstruction; impl VirtualInstructionSequence for MULHSUInstruction { - fn virtual_sequence(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::MULHSU); - // MULHSU operands - let x = trace_row.register_state.rs1_val.unwrap(); - let y = trace_row.register_state.rs2_val.unwrap(); + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + assert_eq!(instruction.opcode, RV32IM::MULHSU); // MULHSU source registers - let r_x = trace_row.instruction.rs1; - let r_y = trace_row.instruction.rs2; + let r_x = instruction.rs1; + let r_y = instruction.rs2; // Virtual registers used in sequence let v_sx = Some(virtual_register_index(0)); let v_1 = Some(virtual_register_index(1)); let v_2 = Some(virtual_register_index(2)); let mut virtual_sequence = vec![]; + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_MOVSIGN, + rs1: r_x, + rs2: None, + rd: v_sx, + imm: None, + virtual_sequence_index: Some(0), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::MULHU, + rs1: r_x, + rs2: r_y, + rd: v_1, + imm: None, + virtual_sequence_index: Some(1), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::MULU, + rs1: v_sx, + rs2: r_y, + rd: v_2, + imm: None, + virtual_sequence_index: Some(2), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::ADD, + rs1: v_1, + rs2: v_2, + rd: instruction.rd, + imm: None, + virtual_sequence_index: Some(3), + }); + virtual_sequence + } + + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::MULHSU); + // MULHSU operands + let x = trace_row.register_state.rs1_val.unwrap(); + let y = trace_row.register_state.rs2_val.unwrap(); + + let virtual_instructions = Self::virtual_sequence(trace_row.instruction); + let mut virtual_trace = vec![]; let s_x = MOVSIGNInstruction::(x).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_MOVSIGN, - rs1: r_x, - rs2: None, - rd: v_sx, - imm: None, - virtual_sequence_index: Some(0), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(x), rs2_val: None, @@ -47,16 +83,8 @@ impl VirtualInstructionSequence for MULHSUInstruction(x, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MULHU, - rs1: r_x, - rs2: r_y, - rd: v_1, - imm: None, - virtual_sequence_index: Some(1), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(x), rs2_val: Some(y), @@ -67,16 +95,8 @@ impl VirtualInstructionSequence for MULHSUInstruction(s_x, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MULU, - rs1: v_sx, - rs2: r_y, - rd: v_2, - imm: None, - virtual_sequence_index: Some(2), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(s_x), rs2_val: Some(y), @@ -87,16 +107,8 @@ impl VirtualInstructionSequence for MULHSUInstruction(xy_high_bits, sx_y_low_bits).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::ADD, - rs1: v_1, - rs2: v_2, - rd: trace_row.instruction.rd, - imm: None, - virtual_sequence_index: Some(3), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(xy_high_bits), rs2_val: Some(sx_y_low_bits), @@ -105,7 +117,7 @@ impl VirtualInstructionSequence for MULHSUInstruction::virtual_sequence(mulhsu_trace_row); + let virtual_sequence = MULHSUInstruction::<32>::virtual_trace(mulhsu_trace_row); let mut registers = vec![0u64; REGISTER_COUNT as usize]; registers[r_x as usize] = x; registers[r_y as usize] = y; diff --git a/jolt-core/src/jolt/instruction/rem.rs b/jolt-core/src/jolt/instruction/rem.rs index 5d5687d66..2b02d800a 100644 --- a/jolt-core/src/jolt/instruction/rem.rs +++ b/jolt-core/src/jolt/instruction/rem.rs @@ -12,20 +12,83 @@ use crate::jolt::instruction::{ pub struct REMInstruction; impl VirtualInstructionSequence for REMInstruction { - fn virtual_sequence(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::REM); - // REM operands - let x = trace_row.register_state.rs1_val.unwrap(); - let y = trace_row.register_state.rs2_val.unwrap(); + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + assert_eq!(instruction.opcode, RV32IM::REM); // REM source registers - let r_x = trace_row.instruction.rs1; - let r_y = trace_row.instruction.rs2; + let r_x = instruction.rs1; + let r_y = instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_q = Some(virtual_register_index(1)); let v_qy = Some(virtual_register_index(2)); let mut virtual_sequence = vec![]; + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_q, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: instruction.rd, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER, + rs1: instruction.rd, + rs2: r_y, + rd: None, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::MUL, + rs1: v_q, + rs2: r_y, + rd: v_qy, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::ADD, + rs1: v_qy, + rs2: instruction.rd, + rd: v_0, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + virtual_sequence.push(ELFInstruction { + address: instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_EQ, + rs1: v_0, + rs2: r_x, + rd: None, + imm: None, + virtual_sequence_index: Some(virtual_sequence.len()), + }); + + virtual_sequence + } + + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::REM); + // REM operands + let x = trace_row.register_state.rs1_val.unwrap(); + let y = trace_row.register_state.rs2_val.unwrap(); + + let virtual_instructions = Self::virtual_sequence(trace_row.instruction); + let mut virtual_trace = vec![]; let (quotient, remainder) = match WORD_SIZE { 32 => { @@ -50,16 +113,8 @@ impl VirtualInstructionSequence for REMInstruction(quotient).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_q, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -70,16 +125,8 @@ impl VirtualInstructionSequence for REMInstruction(remainder).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: trace_row.instruction.rd, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -91,16 +138,8 @@ impl VirtualInstructionSequence for REMInstruction(r, y).lookup_entry(); assert_eq!(is_valid, 1); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER, - rs1: trace_row.instruction.rd, - rs2: r_y, - rd: None, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(r), rs2_val: Some(y), @@ -111,16 +150,8 @@ impl VirtualInstructionSequence for REMInstruction(q, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MUL, - rs1: v_q, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -131,16 +162,8 @@ impl VirtualInstructionSequence for REMInstruction(q_y, r).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: trace_row.instruction.rd, - rd: v_0, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -151,16 +174,8 @@ impl VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction::virtual_sequence(rem_trace_row); + let virtual_sequence = REMInstruction::<32>::virtual_trace(rem_trace_row); let mut registers = vec![0u64; REGISTER_COUNT as usize]; registers[r_x as usize] = x; registers[r_y as usize] = y; diff --git a/jolt-core/src/jolt/instruction/remu.rs b/jolt-core/src/jolt/instruction/remu.rs index e12171f5d..6e66c85e9 100644 --- a/jolt-core/src/jolt/instruction/remu.rs +++ b/jolt-core/src/jolt/instruction/remu.rs @@ -13,14 +13,11 @@ use crate::jolt::instruction::{ pub struct REMUInstruction; impl VirtualInstructionSequence for REMUInstruction { - fn virtual_sequence(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::REMU); - // REMU operands - let x = trace_row.register_state.rs1_val.unwrap(); - let y = trace_row.register_state.rs2_val.unwrap(); + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + assert_eq!(instruction.opcode, RV32IM::REMU); // REMU source registers - let r_x = trace_row.instruction.rs1; - let r_y = trace_row.instruction.rs2; + let r_x = instruction.rs1; + let r_y = instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_q = Some(virtual_register_index(1)); @@ -28,20 +25,88 @@ impl VirtualInstructionSequence for REMUInstruction Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::REMU); + // REMU operands + let x = trace_row.register_state.rs1_val.unwrap(); + let y = trace_row.register_state.rs2_val.unwrap(); + + let virtual_instructions = Self::virtual_sequence(trace_row.instruction); + let mut virtual_trace = vec![]; + let quotient = x / y; let remainder = x - quotient * y; let q = ADVICEInstruction::(quotient).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_q, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -52,16 +117,8 @@ impl VirtualInstructionSequence for REMUInstruction(remainder).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: trace_row.instruction.rd, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -72,16 +129,8 @@ impl VirtualInstructionSequence for REMUInstruction(q, y).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::MULU, - rs1: v_q, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -93,16 +142,8 @@ impl VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction(q_y, r).lookup_entry(); - virtual_sequence.push(RVTraceRow { - instruction: ELFInstruction { - address: trace_row.instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: trace_row.instruction.rd, - rd: v_0, - imm: None, - virtual_sequence_index: Some(virtual_sequence.len()), - }, + virtual_trace.push(RVTraceRow { + instruction: virtual_instructions[virtual_trace.len()].clone(), register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -153,16 +178,8 @@ impl VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction::virtual_sequence(remu_trace_row); + let virtual_sequence = REMUInstruction::<32>::virtual_trace(remu_trace_row); let mut registers = vec![0u64; REGISTER_COUNT as usize]; registers[r_x as usize] = x; registers[r_y as usize] = y; diff --git a/jolt-core/src/jolt/vm/bytecode.rs b/jolt-core/src/jolt/vm/bytecode.rs index 8f6cfbcf2..38330a8c3 100644 --- a/jolt-core/src/jolt/vm/bytecode.rs +++ b/jolt-core/src/jolt/vm/bytecode.rs @@ -2,6 +2,8 @@ use ark_ff::Zero; use rand::rngs::StdRng; use rand::RngCore; use serde::{Deserialize, Serialize}; +#[cfg(test)] +use std::collections::HashSet; use std::{collections::HashMap, marker::PhantomData}; use crate::field::JoltField; @@ -50,6 +52,9 @@ pub struct BytecodeRow { rs2: u64, /// "Immediate" value for this instruction (0 if unused). imm: u64, + /// If this instruction is part of a "virtual sequence" (see Section 6.2 of the + /// Jolt paper), then this contains the instruction's index within the sequence. + virtual_sequence_index: Option, } impl BytecodeRow { @@ -61,6 +66,7 @@ impl BytecodeRow { rs1, rs2, imm, + virtual_sequence_index: None, } } @@ -72,6 +78,7 @@ impl BytecodeRow { rs1: 0, rs2: 0, imm: 0, + virtual_sequence_index: None, } } @@ -83,6 +90,7 @@ impl BytecodeRow { rs1: rng.next_u64() % REGISTER_COUNT, rs2: rng.next_u64() % REGISTER_COUNT, imm: rng.next_u64() % (1 << 20), // U-format instructions have 20-bit imm values + virtual_sequence_index: None, } } @@ -125,6 +133,7 @@ impl BytecodeRow { rs1: instruction.rs1.unwrap_or(0), rs2: instruction.rs2.unwrap_or(0), imm: instruction.imm.unwrap_or(0) as u64, // imm is always cast to its 32-bit repr, signed or unsigned + virtual_sequence_index: instruction.virtual_sequence_index, } } } @@ -167,7 +176,8 @@ pub struct BytecodePreprocessing { /// Maps the memory address of each instruction in the bytecode to its "virtual" address. /// See Section 6.1 of the Jolt paper, "Reflecting the program counter". The virtual address /// is the one used to keep track of the next (potentially virtual) instruction to execute. - virtual_address_map: HashMap, + /// Key: (ELF address, virtual sequence index or 0) + virtual_address_map: HashMap<(usize, usize), usize>, } impl BytecodePreprocessing { @@ -182,7 +192,13 @@ impl BytecodePreprocessing { instruction.address = 1 + (instruction.address - RAM_START_ADDRESS as usize) / BYTES_PER_INSTRUCTION; assert_eq!( - virtual_address_map.insert(instruction.address, virtual_address), + virtual_address_map.insert( + ( + instruction.address, + instruction.virtual_sequence_index.unwrap_or(0) + ), + virtual_address + ), None ); virtual_address += 1; @@ -190,7 +206,7 @@ impl BytecodePreprocessing { // Bytecode: Prepend a single no-op instruction bytecode.insert(0, BytecodeRow::no_op(0)); - assert_eq!(virtual_address_map.insert(0, 0), None); + assert_eq!(virtual_address_map.insert((0, 0), 0), None); // Bytecode: Pad to nearest power of 2 let code_size = bytecode.len().next_power_of_two(); @@ -203,7 +219,7 @@ impl BytecodePreprocessing { let mut rs2 = vec![]; let mut imm = vec![]; - for instruction in bytecode { + for instruction in bytecode.clone() { address.push(F::from_u64(instruction.address as u64).unwrap()); bitflags.push(F::from_u64(instruction.bitflags).unwrap()); rd.push(F::from_u64(instruction.rd).unwrap()); @@ -253,7 +269,10 @@ impl> BytecodePolynomials { let virtual_address = preprocessing .virtual_address_map - .get(&step.bytecode_row.address) + .get(&( + step.bytecode_row.address, + step.bytecode_row.virtual_sequence_index.unwrap_or(0), + )) .unwrap(); a_read_write_usize[step_index] = *virtual_address; let counter = final_cts[*virtual_address]; @@ -287,8 +306,82 @@ impl> BytecodePolynomials { DensePolynomial::new(rs2), DensePolynomial::new(imm), ]; - let t_read = DensePolynomial::from_usize(&read_cts); - let t_final = DensePolynomial::from_usize(&final_cts); + let t_read: DensePolynomial = DensePolynomial::from_usize(&read_cts); + let t_final: DensePolynomial = DensePolynomial::from_usize(&final_cts); + + #[cfg(test)] + let mut init_tuples: HashSet<(u64, [u64; 6], u64)> = HashSet::new(); + #[cfg(test)] + let mut final_tuples: HashSet<(u64, [u64; 6], u64)> = HashSet::new(); + + #[cfg(test)] + for (a, t) in t_final.Z.iter().enumerate() { + init_tuples.insert(( + a as u64, + [ + preprocessing.v_init_final[0][a].to_u64().unwrap(), + preprocessing.v_init_final[1][a].to_u64().unwrap(), + preprocessing.v_init_final[2][a].to_u64().unwrap(), + preprocessing.v_init_final[3][a].to_u64().unwrap(), + preprocessing.v_init_final[4][a].to_u64().unwrap(), + preprocessing.v_init_final[5][a].to_u64().unwrap(), + ], + 0, + )); + final_tuples.insert(( + a as u64, + [ + preprocessing.v_init_final[0][a].to_u64().unwrap(), + preprocessing.v_init_final[1][a].to_u64().unwrap(), + preprocessing.v_init_final[2][a].to_u64().unwrap(), + preprocessing.v_init_final[3][a].to_u64().unwrap(), + preprocessing.v_init_final[4][a].to_u64().unwrap(), + preprocessing.v_init_final[5][a].to_u64().unwrap(), + ], + t.to_u64().unwrap(), + )); + } + + #[cfg(test)] + let mut read_tuples: HashSet<(u64, [u64; 6], u64)> = HashSet::new(); + #[cfg(test)] + let mut write_tuples: HashSet<(u64, [u64; 6], u64)> = HashSet::new(); + + #[cfg(test)] + for (i, a) in a_read_write_usize.iter().enumerate() { + read_tuples.insert(( + *a as u64, + [ + v_read_write[0][i].to_u64().unwrap(), + v_read_write[1][i].to_u64().unwrap(), + v_read_write[2][i].to_u64().unwrap(), + v_read_write[3][i].to_u64().unwrap(), + v_read_write[4][i].to_u64().unwrap(), + v_read_write[5][i].to_u64().unwrap(), + ], + t_read[i].to_u64().unwrap(), + )); + write_tuples.insert(( + *a as u64, + [ + v_read_write[0][i].to_u64().unwrap(), + v_read_write[1][i].to_u64().unwrap(), + v_read_write[2][i].to_u64().unwrap(), + v_read_write[3][i].to_u64().unwrap(), + v_read_write[4][i].to_u64().unwrap(), + v_read_write[5][i].to_u64().unwrap(), + ], + t_read[i].to_u64().unwrap() + 1, + )); + } + + #[cfg(test)] + { + let init_write: HashSet<_> = init_tuples.union(&write_tuples).collect(); + let read_final: HashSet<_> = read_tuples.union(&final_tuples).collect(); + let set_difference: Vec<_> = init_write.symmetric_difference(&read_final).collect(); + assert_eq!(set_difference.len(), 0); + } Self { _group: PhantomData, diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index ad0110c89..046c00ccb 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -13,7 +13,12 @@ use strum::EnumCount; use crate::jolt::vm::timestamp_range_check::RangeCheckPolynomials; use crate::jolt::{ - instruction::JoltInstruction, subtable::JoltSubtableSet, + instruction::{ + div::DIVInstruction, divu::DIVUInstruction, mulh::MULHInstruction, + mulhsu::MULHSUInstruction, rem::REMInstruction, remu::REMUInstruction, JoltInstruction, + VirtualInstructionSequence, + }, + subtable::JoltSubtableSet, vm::timestamp_range_check::TimestampValidityProof, }; use crate::lasso::memory_checking::{MemoryCheckingProver, MemoryCheckingVerifier}; @@ -281,8 +286,17 @@ pub trait Jolt, const C: usize, c let read_write_memory_preprocessing = ReadWriteMemoryPreprocessing::preprocess(memory_init); let bytecode_rows: Vec = bytecode - .iter() - .map(BytecodeRow::from_instruction::) + .into_iter() + .flat_map(|instruction| match instruction.opcode { + tracer::RV32IM::MULH => MULHInstruction::<32>::virtual_sequence(instruction), + tracer::RV32IM::MULHSU => MULHSUInstruction::<32>::virtual_sequence(instruction), + tracer::RV32IM::DIV => DIVInstruction::<32>::virtual_sequence(instruction), + tracer::RV32IM::DIVU => DIVUInstruction::<32>::virtual_sequence(instruction), + tracer::RV32IM::REM => REMInstruction::<32>::virtual_sequence(instruction), + tracer::RV32IM::REMU => REMUInstruction::<32>::virtual_sequence(instruction), + _ => vec![instruction], + }) + .map(|instruction| BytecodeRow::from_instruction::(&instruction)) .collect(); let bytecode_preprocessing = BytecodePreprocessing::::preprocess(bytecode_rows); diff --git a/jolt-core/src/jolt/vm/rv32i_vm.rs b/jolt-core/src/jolt/vm/rv32i_vm.rs index 6acd2930e..6e34ea8a3 100644 --- a/jolt-core/src/jolt/vm/rv32i_vm.rs +++ b/jolt-core/src/jolt/vm/rv32i_vm.rs @@ -320,6 +320,33 @@ mod tests { // fib_e2e::>(); // } + #[test] + fn muldiv_e2e_hyrax() { + let mut program = host::Program::new("muldiv-guest"); + program.set_input(&123u32); + program.set_input(&234u32); + program.set_input(&345u32); + let (bytecode, memory_init) = program.decode(); + let (io_device, trace, circuit_flags) = program.trace(); + + let preprocessing = + RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let (jolt_proof, jolt_commitments) = + , C, M>>::prove( + io_device, + trace, + circuit_flags, + preprocessing.clone(), + ); + + let verification_result = RV32IJoltVM::verify(preprocessing, jolt_proof, jolt_commitments); + assert!( + verification_result.is_ok(), + "Verification failed with error: {:?}", + verification_result.err() + ); + } + #[test] fn sha3_e2e_hyrax() { let _guard = SHA3_FILE_LOCK.lock().unwrap(); From 3c0f4b601f1268cbefaae79696a27b460a3bd912 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Wed, 24 Jul 2024 14:30:10 -0400 Subject: [PATCH 09/17] Fix some constraints --- common/src/rv_trace.rs | 9 +++++++-- jolt-core/src/r1cs/jolt_constraints.rs | 14 +++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/common/src/rv_trace.rs b/common/src/rv_trace.rs index 0120d2cbd..05ec18573 100644 --- a/common/src/rv_trace.rs +++ b/common/src/rv_trace.rs @@ -287,7 +287,7 @@ impl ELFInstruction { RV32IM::BEQ | RV32IM::BNE | RV32IM::BLT | RV32IM::BGE | RV32IM::BLTU | RV32IM::BGEU, ); - // loads, stores, branches, jumps do not store the lookup output to rd (they may update rd in other ways) + // loads, stores, branches, jumps, and asserts do not store the lookup output to rd (they may update rd in other ways) flags[6] = !matches!( self.opcode, RV32IM::SB @@ -301,7 +301,12 @@ impl ELFInstruction { | RV32IM::BGEU | RV32IM::JAL | RV32IM::JALR - | RV32IM::LUI, + | RV32IM::LUI + | RV32IM::VIRTUAL_ASSERT_EQ + | RV32IM::VIRTUAL_ASSERT_LTE + | RV32IM::VIRTUAL_ASSERT_VALID_DIV0 + | RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER + | RV32IM::VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER, ); let mask = 1u32 << 31; diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index b8837d119..4523974e5 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -17,9 +17,9 @@ pub fn construct_jolt_constraints( constraints.build_constraints(&mut uniform_builder); let non_uniform_constraint = OffsetEqConstraint::new( - (JoltIn::PcIn, true), + (JoltIn::Bytecode_ELFAddress, true), (Variable::Auxiliary(PC_BRANCH_AUX_INDEX), false), - (4 * JoltIn::PcIn + PC_START_ADDRESS, true), + (4 * JoltIn::Bytecode_ELFAddress + PC_START_ADDRESS, true), ); CombinedUniformBuilder::construct( @@ -100,7 +100,7 @@ pub enum JoltIn { OpFlags_SignImm, OpFlags_IsConcat, OpFlags_IsVirtualSequence, - OpFlags_IsVirtual, + OpFlags_IsAssert, // Instruction Flags // Should match JoltInstructionSet @@ -164,7 +164,7 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { cs.constrain_pack_be(flags.to_vec(), JoltIn::Bytecode_Bitflags, 1); - let real_pc = 4i64 * JoltIn::PcIn + (PC_START_ADDRESS - PC_NOOP_SHIFT); + let real_pc = 4i64 * JoltIn::Bytecode_ELFAddress + (PC_START_ADDRESS - PC_NOOP_SHIFT); let x = cs.allocate_if_else(JoltIn::OpFlags_IsRs1Rs2, real_pc, JoltIn::RS1_Read); let y = cs.allocate_if_else( JoltIn::OpFlags_IsImm, @@ -262,7 +262,7 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { JoltIn::LookupOutput, ); let rd_nonzero_and_jmp = cs.allocate_prod(JoltIn::Bytecode_RD, JoltIn::OpFlags_IsJmp); - let lhs = JoltIn::PcIn + (PC_START_ADDRESS - PC_NOOP_SHIFT); + let lhs = JoltIn::Bytecode_ELFAddress + (PC_START_ADDRESS - PC_NOOP_SHIFT); let rhs = JoltIn::RD_Write; cs.constrain_eq_conditional(rd_nonzero_and_jmp, lhs, rhs); @@ -271,12 +271,12 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { let next_pc_jump = cs.allocate_if_else( JoltIn::OpFlags_IsJmp, JoltIn::LookupOutput + 4, - 4 * JoltIn::PcIn + PC_START_ADDRESS + 4, + 4 * JoltIn::Bytecode_ELFAddress + PC_START_ADDRESS + 4, ); let next_pc_jump_branch = cs.allocate_if_else( branch_and_lookup_output, - 4 * JoltIn::PcIn + PC_START_ADDRESS + imm_signed, + 4 * JoltIn::Bytecode_ELFAddress + PC_START_ADDRESS + imm_signed, next_pc_jump, ); assert_static_aux_index!(next_pc_jump_branch, PC_BRANCH_AUX_INDEX); From 1814e8575d2e3d663b73802b30e3b423cf808508 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Wed, 24 Jul 2024 15:02:07 -0400 Subject: [PATCH 10/17] Fix virtual assert lte --- .../jolt/instruction/virtual_assert_lte.rs | 53 ++++++------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/jolt-core/src/jolt/instruction/virtual_assert_lte.rs b/jolt-core/src/jolt/instruction/virtual_assert_lte.rs index 9887a60e7..f8114c71c 100644 --- a/jolt-core/src/jolt/instruction/virtual_assert_lte.rs +++ b/jolt-core/src/jolt/instruction/virtual_assert_lte.rs @@ -5,10 +5,7 @@ use serde::{Deserialize, Serialize}; use super::{JoltInstruction, SubtableIndices}; use crate::{ field::JoltField, - jolt::subtable::{ - eq::EqSubtable, eq_abs::EqAbsSubtable, left_msb::LeftMSBSubtable, lt_abs::LtAbsSubtable, - ltu::LtuSubtable, right_msb::RightMSBSubtable, LassoSubtable, - }, + jolt::subtable::{eq::EqSubtable, ltu::LtuSubtable, LassoSubtable}, utils::instruction_utils::chunk_and_concatenate_operands, }; @@ -21,40 +18,26 @@ impl JoltInstruction for ASSERTLTEInstruction { } fn combine_lookups(&self, vals: &[F], C: usize, M: usize) -> F { - // LTS(x,y) let vals_by_subtable = self.slice_values(vals, C, M); + let ltu = vals_by_subtable[0]; + let eq = vals_by_subtable[1]; - let left_msb = vals_by_subtable[0]; - let right_msb = vals_by_subtable[1]; - let ltu = vals_by_subtable[2]; - let eq = vals_by_subtable[3]; - let lt_abs = vals_by_subtable[4]; - let eq_abs = vals_by_subtable[5]; - - // Accumulator for LTU(x_{(); + for i in 0..C { + ltu_sum += ltu[i] * eq_prod; + eq_prod *= eq[i]; + } - // LTS(x,y) || EQ(x,y) - lt + eq - lt * eq + // LTU(x,y) || EQ(x,y) + ltu_sum + eq_prod } fn g_poly_degree(&self, C: usize) -> usize { - C + 1 + C } fn subtables( @@ -63,12 +46,8 @@ impl JoltInstruction for ASSERTLTEInstruction { _: usize, ) -> Vec<(Box>, SubtableIndices)> { vec![ - (Box::new(LeftMSBSubtable::new()), SubtableIndices::from(0)), - (Box::new(RightMSBSubtable::new()), SubtableIndices::from(0)), - (Box::new(LtuSubtable::new()), SubtableIndices::from(1..C)), + (Box::new(LtuSubtable::new()), SubtableIndices::from(0..C)), (Box::new(EqSubtable::new()), SubtableIndices::from(0..C)), - (Box::new(LtAbsSubtable::new()), SubtableIndices::from(0)), - (Box::new(EqAbsSubtable::new()), SubtableIndices::from(0)), ] } @@ -77,7 +56,7 @@ impl JoltInstruction for ASSERTLTEInstruction { } fn lookup_entry(&self) -> u64 { - ((self.0 as i32) <= (self.1 as i32)).into() + (self.0 <= self.1).into() } fn random(&self, rng: &mut StdRng) -> Self { From fed35282ab7bbe1386aa1caf9df881d5d4a88675 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Thu, 25 Jul 2024 11:11:50 -0400 Subject: [PATCH 11/17] Working(?) non-uniform constraints --- common/src/rv_trace.rs | 36 ++++++++++----------- jolt-core/src/host/mod.rs | 2 +- jolt-core/src/jolt/instruction/div.rs | 22 +++++++------ jolt-core/src/jolt/instruction/divu.rs | 23 ++++++++------ jolt-core/src/jolt/instruction/mod.rs | 1 + jolt-core/src/jolt/instruction/mulh.rs | 22 +++++++------ jolt-core/src/jolt/instruction/mulhsu.rs | 16 ++++++---- jolt-core/src/jolt/instruction/rem.rs | 19 ++++++----- jolt-core/src/jolt/instruction/remu.rs | 21 +++++++------ jolt-core/src/jolt/vm/bytecode.rs | 19 ++++++----- jolt-core/src/r1cs/builder.rs | 20 ++++++++---- jolt-core/src/r1cs/jolt_constraints.rs | 40 +++++++++++++++++------- tracer/src/emulator/cpu.rs | 12 +++---- tracer/src/lib.rs | 2 +- 14 files changed, 153 insertions(+), 102 deletions(-) diff --git a/common/src/rv_trace.rs b/common/src/rv_trace.rs index 05ec18573..6d4c49321 100644 --- a/common/src/rv_trace.rs +++ b/common/src/rv_trace.rs @@ -222,11 +222,14 @@ pub struct ELFInstruction { pub rd: Option, pub imm: Option, /// If this instruction is part of a "virtual sequence" (see Section 6.2 of the - /// Jolt paper), then this contains the instruction's index within the sequence. - pub virtual_sequence_index: Option, + /// Jolt paper), then this contains the number of virtual instructions after this + /// one in the sequence. I.e. if this is the last instruction in the sequence, + /// `virtual_sequence_remaining` will be Some(0); if this is the penultimate instruction + /// in the sequence, `virtual_sequence_remaining` will be Some(1); etc. + pub virtual_sequence_remaining: Option, } -pub const NUM_CIRCUIT_FLAGS: usize = 11; +pub const NUM_CIRCUIT_FLAGS: usize = 12; impl ELFInstruction { #[rustfmt::skip] @@ -241,8 +244,9 @@ impl ELFInstruction { // 6: Instruction writes lookup output to rd // 7: Sign-bit of imm // 8: Is concat - // 9: Increment virtual PC + // 9: Virtual instruction // 10: Assert instruction + // 11: Don't update PC let mut flags = [false; NUM_CIRCUIT_FLAGS]; @@ -338,20 +342,8 @@ impl ELFInstruction { | RV32IM::BGEU, ); - // TODO(moodlezoup): Use these flags in R1CS constraints - flags[9] = match self.virtual_sequence_index { - // For virtual sequences, we set - // virtual PC := ProgARW (the bytecode `a` value) - // if it's the first instruction in the sequence. - // Otherwise, we increment the virtual PC: - // virtual PC += 1 - // This prevents a malicious prover from reordering or omitting - // instructions from the virtual sequence. - Some(i) => i > 0, - // For "real" instructions, we always set - // virtual PC := ProgARW (the bytecode `a` value) - None => false - }; + flags[9] = self.virtual_sequence_remaining.is_some(); + flags[10] = matches!(self.opcode, RV32IM::VIRTUAL_ASSERT_EQ | RV32IM::VIRTUAL_ASSERT_LTE | @@ -360,6 +352,14 @@ impl ELFInstruction { RV32IM::VIRTUAL_ASSERT_VALID_DIV0, ); + // All instructions in virtual sequence are mapped from the same + // ELF address. Thus if an instruction is virtual (and not the last one + // in its sequence), then we should *not* update the PC. + flags[11] = match self.virtual_sequence_remaining { + Some(i) => i != 0, + None => false + }; + flags } } diff --git a/jolt-core/src/host/mod.rs b/jolt-core/src/host/mod.rs index 17095f120..1a1ee029f 100644 --- a/jolt-core/src/host/mod.rs +++ b/jolt-core/src/host/mod.rs @@ -201,6 +201,7 @@ impl Program { } }) .collect(); + let padded_trace_len = trace.len().next_power_of_two(); let mut circuit_flag_trace = unsafe_allocate_zero_vec(padded_trace_len * NUM_CIRCUIT_FLAGS); @@ -216,7 +217,6 @@ impl Program { } }); }); - (io_device, trace, circuit_flag_trace) } diff --git a/jolt-core/src/jolt/instruction/div.rs b/jolt-core/src/jolt/instruction/div.rs index 1aeef59d4..904e4c032 100644 --- a/jolt-core/src/jolt/instruction/div.rs +++ b/jolt-core/src/jolt/instruction/div.rs @@ -11,6 +11,8 @@ use crate::jolt::instruction::{ pub struct DIVInstruction; impl VirtualInstructionSequence for DIVInstruction { + const SEQUENCE_LENGTH: usize = 7; + fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::DIV); // DIV source registers @@ -21,7 +23,7 @@ impl VirtualInstructionSequence for DIVInstruction = Some(virtual_register_index(1)); let v_qy = Some(virtual_register_index(2)); - let mut virtual_sequence = vec![]; + let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); virtual_sequence.push(ELFInstruction { address: instruction.address, @@ -30,7 +32,7 @@ impl VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction; impl VirtualInstructionSequence for DIVUInstruction { + const SEQUENCE_LENGTH: usize = 8; + fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::DIVU); // DIVU source registers @@ -23,7 +25,7 @@ impl VirtualInstructionSequence for DIVUInstruction = Some(virtual_register_index(1)); let v_qy = Some(virtual_register_index(2)); - let mut virtual_sequence = vec![]; + let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); virtual_sequence.push(ELFInstruction { address: instruction.address, opcode: RV32IM::VIRTUAL_ADVICE, @@ -31,7 +33,7 @@ impl VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction> for SubtableIndices { } pub trait VirtualInstructionSequence { + const SEQUENCE_LENGTH: usize; fn virtual_sequence(instruction: ELFInstruction) -> Vec; fn virtual_trace(trace_row: RVTraceRow) -> Vec; } diff --git a/jolt-core/src/jolt/instruction/mulh.rs b/jolt-core/src/jolt/instruction/mulh.rs index 6e7ea6396..457dd69d3 100644 --- a/jolt-core/src/jolt/instruction/mulh.rs +++ b/jolt-core/src/jolt/instruction/mulh.rs @@ -10,6 +10,8 @@ use crate::jolt::instruction::{ pub struct MULHInstruction; impl VirtualInstructionSequence for MULHInstruction { + const SEQUENCE_LENGTH: usize = 7; + fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::MULH); // MULH source registers @@ -23,7 +25,7 @@ impl VirtualInstructionSequence for MULHInstruction VirtualInstructionSequence for MULHInstruction VirtualInstructionSequence for MULHInstruction VirtualInstructionSequence for MULHInstruction VirtualInstructionSequence for MULHInstruction VirtualInstructionSequence for MULHInstruction VirtualInstructionSequence for MULHInstruction VirtualInstructionSequence for MULHInstruction; impl VirtualInstructionSequence for MULHSUInstruction { + const SEQUENCE_LENGTH: usize = 4; + fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::MULHSU); // MULHSU source registers @@ -21,7 +23,7 @@ impl VirtualInstructionSequence for MULHSUInstruction VirtualInstructionSequence for MULHSUInstruction VirtualInstructionSequence for MULHSUInstruction VirtualInstructionSequence for MULHSUInstruction VirtualInstructionSequence for MULHSUInstruction; impl VirtualInstructionSequence for REMInstruction { + const SEQUENCE_LENGTH: usize = 6; + fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::REM); // REM source registers @@ -22,7 +24,7 @@ impl VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction; impl VirtualInstructionSequence for REMUInstruction { + const SEQUENCE_LENGTH: usize = 7; + fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::REMU); // REMU source registers @@ -23,7 +25,7 @@ impl VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction, + /// Jolt paper), then this contains the number of virtual instructions after this + /// one in the sequence. I.e. if this is the last instruction in the sequence, + /// `virtual_sequence_remaining` will be Some(0); if this is the penultimate instruction + /// in the sequence, `virtual_sequence_remaining` will be Some(1); etc. + virtual_sequence_remaining: Option, } impl BytecodeRow { @@ -66,7 +69,7 @@ impl BytecodeRow { rs1, rs2, imm, - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -78,7 +81,7 @@ impl BytecodeRow { rs1: 0, rs2: 0, imm: 0, - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -90,7 +93,7 @@ impl BytecodeRow { rs1: rng.next_u64() % REGISTER_COUNT, rs2: rng.next_u64() % REGISTER_COUNT, imm: rng.next_u64() % (1 << 20), // U-format instructions have 20-bit imm values - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -133,7 +136,7 @@ impl BytecodeRow { rs1: instruction.rs1.unwrap_or(0), rs2: instruction.rs2.unwrap_or(0), imm: instruction.imm.unwrap_or(0) as u64, // imm is always cast to its 32-bit repr, signed or unsigned - virtual_sequence_index: instruction.virtual_sequence_index, + virtual_sequence_remaining: instruction.virtual_sequence_remaining, } } } @@ -195,7 +198,7 @@ impl BytecodePreprocessing { virtual_address_map.insert( ( instruction.address, - instruction.virtual_sequence_index.unwrap_or(0) + instruction.virtual_sequence_remaining.unwrap_or(0) ), virtual_address ), @@ -271,7 +274,7 @@ impl> BytecodePolynomials { .virtual_address_map .get(&( step.bytecode_row.address, - step.bytecode_row.virtual_sequence_index.unwrap_or(0), + step.bytecode_row.virtual_sequence_remaining.unwrap_or(0), )) .unwrap(); a_read_write_usize[step_index] = *virtual_address; diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 7d16e1fcd..68fc25b5b 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -907,12 +907,20 @@ impl CombinedUniformBuilder { let uniform_constraint_index = constraint_index / self.uniform_repeat; if az[constraint_index] * bz[constraint_index] != cz[constraint_index] { let step_index = constraint_index % self.uniform_repeat; - panic!( - "Mismatch at global constraint {constraint_index} => {:?}\n\ - uniform constraint: {uniform_constraint_index}\n\ - step: {step_index}", - self.uniform_builder.constraints[uniform_constraint_index] - ); + if uniform_constraint_index >= self.uniform_builder.constraints.len() { + panic!( + "Mismatch at non-uniform constraint: {}\n\ + step: {step_index}", + uniform_constraint_index - self.uniform_builder.constraints.len() + ) + } else { + panic!( + "Mismatch at global constraint {constraint_index} => {:?}\n\ + uniform constraint: {uniform_constraint_index}\n\ + step: {step_index}", + self.uniform_builder.constraints[uniform_constraint_index] + ); + } } } } diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index 4523974e5..b9bbb8b29 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -16,16 +16,32 @@ pub fn construct_jolt_constraints( let constraints = UniformJoltConstraints::new(memory_start); constraints.build_constraints(&mut uniform_builder); - let non_uniform_constraint = OffsetEqConstraint::new( + // If the next instruction's ELF address is not zero (i.e. it's + // not padding), then check the PC update. + let pc_constraint = OffsetEqConstraint::new( (JoltIn::Bytecode_ELFAddress, true), - (Variable::Auxiliary(PC_BRANCH_AUX_INDEX), false), + (Variable::Auxiliary(NEXT_PC), false), (4 * JoltIn::Bytecode_ELFAddress + PC_START_ADDRESS, true), ); + // If the current instruction is virtual, check that the next instruction + // in the trace is the next instruction in bytecode. Virtual sequences + // do not involve jumps or branches, so this should always hold, + // EXCEPT if we encounter a virtual instruction followed by a padding + // instruction. But that should never happen because the execution + // trace should always end with some return handling, which shouldn't involve + // any virtual sequences. + + let virtual_sequence_constraint = OffsetEqConstraint::new( + (JoltIn::OpFlags_IsVirtualInstruction, false), + (JoltIn::Bytecode_A, true), + (JoltIn::Bytecode_A + 1, false), + ); + CombinedUniformBuilder::construct( uniform_builder, padded_trace_length, - vec![non_uniform_constraint], + vec![pc_constraint, virtual_sequence_constraint], ) } @@ -99,8 +115,9 @@ pub enum JoltIn { OpFlags_LookupOutToRd, OpFlags_SignImm, OpFlags_IsConcat, - OpFlags_IsVirtualSequence, + OpFlags_IsVirtualInstruction, OpFlags_IsAssert, + OpFlags_DoNotUpdatePC, // Instruction Flags // Should match JoltInstructionSet @@ -140,7 +157,7 @@ pub const PC_START_ADDRESS: i64 = 0x80000000; const PC_NOOP_SHIFT: i64 = 4; const LOG_M: usize = 16; const OPERAND_SIZE: usize = LOG_M / 2; -pub const PC_BRANCH_AUX_INDEX: usize = 15; +pub const NEXT_PC: usize = 15; pub struct UniformJoltConstraints { memory_start: u64, @@ -266,20 +283,21 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { let rhs = JoltIn::RD_Write; cs.constrain_eq_conditional(rd_nonzero_and_jmp, lhs, rhs); - let branch_and_lookup_output = - cs.allocate_prod(JoltIn::OpFlags_IsBranch, JoltIn::LookupOutput); let next_pc_jump = cs.allocate_if_else( JoltIn::OpFlags_IsJmp, JoltIn::LookupOutput + 4, - 4 * JoltIn::Bytecode_ELFAddress + PC_START_ADDRESS + 4, + 4 * JoltIn::Bytecode_ELFAddress + PC_START_ADDRESS + 4 + - 4 * JoltIn::OpFlags_DoNotUpdatePC, ); - let next_pc_jump_branch = cs.allocate_if_else( - branch_and_lookup_output, + let should_branch = cs.allocate_prod(JoltIn::OpFlags_IsBranch, JoltIn::LookupOutput); + let next_pc = cs.allocate_if_else( + should_branch, 4 * JoltIn::Bytecode_ELFAddress + PC_START_ADDRESS + imm_signed, next_pc_jump, ); - assert_static_aux_index!(next_pc_jump_branch, PC_BRANCH_AUX_INDEX); + + assert_static_aux_index!(next_pc, NEXT_PC); } } diff --git a/tracer/src/emulator/cpu.rs b/tracer/src/emulator/cpu.rs index 10e646f5c..4398e29cd 100644 --- a/tracer/src/emulator/cpu.rs +++ b/tracer/src/emulator/cpu.rs @@ -1847,7 +1847,7 @@ fn trace_r(inst: &Instruction, xlen: &Xlen, word: u32, address: u64) -> ELFInstr rs1: Some(normalize_register(f.rs1)), rs2: Some(normalize_register(f.rs2)), rd: Some(normalize_register(f.rd)), - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -1860,7 +1860,7 @@ fn trace_i(inst: &Instruction, xlen: &Xlen, word: u32, address: u64) -> ELFInstr rs1: Some(normalize_register(f.rs1)), rs2: None, rd: Some(normalize_register(f.rd)), - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -1873,7 +1873,7 @@ fn trace_s(inst: &Instruction, xlen: &Xlen, word: u32, address: u64) -> ELFInstr rs1: Some(normalize_register(f.rs1)), rs2: Some(normalize_register(f.rs2)), rd: None, - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -1886,7 +1886,7 @@ fn trace_b(inst: &Instruction, xlen: &Xlen, word: u32, address: u64) -> ELFInstr rs1: Some(normalize_register(f.rs1)), rs2: Some(normalize_register(f.rs2)), rd: None, - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -1899,7 +1899,7 @@ fn trace_u(inst: &Instruction, xlen: &Xlen, word: u32, address: u64) -> ELFInstr rs1: None, rs2: None, rd: Some(normalize_register(f.rd)), - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } @@ -1913,7 +1913,7 @@ fn trace_j(inst: &Instruction, xlen: &Xlen, word: u32, address: u64) -> ELFInstr rs1: None, rs2: None, rd: Some(normalize_register(f.rd)), - virtual_sequence_index: None, + virtual_sequence_remaining: None, } } diff --git a/tracer/src/lib.rs b/tracer/src/lib.rs index 8d5ac1bb1..d70a768cf 100644 --- a/tracer/src/lib.rs +++ b/tracer/src/lib.rs @@ -104,7 +104,7 @@ pub fn decode(elf: &[u8]) -> (Vec, Vec<(u64, u8)>) { rs2: None, rd: None, imm: None, - virtual_sequence_index: None, + virtual_sequence_remaining: None, }); } } From 0b3e369c4e31225cbd12ea3ee089f0c8466ec23a Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Thu, 25 Jul 2024 18:36:10 -0400 Subject: [PATCH 12/17] Working uniform constraints (and many bug fixes along the way) --- common/src/rv_trace.rs | 30 +++--- jolt-core/src/jolt/instruction/div.rs | 33 +++++-- jolt-core/src/jolt/instruction/divu.rs | 33 +++++-- jolt-core/src/jolt/instruction/lb.rs | 8 +- jolt-core/src/jolt/instruction/lh.rs | 7 +- jolt-core/src/jolt/instruction/mod.rs | 1 + jolt-core/src/jolt/instruction/rem.rs | 31 +++++- jolt-core/src/jolt/instruction/remu.rs | 31 +++++- jolt-core/src/jolt/instruction/sb.rs | 22 +++-- jolt-core/src/jolt/instruction/sh.rs | 18 ++-- .../src/jolt/instruction/virtual_move.rs | 94 +++++++++++++++++++ .../src/jolt/instruction/virtual_movsign.rs | 55 ++++++----- jolt-core/src/jolt/trace/rv.rs | 3 + jolt-core/src/jolt/vm/rv32i_vm.rs | 2 + jolt-core/src/r1cs/jolt_constraints.rs | 19 ++-- 15 files changed, 310 insertions(+), 77 deletions(-) create mode 100644 jolt-core/src/jolt/instruction/virtual_move.rs diff --git a/common/src/rv_trace.rs b/common/src/rv_trace.rs index 6d4c49321..6f209bacc 100644 --- a/common/src/rv_trace.rs +++ b/common/src/rv_trace.rs @@ -114,6 +114,7 @@ impl From<&RVTraceRow> for [MemoryOp; MEMORY_OPS_PER_INSTRUCTION] { | RV32IM::SLTI | RV32IM::SLTIU | RV32IM::JALR + | RV32IM::VIRTUAL_MOVE | RV32IM::VIRTUAL_MOVSIGN => [ rs1_read(), MemoryOp::noop_read(), @@ -339,7 +340,12 @@ impl ELFInstruction { | RV32IM::BLT | RV32IM::BGE | RV32IM::BLTU - | RV32IM::BGEU, + | RV32IM::BGEU + | RV32IM::VIRTUAL_ASSERT_EQ + | RV32IM::VIRTUAL_ASSERT_LTE + | RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER + | RV32IM::VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER + | RV32IM::VIRTUAL_ASSERT_VALID_DIV0, ); flags[9] = self.virtual_sequence_remaining.is_some(); @@ -449,6 +455,7 @@ pub enum RV32IM { UNIMPL, // Virtual instructions VIRTUAL_MOVSIGN, + VIRTUAL_MOVE, VIRTUAL_ADVICE, VIRTUAL_ASSERT_LTE, VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER, @@ -551,16 +558,17 @@ impl RV32IM { RV32IM::REM | RV32IM::REMU => RV32InstructionFormat::R, - RV32IM::ADDI | - RV32IM::XORI | - RV32IM::ORI | - RV32IM::ANDI | - RV32IM::SLLI | - RV32IM::SRLI | - RV32IM::SRAI | - RV32IM::SLTI | - RV32IM::FENCE | - RV32IM::SLTIU | + RV32IM::ADDI | + RV32IM::XORI | + RV32IM::ORI | + RV32IM::ANDI | + RV32IM::SLLI | + RV32IM::SRLI | + RV32IM::SRAI | + RV32IM::SLTI | + RV32IM::FENCE | + RV32IM::SLTIU | + RV32IM::VIRTUAL_MOVE | RV32IM::VIRTUAL_MOVSIGN=> RV32InstructionFormat::I, RV32IM::LB | diff --git a/jolt-core/src/jolt/instruction/div.rs b/jolt-core/src/jolt/instruction/div.rs index 904e4c032..f773c854b 100644 --- a/jolt-core/src/jolt/instruction/div.rs +++ b/jolt-core/src/jolt/instruction/div.rs @@ -11,7 +11,7 @@ use crate::jolt::instruction::{ pub struct DIVInstruction; impl VirtualInstructionSequence for DIVInstruction { - const SEQUENCE_LENGTH: usize = 7; + const SEQUENCE_LENGTH: usize = 8; fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::DIV); @@ -20,8 +20,9 @@ impl VirtualInstructionSequence for DIVInstruction = Some(virtual_register_index(1)); - let v_qy = Some(virtual_register_index(2)); + let v_q: Option = Some(virtual_register_index(1)); + let v_r: Option = Some(virtual_register_index(2)); + let v_qy = Some(virtual_register_index(3)); let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); @@ -30,7 +31,7 @@ impl VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction; impl VirtualInstructionSequence for DIVUInstruction { - const SEQUENCE_LENGTH: usize = 8; + const SEQUENCE_LENGTH: usize = 9; fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::DIVU); @@ -22,8 +22,9 @@ impl VirtualInstructionSequence for DIVUInstruction = Some(virtual_register_index(1)); - let v_qy = Some(virtual_register_index(2)); + let v_q = Some(virtual_register_index(1)); + let v_r: Option = Some(virtual_register_index(2)); + let v_qy = Some(virtual_register_index(3)); let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); virtual_sequence.push(ELFInstruction { @@ -31,7 +32,7 @@ impl VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction(&self, vals: &[F], C: usize, M: usize) -> F { assert!(M >= 1 << 8); - assert!(vals.len() == 2); let byte = vals[0]; let sign_extension = vals[1]; @@ -55,6 +55,12 @@ impl JoltInstruction for LBInstruction { Box::new(SignExtendSubtable::::new()), SubtableIndices::from(C - 1), ), + ( + // Not used for lookup, but this implicitly range-checks + // the remaining query chunks + Box::new(IdentitySubtable::::new()), + SubtableIndices::from(0..C - 1), + ), ] } diff --git a/jolt-core/src/jolt/instruction/lh.rs b/jolt-core/src/jolt/instruction/lh.rs index 87ebf64f9..37b9e0969 100644 --- a/jolt-core/src/jolt/instruction/lh.rs +++ b/jolt-core/src/jolt/instruction/lh.rs @@ -20,7 +20,6 @@ impl JoltInstruction for LHInstruction { fn combine_lookups(&self, vals: &[F], _C: usize, M: usize) -> F { // TODO(moodlezoup): make this work with different M assert!(M == 1 << 16); - assert!(vals.len() == 2); let half = vals[0]; let sign_extension = vals[1]; @@ -51,6 +50,12 @@ impl JoltInstruction for LHInstruction { Box::new(SignExtendSubtable::::new()), SubtableIndices::from(C - 1), ), + ( + // Not used for lookup, but this implicitly range-checks + // the remaining query chunks + Box::new(IdentitySubtable::::new()), + SubtableIndices::from(0..C - 1), + ), ] } diff --git a/jolt-core/src/jolt/instruction/mod.rs b/jolt-core/src/jolt/instruction/mod.rs index 16bb02141..a2b539d6d 100644 --- a/jolt-core/src/jolt/instruction/mod.rs +++ b/jolt-core/src/jolt/instruction/mod.rs @@ -163,6 +163,7 @@ pub mod virtual_assert_lte; pub mod virtual_assert_valid_div0; pub mod virtual_assert_valid_signed_remainder; pub mod virtual_assert_valid_unsigned_remainder; +pub mod virtual_move; pub mod virtual_movsign; pub mod xor; diff --git a/jolt-core/src/jolt/instruction/rem.rs b/jolt-core/src/jolt/instruction/rem.rs index cd8145d78..1a423ec7a 100644 --- a/jolt-core/src/jolt/instruction/rem.rs +++ b/jolt-core/src/jolt/instruction/rem.rs @@ -12,7 +12,7 @@ use crate::jolt::instruction::{ pub struct REMInstruction; impl VirtualInstructionSequence for REMInstruction { - const SEQUENCE_LENGTH: usize = 6; + const SEQUENCE_LENGTH: usize = 7; fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::REM); @@ -22,7 +22,8 @@ impl VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction; impl VirtualInstructionSequence for REMUInstruction { - const SEQUENCE_LENGTH: usize = 7; + const SEQUENCE_LENGTH: usize = 8; fn virtual_sequence(instruction: ELFInstruction) -> Vec { assert_eq!(instruction.opcode, RV32IM::REMU); @@ -23,7 +23,8 @@ impl VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction(&self, vals: &[F], _: usize, M: usize) -> F { assert!(M >= 1 << 8); - assert!(vals.len() == 1); vals[0] } @@ -32,12 +32,20 @@ impl JoltInstruction for SBInstruction { ) -> Vec<(Box>, SubtableIndices)> { // This assertion ensures that we only need one TruncateOverflowSubtable assert!(M >= 1 << 8); - vec![( - // Truncate all but the lowest eight bits of the last chunk, - // which contains the lower 8 bits of the rs2 value. - Box::new(TruncateOverflowSubtable::::new()), - SubtableIndices::from(C - 1), - )] + vec![ + ( + // Truncate all but the lowest eight bits of the last chunk, + // which contains the lower 8 bits of the rs2 value. + Box::new(TruncateOverflowSubtable::::new()), + SubtableIndices::from(C - 1), + ), + ( + // Not used for lookup, but this implicitly range-checks + // the remaining query chunks + Box::new(IdentitySubtable::::new()), + SubtableIndices::from(0..C - 1), + ), + ] } fn to_indices(&self, C: usize, log_M: usize) -> Vec { diff --git a/jolt-core/src/jolt/instruction/sh.rs b/jolt-core/src/jolt/instruction/sh.rs index a8b8b4266..c231af277 100644 --- a/jolt-core/src/jolt/instruction/sh.rs +++ b/jolt-core/src/jolt/instruction/sh.rs @@ -18,7 +18,6 @@ impl JoltInstruction for SHInstruction { fn combine_lookups(&self, vals: &[F], _: usize, M: usize) -> F { // TODO(moodlezoup): make this work with different M assert!(M == 1 << 16); - assert!(vals.len() == 1); vals[0] } @@ -31,13 +30,20 @@ impl JoltInstruction for SHInstruction { C: usize, M: usize, ) -> Vec<(Box>, SubtableIndices)> { - // This assertion ensures that we only need two TruncateOverflowSubtables // TODO(moodlezoup): make this work with different M assert!(M == 1 << 16); - vec![( - Box::new(IdentitySubtable::::new()), - SubtableIndices::from(C - 1), - )] + vec![ + ( + Box::new(IdentitySubtable::::new()), + SubtableIndices::from(C - 1), + ), + ( + // Not used for lookup, but this implicitly range-checks + // the remaining query chunks + Box::new(IdentitySubtable::::new()), + SubtableIndices::from(0..C - 1), + ), + ] } fn to_indices(&self, C: usize, log_M: usize) -> Vec { diff --git a/jolt-core/src/jolt/instruction/virtual_move.rs b/jolt-core/src/jolt/instruction/virtual_move.rs new file mode 100644 index 000000000..77be2ac9c --- /dev/null +++ b/jolt-core/src/jolt/instruction/virtual_move.rs @@ -0,0 +1,94 @@ +use ark_std::log2; +use rand::prelude::StdRng; +use rand::RngCore; +use serde::{Deserialize, Serialize}; + +use super::JoltInstruction; +use crate::{ + field::JoltField, + jolt::{ + instruction::SubtableIndices, + subtable::{identity::IdentitySubtable, LassoSubtable}, + }, + utils::instruction_utils::{chunk_operand_usize, concatenate_lookups}, +}; + +#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)] +pub struct MOVEInstruction(pub u64); + +impl JoltInstruction for MOVEInstruction { + fn operands(&self) -> (u64, u64) { + (self.0, 0) + } + + fn combine_lookups(&self, vals: &[F], C: usize, M: usize) -> F { + concatenate_lookups(vals, C, log2(M) as usize) + } + + fn g_poly_degree(&self, _: usize) -> usize { + 1 + } + + fn subtables( + &self, + C: usize, + M: usize, + ) -> Vec<(Box>, SubtableIndices)> { + assert!(M == 1 << 16); + vec![( + // Implicitly range-checks all query chunks + Box::new(IdentitySubtable::::new()), + SubtableIndices::from(0..C), + )] + } + + fn to_indices(&self, C: usize, log_M: usize) -> Vec { + chunk_operand_usize(self.0, C, log_M) + } + + fn lookup_entry(&self) -> u64 { + self.0 + } + + fn random(&self, rng: &mut StdRng) -> Self { + Self(rng.next_u32() as u64) + } +} + +#[cfg(test)] +mod test { + use ark_bn254::Fr; + use ark_std::test_rng; + use rand_chacha::rand_core::RngCore; + + use crate::{jolt::instruction::JoltInstruction, jolt_instruction_test}; + + use super::MOVEInstruction; + + #[test] + fn virtual_move_instruction_32_e2e() { + let mut rng = test_rng(); + const C: usize = 4; + const M: usize = 1 << 16; + + // Random + for _ in 0..256 { + let x = rng.next_u32() as u64; + let instruction = MOVEInstruction::<32>(x); + jolt_instruction_test!(instruction); + } + } + + #[test] + fn virtual_move_instruction_64_e2e() { + let mut rng = test_rng(); + const C: usize = 8; + const M: usize = 1 << 16; + + for _ in 0..256 { + let x = rng.next_u64(); + let instruction = MOVEInstruction::<64>(x); + jolt_instruction_test!(instruction); + } + } +} diff --git a/jolt-core/src/jolt/instruction/virtual_movsign.rs b/jolt-core/src/jolt/instruction/virtual_movsign.rs index 45405c05d..01a854b02 100644 --- a/jolt-core/src/jolt/instruction/virtual_movsign.rs +++ b/jolt-core/src/jolt/instruction/virtual_movsign.rs @@ -1,3 +1,4 @@ +use ark_std::log2; use rand::prelude::StdRng; use rand::RngCore; use serde::{Deserialize, Serialize}; @@ -7,18 +8,17 @@ use crate::{ field::JoltField, jolt::{ instruction::SubtableIndices, - subtable::{sign_extend::SignExtendSubtable, LassoSubtable}, + subtable::{identity::IdentitySubtable, sign_extend::SignExtendSubtable, LassoSubtable}, }, - utils::instruction_utils::chunk_operand_usize, + utils::instruction_utils::{chunk_operand_usize, concatenate_lookups}, }; #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)] pub struct MOVSIGNInstruction(pub u64); // Constants for 32-bit and 64-bit word sizes -const NEGATIVE_32: u64 = 0xFFFF_FFFF; -const NEGATIVE_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; -const POSITIVE: u64 = 0; +const ALL_ONES_32: u64 = 0xFFFF_FFFF; +const ALL_ONES_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; const SIGN_BIT_32: u64 = 0x8000_0000; const SIGN_BIT_64: u64 = 0x8000_0000_0000_0000; @@ -27,13 +27,12 @@ impl JoltInstruction for MOVSIGNInstruction { (self.0, 0) } - fn combine_lookups(&self, vals: &[F], _: usize, _: usize) -> F { + fn combine_lookups(&self, vals: &[F], _: usize, M: usize) -> F { + // TODO(moodlezoup): make this work with different M + assert!(M == 1 << 16); let val = vals[0]; - let mut result = F::zero(); - for i in 0..WORD_SIZE / 16 { - result += F::from_u64(1 << (16 * i)).unwrap() * val; - } - result + let repeat = WORD_SIZE / 16; + concatenate_lookups(&vec![val; repeat], repeat, log2(M) as usize) } fn g_poly_degree(&self, _: usize) -> usize { @@ -42,33 +41,43 @@ impl JoltInstruction for MOVSIGNInstruction { fn subtables( &self, - _: usize, - _: usize, + C: usize, + M: usize, ) -> Vec<(Box>, SubtableIndices)> { - vec![( - Box::new(SignExtendSubtable::::new()), - SubtableIndices::from(0), - )] + assert!(M == 1 << 16); + let msb_chunk_index = C - (WORD_SIZE / 16); + vec![ + ( + Box::new(SignExtendSubtable::::new()), + SubtableIndices::from(msb_chunk_index), + ), + ( + // Not used for lookup, but this implicitly range-checks + // the remaining query chunks + Box::new(IdentitySubtable::::new()), + SubtableIndices::from(0..C), + ), + ] } - fn to_indices(&self, _: usize, log_M: usize) -> Vec { - chunk_operand_usize(self.0, WORD_SIZE / 16, log_M) + fn to_indices(&self, C: usize, log_M: usize) -> Vec { + chunk_operand_usize(self.0, C, log_M) } fn lookup_entry(&self) -> u64 { match WORD_SIZE { 32 => { if self.0 & SIGN_BIT_32 != 0 { - NEGATIVE_32 + ALL_ONES_32 } else { - POSITIVE + 0 } } 64 => { if self.0 & SIGN_BIT_64 != 0 { - NEGATIVE_64 + ALL_ONES_64 } else { - POSITIVE + 0 } } _ => panic!("only implemented for u32 / u64"), diff --git a/jolt-core/src/jolt/trace/rv.rs b/jolt-core/src/jolt/trace/rv.rs index 90ee82166..bc02b056b 100644 --- a/jolt-core/src/jolt/trace/rv.rs +++ b/jolt-core/src/jolt/trace/rv.rs @@ -23,6 +23,7 @@ use crate::jolt::instruction::virtual_assert_lte::ASSERTLTEInstruction; use crate::jolt::instruction::virtual_assert_valid_div0::AssertValidDiv0Instruction; use crate::jolt::instruction::virtual_assert_valid_signed_remainder::AssertValidSignedRemainderInstruction; use crate::jolt::instruction::virtual_assert_valid_unsigned_remainder::AssertValidUnsignedRemainderInstruction; +use crate::jolt::instruction::virtual_move::MOVEInstruction; use crate::jolt::instruction::xor::XORInstruction; use crate::jolt::instruction::{add::ADDInstruction, virtual_movsign::MOVSIGNInstruction}; use crate::jolt::vm::rv32i_vm::RV32I; @@ -81,6 +82,7 @@ impl TryFrom<&ELFInstruction> for RV32I { RV32IM::MULHU => Ok(MULHUInstruction::default().into()), RV32IM::VIRTUAL_ADVICE => Ok(ADVICEInstruction::default().into()), + RV32IM::VIRTUAL_MOVE => Ok(MOVEInstruction::default().into()), RV32IM::VIRTUAL_MOVSIGN => Ok(MOVSIGNInstruction::default().into()), RV32IM::VIRTUAL_ASSERT_EQ => Ok(BEQInstruction::default().into()), RV32IM::VIRTUAL_ASSERT_LTE => Ok(ASSERTLTEInstruction::default().into()), @@ -146,6 +148,7 @@ impl TryFrom<&RVTraceRow> for RV32I { RV32IM::MULHU => Ok(MULHUInstruction(row.register_state.rs1_val.unwrap(), row.register_state.rs2_val.unwrap()).into()), RV32IM::VIRTUAL_ADVICE => Ok(ADVICEInstruction(row.advice_value.unwrap()).into()), + RV32IM::VIRTUAL_MOVE => Ok(MOVEInstruction(row.register_state.rs1_val.unwrap()).into()), RV32IM::VIRTUAL_MOVSIGN => Ok(MOVSIGNInstruction(row.register_state.rs1_val.unwrap()).into()), RV32IM::VIRTUAL_ASSERT_EQ => Ok(BEQInstruction(row.register_state.rs1_val.unwrap(), row.register_state.rs2_val.unwrap()).into()), RV32IM::VIRTUAL_ASSERT_LTE => Ok(ASSERTLTEInstruction(row.register_state.rs1_val.unwrap(), row.register_state.rs2_val.unwrap()).into()), diff --git a/jolt-core/src/jolt/vm/rv32i_vm.rs b/jolt-core/src/jolt/vm/rv32i_vm.rs index 6e34ea8a3..0fa2a9f96 100644 --- a/jolt-core/src/jolt/vm/rv32i_vm.rs +++ b/jolt-core/src/jolt/vm/rv32i_vm.rs @@ -1,6 +1,7 @@ use crate::field::JoltField; use crate::jolt::instruction::virtual_assert_valid_div0::AssertValidDiv0Instruction; use crate::jolt::instruction::virtual_assert_valid_unsigned_remainder::AssertValidUnsignedRemainderInstruction; +use crate::jolt::instruction::virtual_move::MOVEInstruction; use crate::jolt::subtable::div_by_zero::DivByZeroSubtable; use crate::jolt::subtable::right_is_zero::RightIsZeroSubtable; use crate::poly::commitment::hyrax::HyraxScheme; @@ -119,6 +120,7 @@ instruction_set!( MULU: MULUInstruction, MULHU: MULHUInstruction, VIRTUAL_ADVICE: ADVICEInstruction, + VIRTUAL_MOVE: MOVEInstruction, VIRTUAL_ASSERT_LTE: ASSERTLTEInstruction, VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER: AssertValidSignedRemainderInstruction, VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER: AssertValidUnsignedRemainderInstruction, diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index b9bbb8b29..036be9a82 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -106,7 +106,7 @@ pub enum JoltIn { LookupOutput, // Should match rv_trace.to_circuit_flags() - OpFlags_IsRs1Rs2, + OpFlags_IsPC, OpFlags_IsImm, OpFlags_IsLoad, OpFlags_IsStore, @@ -144,7 +144,8 @@ pub enum JoltIn { IF_Mul, IF_MulU, IF_MulHu, - IF_Virt_Adv, + IF_Virt_Advice, + IF_Virt_Move, IF_Virt_Assert_LTE, IF_Virt_Assert_VALID_SIGNED_REMAINDER, IF_Virt_Assert_VALID_UNSIGNED_REMAINDER, @@ -157,7 +158,7 @@ pub const PC_START_ADDRESS: i64 = 0x80000000; const PC_NOOP_SHIFT: i64 = 4; const LOG_M: usize = 16; const OPERAND_SIZE: usize = LOG_M / 2; -pub const NEXT_PC: usize = 15; +pub const NEXT_PC: usize = 16; pub struct UniformJoltConstraints { memory_start: u64, @@ -172,7 +173,7 @@ impl UniformJoltConstraints { impl R1CSConstraintBuilder for UniformJoltConstraints { type Inputs = JoltIn; fn build_constraints(&self, cs: &mut R1CSBuilder) { - let flags = input_range!(JoltIn::OpFlags_IsRs1Rs2, JoltIn::IF_Virt_Assert_VALID_DIV0); + let flags = input_range!(JoltIn::OpFlags_IsPC, JoltIn::IF_Virt_Assert_VALID_DIV0); for flag in flags { cs.constrain_binary(flag); } @@ -182,7 +183,7 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { cs.constrain_pack_be(flags.to_vec(), JoltIn::Bytecode_Bitflags, 1); let real_pc = 4i64 * JoltIn::Bytecode_ELFAddress + (PC_START_ADDRESS - PC_NOOP_SHIFT); - let x = cs.allocate_if_else(JoltIn::OpFlags_IsRs1Rs2, real_pc, JoltIn::RS1_Read); + let x = cs.allocate_if_else(JoltIn::OpFlags_IsPC, real_pc, JoltIn::RS1_Read); let y = cs.allocate_if_else( JoltIn::OpFlags_IsImm, JoltIn::Bytecode_Imm, @@ -239,9 +240,15 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { cs.constrain_eq_conditional(JoltIn::IF_Add, packed_query, x + y); // Converts from unsigned to twos-complement representation cs.constrain_eq_conditional(JoltIn::IF_Sub, packed_query, x - y + (0xffffffffi64 + 1)); + let is_mul = JoltIn::IF_Mul + JoltIn::IF_MulU + JoltIn::IF_MulHu; + let product = cs.allocate_prod(x, y); + cs.constrain_eq_conditional(is_mul, packed_query, product); + cs.constrain_eq_conditional(JoltIn::IF_Movsign + JoltIn::IF_Virt_Move, packed_query, x); cs.constrain_eq_conditional(JoltIn::OpFlags_IsLoad, packed_query, packed_load_store); cs.constrain_eq_conditional(JoltIn::OpFlags_IsStore, packed_query, JoltIn::RS2_Read); + cs.constrain_eq_conditional(JoltIn::OpFlags_IsAssert, JoltIn::LookupOutput, 1); + // TODO(sragss): Uses 2 excess constraints for condition gating. Could make constrain_pack_be_conditional... Or make everything conditional... let chunked_x = cs.allocate_pack_be( input_range!(JoltIn::ChunksX_0, JoltIn::ChunksX_3).to_vec(), @@ -347,7 +354,7 @@ mod tests { // rv_trace::to_circuit_flags // all zero for ADD - inputs[JoltIn::OpFlags_IsRs1Rs2 as usize][0] = Fr::zero(); // first_operand = rs1 + inputs[JoltIn::OpFlags_IsPC as usize][0] = Fr::zero(); // first_operand = rs1 inputs[JoltIn::OpFlags_IsImm as usize][0] = Fr::zero(); // second_operand = rs2 => immediate let aux = combined_builder.compute_aux(&inputs); From 6553b5665935a6c0e791ceb6ca41110d2b530163 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Fri, 26 Jul 2024 15:33:09 -0400 Subject: [PATCH 13/17] Remove unnecessary packing aux variables and constraints --- jolt-core/src/r1cs/builder.rs | 118 +++++-------------------- jolt-core/src/r1cs/jolt_constraints.rs | 35 +++++--- 2 files changed, 43 insertions(+), 110 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 68fc25b5b..bbc9e8dd2 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -1,5 +1,5 @@ use crate::{ - field::{JoltField, OptimizedMul}, + field::JoltField, r1cs::key::{SparseConstraints, UniformR1CS}, utils::{ math::Math, @@ -329,6 +329,25 @@ impl R1CSBuilder { self.allocate_aux(symbolic_inputs, compute) } + pub fn pack_le(unpacked: Vec>, operand_bits: usize) -> LC { + let packed: Vec> = unpacked + .into_iter() + .enumerate() + .map(|(idx, unpacked)| Term(unpacked, 1 << (idx * operand_bits))) + .collect(); + packed.into() + } + + pub fn pack_be(unpacked: Vec>, operand_bits: usize) -> LC { + let packed: Vec> = unpacked + .into_iter() + .rev() + .enumerate() + .map(|(idx, unpacked)| Term(unpacked, 1 << (idx * operand_bits))) + .collect(); + packed.into() + } + pub fn constrain_pack_le( &mut self, unpacked: Vec>, @@ -345,33 +364,6 @@ impl R1CSBuilder { self.constrain_eq(packed, result); } - #[must_use] - pub fn allocate_pack_le( - &mut self, - unpacked: Vec>, - operand_bits: usize, - ) -> Variable { - let packed = self.aux_pack_le(&unpacked, operand_bits); - - self.constrain_pack_le(unpacked, packed, operand_bits); - packed - } - - fn aux_pack_le(&mut self, to_pack: &[Variable], operand_bits: usize) -> Variable { - let pack = move |values: &[F]| -> F { - values - .iter() - .enumerate() - .fold(F::zero(), |acc, (idx, &value)| { - acc + value.mul_01_optimized(F::from_u64(1 << (idx * operand_bits)).unwrap()) - }) - }; - - let symbolic_inputs = to_pack.iter().cloned().map(|sym| sym.into()).collect(); - let compute = Box::new(pack); - self.allocate_aux(symbolic_inputs, compute) - } - pub fn constrain_pack_be( &mut self, unpacked: Vec>, @@ -390,34 +382,6 @@ impl R1CSBuilder { self.constrain_eq(packed, result); } - #[must_use] - pub fn allocate_pack_be( - &mut self, - unpacked: Vec>, - operand_bits: usize, - ) -> Variable { - let packed = self.aux_pack_be(&unpacked, operand_bits); - - self.constrain_pack_be(unpacked, packed, operand_bits); - packed - } - - fn aux_pack_be(&mut self, to_pack: &[Variable], operand_bits: usize) -> Variable { - let pack = move |values: &[F]| -> F { - values - .iter() - .rev() - .enumerate() - .fold(F::zero(), |acc, (idx, &value)| { - acc + value.mul_01_optimized(F::from_u64(1 << (idx * operand_bits)).unwrap()) - }) - }; - - let symbolic_inputs = to_pack.iter().cloned().map(|sym| sym.into()).collect(); - let compute = Box::new(pack); - self.allocate_aux(symbolic_inputs, compute) - } - /// Constrain x * y == z pub fn constrain_prod( &mut self, @@ -1143,48 +1107,6 @@ mod tests { assert!(constraint.is_sat(&z)); } - #[test] - fn alloc_packing_le_builder() { - let mut builder = R1CSBuilder::::new(); - - // pack_le(OpFlags0, OpFlags1, OpFlags2, OpFlags3) == Aux(0) - struct TestConstraints(); - impl R1CSConstraintBuilder for TestConstraints { - type Inputs = TestInputs; - fn build_constraints(&self, builder: &mut R1CSBuilder) { - let unpacked: Vec> = vec![ - TestInputs::OpFlags0.into(), - TestInputs::OpFlags1.into(), - TestInputs::OpFlags2.into(), - TestInputs::OpFlags3.into(), - ]; - let _result = builder.allocate_pack_le(unpacked, 1); - } - } - - let concrete_constraints = TestConstraints(); - concrete_constraints.build_constraints(&mut builder); - assert_eq!(builder.constraints.len(), 1); - let constraint = &builder.constraints[0]; - - // 1101 == 13 - let mut z = vec![0i64; TestInputs::COUNT + 1]; - // (little endian) - z[TestInputs::OpFlags0 as usize] = 1; - z[TestInputs::OpFlags1 as usize] = 0; - z[TestInputs::OpFlags2 as usize] = 1; - z[TestInputs::OpFlags3 as usize] = 1; - - assert_eq!(builder.aux_computations.len(), 1); - let computed_aux = aux_compute_single( - &builder.aux_computations[0], - &[Fr::one(), Fr::zero(), Fr::one(), Fr::one()], - ); - assert_eq!(computed_aux, Fr::from(13)); - z[builder.witness_index(Variable::Auxiliary(0))] = 13; - assert!(constraint.is_sat(&z)); - } - #[test] fn packing_be_builder() { let mut builder = R1CSBuilder::::new(); diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index 036be9a82..5fb0d8711 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -158,7 +158,7 @@ pub const PC_START_ADDRESS: i64 = 0x80000000; const PC_NOOP_SHIFT: i64 = 4; const LOG_M: usize = 16; const OPERAND_SIZE: usize = LOG_M / 2; -pub const NEXT_PC: usize = 16; +pub const NEXT_PC: usize = 12; pub struct UniformJoltConstraints { memory_start: u64, @@ -225,36 +225,47 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { ); let ram_writes = input_range!(JoltIn::RAM_Write_Byte0, JoltIn::RAM_Write_Byte3); - let packed_load_store = cs.allocate_pack_le(ram_writes.to_vec(), 8); + let packed_load_store = R1CSBuilder::::pack_le(ram_writes.to_vec(), 8); cs.constrain_eq_conditional( JoltIn::OpFlags_IsStore, - packed_load_store, + packed_load_store.clone(), JoltIn::LookupOutput, ); - let packed_query = cs.allocate_pack_be( + let packed_query = R1CSBuilder::::pack_be( input_range!(JoltIn::ChunksQ_0, JoltIn::ChunksQ_3).to_vec(), LOG_M, ); - cs.constrain_eq_conditional(JoltIn::IF_Add, packed_query, x + y); + cs.constrain_eq_conditional(JoltIn::IF_Add, packed_query.clone(), x + y); // Converts from unsigned to twos-complement representation - cs.constrain_eq_conditional(JoltIn::IF_Sub, packed_query, x - y + (0xffffffffi64 + 1)); + cs.constrain_eq_conditional( + JoltIn::IF_Sub, + packed_query.clone(), + x - y + (0xffffffffi64 + 1), + ); let is_mul = JoltIn::IF_Mul + JoltIn::IF_MulU + JoltIn::IF_MulHu; let product = cs.allocate_prod(x, y); - cs.constrain_eq_conditional(is_mul, packed_query, product); - cs.constrain_eq_conditional(JoltIn::IF_Movsign + JoltIn::IF_Virt_Move, packed_query, x); - cs.constrain_eq_conditional(JoltIn::OpFlags_IsLoad, packed_query, packed_load_store); + cs.constrain_eq_conditional(is_mul, packed_query.clone(), product); + cs.constrain_eq_conditional( + JoltIn::IF_Movsign + JoltIn::IF_Virt_Move, + packed_query.clone(), + x, + ); + cs.constrain_eq_conditional( + JoltIn::OpFlags_IsLoad, + packed_query.clone(), + packed_load_store, + ); cs.constrain_eq_conditional(JoltIn::OpFlags_IsStore, packed_query, JoltIn::RS2_Read); cs.constrain_eq_conditional(JoltIn::OpFlags_IsAssert, JoltIn::LookupOutput, 1); - // TODO(sragss): Uses 2 excess constraints for condition gating. Could make constrain_pack_be_conditional... Or make everything conditional... - let chunked_x = cs.allocate_pack_be( + let chunked_x = R1CSBuilder::::pack_be( input_range!(JoltIn::ChunksX_0, JoltIn::ChunksX_3).to_vec(), OPERAND_SIZE, ); - let chunked_y = cs.allocate_pack_be( + let chunked_y = R1CSBuilder::::pack_be( input_range!(JoltIn::ChunksY_0, JoltIn::ChunksY_3).to_vec(), OPERAND_SIZE, ); From 438d3b172bb367409bc1e01976fe066990caadcf Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Fri, 26 Jul 2024 16:07:47 -0400 Subject: [PATCH 14/17] Refactor virtual_sequence and virtual_trace --- jolt-core/src/jolt/instruction/div.rs | 203 +++++++++++------------ jolt-core/src/jolt/instruction/divu.rs | 203 +++++++++++------------ jolt-core/src/jolt/instruction/mod.rs | 19 ++- jolt-core/src/jolt/instruction/mulh.rs | 152 ++++++++--------- jolt-core/src/jolt/instruction/mulhsu.rs | 94 +++++------ jolt-core/src/jolt/instruction/rem.rs | 183 ++++++++++---------- jolt-core/src/jolt/instruction/remu.rs | 183 ++++++++++---------- 7 files changed, 489 insertions(+), 548 deletions(-) diff --git a/jolt-core/src/jolt/instruction/div.rs b/jolt-core/src/jolt/instruction/div.rs index f773c854b..11b372e3d 100644 --- a/jolt-core/src/jolt/instruction/div.rs +++ b/jolt-core/src/jolt/instruction/div.rs @@ -13,130 +13,63 @@ pub struct DIVInstruction; impl VirtualInstructionSequence for DIVInstruction { const SEQUENCE_LENGTH: usize = 8; - fn virtual_sequence(instruction: ELFInstruction) -> Vec { - assert_eq!(instruction.opcode, RV32IM::DIV); + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::DIV); // DIV source registers - let r_x = instruction.rs1; - let r_y = instruction.rs2; + let r_x = trace_row.instruction.rs1; + let r_y = trace_row.instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_q: Option = Some(virtual_register_index(1)); let v_r: Option = Some(virtual_register_index(2)); let v_qy = Some(virtual_register_index(3)); - - let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); - - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_q, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_r, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER, - rs1: v_r, - rs2: r_y, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_DIV0, - rs1: r_y, - rs2: v_q, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::MUL, - rs1: v_q, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: v_r, - rd: v_0, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_EQ, - rs1: v_0, - rs2: r_x, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_MOVE, - rs1: v_q, - rs2: None, - rd: instruction.rd, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - - debug_assert_eq!(virtual_sequence.len(), Self::SEQUENCE_LENGTH); - virtual_sequence - } - - fn virtual_trace(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::DIV); // DIV operands let x = trace_row.register_state.rs1_val.unwrap(); let y = trace_row.register_state.rs2_val.unwrap(); - let virtual_instructions = Self::virtual_sequence(trace_row.instruction); let mut virtual_trace = vec![]; let (quotient, remainder) = match WORD_SIZE { 32 => { - let mut quotient = x as i32 / y as i32; - let mut remainder = x as i32 % y as i32; - if (remainder < 0 && (y as i32) > 0) || (remainder > 0 && (y as i32) < 0) { - remainder += y as i32; - quotient -= 1; + if y == 0 { + (u32::MAX as u64, x) + } else { + let mut quotient = x as i32 / y as i32; + let mut remainder = x as i32 % y as i32; + if (remainder < 0 && (y as i32) > 0) || (remainder > 0 && (y as i32) < 0) { + remainder += y as i32; + quotient -= 1; + } + (quotient as u32 as u64, remainder as u32 as u64) } - (quotient as u32 as u64, remainder as u32 as u64) } 64 => { - let mut quotient = x as i64 / y as i64; - let mut remainder = x as i64 % y as i64; - if (remainder < 0 && (y as i64) > 0) || (remainder > 0 && (y as i64) < 0) { - remainder += y as i64; - quotient -= 1; + if y == 0 { + (u64::MAX, x) + } else { + let mut quotient = x as i64 / y as i64; + let mut remainder = x as i64 % y as i64; + if (remainder < 0 && (y as i64) > 0) || (remainder > 0 && (y as i64) < 0) { + remainder += y as i64; + quotient -= 1; + } + (quotient as u64, remainder as u64) } - (quotient as u64, remainder as u64) } _ => panic!("Unsupported WORD_SIZE: {}", WORD_SIZE), }; let q = ADVICEInstruction::(quotient).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_q, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -148,7 +81,15 @@ impl VirtualInstructionSequence for DIVInstruction(remainder).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_r, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -161,7 +102,15 @@ impl VirtualInstructionSequence for DIVInstruction(r, y).lookup_entry(); assert_eq!(is_valid, 1); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER, + rs1: v_r, + rs2: r_y, + rd: None, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(r), rs2_val: Some(y), @@ -174,7 +123,15 @@ impl VirtualInstructionSequence for DIVInstruction(y, q).lookup_entry(); assert_eq!(is_valid, 1); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_VALID_DIV0, + rs1: r_y, + rs2: v_q, + rd: None, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(y), rs2_val: Some(q), @@ -186,7 +143,15 @@ impl VirtualInstructionSequence for DIVInstruction(q, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MUL, + rs1: v_q, + rs2: r_y, + rd: v_qy, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -198,7 +163,15 @@ impl VirtualInstructionSequence for DIVInstruction(q_y, r).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::ADD, + rs1: v_qy, + rs2: v_r, + rd: v_0, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -210,7 +183,15 @@ impl VirtualInstructionSequence for DIVInstruction VirtualInstructionSequence for DIVInstruction; impl VirtualInstructionSequence for DIVUInstruction { const SEQUENCE_LENGTH: usize = 9; - fn virtual_sequence(instruction: ELFInstruction) -> Vec { - assert_eq!(instruction.opcode, RV32IM::DIVU); + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::DIVU); // DIVU source registers - let r_x = instruction.rs1; - let r_y = instruction.rs2; + let r_x = trace_row.instruction.rs1; + let r_y = trace_row.instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_q = Some(virtual_register_index(1)); - let v_r: Option = Some(virtual_register_index(2)); + let v_r = Some(virtual_register_index(2)); let v_qy = Some(virtual_register_index(3)); - - let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_q, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_r, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::MULU, - rs1: v_q, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER, - rs1: v_r, - rs2: r_y, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_LTE, - rs1: v_qy, - rs2: r_x, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_DIV0, - rs1: r_y, - rs2: v_q, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: v_r, - rd: v_0, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_EQ, - rs1: v_0, - rs2: r_x, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_MOVE, - rs1: v_q, - rs2: None, - rd: instruction.rd, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - - debug_assert_eq!(virtual_sequence.len(), Self::SEQUENCE_LENGTH); - virtual_sequence - } - - fn virtual_trace(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::DIVU); // DIVU operands let x = trace_row.register_state.rs1_val.unwrap(); let y = trace_row.register_state.rs2_val.unwrap(); - let virtual_instructions = Self::virtual_sequence(trace_row.instruction); let mut virtual_trace = vec![]; - let quotient = x / y; - let remainder = x - quotient * y; + let quotient = if y == 0 { + match WORD_SIZE { + 32 => u32::MAX as u64, + 64 => u64::MAX, + _ => panic!("Unsupported WORD_SIZE: {}", WORD_SIZE), + } + } else { + x / y + }; + let remainder = if y == 0 { x } else { x - quotient * y }; let q = ADVICEInstruction::(quotient).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_q, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -139,7 +64,15 @@ impl VirtualInstructionSequence for DIVUInstruction(remainder).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_r, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -151,7 +84,15 @@ impl VirtualInstructionSequence for DIVUInstruction(q, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MULU, + rs1: v_q, + rs2: r_y, + rd: v_qy, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -164,7 +105,15 @@ impl VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction(y, q).lookup_entry(); assert_eq!(is_valid, 1); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_VALID_DIV0, + rs1: r_y, + rs2: v_q, + rd: None, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(y), rs2_val: Some(q), @@ -202,7 +167,15 @@ impl VirtualInstructionSequence for DIVUInstruction(q_y, r).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::ADD, + rs1: v_qy, + rs2: v_r, + rd: v_0, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -214,7 +187,15 @@ impl VirtualInstructionSequence for DIVUInstruction VirtualInstructionSequence for DIVUInstruction> for SubtableIndices { pub trait VirtualInstructionSequence { const SEQUENCE_LENGTH: usize; - fn virtual_sequence(instruction: ELFInstruction) -> Vec; + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + let dummy_trace_row = RVTraceRow { + instruction, + register_state: RegisterState { + rs1_val: Some(0), + rs2_val: Some(0), + rd_post_val: Some(0), + }, + memory_state: None, + advice_value: None, + }; + Self::virtual_trace(dummy_trace_row) + .into_iter() + .map(|trace_row| trace_row.instruction) + .collect() + } fn virtual_trace(trace_row: RVTraceRow) -> Vec; } diff --git a/jolt-core/src/jolt/instruction/mulh.rs b/jolt-core/src/jolt/instruction/mulh.rs index 457dd69d3..bd3bcc0f2 100644 --- a/jolt-core/src/jolt/instruction/mulh.rs +++ b/jolt-core/src/jolt/instruction/mulh.rs @@ -12,11 +12,11 @@ pub struct MULHInstruction; impl VirtualInstructionSequence for MULHInstruction { const SEQUENCE_LENGTH: usize = 7; - fn virtual_sequence(instruction: ELFInstruction) -> Vec { - assert_eq!(instruction.opcode, RV32IM::MULH); + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::MULH); // MULH source registers - let r_x = instruction.rs1; - let r_y = instruction.rs2; + let r_x = trace_row.instruction.rs1; + let r_y = trace_row.instruction.rs2; // Virtual registers used in sequence let v_sx = Some(virtual_register_index(0)); let v_sy = Some(virtual_register_index(1)); @@ -24,89 +24,23 @@ impl VirtualInstructionSequence for MULHInstruction Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::MULH); // MULH operands let x = trace_row.register_state.rs1_val.unwrap(); let y = trace_row.register_state.rs2_val.unwrap(); - let virtual_instructions = Self::virtual_sequence(trace_row.instruction); let mut virtual_trace = vec![]; let s_x = MOVSIGNInstruction::(x).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_MOVSIGN, + rs1: r_x, + rs2: None, + rd: v_sx, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(x), rs2_val: None, @@ -118,7 +52,15 @@ impl VirtualInstructionSequence for MULHInstruction(y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_MOVSIGN, + rs1: r_y, + rs2: None, + rd: v_sy, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(y), rs2_val: None, @@ -130,7 +72,15 @@ impl VirtualInstructionSequence for MULHInstruction(x, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MULHU, + rs1: r_x, + rs2: r_y, + rd: v_0, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(x), rs2_val: Some(y), @@ -142,7 +92,15 @@ impl VirtualInstructionSequence for MULHInstruction(s_x, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MULU, + rs1: v_sx, + rs2: r_y, + rd: v_1, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(s_x), rs2_val: Some(y), @@ -154,7 +112,15 @@ impl VirtualInstructionSequence for MULHInstruction(s_y, x).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MULU, + rs1: v_sy, + rs2: r_x, + rd: v_2, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(s_y), rs2_val: Some(x), @@ -166,7 +132,15 @@ impl VirtualInstructionSequence for MULHInstruction(xy_high_bits, sx_y_low_bits).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::ADD, + rs1: v_0, + rs2: v_1, + rd: v_3, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(xy_high_bits), rs2_val: Some(sx_y_low_bits), @@ -178,7 +152,15 @@ impl VirtualInstructionSequence for MULHInstruction(partial_sum, sy_x_low_bits).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::ADD, + rs1: v_3, + rs2: v_2, + rd: trace_row.instruction.rd, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(partial_sum), rs2_val: Some(sy_x_low_bits), diff --git a/jolt-core/src/jolt/instruction/mulhsu.rs b/jolt-core/src/jolt/instruction/mulhsu.rs index 9565e20e1..b892f2747 100644 --- a/jolt-core/src/jolt/instruction/mulhsu.rs +++ b/jolt-core/src/jolt/instruction/mulhsu.rs @@ -13,70 +13,32 @@ pub struct MULHSUInstruction; impl VirtualInstructionSequence for MULHSUInstruction { const SEQUENCE_LENGTH: usize = 4; - fn virtual_sequence(instruction: ELFInstruction) -> Vec { - assert_eq!(instruction.opcode, RV32IM::MULHSU); + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::MULHSU); // MULHSU source registers - let r_x = instruction.rs1; - let r_y = instruction.rs2; + let r_x = trace_row.instruction.rs1; + let r_y = trace_row.instruction.rs2; // Virtual registers used in sequence let v_sx = Some(virtual_register_index(0)); let v_1 = Some(virtual_register_index(1)); let v_2 = Some(virtual_register_index(2)); - - let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_MOVSIGN, - rs1: r_x, - rs2: None, - rd: v_sx, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::MULHU, - rs1: r_x, - rs2: r_y, - rd: v_1, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::MULU, - rs1: v_sx, - rs2: r_y, - rd: v_2, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::ADD, - rs1: v_1, - rs2: v_2, - rd: instruction.rd, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - - debug_assert_eq!(virtual_sequence.len(), Self::SEQUENCE_LENGTH); - virtual_sequence - } - - fn virtual_trace(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::MULHSU); // MULHSU operands let x = trace_row.register_state.rs1_val.unwrap(); let y = trace_row.register_state.rs2_val.unwrap(); - let virtual_instructions = Self::virtual_sequence(trace_row.instruction); let mut virtual_trace = vec![]; let s_x = MOVSIGNInstruction::(x).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_MOVSIGN, + rs1: r_x, + rs2: None, + rd: v_sx, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(x), rs2_val: None, @@ -88,7 +50,15 @@ impl VirtualInstructionSequence for MULHSUInstruction(x, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MULHU, + rs1: r_x, + rs2: r_y, + rd: v_1, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(x), rs2_val: Some(y), @@ -100,7 +70,15 @@ impl VirtualInstructionSequence for MULHSUInstruction(s_x, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MULU, + rs1: v_sx, + rs2: r_y, + rd: v_2, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(s_x), rs2_val: Some(y), @@ -112,7 +90,15 @@ impl VirtualInstructionSequence for MULHSUInstruction(xy_high_bits, sx_y_low_bits).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::ADD, + rs1: v_1, + rs2: v_2, + rd: trace_row.instruction.rd, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(xy_high_bits), rs2_val: Some(sx_y_low_bits), diff --git a/jolt-core/src/jolt/instruction/rem.rs b/jolt-core/src/jolt/instruction/rem.rs index 1a423ec7a..92d1bb03d 100644 --- a/jolt-core/src/jolt/instruction/rem.rs +++ b/jolt-core/src/jolt/instruction/rem.rs @@ -14,120 +14,63 @@ pub struct REMInstruction; impl VirtualInstructionSequence for REMInstruction { const SEQUENCE_LENGTH: usize = 7; - fn virtual_sequence(instruction: ELFInstruction) -> Vec { - assert_eq!(instruction.opcode, RV32IM::REM); + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::REM); // REM source registers - let r_x = instruction.rs1; - let r_y = instruction.rs2; + let r_x = trace_row.instruction.rs1; + let r_y = trace_row.instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_q = Some(virtual_register_index(1)); let v_r = Some(virtual_register_index(2)); let v_qy = Some(virtual_register_index(3)); - - let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_q, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_r, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER, - rs1: v_r, - rs2: r_y, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::MUL, - rs1: v_q, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: v_r, - rd: v_0, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_EQ, - rs1: v_0, - rs2: r_x, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_MOVE, - rs1: v_r, - rs2: None, - rd: instruction.rd, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - - debug_assert_eq!(virtual_sequence.len(), Self::SEQUENCE_LENGTH); - virtual_sequence - } - - fn virtual_trace(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::REM); // REM operands let x = trace_row.register_state.rs1_val.unwrap(); let y = trace_row.register_state.rs2_val.unwrap(); - let virtual_instructions = Self::virtual_sequence(trace_row.instruction); let mut virtual_trace = vec![]; let (quotient, remainder) = match WORD_SIZE { 32 => { - let mut quotient = x as i32 / y as i32; - let mut remainder = x as i32 % y as i32; - if (remainder < 0 && (y as i32) > 0) || (remainder > 0 && (y as i32) < 0) { - remainder += y as i32; - quotient -= 1; + if y == 0 { + (u32::MAX as u64, x) + } else { + let mut quotient = x as i32 / y as i32; + let mut remainder = x as i32 % y as i32; + if (remainder < 0 && (y as i32) > 0) || (remainder > 0 && (y as i32) < 0) { + remainder += y as i32; + quotient -= 1; + } + (quotient as u32 as u64, remainder as u32 as u64) } - (quotient as u32 as u64, remainder as u32 as u64) } 64 => { - let mut quotient = x as i64 / y as i64; - let mut remainder = x as i64 % y as i64; - if (remainder < 0 && (y as i64) > 0) || (remainder > 0 && (y as i64) < 0) { - remainder += y as i64; - quotient -= 1; + if y == 0 { + (u64::MAX, x) + } else { + let mut quotient = x as i64 / y as i64; + let mut remainder = x as i64 % y as i64; + if (remainder < 0 && (y as i64) > 0) || (remainder > 0 && (y as i64) < 0) { + remainder += y as i64; + quotient -= 1; + } + (quotient as u64, remainder as u64) } - (quotient as u64, remainder as u64) } _ => panic!("Unsupported WORD_SIZE: {}", WORD_SIZE), }; let q = ADVICEInstruction::(quotient).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_q, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -139,7 +82,15 @@ impl VirtualInstructionSequence for REMInstruction(remainder).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_r, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -152,7 +103,15 @@ impl VirtualInstructionSequence for REMInstruction(r, y).lookup_entry(); assert_eq!(is_valid, 1); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER, + rs1: v_r, + rs2: r_y, + rd: None, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(r), rs2_val: Some(y), @@ -164,7 +123,15 @@ impl VirtualInstructionSequence for REMInstruction(q, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MUL, + rs1: v_q, + rs2: r_y, + rd: v_qy, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -176,7 +143,15 @@ impl VirtualInstructionSequence for REMInstruction(q_y, r).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::ADD, + rs1: v_qy, + rs2: v_r, + rd: v_0, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -188,7 +163,15 @@ impl VirtualInstructionSequence for REMInstruction VirtualInstructionSequence for REMInstruction; impl VirtualInstructionSequence for REMUInstruction { const SEQUENCE_LENGTH: usize = 8; - fn virtual_sequence(instruction: ELFInstruction) -> Vec { - assert_eq!(instruction.opcode, RV32IM::REMU); + fn virtual_trace(trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::REMU); // REMU source registers - let r_x = instruction.rs1; - let r_y = instruction.rs2; + let r_x = trace_row.instruction.rs1; + let r_y = trace_row.instruction.rs2; // Virtual registers used in sequence let v_0 = Some(virtual_register_index(0)); let v_q = Some(virtual_register_index(1)); let v_r = Some(virtual_register_index(2)); let v_qy = Some(virtual_register_index(3)); - - let mut virtual_sequence = Vec::with_capacity(Self::SEQUENCE_LENGTH); - - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_q, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ADVICE, - rs1: None, - rs2: None, - rd: v_r, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::MULU, - rs1: v_q, - rs2: r_y, - rd: v_qy, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER, - rs1: v_r, - rs2: r_y, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_LTE, - rs1: v_qy, - rs2: r_x, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::ADD, - rs1: v_qy, - rs2: v_r, - rd: v_0, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_ASSERT_EQ, - rs1: v_0, - rs2: r_x, - rd: None, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - virtual_sequence.push(ELFInstruction { - address: instruction.address, - opcode: RV32IM::VIRTUAL_MOVE, - rs1: v_r, - rs2: None, - rd: instruction.rd, - imm: None, - virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_sequence.len() - 1), - }); - - debug_assert_eq!(virtual_sequence.len(), Self::SEQUENCE_LENGTH); - virtual_sequence - } - - fn virtual_trace(trace_row: RVTraceRow) -> Vec { - assert_eq!(trace_row.instruction.opcode, RV32IM::REMU); // REMU operands let x = trace_row.register_state.rs1_val.unwrap(); let y = trace_row.register_state.rs2_val.unwrap(); - let virtual_instructions = Self::virtual_sequence(trace_row.instruction); let mut virtual_trace = vec![]; - let quotient = x / y; - let remainder = x - quotient * y; + let quotient = if y == 0 { + match WORD_SIZE { + 32 => u32::MAX as u64, + 64 => u64::MAX, + _ => panic!("Unsupported WORD_SIZE: {}", WORD_SIZE), + } + } else { + x / y + }; + let remainder = if y == 0 { x } else { x - quotient * y }; let q = ADVICEInstruction::(quotient).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_q, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -131,7 +64,15 @@ impl VirtualInstructionSequence for REMUInstruction(remainder).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ADVICE, + rs1: None, + rs2: None, + rd: v_r, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: None, rs2_val: None, @@ -143,7 +84,15 @@ impl VirtualInstructionSequence for REMUInstruction(q, y).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::MULU, + rs1: v_q, + rs2: r_y, + rd: v_qy, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q), rs2_val: Some(y), @@ -156,7 +105,15 @@ impl VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction(q_y, r).lookup_entry(); virtual_trace.push(RVTraceRow { - instruction: virtual_instructions[virtual_trace.len()].clone(), + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::ADD, + rs1: v_qy, + rs2: v_r, + rd: v_0, + imm: None, + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, register_state: RegisterState { rs1_val: Some(q_y), rs2_val: Some(r), @@ -192,7 +165,15 @@ impl VirtualInstructionSequence for REMUInstruction VirtualInstructionSequence for REMUInstruction Date: Thu, 1 Aug 2024 14:24:42 -0400 Subject: [PATCH 15/17] Add M extension page --- book/src/SUMMARY.md | 3 +- book/src/how/{jolt.md => architecture.md} | 8 +- book/src/how/m-extension.md | 117 ++++++++++++++++++++++ book/src/how_it_works.md | 2 +- 4 files changed, 124 insertions(+), 6 deletions(-) rename book/src/how/{jolt.md => architecture.md} (95%) create mode 100644 book/src/how/m-extension.md diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 6dc8b8f27..df7979bbc 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -13,11 +13,12 @@ - [Troubleshooting](./usage/troubleshooting.md) - [Contributors](./contributors.md) - [How it works](./how_it_works.md) - - [Jolt](./how/jolt.md) + - [Architecture overview](./how/architecture.md) - [Instruction lookups](./how/instruction_lookups.md) - [Read-write memory](./how/read_write_memory.md) - [Bytecode](./how/bytecode.md) - [R1CS constraints](./how/r1cs_constraints.md) + - [M extension](./how/m-extension.md) - [Background](./background.md) - [Sumcheck](./background/sumcheck.md) - [Multilinear Extensions](./background/multilinear-extensions.md) diff --git a/book/src/how/jolt.md b/book/src/how/architecture.md similarity index 95% rename from book/src/how/jolt.md rename to book/src/how/architecture.md index 30f5a1d46..81cb3061e 100644 --- a/book/src/how/jolt.md +++ b/book/src/how/architecture.md @@ -1,10 +1,10 @@ -# Jolt +# Architecture overview This section gives an overview of the core components of Jolt. ## Jolt's four components -A VM does two things: +A VM does two things: - Repeatedly execute the fetch-decode-execute logic of its instruction set architecture. - Perform reads and writes to Random Access Memory (RAM). @@ -25,13 +25,13 @@ To handle reads/writes to RAM (and registers) Jolt uses a memory checking argume ### R1CS constraints -To handle the "fetch" part of the fetch-decode-execute loop, there is a minimal R1CS instance (about 60 constraints per cycle of the RISC-V VM). These constraints handle program counter (PC) updates and serves as the "glue" enforcing consistency between polynomials used in the components below. Jolt uses [Spartan](https://eprint.iacr.org/2019/550), optimized for the highly-structured nature of the constraint system (e.g., the R1CS constraint matrices are block-diagonal with blocks of size only about 60 x 80). This is implemented in [jolt-core/src/r1cs](../../../jolt-core/src/r1cs/). +To handle the "fetch" part of the fetch-decode-execute loop, there is a minimal R1CS instance (about 60 constraints per cycle of the RISC-V VM). These constraints handle program counter (PC) updates and serves as the "glue" enforcing consistency between polynomials used in the components below. Jolt uses [Spartan](https://eprint.iacr.org/2019/550), optimized for the highly-structured nature of the constraint system (e.g., the R1CS constraint matrices are block-diagonal with blocks of size only about 60 x 80). This is implemented in [jolt-core/src/r1cs](../../../jolt-core/src/r1cs/). *For more details: [R1CS constraints](./r1cs_constraints.md)* ### Instruction lookups -To handle the "execute" part of the fetch-decode-execute loop, Jolt invokes the Lasso lookup argument. The lookup argument maps every instruction (including its operands) to its output. This is implemented in [instruction_lookups.rs](https://github.com/a16z/jolt/blob/main/jolt-core/src/jolt/vm/instruction_lookups.rs). +To handle the "execute" part of the fetch-decode-execute loop, Jolt invokes the Lasso lookup argument. The lookup argument maps every instruction (including its operands) to its output. This is implemented in [instruction_lookups.rs](https://github.com/a16z/jolt/blob/main/jolt-core/src/jolt/vm/instruction_lookups.rs). *For more details: [Instruction lookups](./instruction_lookups.md)* diff --git a/book/src/how/m-extension.md b/book/src/how/m-extension.md new file mode 100644 index 000000000..4ffd59caf --- /dev/null +++ b/book/src/how/m-extension.md @@ -0,0 +1,117 @@ +# M extension + +Jolt supports the RISC-V "M" extension for integer multiplication and division. +The instructions included in this extension are described [here](https://msyksphinz-self.github.io/riscv-isadoc/html/rvm.html). +For RV32, the M extension includes 8 instructions: `MUL`, `MULH`, `MULHSU`, `MULU`, `DIV`, `DIVU`, `REM`, and `REMU`. + +The [Jolt paper](https://eprint.iacr.org/2023/1217.pdf) describes how to handle the M extension instructions in Section 6, +but our implementation deviates from the paper in a couple ways (described below). + +## Virtual sequences + +Section 6.1 of the Jolt paper introduces virtual instructions and registers –– some of the M extension +instructions cannot be implemented as a single subtable decomposition, but rather must be split into +a sequence of instructions which together compute the output and places it in the destination register. +In our implementation, these sequences are captured by the `VirtualInstructionSequence` trait. + +The instructions that comprise such a sequence can be a combination of "real" RISC-V instructions and "virtual" +instructions which only appear in the context of virtual sequences. +We also introduce 32 virtual registers as "scratch space" where instructions in a virtual sequence +can write intermediate values. + +## Deviations from the Jolt paper + +There are three inconsistencies between the virtual sequences provided in Section 6.3 +of the Jolt paper, and the RISC-V specification. Namely: + +1. The Jolt prover (as described in the paper) would fail to produce a valid proof +if it encountered a division by zero; since the divisor `y` is 0, the `ASSERT_LTU`/`ASSERT_LT_ABS` would +always fail (for `DIVU` and `DIV`, respectively). +1. The MLE provided for `ASSERT_LT_ABS` in Section 6.1.1 doesn't account for two's complement. +1. The `ASSERT_EQ_SIGNS` instruction should always return true if the remainder is 0. + +To address these issues, our implementation of `DIVU`, `DIV`, `REMU`, and `REM` deviate from the +Jolt paper in the following ways. + +### `DIVU` virtual sequence + +1. `ADVICE` --, --, --, $v_q$ `// store non-deterministic advice` $q$ `into `$v_q$ +1. `ADVICE` --, --, --, $v_r$ `// store non-deterministic advice` $r$ `into `$v_r$ +1. `MUL` $v_q$, $r_y$, --, $v_{qy}$ `// compute q * y` +1. `ASSERT_VALID_UNSIGNED_REMAINDER` $v_r$, $r_y$, --, -- `// assert that y == 0 || r < y` +1. `ASSERT_LTE` $v_{qy}$, $r_x$, --, -- `// assert q * y <= x` +1. `ASSERT_VALID_DIV0` $r_y$, $v_q$, --, -- `// assert that y != 0 || q == 2 ** WORD_SIZE - 1` +1. `ADD` $v_{qy}$, $v_r$, --, $v_0$ `// compute q * y + r` +1. `ASSERT_EQ` $v_0$, $x$, --, -- +1. `MOVE` $v_q$, --, --, `rd` + +### `REMU` virtual sequence + +1. `ADVICE` --, --, --, $v_q$ `// store non-deterministic advice` $q$ `into `$v_q$ +1. `ADVICE` --, --, --, $v_r$ `// store non-deterministic advice` $r$ `into `$v_r$ +1. `MUL` $v_q$, $r_y$, --, $v_{qy}$ `// compute q * y` +1. `ASSERT_VALID_UNSIGNED_REMAINDER` $v_r$, $r_y$, --, -- `// assert that y == 0 || r < y` +1. `ASSERT_LTE` $v_{qy}$, $r_x$, --, -- `// assert q * y <= x` +1. `ADD` $v_{qy}$, $v_r$, --, $v_0$ `// compute q * y + r` +1. `ASSERT_EQ` $v_0$, $x$, --, -- +1. `MOVE` $v_r$, --, --, `rd` + +### `DIV` virtual sequence + +1. `ADVICE` --, --, --, $v_q$ `// store non-deterministic advice` $q$ `into `$v_q$ +1. `ADVICE` --, --, --, $v_r$ `// store non-deterministic advice` $r$ `into `$v_r$ +1. `ASSERT_VALID_SIGNED_REMAINDER` $v_r$, $r_y$, --, -- `// assert that r == 0 || y == 0 || (|r| < |y| && sign(r) == sign(y))` +1. `ASSERT_VALID_DIV0` $r_y$, $v_q$, --, -- `// assert that y != 0 || q == 2 ** WORD_SIZE - 1` +1. `MUL` $v_q$, $r_y$, --, $v_{qy}$ `// compute q * y` +1. `ADD` $v_{qy}$, $v_r$, --, $v_0$ `// compute q * y + r` +1. `ASSERT_EQ` $v_0$, $x$, --, -- +1. `MOVE` $v_q$, --, --, `rd` + +### `REM` virtual sequence + +1. `ADVICE` --, --, --, $v_q$ `// store non-deterministic advice` $q$ `into `$v_q$ +1. `ADVICE` --, --, --, $v_r$ `// store non-deterministic advice` $r$ `into `$v_r$ +1. `ASSERT_VALID_SIGNED_REMAINDER` $v_r$, $r_y$, --, -- `// assert that r == 0 || y == 0 || (|r| < |y| && sign(r) == sign(y))` +1. `MUL` $v_q$, $r_y$, --, $v_{qy}$ `// compute q * y` +1. `ADD` $v_{qy}$, $v_r$, --, $v_0$ `// compute q * y + r` +1. `ASSERT_EQ` $v_0$, $x$, --, -- +1. `MOVE` $v_r$, --, --, `rd` + +## R1CS constraints + +### Ciruict flags + +With the M extension we introduce the following circuit flags: + +1. `is_virtual`: Is this instruction part of a virtual sequence? +1. `is_assert`: Is this instruction an `ASSERT_*` instruction? +1. `do_not_update_pc`: If this instruction is virtual and *not the last one in its sequence*, +then we should *not* update the PC. +This is because all instructions in virtual sequences are mapped to the same ELF address. + +### Uniform constraints + +The following constraints are enforced for every step of the execution trace: + +1. If the instruction is a `MUL`, `MULU`, or `MULHU`, the lookup query is the product +of the two operands `x * y` (field multiplication of two 32-bit values). +1. If the instruction is a `MOV` or `MOVSIGN`, the lookup query is a single operand `x` +(read from the first source register `rs1`). +1. If the instruction is an assert, the lookup output must be true. + +### Program counter constraints + +Each instruction in the preprocessed [bytecode](./bytecode.md) contains its (compressed) +memory address as given by the ELF file. +This is used to compute the expected program counter for each step in the program trace. + +If the `do_not_update_pc` flag is set, we constrain the next PC value to be equal to the current one. +This handles the fact that all instructions in virtual sequences are mapped to the same ELF address. + +This also means we need some other mechanism to ensure that virtual sequences are executed in *order* and in *full*. +If the current instruction is virtual, we can constrain the next instruction in the trace to be the +next instruction in the bytecode. +We observe that the virtual sequences used in the M extension don't involve jumps or branches, +so this should always hold, *except* if we encounter a virtual instruction followed by a padding instruction. +But that should never happend because an execution trace should always end with some return handling, +which shouldn't involve a virtual sequence. diff --git a/book/src/how_it_works.md b/book/src/how_it_works.md index 0508d92c6..d083acfae 100644 --- a/book/src/how_it_works.md +++ b/book/src/how_it_works.md @@ -1,5 +1,5 @@ # How it works -- [Jolt](./how/jolt.md) +- [Architecture overview](./how/architecture.md) - [Instruction lookups](./how/instruction_lookups.md) - [Read-write memory](./how/read_write_memory.md) - [R1CS constraints](./how/r1cs_constraints.md) From fad60f6a8c20481b2e83f713b11b8dec272fe314 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Tue, 6 Aug 2024 15:23:54 -0400 Subject: [PATCH 16/17] remove clone --- jolt-core/src/jolt/vm/bytecode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jolt-core/src/jolt/vm/bytecode.rs b/jolt-core/src/jolt/vm/bytecode.rs index 9e24e2071..007581ac0 100644 --- a/jolt-core/src/jolt/vm/bytecode.rs +++ b/jolt-core/src/jolt/vm/bytecode.rs @@ -222,7 +222,7 @@ impl BytecodePreprocessing { let mut rs2 = vec![]; let mut imm = vec![]; - for instruction in bytecode.clone() { + for instruction in bytecode { address.push(F::from_u64(instruction.address as u64).unwrap()); bitflags.push(F::from_u64(instruction.bitflags).unwrap()); rd.push(F::from_u64(instruction.rd).unwrap()); From 660ebd1dcd87274e40be45242a3cdbd4ddc9bb74 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Fri, 9 Aug 2024 12:58:37 -0400 Subject: [PATCH 17/17] Bump toolchain --- guest-toolchain-tag | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guest-toolchain-tag b/guest-toolchain-tag index 5f35dcf5b..954ec1a91 100644 --- a/guest-toolchain-tag +++ b/guest-toolchain-tag @@ -1 +1 @@ -nightly-3cce1fd56f8c0f705f27f5dfb8f777583bc62e20 +nightly-8af9d45d5e09a04832cc9b2e1df993fd1ce49d02