From 3900ca7406a498bb1a25dc31193bc4ef09451172 Mon Sep 17 00:00:00 2001 From: Danno Ferrin Date: Tue, 16 May 2023 11:35:29 -0600 Subject: [PATCH 1/4] Change coordinates from native to nativelib Migrate all maven coordinates that were `.native.` to `.nativelib.`. The old coordinate name presented problems in some auto-module systems for the JPMS. Signed-off-by: Danno Ferrin --- altbn128/build.gradle | 2 +- arithmetic/build.gradle | 2 +- blake2bf/build.gradle | 2 +- bls12-381/build.gradle | 2 +- gnark/build.gradle | 2 +- ipa-multipoint/build.gradle | 2 +- secp256k1/build.gradle | 2 +- secp256r1/build.gradle | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/altbn128/build.gradle b/altbn128/build.gradle index c6256759..452f6ba3 100644 --- a/altbn128/build.gradle +++ b/altbn128/build.gradle @@ -56,7 +56,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.altbn128' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.altbn128' ) } } diff --git a/arithmetic/build.gradle b/arithmetic/build.gradle index 8148ab48..bd7d3bf4 100644 --- a/arithmetic/build.gradle +++ b/arithmetic/build.gradle @@ -63,7 +63,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.arithmetic' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.arithmetic' ) } } diff --git a/blake2bf/build.gradle b/blake2bf/build.gradle index 6973fae5..14795ce9 100644 --- a/blake2bf/build.gradle +++ b/blake2bf/build.gradle @@ -60,7 +60,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.blake2bf' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.blake2bf' ) } } diff --git a/bls12-381/build.gradle b/bls12-381/build.gradle index 58543bc7..49b68016 100644 --- a/bls12-381/build.gradle +++ b/bls12-381/build.gradle @@ -60,7 +60,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.bls12_381' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.bls12_381' ) } } diff --git a/gnark/build.gradle b/gnark/build.gradle index 9afb4a1d..af6a887a 100644 --- a/gnark/build.gradle +++ b/gnark/build.gradle @@ -61,7 +61,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.gnark' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.gnark' ) } } diff --git a/ipa-multipoint/build.gradle b/ipa-multipoint/build.gradle index fd5f1946..18147441 100644 --- a/ipa-multipoint/build.gradle +++ b/ipa-multipoint/build.gradle @@ -60,7 +60,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.ipa.multipoint' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.ipa.multipoint' ) } } diff --git a/secp256k1/build.gradle b/secp256k1/build.gradle index d2d21a39..365df19f 100644 --- a/secp256k1/build.gradle +++ b/secp256k1/build.gradle @@ -56,7 +56,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.secp256k1' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.secp256k1' ) } } diff --git a/secp256r1/build.gradle b/secp256r1/build.gradle index 4de82988..93cf5236 100644 --- a/secp256r1/build.gradle +++ b/secp256r1/build.gradle @@ -65,7 +65,7 @@ jar { 'Specification-Version': project.version, 'Implementation-Title': archiveBaseName, 'Implementation-Version': project.version, - 'Automatic-Module-Name': 'org.hyperledger.besu.native.secp256r1' + 'Automatic-Module-Name': 'org.hyperledger.besu.nativelib.secp256r1' ) } } From 45683a130c1dd8ea6ed0b46072a3d4a9e21ef3ec Mon Sep 17 00:00:00 2001 From: Danno Ferrin Date: Tue, 16 May 2023 11:36:57 -0600 Subject: [PATCH 2/4] bump major version because coordinates are changing Signed-off-by: Danno Ferrin --- gradle.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradle.properties b/gradle.properties index 7dab762c..65a82a47 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1 +1 @@ -version=0.7.2-SNAPSHOT +version=0.8.0-SNAPSHOT From a74afc7396c307c8c1634336face94f61ad9be72 Mon Sep 17 00:00:00 2001 From: Danno Ferrin Date: Wed, 17 May 2023 13:59:21 -0600 Subject: [PATCH 3/4] aurora modexp Signed-off-by: Danno Ferrin --- arithmetic/arithmetic/Cargo.toml | 10 +- arithmetic/arithmetic/README.md | 34 + arithmetic/arithmetic/src/arith.rs | 706 +++++++++++++++++ arithmetic/arithmetic/src/lib.rs | 127 +-- arithmetic/arithmetic/src/mpnat.rs | 729 ++++++++++++++++++ .../nativelib/arithmetic/LibArithmetic.java | 32 +- .../arithmetic/TestLibArithmetic.java | 61 -- 7 files changed, 1529 insertions(+), 170 deletions(-) create mode 100644 arithmetic/arithmetic/README.md create mode 100644 arithmetic/arithmetic/src/arith.rs create mode 100644 arithmetic/arithmetic/src/mpnat.rs diff --git a/arithmetic/arithmetic/Cargo.toml b/arithmetic/arithmetic/Cargo.toml index f8deaf01..0c612e32 100644 --- a/arithmetic/arithmetic/Cargo.toml +++ b/arithmetic/arithmetic/Cargo.toml @@ -1,15 +1,17 @@ [package] name = "besu-native-arithmetic" version = "0.11.0" -description = "Native arithemetic for EVM." +description = """Native arithemetic for EVM. +Derived from aurora - https://github.com/aurora-is-near/aurora-engine/tree/d1af9f8c42ac37d22a770adf43e2b793dd20a345/engine-modexp - originally CC0-1.0 license.""" license = "Apache-2.0" -authors = ["Danno Ferrin "] +authors = ["Aurora Labs ", "Danno Ferrin "] repository = "https://github.com/hyperledger/besu-native" edition = "2021" [dependencies] -num-bigint = "0.4.3" -num-traits = "0.2.15" +ibig = { version = "0.3.6", default-features = false, features = ["num-traits"], optional = true } +num = { version = "0.4.0", default-features = false, features = ["alloc"] } +hex = { version = "0.4", default-features = false, features = ["alloc"] } libc = "0.2" [lib] diff --git a/arithmetic/arithmetic/README.md b/arithmetic/arithmetic/README.md new file mode 100644 index 00000000..5a3fa210 --- /dev/null +++ b/arithmetic/arithmetic/README.md @@ -0,0 +1,34 @@ +# Besu native `modexp` + +Originally from Aurora `modexp` [implementation](https://github.com/aurora-is-near/aurora-engine/tree/d1af9f8c42ac37d22a770adf43e2b793dd20a345/engine-modexp) + +## What this crate is + +This crate is an efficient implementation of the EVM `modexp` precompile. +This crate exposes a single public function + +```rust +pub fn modexp(base: &[u8], exp: &[u8], modulus: &[u8]) -> Vec +``` + +This function takes the base, exponent and modulus as big-endian encoded bytes and returns the result in big-endian as well. + +This crate is meant to be an efficient implementation, using as little memory as possible (for example, it does not copy the exponent slice). +The exponentiation is done using the ["binary method"](https://en.wikipedia.org/wiki/Exponentiation_by_squaring). +The multiplication steps within the exponentiation use ["Montgomery multiplication"](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication). +In the case of even modulus, Montgomery multiplication does not apply directly. +However we can reduce the problem to one involving an odd modulus and one where the modulus is a power of two. +These two sub-problems can be solved efficiently (the former using Montgomery multiplication, the latter the modular arithmetic is trivial on a binary computer), +then the results are combined using the [Chinese remainder theorem](https://en.wikipedia.org/wiki/Chinese_remainder_theorem). + +The primary academic references for this implementation are: + +1. [Analyzing and Comparing Montgomery Multiplication Algorithms](https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/j37acmon.pdf) +2. [Montgomery Reduction with Even Modulus](http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf) +3. [A Cryptographic Library for the Motorola DSP56000](https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf) +4. [The Art of Computer Programming Volume 2](https://www-cs-faculty.stanford.edu/~knuth/taocp.html) + +## What this crate is NOT + +This crate is not a general purpose big integer library. +If you need anything other than `modexp`, then you should use something like [num-bigint](https://crates.io/crates/num-bigint) or [ibig](https://crates.io/crates/ibig). diff --git a/arithmetic/arithmetic/src/arith.rs b/arithmetic/arithmetic/src/arith.rs new file mode 100644 index 00000000..bc841c7f --- /dev/null +++ b/arithmetic/arithmetic/src/arith.rs @@ -0,0 +1,706 @@ +use crate::{ + mpnat::{DoubleWord, MPNat, Word, BASE, WORD_BITS}, +}; +pub use std::{vec, vec::Vec}; + +// Computes the "Montgomery Product" of two numbers. +// See Coarsely Integrated Operand Scanning (CIOS) Method in +// https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/j37acmon.pdf +// In short, computes `xy (r^-1) mod n`, where `r = 2^(8*4*s)` and `s` is the number of +// digits needs to represent `n`. `n_prime` has the property that `r(r^(-1)) - nn' = 1`. +// Note: This algorithm only works if `xy < rn` (generally we will either have both `x < n`, `y < n` +// or we will have `x < r`, `y < n`). +pub fn monpro(x: &MPNat, y: &MPNat, n: &MPNat, n_prime: Word, out: &mut [Word]) { + debug_assert!( + n.is_odd(), + "Montgomery multiplication only makes sense with odd modulus" + ); + debug_assert!( + out.len() >= n.digits.len() + 2, + "Output needs 2 extra words over the size needed to represent n" + ); + let s = out.len() - 2; + // Using a range loop as opposed to `out.iter_mut().enumerate().take(s)` + // does make a meaningful performance difference in this case. + #[allow(clippy::needless_range_loop)] + for i in 0..s { + let mut c = 0; + for j in 0..s { + let (prod, carry) = shifted_carrying_mul( + out[j], + x.digits.get(j).copied().unwrap_or(0), + y.digits.get(i).copied().unwrap_or(0), + c, + ); + out[j] = prod; + c = carry; + } + let (sum, carry) = carrying_add(out[s], c, false); + out[s] = sum; + out[s + 1] = carry as Word; + let m = out[0].wrapping_mul(n_prime); + let (_, carry) = shifted_carrying_mul(out[0], m, n.digits.first().copied().unwrap_or(0), 0); + c = carry; + for j in 1..s { + let (prod, carry) = + shifted_carrying_mul(out[j], m, n.digits.get(j).copied().unwrap_or(0), c); + out[j - 1] = prod; + c = carry; + } + let (sum, carry) = carrying_add(out[s], c, false); + out[s - 1] = sum; + out[s] = out[s + 1] + (carry as Word); // overflow impossible at this stage + } + // Result is only in the first s + 1 words of the output. + out[s + 1] = 0; + + // Check if we need to do the final subtraction + for i in (0..=s).rev() { + match out[i].cmp(n.digits.get(i).unwrap_or(&0)) { + core::cmp::Ordering::Less => return, // No subtraction needed + core::cmp::Ordering::Greater => break, + core::cmp::Ordering::Equal => (), + } + } + + let mut b = false; + for (i, out_digit) in out.iter_mut().enumerate().take(s) { + let (diff, borrow) = borrowing_sub(*out_digit, n.digits.get(i).copied().unwrap_or(0), b); + *out_digit = diff; + b = borrow; + } + let (diff, borrow) = borrowing_sub(out[s], 0, b); + out[s] = diff; + + debug_assert!(!borrow, "No borrow needed since out < n"); +} + +// Equivalent to `monpro(x, x, n, n_prime, out)`, but more efficient. +pub fn monsq(x: &MPNat, n: &MPNat, n_prime: Word, out: &mut [Word]) { + debug_assert!( + n.is_odd(), + "Montgomery multiplication only makes sense with odd modulus" + ); + debug_assert!( + x.digits.len() <= n.digits.len(), + "x cannot be larger than n" + ); + debug_assert!( + out.len() > 2 * n.digits.len(), + "Output needs double the digits to hold the value of x^2 plus an extra word" + ); + let s = n.digits.len(); + + big_sq(x, out); + for i in 0..s { + let mut c: Word = 0; + let m = out[i].wrapping_mul(n_prime); + for j in 0..s { + let (prod, carry) = + shifted_carrying_mul(out[i + j], m, n.digits.get(j).copied().unwrap_or(0), c); + out[i + j] = prod; + c = carry; + } + let mut j = i + s; + while c > 0 { + let (sum, carry) = carrying_add(out[j], c, false); + out[j] = sum; + c = carry as Word; + j += 1; + } + } + // Only keep the last `s + 1` digits in `out`. + for i in 0..(s + 1) { + out[i] = out[i + s]; + } + out[(s + 1)..].fill(0); + + // Check if we need to do the final subtraction + for i in (0..=s).rev() { + match out[i].cmp(n.digits.get(i).unwrap_or(&0)) { + core::cmp::Ordering::Less => return, + core::cmp::Ordering::Greater => break, + core::cmp::Ordering::Equal => (), + } + } + + let mut b = false; + for (i, out_digit) in out.iter_mut().enumerate().take(s) { + let (diff, borrow) = borrowing_sub(*out_digit, n.digits.get(i).copied().unwrap_or(0), b); + *out_digit = diff; + b = borrow; + } + let (diff, borrow) = borrowing_sub(out[s], 0, b); + out[s] = diff; + + debug_assert!(!borrow, "No borrow needed since out < n"); +} + +// Given x odd, computes `x^(-1) mod 2^32`. +// See `MODULAR-INVERSE` in https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf +pub fn mod_inv(x: Word) -> Word { + debug_assert_eq!(x & 1, 1, "Algorithm only valid for odd n"); + + let mut y = 1; + for i in 2..WORD_BITS { + let mask = (1 << i) - 1; + let xy = x.wrapping_mul(y) & mask; + let q = 1 << (i - 1); + if xy >= q { + y += q; + } + } + let xy = x.wrapping_mul(y); + let q = 1 << (WORD_BITS - 1); + if xy >= q { + y += q; + } + y +} + +// Given x odd, computes `x^(-1) mod 2^(WORD_BYTES*out.digits.len())`. +// See `MODULAR-INVERSE` in https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf +pub fn big_mod_inv(x: &MPNat, out: &mut MPNat, scratch: &mut [Word]) { + let s = out.digits.len(); + out.digits[0] = mod_inv(x.digits[0]); + + for digit_index in 1..s { + for i in 1..WORD_BITS { + let mask = (1 << i) - 1; + big_wrapping_mul(x, out, scratch); + scratch[digit_index] &= mask; + let q = 1 << (i - 1); + if scratch[digit_index] >= q { + out.digits[digit_index] += q; + } + scratch.fill(0); + } + big_wrapping_mul(x, out, scratch); + let q = 1 << (WORD_BITS - 1); + if scratch[digit_index] >= q { + out.digits[digit_index] += q; + } + scratch.fill(0); + } +} + +/// Computes R mod n, where R = 2^(WORD_BITS*k) and k = n.digits.len() +/// Note that if R = qn + r, q must be smaller than 2^WORD_BITS since `2^(WORD_BITS) * n > R` +/// (adding a whole additional word to n is too much). +/// Uses the two most significant digits of n to approximate the quotient, +/// then computes the difference to get the remainder. It is possible that this +/// quotient is too big by 1; we can catch that case by looking for overflow +/// in the subtraction. +pub fn compute_r_mod_n(n: &MPNat, out: &mut [Word]) { + let k = n.digits.len(); + + if k == 1 { + let r = BASE; + let result = r % (n.digits[0] as DoubleWord); + out[0] = result as Word; + return; + } + + debug_assert!(n.is_odd(), "This algorithm only works for odd numbers"); + debug_assert!( + out.len() >= k, + "Output must be able to hold numbers of the same size as n" + ); + + let approx_n = join_as_double(n.digits[k - 1], n.digits[k - 2]); + let approx_q = DoubleWord::MAX / approx_n; + debug_assert!( + approx_q <= (Word::MAX as DoubleWord), + "quotient must fit in a single digit" + ); + let mut approx_q = approx_q as Word; + + loop { + let mut c = 0; + let mut b = false; + for (n_digit, out_digit) in n.digits.iter().zip(out.iter_mut()) { + let (prod, carry) = carrying_mul(approx_q, *n_digit, c); + c = carry; + let (diff, borrow) = borrowing_sub(0, prod, b); + b = borrow; + *out_digit = diff; + } + let (diff, borrow) = borrowing_sub(1, c, b); + if borrow { + // approx_q was too large so `R - approx_q*n` overflowed. + // try again with approx_q -= 1 + approx_q -= 1; + } else { + debug_assert_eq!( + diff, 0, + "R - qn must be smaller than n, hence fit in k digits" + ); + break; + } + } +} + +/// Computes `base ^ exp`, ignoring overflow. +pub fn big_wrapping_pow(base: &MPNat, exp: &[u8], scratch_space: &mut [Word]) -> MPNat { + // Compute result via the "binary method", see Knuth The Art of Computer Programming + let mut result = MPNat { + digits: vec![0; scratch_space.len()], + }; + result.digits[0] = 1; + for &b in exp { + let mut mask: u8 = 1 << 7; + while mask > 0 { + big_wrapping_mul(&result, &result, scratch_space); + result.digits.copy_from_slice(scratch_space); + scratch_space.fill(0); // zero-out the scratch space + if b & mask != 0 { + big_wrapping_mul(&result, base, scratch_space); + result.digits.copy_from_slice(scratch_space); + scratch_space.fill(0); // zero-out the scratch space + } + mask >>= 1; + } + } + result +} + +/// Computes `(x * y) mod 2^(WORD_BITS*out.len())`. +pub fn big_wrapping_mul(x: &MPNat, y: &MPNat, out: &mut [Word]) { + let s = out.len(); + for i in 0..s { + let mut c: Word = 0; + for j in 0..(s - i) { + let (prod, carry) = shifted_carrying_mul( + out[i + j], + x.digits.get(j).copied().unwrap_or(0), + y.digits.get(i).copied().unwrap_or(0), + c, + ); + c = carry; + out[i + j] = prod; + } + } +} + +/// Computes `x^2`, storing the result in `out`. +pub fn big_sq(x: &MPNat, out: &mut [Word]) { + debug_assert!( + out.len() > 2 * x.digits.len(), + "Output needs double the digits to hold the value of x^2" + ); + let s = x.digits.len(); + for i in 0..s { + let (product, carry) = shifted_carrying_mul(out[i + i], x.digits[i], x.digits[i], 0); + out[i + i] = product; + let mut c = carry as DoubleWord; + for j in (i + 1)..s { + let product = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord); + let (product, overflow) = product.overflowing_add(product); + let sum = (out[i + j] as DoubleWord) + product + c; + out[i + j] = sum as Word; + c = (sum >> WORD_BITS) as DoubleWord; + if overflow { + c += BASE; + } + } + let (sum, carry) = carrying_add(out[i + s], c as Word, false); + out[i + s] = sum; + out[i + s + 1] = ((c >> WORD_BITS) as Word) + (carry as Word); + } +} + +// Performs `a <<= shift`, returning the overflow +pub fn in_place_shl(a: &mut [Word], shift: u32) -> Word { + let mut c: Word = 0; + let carry_shift = (WORD_BITS as u32) - shift; + for a_digit in a.iter_mut() { + let carry = a_digit.overflowing_shr(carry_shift).0; + *a_digit = a_digit.overflowing_shl(shift).0 | c; + c = carry; + } + c +} + +// Performs `a >>= shift`, returning the overflow +pub fn in_place_shr(a: &mut [Word], shift: u32) -> Word { + let mut b: Word = 0; + let borrow_shift = (WORD_BITS as u32) - shift; + for a_digit in a.iter_mut().rev() { + let borrow = a_digit.overflowing_shl(borrow_shift).0; + *a_digit = a_digit.overflowing_shr(shift).0 | b; + b = borrow; + } + b +} + +// Performs a += b, returning if there was overflow +pub fn in_place_add(a: &mut [Word], b: &[Word]) -> bool { + debug_assert!(a.len() == b.len()); + + let mut c = false; + for (a_digit, b_digit) in a.iter_mut().zip(b) { + let (sum, carry) = carrying_add(*a_digit, *b_digit, c); + *a_digit = sum; + c = carry; + } + + c +} + +// Performs `a -= xy`, returning the "borrow". +pub fn in_place_mul_sub(a: &mut [Word], x: &[Word], y: Word) -> Word { + debug_assert!(a.len() == x.len()); + + // carry is between -big_digit::MAX and 0, so to avoid overflow we store + // offset_carry = carry + big_digit::MAX + let mut offset_carry = Word::MAX; + + for (a_digit, x_digit) in a.iter_mut().zip(x) { + // We want to calculate sum = x - y * c + carry. + // sum >= -(big_digit::MAX * big_digit::MAX) - big_digit::MAX + // sum <= big_digit::MAX + // Offsetting sum by (big_digit::MAX << big_digit::BITS) puts it in DoubleBigDigit range. + let offset_sum = join_as_double(Word::MAX, *a_digit) - Word::MAX as DoubleWord + + offset_carry as DoubleWord + - ((*x_digit as DoubleWord) * (y as DoubleWord)); + + let new_offset_carry = (offset_sum >> WORD_BITS) as Word; + let new_x = offset_sum as Word; + offset_carry = new_offset_carry; + *a_digit = new_x; + } + + // Return the borrow. + Word::MAX - offset_carry +} + +/// Computes `a + xy + c` where any overflow is captured as the "carry", +/// the second part of the output. The arithmetic in this function is +/// guaranteed to never overflow because even when all 4 variables are +/// equal to `Word::MAX` the output is smaller than `DoubleWord::MAX`. +pub fn shifted_carrying_mul(a: Word, x: Word, y: Word, c: Word) -> (Word, Word) { + let wide = { (a as DoubleWord) + ((x as DoubleWord) * (y as DoubleWord)) + (c as DoubleWord) }; + (wide as Word, (wide >> WORD_BITS) as Word) +} + +/// Computes `xy + c` where any overflow is captured as the "carry", +/// the second part of the output. The arithmetic in this function is +/// guaranteed to never overflow because even when all 3 variables are +/// equal to `Word::MAX` the output is smaller than `DoubleWord::MAX`. +pub fn carrying_mul(x: Word, y: Word, c: Word) -> (Word, Word) { + let wide = { ((x as DoubleWord) * (y as DoubleWord)) + (c as DoubleWord) }; + (wide as Word, (wide >> WORD_BITS) as Word) +} + +// Computes `x + y` with "carry the 1" semantics +pub fn carrying_add(x: Word, y: Word, carry: bool) -> (Word, bool) { + let (a, b) = x.overflowing_add(y); + let (c, d) = a.overflowing_add(carry as Word); + (c, b | d) +} + +// Computes `x - y` with "borrow from your neighbour" semantics +pub fn borrowing_sub(x: Word, y: Word, borrow: bool) -> (Word, bool) { + let (a, b) = x.overflowing_sub(y); + let (c, d) = a.overflowing_sub(borrow as Word); + (c, b | d) +} + +pub fn join_as_double(hi: Word, lo: Word) -> DoubleWord { + DoubleWord::from(lo) | (DoubleWord::from(hi) << WORD_BITS) +} + +#[test] +fn test_monsq() { + check_monsq(1, 31); + check_monsq(6, 31); + // This example is intentionally chosen because 5 * 5 = 25 = 0 mod 25, + // therefore it requires the final subtraction step in the algorithm. + check_monsq(5, 25); + check_monsq(0x1FFF_FFFF_FFFF_FFF0, 0x1FFF_FFFF_FFFF_FFF1); + check_monsq(0x16FF_221F_CB7D, 0x011E_842B_6BAA_5017_EBF2_8293); + check_monsq(0x0A2D_63F5_CFF9, 0x1F3B_3BD9_43EF); + check_monsq( + 0xa6b0ce71a380dea7c83435bc, + 0xc4550871a1cfc67af3e77eceb2ecfce5, + ); + + fn check_monsq(x: u128, n: u128) { + let a = MPNat::from_big_endian(&x.to_be_bytes()); + let m = MPNat::from_big_endian(&n.to_be_bytes()); + let n_prime = Word::MAX - mod_inv(m.digits[0]) + 1; + + let mut output = vec![0; 2 * m.digits.len() + 1]; + monsq(&a, &m, n_prime, &mut output); + let result = MPNat { digits: output }; + + let mut output = vec![0; m.digits.len() + 2]; + monpro(&a, &a, &m, n_prime, &mut output); + let expected = MPNat { digits: output }; + + assert_eq!( + num::BigUint::from_bytes_be(&result.to_big_endian()), + num::BigUint::from_bytes_be(&expected.to_big_endian()), + "{x}^2 failed monsq check" + ); + } +} + +#[test] +fn test_monpro() { + use num::Integer; + + check_monpro(1, 1, 31); + check_monpro(6, 7, 31); + // This example is intentionally chosen because 5 * 7 = 35 = 0 mod 35, + // therefore it requires the final subtraction step in the algorithm. + check_monpro(5, 7, 35); + check_monpro(0x1FFF_FFFF_FFFF_FFF0, 0x1234, 0x1FFF_FFFF_FFFF_FFF1); + check_monpro( + 0x16FF_221F_CB7D, + 0x0C75_8535_434F, + 0x011E_842B_6BAA_5017_EBF2_8293, + ); + check_monpro(0x0A2D_63F5_CFF9, 0x1B21_FF3C_FA8E, 0x1F3B_3BD9_43EF); + + fn check_monpro(x: u128, y: u128, n: u128) { + let a = MPNat::from_big_endian(&x.to_be_bytes()); + let b = MPNat::from_big_endian(&y.to_be_bytes()); + let m = MPNat::from_big_endian(&n.to_be_bytes()); + let n_prime = Word::MAX - mod_inv(m.digits[0]) + 1; + + let mut output = vec![0; m.digits.len() + 2]; + monpro(&a, &b, &m, n_prime, &mut output); + let result = MPNat { digits: output }; + + let r = num::BigInt::from(2).pow((WORD_BITS * m.digits.len()) as u32); + let r_inv = r.extended_gcd(&num::BigInt::from(n as i128)).x; + let r_inv: u128 = r_inv.try_into().unwrap(); + + let expected = (((x * y) % n) * r_inv) % n; + let actual = mp_nat_to_u128(&result); + assert_eq!(actual, expected, "{x}*{y} failed monpro check"); + } +} + +#[test] +fn test_r_mod_n() { + check_r_mod_n(0x01_00_00_00_01); + check_r_mod_n(0x80_00_00_00_01); + check_r_mod_n(0xFFFF_FFFF_FFFF_FFFF); + check_r_mod_n(0x0001_0000_0000_0000_0001); + check_r_mod_n(0x8000_0000_0000_0000_0001); + check_r_mod_n(0xbf2d_c9a3_82c5_6e85_b033_7651); + check_r_mod_n(0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF); + + fn check_r_mod_n(n: u128) { + let x = MPNat::from_big_endian(&n.to_be_bytes()); + let mut out = vec![0; x.digits.len()]; + compute_r_mod_n(&x, &mut out); + let result = mp_nat_to_u128(&MPNat { digits: out }); + let expected = num::BigUint::from(2_u32).pow((WORD_BITS * x.digits.len()) as u32) + % num::BigUint::from(n); + assert_eq!(num::BigUint::from(result), expected); + } +} + +#[test] +fn test_big_mod_inv() { + check_big_mod_inv(0x02_FF_FF_FF); + check_big_mod_inv(0x1234_0000_DDDD_FFFF); + check_big_mod_inv(0x52DA_9A91_F82D_6E17_FDF8_6743_2B58_7917); + + fn check_big_mod_inv(n: u128) { + let x = MPNat::from_big_endian(&n.to_be_bytes()); + let s = x.digits.len(); + let mut result = MPNat { digits: vec![0; s] }; + let mut scratch = vec![0; s]; + big_mod_inv(&x, &mut result, &mut scratch); + let n_inv = mp_nat_to_u128(&result); + if WORD_BITS * s < u128::BITS as usize { + assert_eq!( + n.wrapping_mul(n_inv) % (1 << (WORD_BITS * s)), + 1, + "{n} failed big_mod_inv check" + ); + } else { + assert_eq!(n.wrapping_mul(n_inv), 1, "{n} failed big_mod_inv check"); + } + } +} + +#[test] +fn test_in_place_shl() { + check_in_place_shl(0, 0); + check_in_place_shl(1, 10); + check_in_place_shl(Word::MAX as u128, 5); + check_in_place_shl(u128::MAX, 16); + + fn check_in_place_shl(n: u128, shift: u32) { + let mut x = MPNat::from_big_endian(&n.to_be_bytes()); + in_place_shl(&mut x.digits, shift); + let result = mp_nat_to_u128(&x); + let mask = BASE + .overflowing_pow(x.digits.len() as u32) + .0 + .wrapping_sub(1); + assert_eq!(result, n.overflowing_shl(shift).0 & mask); + } +} + +#[test] +fn test_in_place_shr() { + check_in_place_shr(0, 0); + check_in_place_shr(1, 10); + check_in_place_shr(0x1234_5678, 10); + check_in_place_shr(Word::MAX as u128, 5); + check_in_place_shr(u128::MAX, 16); + + fn check_in_place_shr(n: u128, shift: u32) { + let mut x = MPNat::from_big_endian(&n.to_be_bytes()); + in_place_shr(&mut x.digits, shift); + let result = mp_nat_to_u128(&x); + assert_eq!(result, n.overflowing_shr(shift).0); + } +} + +#[test] +fn test_mod_inv() { + for i in 1..1025 { + check_mod_inv(2 * i - 1); + } + for i in 0..1025 { + check_mod_inv(0xFF_FF_FF_FF - 2 * i); + } + + fn check_mod_inv(n: Word) { + let n_inv = mod_inv(n); + assert_eq!(n.wrapping_mul(n_inv), 1, "{n} failed mod_inv check"); + } +} + +#[test] +fn test_big_wrapping_pow() { + check_big_wrapping_pow(1, 1); + check_big_wrapping_pow(10, 2); + check_big_wrapping_pow(2, 32); + check_big_wrapping_pow(2, 64); + check_big_wrapping_pow(2766, 844); + + fn check_big_wrapping_pow(a: u128, b: u32) { + let expected = num::BigUint::from(a).pow(b); + let x = MPNat::from_big_endian(&a.to_be_bytes()); + let y = b.to_be_bytes(); + let mut scratch = vec![0; 1 + (expected.to_bytes_be().len() / crate::mpnat::WORD_BYTES)]; + let result = big_wrapping_pow(&x, &y, &mut scratch); + let result = { + let result = result.to_big_endian(); + num::BigUint::from_bytes_be(&result) + }; + assert_eq!(result, expected, "{a} ^ {b} != {expected}"); + } +} + +#[test] +fn test_big_wrapping_mul() { + check_big_wrapping_mul(0, 0, 1); + check_big_wrapping_mul(1, 1, 1); + check_big_wrapping_mul(7, 6, 1); + check_big_wrapping_mul(Word::MAX.into(), Word::MAX.into(), 2); + check_big_wrapping_mul(Word::MAX.into(), Word::MAX.into(), 1); + check_big_wrapping_mul(DoubleWord::MAX - 5, DoubleWord::MAX - 6, 2); + check_big_wrapping_mul(0xa945_aa5e_429a_6d1a, 0x4072_d45d_3355_237b, 3); + check_big_wrapping_mul( + 0x8ae1_5515_fc92_b1c0_b473_8ce8_6bbf_7218, + 0x43e9_8b77_1f7c_aa93_6c4c_85e9_7fd0_504f, + 3, + ); + + fn check_big_wrapping_mul(a: u128, b: u128, output_digits: usize) { + let expected = (num::BigUint::from(a) * num::BigUint::from(b)) + % num::BigUint::from(2_u32).pow((output_digits * WORD_BITS) as u32); + let x = MPNat::from_big_endian(&a.to_be_bytes()); + let y = MPNat::from_big_endian(&b.to_be_bytes()); + let mut out = vec![0; output_digits]; + big_wrapping_mul(&x, &y, &mut out); + let result = { + let result = MPNat { digits: out }.to_big_endian(); + num::BigUint::from_bytes_be(&result) + }; + assert_eq!(result, expected, "{a}*{b} != {expected}"); + } +} + +#[test] +fn test_big_sq() { + check_big_sq(0); + check_big_sq(1); + check_big_sq(Word::MAX.into()); + check_big_sq(2 * (Word::MAX as u128)); + check_big_sq(0x8e67904953db9a2bf6da64bf8bda866d); + check_big_sq(0x9f8dc1c3fc0bf50fe75ac3bbc03124c9); + check_big_sq(0x9c9a17378f3d064e5eaa80eeb3850cd7); + check_big_sq(0xc7f03fbb1c186c05e54b3ee19106baa4); + check_big_sq(0xcf2025cee03025d247ad190e9366d926); + check_big_sq(u128::MAX); + + fn check_big_sq(a: u128) { + let expected = num::BigUint::from(a).pow(2_u32); + let x = MPNat::from_big_endian(&a.to_be_bytes()); + let mut out = vec![0; 2 * x.digits.len() + 1]; + big_sq(&x, &mut out); + let result = { + let result = MPNat { digits: out }.to_big_endian(); + num::BigUint::from_bytes_be(&result) + }; + assert_eq!(result, expected, "{a}^2 != {expected}"); + } +} + +#[test] +fn test_borrowing_sub() { + assert_eq!(borrowing_sub(0, 0, false), (0, false)); + assert_eq!(borrowing_sub(1, 0, false), (1, false)); + assert_eq!(borrowing_sub(47, 5, false), (42, false)); + assert_eq!(borrowing_sub(101, 7, true), (93, false)); + assert_eq!( + borrowing_sub(0x00_00_01_00, 0x00_00_02_00, false), + (Word::MAX - 0xFF, true) + ); + assert_eq!( + borrowing_sub(0x00_00_01_00, 0x00_00_10_00, true), + (Word::MAX - 0x0F_00, true) + ); +} + +// These examples are correctly stated +#[allow(clippy::mistyped_literal_suffixes)] +#[test] +fn test_shifted_carrying_mul() { + assert_eq!(shifted_carrying_mul(0, 0, 0, 0), (0, 0)); + assert_eq!(shifted_carrying_mul(0, 6, 7, 0), (42, 0)); + assert_eq!(shifted_carrying_mul(0, 6, 7, 8), (50, 0)); + assert_eq!(shifted_carrying_mul(5, 6, 7, 8), (55, 0)); + assert_eq!( + shifted_carrying_mul( + Word::MAX - 0x11, + Word::MAX - 0x1234, + Word::MAX - 0xABCD, + Word::MAX - 0xFF + ), + (0x0C_38_0C_94, Word::MAX - 0xBE00) + ); + assert_eq!( + shifted_carrying_mul(Word::MAX, Word::MAX, Word::MAX, Word::MAX), + (Word::MAX, Word::MAX) + ); +} + +#[cfg(test)] +pub fn mp_nat_to_u128(x: &MPNat) -> u128 { + let mut buf = [0u8; 16]; + let result = x.to_big_endian(); + let k = result.len(); + buf[(16 - k)..].copy_from_slice(&result); + u128::from_be_bytes(buf) +} diff --git a/arithmetic/arithmetic/src/lib.rs b/arithmetic/arithmetic/src/lib.rs index dd367fa1..9fa96330 100644 --- a/arithmetic/arithmetic/src/lib.rs +++ b/arithmetic/arithmetic/src/lib.rs @@ -1,6 +1,9 @@ // Copyright contributors to Hyperledger Besu // SPDX-License-Identifier: Apache-2.0 +mod arith; +mod mpnat; + use std::rc::Rc; use std::io::Write; @@ -9,9 +12,6 @@ use core::{ mem::size_of, }; -use num_bigint::{BigUint}; -use num_traits::{Zero, One}; - #[derive(Debug, Clone)] pub enum RuntimeError { /// Input was a bad format. @@ -31,21 +31,15 @@ pub extern "C" fn modexp_precompiled( let raw_out_i8: &mut [libc::c_char] = unsafe { std::slice::from_raw_parts_mut(o, o_len as usize) }; let mut raw_out: &mut [u8] = unsafe { std::mem::transmute(raw_out_i8) }; + let answer = modexp_precompiled_impl(input); - return match modexp_precompiled_impl(input) { - Ok(result) => { - let written = raw_out.write(result.as_ref()); - if let Ok(bytes_written) = written { - unsafe { *o_len = bytes_written as u32 }; - 0u32 - } else { - 1u32 - } - } - Err(_error) => { - 1u32 - } - }; + let written = raw_out.write(answer.as_slice()); + if let Ok(bytes_written) = written { + unsafe { *o_len = bytes_written as u32 }; + 0u32 + } else { + 1u32 + } } @@ -68,93 +62,56 @@ macro_rules! read_u64_with_overflow { } /// from revm - https://github.com/bluealloy/revm/blob/main/crates/revm_precompiles/src/modexp.rs -fn modexp_precompiled_impl(input: &[u8]) -> Result>, RuntimeError> { - let len = input.len(); +fn modexp_precompiled_impl(input: &[u8]) -> Rc> { let (base_len, base_overflow) = read_u64_with_overflow!(input, 0, 32, u32::MAX as usize); let (exp_len, exp_overflow) = read_u64_with_overflow!(input, 32, 64, u32::MAX as usize); let (mod_len, mod_overflow) = read_u64_with_overflow!(input, 64, 96, u32::MAX as usize); if base_overflow || mod_overflow { - return Ok(Rc::new(Vec::new())); + return Rc::new(Vec::new()); } - let r = if base_len == 0 && mod_len == 0 { - BigUint::zero() - } else { - // set limit for exp overflow - if exp_overflow { - return Ok(Rc::new(Vec::new())); - } - let base_start = 96; - let base_end = base_start + base_len; - let exp_end = base_end + exp_len; - let mod_end = exp_end + mod_len; - - let read_big = |from: usize, to: usize| { - let mut out = vec![0; to - from]; - let from = min(from, len); - let to = min(to, len); - out[..to - from].copy_from_slice(&input[from..to]); - BigUint::from_bytes_be(&out) - }; - - let base = read_big(base_start, base_end); - let exponent = read_big(base_end, exp_end); - let modulus = read_big(exp_end, mod_end); - - if modulus.is_zero() || modulus.is_one() { - BigUint::zero() - } else { - base.modpow(&exponent, &modulus) - } - }; + if base_len == 0 && mod_len == 0 { + return Rc::new(Vec::new()); + } + // set limit for exp overflow + if exp_overflow { + return Rc::new(Vec::new()); + } + let base_start = 96; + let base_end = base_start + base_len; + let exp_end = base_end + exp_len; + let mod_end = exp_end + mod_len; + + let base = &input[base_start..base_end]; + let exponent = &input[base_end..exp_end]; + let modulus = &input[exp_end..mod_end]; + let bytes = modexp(base, exponent, modulus); // write output to given memory, left padded and same length as the modulus. - let bytes = r.to_bytes_be(); // always true except in the case of zero-length modulus, which leads to // output of length and value 1. match bytes.len().cmp(&mod_len) { - Ordering::Equal => Ok(Rc::new(bytes.to_vec())), + Ordering::Equal => Rc::new(bytes.to_vec()), Ordering::Less => { let mut ret = Vec::with_capacity(mod_len); ret.extend(core::iter::repeat(0).take(mod_len - bytes.len())); ret.extend_from_slice(&bytes[..]); - Ok(Rc::new(ret.to_vec())) + Rc::new(ret.to_vec()) } - Ordering::Greater => Ok(Rc::new(Vec::new())), + Ordering::Greater => Rc::new(Vec::new()), } } - -/// Big Integer multiplication only returning the lower 32 bytes of the result. -#[no_mangle] -pub extern "C" fn mul_operation( - term1: *const std::os::raw::c_char, - term1_len: u32, - term2: *const std::os::raw::c_char, - term2_len: u32, - o: *mut std::os::raw::c_char, - o_len: *mut u32, -) -> u32 { - let a_i8: &[libc::c_char] = unsafe { std::slice::from_raw_parts(term1, term1_len as usize) }; - let a: &[u8] = unsafe { std::mem::transmute(a_i8) }; - let b_i8: &[libc::c_char] = unsafe { std::slice::from_raw_parts(term2, term2_len as usize) }; - let b: &[u8] = unsafe { std::mem::transmute(b_i8) }; - - let raw_out_i8: &mut [libc::c_char] = unsafe { std::slice::from_raw_parts_mut(o, o_len as usize) }; - let mut raw_out: &mut [u8] = unsafe { std::mem::transmute(raw_out_i8) }; - - let c = (BigUint::from_bytes_be(a) * BigUint::from_bytes_be(b)).to_bytes_be(); - let result = if c.len() > 32 { c[(c.len() - 32)..].to_vec() } else { c }; - - let written = raw_out.write(result.as_ref()); - if let Ok(bytes_written) = written { - unsafe { *o_len = bytes_written as u32 }; - 0u32 - } else { - 1u32 +// from aurora +/// Computes `(base ^ exp) % modulus`, where all values are given as big-endian +/// encoded bytes. +pub fn modexp(base: &[u8], exp: &[u8], modulus: &[u8]) -> Vec { + let mut x = mpnat::MPNat::from_big_endian(base); + let m = mpnat::MPNat::from_big_endian(modulus); + if m.digits.len() == 1 && m.digits[0] == 0 { + return Vec::new(); } + let result = x.modpow(exp, &m); + result.to_big_endian() } - - - diff --git a/arithmetic/arithmetic/src/mpnat.rs b/arithmetic/arithmetic/src/mpnat.rs new file mode 100644 index 00000000..802d5321 --- /dev/null +++ b/arithmetic/arithmetic/src/mpnat.rs @@ -0,0 +1,729 @@ +use crate::{ + arith::{ + big_mod_inv, big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, + compute_r_mod_n, in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, + join_as_double, mod_inv, monpro, monsq, + }, +}; + +pub type Word = u64; +pub type DoubleWord = u128; +pub const WORD_BYTES: usize = core::mem::size_of::(); +pub const WORD_BITS: usize = Word::BITS as usize; +pub const BASE: DoubleWord = (Word::MAX as DoubleWord) + 1; + +/// Multi-precision natural number, represented in base `Word::MAX + 1 = 2^WORD_BITS`. +/// The digits are stored in little-endian order, i.e. digits[0] is the least +/// significant digit. +#[derive(Debug)] +pub struct MPNat { + pub digits: Vec, +} + +impl MPNat { + pub fn from_big_endian(bytes: &[u8]) -> Self { + if bytes.is_empty() { + return Self { digits: vec![0] }; + } + // Remainder on division by WORD_BYTES + let r = bytes.len() & (WORD_BYTES - 1); + let n_digits = if r == 0 { + bytes.len() / WORD_BYTES + } else { + // Need an extra digit for the remainder + (bytes.len() / WORD_BYTES) + 1 + }; + let mut digits = vec![0; n_digits]; + // buffer to hold Word-sized slices of the input bytes + let mut buf = [0u8; WORD_BYTES]; + let mut i = n_digits - 1; + if r != 0 { + buf[(WORD_BYTES - r)..].copy_from_slice(&bytes[0..r]); + digits[i] = Word::from_be_bytes(buf); + if i == 0 { + // Special case where there is just one digit + return Self { digits }; + } + i -= 1; + } + let mut j = r; + loop { + let next_j = j + WORD_BYTES; + buf.copy_from_slice(&bytes[j..next_j]); + digits[i] = Word::from_be_bytes(buf); + if i == 0 { + break; + } else { + i -= 1; + j = next_j; + } + } + // throw away leading zeros + while digits.len() > 1 && digits[digits.len() - 1] == 0 { + digits.pop(); + } + Self { digits } + } + + pub fn is_power_of_two(&self) -> bool { + // A multi-precision number is a power of 2 iff exactly one digit + // is a power of 2 and all others are zero. + let mut found_power_of_two = false; + for &d in self.digits.iter() { + let is_p2 = d.is_power_of_two(); + if (!is_p2 && d != 0) || (is_p2 && found_power_of_two) { + return false; + } else if is_p2 { + found_power_of_two = true; + } + } + found_power_of_two + } + + pub fn is_odd(&self) -> bool { + // A binary number is odd iff its lowest order bit is set. + self.digits[0] & 1 == 1 + } + + /// Computes `self ^ exp mod modulus`. `exp` must be given as big-endian bytes. + pub fn modpow(&mut self, exp: &[u8], modulus: &Self) -> Self { + if exp.len() <= core::mem::size_of::() { + let exp_as_number = { + let mut tmp: usize = 0; + for d in exp { + tmp *= 256; + tmp += (*d) as usize; + } + tmp + }; + + if let Some(max_output_digits) = self.digits.len().checked_mul(exp_as_number) { + if modulus.digits.len() > max_output_digits { + // Special case: modulus is larger than `base ^ exp`, so division is not relevant + let mut scratch_space = vec![0; max_output_digits]; + return big_wrapping_pow(self, exp, &mut scratch_space); + } + } + } + + if modulus.is_power_of_two() { + return self.modpow_with_power_of_two(exp, modulus); + } else if modulus.is_odd() { + return self.modpow_montgomery(exp, modulus); + } + + // If the modulus is not a power of two and not an odd number then + // it is a product of some power of two with an odd number. In this + // case we will use the Chinese remainder theorem to get the result. + // See http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf + + let trailing_zeros = modulus.digits.iter().take_while(|x| x == &&0).count(); + let additional_zero_bits = modulus.digits[trailing_zeros].trailing_zeros() as usize; + let power_of_two = { + let mut tmp = MPNat { + digits: vec![0; trailing_zeros + 1], + }; + tmp.digits[trailing_zeros] = 1 << additional_zero_bits; + tmp + }; + let power_of_two_mask = *power_of_two.digits.last().unwrap() - 1; + let odd = { + let num_digits = modulus.digits.len() - trailing_zeros; + let mut tmp = MPNat { + digits: vec![0; num_digits], + }; + if additional_zero_bits > 0 { + tmp.digits[0] = modulus.digits[trailing_zeros] >> additional_zero_bits; + for i in 1..num_digits { + let d = modulus.digits[trailing_zeros + i]; + tmp.digits[i - 1] += + (d & power_of_two_mask) << (WORD_BITS - additional_zero_bits); + tmp.digits[i] = d >> additional_zero_bits; + } + } else { + tmp.digits + .copy_from_slice(&modulus.digits[trailing_zeros..]); + } + while tmp.digits.last() == Some(&0) { + tmp.digits.pop(); + } + tmp + }; + debug_assert!(power_of_two.is_power_of_two(), "Factored out power of two"); + debug_assert!( + odd.is_odd(), + "Remaining number is odd after factoring out powers of two" + ); + debug_assert!( + { + let mut tmp = vec![0; modulus.digits.len()]; + big_wrapping_mul(&power_of_two, &odd, &mut tmp); + tmp == modulus.digits + }, + "modulus is factored" + ); + + let mut base_copy = MPNat { + digits: self.digits.clone(), + }; + let x1 = base_copy.modpow_montgomery(exp, &odd); + let x2 = self.modpow_with_power_of_two(exp, &power_of_two); + + let s = power_of_two.digits.len(); + let mut scratch = vec![0; s]; + let odd_inv = { + let mut tmp = MPNat { digits: vec![0; s] }; + big_mod_inv(&odd, &mut tmp, &mut scratch); + *tmp.digits.last_mut().unwrap() &= power_of_two_mask; + tmp + }; + let diff = { + scratch.fill(0); + let mut b = false; + for (i, scratch_digit) in scratch.iter_mut().enumerate().take(s) { + let (diff, borrow) = borrowing_sub( + x2.digits.get(i).copied().unwrap_or(0), + x1.digits.get(i).copied().unwrap_or(0), + b, + ); + *scratch_digit = diff; + b = borrow; + } + MPNat { digits: scratch } + }; + let y = { + let mut out = vec![0; s]; + big_wrapping_mul(&diff, &odd_inv, &mut out); + *out.last_mut().unwrap() &= power_of_two_mask; + MPNat { digits: out } + }; + + // Re-use allocation for efficiency + let mut digits = diff.digits; + let s = modulus.digits.len(); + digits.fill(0); + digits.resize(s, 0); + big_wrapping_mul(&odd, &y, &mut digits); + let mut c = false; + for (i, out_digit) in digits.iter_mut().enumerate() { + let (sum, carry) = carrying_add(x1.digits.get(i).copied().unwrap_or(0), *out_digit, c); + c = carry; + *out_digit = sum; + } + MPNat { digits } + } + + // Computes `self ^ exp mod modulus` using Montgomery multiplication. + // See https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/j37acmon.pdf + fn modpow_montgomery(&mut self, exp: &[u8], modulus: &Self) -> Self { + // The montgomery method only works with odd modulus. + debug_assert!(modulus.is_odd()); + + // n_prime satisfies `r * (r^(-1)) - modulus * n' = 1`, where + // `r = 2^(WORD_BITS*modulus.digits.len())`. + let n_prime = Word::MAX - mod_inv(modulus.digits[0]) + 1; + let s = modulus.digits.len(); + + let mut x_bar = MPNat { digits: vec![0; s] }; + // Initialize result as `r mod modulus` (Montgomery form of 1) + compute_r_mod_n(modulus, &mut x_bar.digits); + + // Reduce base mod modulus + self.sub_to_same_size(modulus); + + // Need to compute a_bar = base * r mod modulus; + // First directly multiply base * r to get a 2s-digit number, + // then reduce mod modulus. + let a_bar = { + let mut tmp = MPNat { + digits: vec![0; 2 * s], + }; + big_wrapping_mul(self, &x_bar, &mut tmp.digits); + tmp.sub_to_same_size(modulus); + tmp + }; + + // scratch space for monpro algorithm + let mut scratch = vec![0; 2 * s + 1]; + let monpro_len = s + 2; + + // Use binary method for computing exp, but with monpro as the multiplication + for &b in exp { + let mut mask: u8 = 1 << 7; + while mask > 0 { + monsq(&x_bar, modulus, n_prime, &mut scratch); + x_bar.digits.copy_from_slice(&scratch[0..s]); + scratch.fill(0); + if b & mask != 0 { + monpro( + &x_bar, + &a_bar, + modulus, + n_prime, + &mut scratch[0..monpro_len], + ); + x_bar.digits.copy_from_slice(&scratch[0..s]); + scratch.fill(0); + } + mask >>= 1; + } + } + + // Convert out of Montgomery form by computing monpro with 1 + let one = { + // We'll reuse the memory space from a_bar for efficiency. + let mut digits = a_bar.digits; + digits.fill(0); + digits[0] = 1; + MPNat { digits } + }; + monpro(&x_bar, &one, modulus, n_prime, &mut scratch[0..monpro_len]); + scratch.resize(s, 0); + MPNat { digits: scratch } + } + + fn modpow_with_power_of_two(&mut self, exp: &[u8], modulus: &Self) -> Self { + debug_assert!(modulus.is_power_of_two()); + // We know `modulus` is a power of 2. So reducing is as easy as bit shifting. + // We also know the modulus is non-zero because 0 is not a power of 2. + + // First reduce self to be the same size as the modulus + self.force_same_size(modulus); + // The modulus is a power of 2 but that power may not be a multiple of a whole word. + // We can clear out any higher order bits to fix this. + let modulus_mask = *modulus.digits.last().unwrap() - 1; + *self.digits.last_mut().unwrap() &= modulus_mask; + + // We know that `totient(2^k) = 2^(k-1)`, therefore by Euler's theorem + // we can also reduce the exponent mod `2^(k-1)`. Effectively this means + // throwing away bytes to make `exp` shorter. Note: Euler's theorem only applies + // if the base and modulus are coprime (which in this case means the base is odd). + let exp = if self.is_odd() && (exp.len() > WORD_BYTES * modulus.digits.len()) { + &exp[(exp.len() - WORD_BYTES * modulus.digits.len())..] + } else { + exp + }; + + let mut scratch_space = vec![0; modulus.digits.len()]; + let mut result = big_wrapping_pow(self, exp, &mut scratch_space); + + // The modulus is a power of 2 but that power may not be a multiple of a whole word. + // We can clear out any higher order bits to fix this. + *result.digits.last_mut().unwrap() &= modulus_mask; + + result + } + + /// Makes `self` have the same number of digits as `other` by + /// pushing 0s or dropping higher order digits as needed. + /// This is equivalent to reducing `self` modulo `2^(WORD_BITS*k)` where + /// `k` is the number of digits in `other`. + fn force_same_size(&mut self, other: &Self) { + self.digits.resize(other.digits.len(), 0); + + // This is here to just drive the point home about what the + // invariant is after calling this function. + debug_assert_eq!(self.digits.len(), other.digits.len()); + } + + /// Assumes `self` has more digits than `other`. + /// Makes `self` have the same number of digits as `other` by subtracting off multiples + /// of `other`. This is a partial reduction of `self` modulo `other`, but rather + /// than doing the full division, the goal is simply to make the two numbers have the + /// same number of digits. + fn sub_to_same_size(&mut self, other: &Self) { + // Remove leading zeros before starting + while self.digits.len() > 1 && self.digits.last() == Some(&0) { + self.digits.pop(); + } + + let n = other.digits.len(); + let m = self.digits.len().saturating_sub(n); + if m == 0 { + return; + } + + let other_most_sig = *other.digits.last().unwrap(); + + if self.digits.len() == 2 { + // This is the smallest case since `n >= 1` and `m > 0` + // implies that `self.digits.len() >= 2`. + // In this case we can use DoubleWord-sized arithmetic + // to get the answer directly. + let self_most_sig = self.digits.pop().unwrap(); + let a = join_as_double(self_most_sig, self.digits[0]); + let b = other_most_sig as DoubleWord; + self.digits[0] = (a % b) as Word; + return; + } + + if n == 1 { + // The divisor is only 1 digit, so the long-division + // algorithm is easy. + let k = self.digits.len() - 1; + for j in (0..k).rev() { + let self_most_sig = self.digits.pop().unwrap(); + let self_second_sig = self.digits[j]; + let r = + join_as_double(self_most_sig, self_second_sig) % (other_most_sig as DoubleWord); + self.digits[j] = r as Word; + } + return; + } + + // At this stage we know that `n >= 2` and `self.digits.len() >= 3`. + // The smaller cases are covered in the if-statements above. + + // The algorithm below only works well when the divisor's + // most significant digit is at least `BASE / 2`. + // If it is too small then we "normalize" by multiplying + // both numerator and denominator by a common factor + // and run the algorithm on those numbers. + // See Knuth The Art of Computer Programming vol. 2 section 4.3 for details. + let shift = other_most_sig.leading_zeros(); + if shift > 0 { + // Normalize self + let overflow = in_place_shl(&mut self.digits, shift); + self.digits.push(overflow); + + // Normalize other + let mut normalized = other.digits.clone(); + let overflow = in_place_shl(&mut normalized, shift); + debug_assert_eq!(overflow, 0, "Normalizing modulus cannot overflow"); + debug_assert_eq!( + normalized[n - 1].leading_zeros(), + 0, + "Most significant bit is set" + ); + + // Run algorithm on normalized values + self.sub_to_same_size(&MPNat { digits: normalized }); + + // need to de-normalize to get the correct result + in_place_shr(&mut self.digits, shift); + + return; + } + + let other_second_sig = other.digits[n - 2]; + let mut self_most_sig: Word = 0; + for j in (0..=m).rev() { + let self_second_sig = *self.digits.last().unwrap(); + let self_third_sig = self.digits[self.digits.len() - 2]; + + let (mut q_hat, mut r_hat) = { + let a = join_as_double(self_most_sig, self_second_sig); + let mut q_hat = a / (other_most_sig as DoubleWord); + let mut r_hat = a % (other_most_sig as DoubleWord); + + if q_hat == BASE { + q_hat -= 1; + r_hat += other_most_sig as DoubleWord; + } + + (q_hat as Word, r_hat) + }; + + while r_hat < BASE + && join_as_double(r_hat as Word, self_third_sig) + < (q_hat as DoubleWord) * (other_second_sig as DoubleWord) + { + q_hat -= 1; + r_hat += other_most_sig as DoubleWord; + } + + let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat); + if borrow > self_most_sig { + // q_hat was too large, add back one multiple of the modulus + let carry = in_place_add(&mut self.digits[j..], &other.digits); + debug_assert!( + carry, + "Adding back should cause overflow to cancel the borrow" + ); + borrow -= 1; + } + // Most significant digit of self has been cancelled out + debug_assert_eq!(borrow, self_most_sig); + self_most_sig = self.digits.pop().unwrap(); + } + + self.digits.push(self_most_sig); + debug_assert!(self.digits.len() <= n); + } + + pub fn to_big_endian(&self) -> Vec { + if self.digits.iter().all(|x| x == &0) { + return vec![0]; + } + + // Safety: unwrap is safe since `self.digits` is always non-empty. + let most_sig_bytes: [u8; WORD_BYTES] = self.digits.last().unwrap().to_be_bytes(); + // The most significant digit may not need 4 bytes. + // Only include as many bytes as needed in the output. + let be_initial_bytes = { + let mut tmp: &[u8] = &most_sig_bytes; + while !tmp.is_empty() && tmp[0] == 0 { + tmp = &tmp[1..]; + } + tmp + }; + + let mut result = vec![0u8; (self.digits.len() - 1) * WORD_BYTES + be_initial_bytes.len()]; + result[0..be_initial_bytes.len()].copy_from_slice(be_initial_bytes); + for (i, d) in self.digits.iter().take(self.digits.len() - 1).enumerate() { + let bytes = d.to_be_bytes(); + let j = result.len() - WORD_BYTES * i; + result[(j - WORD_BYTES)..j].copy_from_slice(&bytes); + } + result + } +} + +#[test] +fn test_modpow_even() { + check_modpow_even(3, 5, 500, 243); + check_modpow_even(3, 5, 20, 3); + + check_modpow_even( + 0x2ff4f4df4c518867207c84b57a77aa50, + 0xca83c2925d17c577c9a03598b6f360, + 0xf863d4f17a5405d84814f54c92f803c8, + 0x8d216c9a1fb275ed18eb340ed43cacc0, + ); + check_modpow_even( + 0x13881e1614244c56d15ac01096b070e7, + 0x336df5b4567cbe4c093271dc151e6c72, + 0x7540f399a0b6c220f1fc60d2451a1ff0, + 0x1251d64c552e8f831f5b841d2811f9c1, + ); + check_modpow_even( + 0x774d5b2494a449d8f22b22ea542d4ddf, + 0xd2f602e1688f271853e7794503c2837e, + 0xa80d20ebf75f92192159197b60f36e8e, + 0x3fbbba42489b27fc271fb39f54aae2e1, + ); + check_modpow_even( + 0x756e409cc3583a6b68ae27ccd9eb3d50, + 0x16dafb38a334288954d038bedbddc970, + 0x1f9b2237f09413d1fc44edf9bd02b8bc, + 0x9347445ac61536a402723cd07a3f5a4, + ); + check_modpow_even( + 0x6dcb8405e2cc4dcebee3e2b14861b47d, + 0xe6c1e5251d6d5deb8dddd0198481d671, + 0xe34a31d814536e8b9ff6cc5300000000, + 0xaa86af638386880334694967564d0c3d, + ); + check_modpow_even( + 0x9c12fe4a1a97d17c1e4573247a43b0e5, + 0x466f3e0a2e8846b8c48ecbf612b96412, + 0x710d7b9d5718acff0000000000000000, + 0x569bf65929e71cd10a553a8623bdfc99, + ); + check_modpow_even( + 0x6d018fdeaa408222cb10ff2c36124dcf, + 0x8e35fc05d490bb138f73c2bc284a67a7, + 0x6c237160750d78400000000000000000, + 0x3fe14e11392c6c6be8efe956c965d5af, + ); + + let base: Vec = vec![ + 0x36, 0xAB, 0xD4, 0x52, 0x4E, 0x89, 0xA3, 0x4C, 0x89, 0xC4, 0x20, 0x94, 0x25, 0x47, 0xE1, + 0x2C, 0x7B, 0xE1, + ]; + let exponent: Vec = vec![0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0x17, 0xEA, 0x78]; + let modulus: Vec = vec![ + 0x02, 0xF0, 0x75, 0x8C, 0x6A, 0x04, 0x20, 0x09, 0x55, 0xB6, 0x49, 0xC3, 0x57, 0x22, 0xB8, + 0x00, 0x00, 0x00, 0x00, + ]; + let result = crate::modexp(&base, &exponent, &modulus); + assert_eq!( + result, + vec![2, 63, 79, 118, 41, 54, 235, 9, 115, 212, 107, 110, 173, 181, 157, 104, 208, 97, 1] + ); + + let base = hex::decode("36abd4524e89a34c89c420942547e12c7be1").unwrap(); + let exponent = hex::decode("01000000000517ea78").unwrap(); + let modulus = hex::decode("02f0758c6a04200955b649c35722b800000000").unwrap(); + let result = crate::modexp(&base, &exponent, &modulus); + assert_eq!( + hex::encode(result), + "023f4f762936eb0973d46b6eadb59d68d06101" + ); + + fn check_modpow_even(base: u128, exp: u128, modulus: u128, expected: u128) { + let mut x = MPNat::from_big_endian(&base.to_be_bytes()); + let m = MPNat::from_big_endian(&modulus.to_be_bytes()); + let result = x.modpow(&exp.to_be_bytes(), &m); + let result = crate::arith::mp_nat_to_u128(&result); + assert_eq!(result, expected); + } +} + +#[test] +fn test_modpow_montgomery() { + check_modpow_montgomery(3, 5, 0x9346_9d50_1f74_d1c1, 243); + check_modpow_montgomery(3, 5, 19, 15); + check_modpow_montgomery( + 0x5c4b74ec760dfb021499f5c5e3c69222, + 0x62b2a34b21cf4cc036e880b3fb59fe09, + 0x7b799c4502cd69bde8bb12601ce3ff15, + 0x10c9d9071d0b86d6a59264d2f461200, + ); + check_modpow_montgomery( + 0xadb5ce8589030e3a9112123f4558f69c, + 0xb002827068f05b84a87431a70fb763ab, + 0xc4550871a1cfc67af3e77eceb2ecfce5, + 0x7cb78c0e1c1b43f6412e9d1155ea96d2, + ); + check_modpow_montgomery( + 0x26eb51a5d9bf15a536b6e3c67867b492, + 0xddf007944a79bf55806003220a58cc6, + 0xc96275a80c694a62330872b2690f8773, + 0x23b75090ead913def3a1e0bde863eda7, + ); + check_modpow_montgomery( + 0xb93fa81979e597f548c78f2ecb6800f3, + 0x5fad650044963a271898d644984cb9f0, + 0xbeb60d6bd0439ea39d447214a4f8d3ab, + 0x354e63e6a5e007014acd3e5ea88dc3ad, + ); + check_modpow_montgomery( + 0x1993163e4f578869d04949bc005c878f, + 0x8cb960f846475690259514af46868cf5, + 0x52e104dc72423b534d8e49d878f29e3b, + 0x2aa756846258d5cfa6a3f8b9b181a11c, + ); + + fn check_modpow_montgomery(base: u128, exp: u128, modulus: u128, expected: u128) { + let mut x = MPNat::from_big_endian(&base.to_be_bytes()); + let m = MPNat::from_big_endian(&modulus.to_be_bytes()); + let result = x.modpow_montgomery(&exp.to_be_bytes(), &m); + let result = crate::arith::mp_nat_to_u128(&result); + assert_eq!( + result, expected, + "({base} ^ {exp}) % {modulus} failed check_modpow_montgomery" + ); + } +} + +#[test] +fn test_modpow_with_power_of_two() { + check_modpow_with_power_of_two(3, 2, 1 << 30, 9); + check_modpow_with_power_of_two(3, 5, 1 << 30, 243); + check_modpow_with_power_of_two(3, 1_000_000, 1 << 30, 641836289); + check_modpow_with_power_of_two(3, 1_000_000, 1 << 31, 1715578113); + check_modpow_with_power_of_two(3, 1_000_000, 1 << 32, 3863061761); + check_modpow_with_power_of_two( + 0xabcd_ef01_2345_6789_1111, + 0x1234_5678_90ab_cdef, + 1 << 5, + 17, + ); + check_modpow_with_power_of_two( + 0x3f47_9dc0_d5b9_6003, + 0xa180_e045_e314_8581, + 1 << 118, + 0x0028_3d19_e6cc_b8a0_e050_6abb_b9b1_1a03, + ); + + fn check_modpow_with_power_of_two(base: u128, exp: u128, modulus: u128, expected: u128) { + let mut x = MPNat::from_big_endian(&base.to_be_bytes()); + let m = MPNat::from_big_endian(&modulus.to_be_bytes()); + let result = x.modpow_with_power_of_two(&exp.to_be_bytes(), &m); + let result = crate::arith::mp_nat_to_u128(&result); + assert_eq!(result, expected); + } +} + +#[test] +fn test_sub_to_same_size() { + check_sub_to_same_size(0x10_00_00_00_00, 0xFF_00_00_00); + check_sub_to_same_size(0x10_00_00_00_00, 0x01_00_00_00); + check_sub_to_same_size(0x35_00_00_00_00, 0x01_00_00_00); + check_sub_to_same_size(0xEF_00_00_00_00_00_00, 0x02_FF_FF_FF); + + let n = 10; + let a = 57 + 2 * n + 0x1234_0000_0000 * n + 0x000b_0000_0000_0000_0000 * n; + check_sub_to_same_size(a, n); + + fn check_sub_to_same_size(a: u128, n: u128) { + let mut x = MPNat::from_big_endian(&a.to_be_bytes()); + let y = MPNat::from_big_endian(&n.to_be_bytes()); + x.sub_to_same_size(&y); + assert!(x.digits.len() <= y.digits.len()); + let result = crate::arith::mp_nat_to_u128(&x); + assert_eq!(result % n, a % n, "{a} % {n} failed sub_to_same_size check"); + } +} + +#[test] +fn test_mp_nat_is_odd() { + for n in 0..1025 { + check_is_odd(n); + } + for n in 0xFF_FF_FF_FF_00_00_00_00..0xFF_FF_FF_FF_00_00_04_01 { + check_is_odd(n); + } + + fn check_is_odd(n: u128) { + let mp = MPNat::from_big_endian(&n.to_be_bytes()); + assert_eq!(mp.is_odd(), n % 2 == 1, "{n} failed is_odd test"); + } +} + +#[test] +fn test_mp_nat_is_power_of_two() { + check_is_p2(0, false); + check_is_p2(1, true); + check_is_p2(1327, false); + check_is_p2((1 << 1) + (1 << 35), false); + check_is_p2(1 << 1, true); + check_is_p2(1 << 2, true); + check_is_p2(1 << 3, true); + check_is_p2(1 << 4, true); + check_is_p2(1 << 5, true); + check_is_p2(1 << 31, true); + check_is_p2(1 << 32, true); + check_is_p2(1 << 64, true); + check_is_p2(1 << 65, true); + check_is_p2(1 << 127, true); + + fn check_is_p2(n: u128, expected_result: bool) { + let mp = MPNat::from_big_endian(&n.to_be_bytes()); + assert_eq!( + mp.is_power_of_two(), + expected_result, + "{n} failed is_power_of_two test" + ); + } +} + +#[test] +fn test_mp_nat_be() { + be_round_trip(""); + be_round_trip("00"); + be_round_trip("77"); + be_round_trip("abcd"); + be_round_trip("00000000abcd"); + be_round_trip("abcdef"); + be_round_trip("abcdef00"); + be_round_trip("abcdef0011"); + be_round_trip("abcdef001122"); + be_round_trip("abcdef00112233"); + be_round_trip("abcdef0011223344"); + be_round_trip("abcdef001122334455"); + be_round_trip("abcdef01234567891011121314151617181920"); + + fn be_round_trip(hex_input: &str) { + let bytes = hex::decode(hex_input).unwrap(); + let mp = MPNat::from_big_endian(&bytes); + let output = mp.to_big_endian(); + let hex_output = hex::encode(output); + let trimmed = match hex_input.trim_start_matches('0') { + x if x.is_empty() => "00", + x => x, + }; + assert_eq!(hex_output, trimmed) + } +} diff --git a/arithmetic/src/main/java/org/hyperledger/besu/nativelib/arithmetic/LibArithmetic.java b/arithmetic/src/main/java/org/hyperledger/besu/nativelib/arithmetic/LibArithmetic.java index 61a68586..aefe3099 100644 --- a/arithmetic/src/main/java/org/hyperledger/besu/nativelib/arithmetic/LibArithmetic.java +++ b/arithmetic/src/main/java/org/hyperledger/besu/nativelib/arithmetic/LibArithmetic.java @@ -21,28 +21,20 @@ public class LibArithmetic implements Library { - private LibArithmetic() { + private LibArithmetic() {} - } - - public static final boolean ENABLED; + public static final boolean ENABLED; - static { - boolean enabled; - try { - Native.register(LibArithmetic.class, "eth_arithmetic"); - enabled = true; - } catch (final Exception t) { - enabled = false; - } - ENABLED = enabled; + static { + boolean enabled; + try { + Native.register(LibArithmetic.class, "eth_arithmetic"); + enabled = true; + } catch (final Exception t) { + enabled = false; } + ENABLED = enabled; + } - public static native int mul_operation( - byte[] a, int a_len, byte[] b, int b_len, byte[] o, IntByReference o_len); - - public static native int modexp_precompiled( - byte[] i, int i_len, byte[] o, IntByReference o_len); + public static native int modexp_precompiled(byte[] i, int i_len, byte[] o, IntByReference o_len); } - - diff --git a/arithmetic/src/test/java/org/hyperledger/besu/nativelib/arithmetic/TestLibArithmetic.java b/arithmetic/src/test/java/org/hyperledger/besu/nativelib/arithmetic/TestLibArithmetic.java index 02e09600..3dba00ed 100644 --- a/arithmetic/src/test/java/org/hyperledger/besu/nativelib/arithmetic/TestLibArithmetic.java +++ b/arithmetic/src/test/java/org/hyperledger/besu/nativelib/arithmetic/TestLibArithmetic.java @@ -101,65 +101,4 @@ void testModExp(String inputString, String outputString) { assertThat(result).isEqualTo(output); } - public static Object[][] mulParameters() { - return new Object[][] { - {"03", "02", "06"}, - { - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - "0000000000000000000000000000000000000000000000000000000000000001" - }, - {"17", "00", "00"}, - {"17", "", "00"}, - {"01", "17", "17"}, - { - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - "8000000000000000000000000000000000000000000000000000000000000000", - "8000000000000000000000000000000000000000000000000000000000000000" - }, - { - "8000000000000000000000000000000000000000000000000000000000000000", - "8000000000000000000000000000000000000000000000000000000000000000", - "0000000000000000000000000000000000000000000000000000000000000000" - }, - { - "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - "0000000000000000000000000000000000000000000000000000000000000001" - }, - { - "01234567890abcdef0fedcba0987654321", - "01234567890abcdef0fedcba0987654321", - "4b66dc328828bca88b5309b760ec6bf947034577db029a3acefea12cd7a44a41" - }, - { - "4b66dc328828bca88b5309b760ec6bf947034577db029a3acefea12cd7a44a41", - "7001234567890abcdef0fedcba0987654321", - "b634f14f64e16c7ab623657a91b05f9d7d9a6fb4c2c1c442d000107a5e419561" - } - }; - } - - @MethodSource("mulParameters") - @ParameterizedTest - void testMul(String term1String, String term2String, String outputString) { - Bytes term1 = Bytes.fromHexString(term1String); - Bytes term2 = Bytes.fromHexString(term2String); - Bytes output = Bytes.fromHexString(outputString); - - // byte[] resultArray = new byte[output.size() * 2]; - byte[] resultArray = new byte[64]; - IntByReference resultSize = new IntByReference(resultArray.length); - LibArithmetic.mul_operation( - term1.toArrayUnsafe(), - term1.size(), - term2.toArrayUnsafe(), - term2.size(), - resultArray, - resultSize); - - Bytes result = Bytes.wrap(resultArray, 0, resultSize.getValue()); - - assertThat(result).isEqualTo(output); - } } From 937487cb516441680bca7763a521e558e80addaa Mon Sep 17 00:00:00 2001 From: Danno Ferrin Date: Wed, 2 Aug 2023 17:47:32 -0600 Subject: [PATCH 4/4] aurora modexp impl Port over aurora's modexp implementaiton. Signed-off-by: Danno Ferrin --- README.md | 2 +- arithmetic/arithmetic/src/arith.rs | 94 ++++++++---------- arithmetic/arithmetic/src/mpnat.rs | 147 ++++++++++++++++++++++------- build.sh | 5 +- 4 files changed, 155 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index b32ba76e..743af57c 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ TBD ### Rust -Rust needs to be installed to compile the altbn128 library. The default way to install it on Linux or OS X is: +Rust needs to be installed to compile the arithmetic and bls12-381 libraries. The default way to install it on Linux or OS X is: ``` curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh diff --git a/arithmetic/arithmetic/src/arith.rs b/arithmetic/arithmetic/src/arith.rs index bc841c7f..5fa3d87e 100644 --- a/arithmetic/arithmetic/src/arith.rs +++ b/arithmetic/arithmetic/src/arith.rs @@ -158,32 +158,6 @@ pub fn mod_inv(x: Word) -> Word { y } -// Given x odd, computes `x^(-1) mod 2^(WORD_BYTES*out.digits.len())`. -// See `MODULAR-INVERSE` in https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf -pub fn big_mod_inv(x: &MPNat, out: &mut MPNat, scratch: &mut [Word]) { - let s = out.digits.len(); - out.digits[0] = mod_inv(x.digits[0]); - - for digit_index in 1..s { - for i in 1..WORD_BITS { - let mask = (1 << i) - 1; - big_wrapping_mul(x, out, scratch); - scratch[digit_index] &= mask; - let q = 1 << (i - 1); - if scratch[digit_index] >= q { - out.digits[digit_index] += q; - } - scratch.fill(0); - } - big_wrapping_mul(x, out, scratch); - let q = 1 << (WORD_BITS - 1); - if scratch[digit_index] >= q { - out.digits[digit_index] += q; - } - scratch.fill(0); - } -} - /// Computes R mod n, where R = 2^(WORD_BITS*k) and k = n.digits.len() /// Note that if R = qn + r, q must be smaller than 2^WORD_BITS since `2^(WORD_BITS) * n > R` /// (adding a whole additional word to n is too much). @@ -294,14 +268,22 @@ pub fn big_sq(x: &MPNat, out: &mut [Word]) { out[i + i] = product; let mut c = carry as DoubleWord; for j in (i + 1)..s { - let product = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord); - let (product, overflow) = product.overflowing_add(product); - let sum = (out[i + j] as DoubleWord) + product + c; - out[i + j] = sum as Word; - c = (sum >> WORD_BITS) as DoubleWord; + let mut new_c: DoubleWord = 0; + let res = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord); + let (res, overflow) = res.overflowing_add(res); if overflow { - c += BASE; + new_c += BASE; } + let (res, overflow) = (out[i + j] as DoubleWord).overflowing_add(res); + if overflow { + new_c += BASE; + } + let (res, overflow) = res.overflowing_add(c); + if overflow { + new_c += BASE; + } + out[i + j] = res as Word; + c = new_c + ((res >> WORD_BITS) as DoubleWord); } let (sum, carry) = carrying_add(out[i + s], c as Word, false); out[i + s] = sum; @@ -351,6 +333,11 @@ pub fn in_place_add(a: &mut [Word], b: &[Word]) -> bool { pub fn in_place_mul_sub(a: &mut [Word], x: &[Word], y: Word) -> Word { debug_assert!(a.len() == x.len()); + // a -= x*0 leaves a unchanged, so return early + if y == 0 { + return 0; + } + // carry is between -big_digit::MAX and 0, so to avoid overflow we store // offset_carry = carry + big_digit::MAX let mut offset_carry = Word::MAX; @@ -504,31 +491,6 @@ fn test_r_mod_n() { } } -#[test] -fn test_big_mod_inv() { - check_big_mod_inv(0x02_FF_FF_FF); - check_big_mod_inv(0x1234_0000_DDDD_FFFF); - check_big_mod_inv(0x52DA_9A91_F82D_6E17_FDF8_6743_2B58_7917); - - fn check_big_mod_inv(n: u128) { - let x = MPNat::from_big_endian(&n.to_be_bytes()); - let s = x.digits.len(); - let mut result = MPNat { digits: vec![0; s] }; - let mut scratch = vec![0; s]; - big_mod_inv(&x, &mut result, &mut scratch); - let n_inv = mp_nat_to_u128(&result); - if WORD_BITS * s < u128::BITS as usize { - assert_eq!( - n.wrapping_mul(n_inv) % (1 << (WORD_BITS * s)), - 1, - "{n} failed big_mod_inv check" - ); - } else { - assert_eq!(n.wrapping_mul(n_inv), 1, "{n} failed big_mod_inv check"); - } - } -} - #[test] fn test_in_place_shl() { check_in_place_shl(0, 0); @@ -655,6 +617,24 @@ fn test_big_sq() { }; assert_eq!(result, expected, "{a}^2 != {expected}"); } + + /* Test for addition overflows in the big_sq inner loop */ + { + let x = MPNat::from_big_endian(&[ + 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x40, 0x00, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00, + ]); + let mut out = vec![0; 2 * x.digits.len() + 1]; + big_sq(&x, &mut out); + let result = MPNat { digits: out }.to_big_endian(); + let expected = vec![ + 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0xff, 0xff, 0xff, 0xfe, 0x40, 0x00, 0x00, 0x01, 0x90, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xbf, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + assert_eq!(result, expected); + } } #[test] diff --git a/arithmetic/arithmetic/src/mpnat.rs b/arithmetic/arithmetic/src/mpnat.rs index 802d5321..66618a11 100644 --- a/arithmetic/arithmetic/src/mpnat.rs +++ b/arithmetic/arithmetic/src/mpnat.rs @@ -1,10 +1,11 @@ use crate::{ arith::{ - big_mod_inv, big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, - compute_r_mod_n, in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, - join_as_double, mod_inv, monpro, monsq, + big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, compute_r_mod_n, + in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, join_as_double, mod_inv, + monpro, monsq, }, }; +pub use std::{vec, vec::Vec}; pub type Word = u64; pub type DoubleWord = u128; @@ -21,6 +22,61 @@ pub struct MPNat { } impl MPNat { + // KoƧ's algorithm for inversion mod 2^k + // https://eprint.iacr.org/2017/411.pdf + fn koc_2017_inverse(aa: &Self, k: usize) -> Self { + debug_assert!(aa.is_odd()); + + let length = k / WORD_BITS; + let mut b = MPNat { + digits: vec![0; length + 1], + }; + b.digits[0] = 1; + + let mut a = MPNat { + digits: aa.digits.clone(), + }; + a.digits.resize(length + 1, 0); + + let mut neg: bool = false; + + let mut res = MPNat { + digits: vec![0; length + 1], + }; + + let (mut wordpos, mut bitpos) = (0, 0); + + for _ in 0..k { + let x = b.digits[0] & 1; + if x != 0 { + if !neg { + // b = a - b + let mut tmp = MPNat { + digits: a.digits.clone(), + }; + in_place_mul_sub(&mut tmp.digits, &b.digits, 1); + b = tmp; + neg = true; + } else { + // b = b - a + in_place_add(&mut b.digits, &a.digits); + } + } + + in_place_shr(&mut b.digits, 1); + + res.digits[wordpos] |= x << bitpos; + + bitpos += 1; + if bitpos == WORD_BITS { + bitpos = 0; + wordpos += 1; + } + } + + res + } + pub fn from_big_endian(bytes: &[u8]) -> Self { if bytes.is_empty() { return Self { digits: vec![0] }; @@ -87,6 +143,10 @@ impl MPNat { /// Computes `self ^ exp mod modulus`. `exp` must be given as big-endian bytes. pub fn modpow(&mut self, exp: &[u8], modulus: &Self) -> Self { + if exp.iter().all(|x| x == &0) { + return Self { digits: vec![1] }; + } + if exp.len() <= core::mem::size_of::() { let exp_as_number = { let mut tmp: usize = 0; @@ -169,14 +229,11 @@ impl MPNat { let x1 = base_copy.modpow_montgomery(exp, &odd); let x2 = self.modpow_with_power_of_two(exp, &power_of_two); + let odd_inv = + Self::koc_2017_inverse(&odd, trailing_zeros * WORD_BITS + additional_zero_bits); + let s = power_of_two.digits.len(); let mut scratch = vec![0; s]; - let odd_inv = { - let mut tmp = MPNat { digits: vec![0; s] }; - big_mod_inv(&odd, &mut tmp, &mut scratch); - *tmp.digits.last_mut().unwrap() &= power_of_two_mask; - tmp - }; let diff = { scratch.fill(0); let mut b = false; @@ -343,7 +400,7 @@ impl MPNat { return; } - let other_most_sig = *other.digits.last().unwrap(); + let other_most_sig = *other.digits.last().unwrap() as DoubleWord; if self.digits.len() == 2 { // This is the smallest case since `n >= 1` and `m > 0` @@ -352,7 +409,7 @@ impl MPNat { // to get the answer directly. let self_most_sig = self.digits.pop().unwrap(); let a = join_as_double(self_most_sig, self.digits[0]); - let b = other_most_sig as DoubleWord; + let b = other_most_sig; self.digits[0] = (a % b) as Word; return; } @@ -364,8 +421,7 @@ impl MPNat { for j in (0..k).rev() { let self_most_sig = self.digits.pop().unwrap(); let self_second_sig = self.digits[j]; - let r = - join_as_double(self_most_sig, self_second_sig) % (other_most_sig as DoubleWord); + let r = join_as_double(self_most_sig, self_second_sig) % other_most_sig; self.digits[j] = r as Word; } return; @@ -380,7 +436,7 @@ impl MPNat { // both numerator and denominator by a common factor // and run the algorithm on those numbers. // See Knuth The Art of Computer Programming vol. 2 section 4.3 for details. - let shift = other_most_sig.leading_zeros(); + let shift = (other_most_sig as Word).leading_zeros(); if shift > 0 { // Normalize self let overflow = in_place_shl(&mut self.digits, shift); @@ -405,34 +461,31 @@ impl MPNat { return; } - let other_second_sig = other.digits[n - 2]; + let other_second_sig = other.digits[n - 2] as DoubleWord; let mut self_most_sig: Word = 0; for j in (0..=m).rev() { let self_second_sig = *self.digits.last().unwrap(); let self_third_sig = self.digits[self.digits.len() - 2]; - let (mut q_hat, mut r_hat) = { - let a = join_as_double(self_most_sig, self_second_sig); - let mut q_hat = a / (other_most_sig as DoubleWord); - let mut r_hat = a % (other_most_sig as DoubleWord); + let a = join_as_double(self_most_sig, self_second_sig); + let mut q_hat = a / other_most_sig; + let mut r_hat = a % other_most_sig; - if q_hat == BASE { + loop { + let a = q_hat * other_second_sig; + let b = join_as_double(r_hat as Word, self_third_sig); + if q_hat >= BASE || a > b { q_hat -= 1; - r_hat += other_most_sig as DoubleWord; + r_hat += other_most_sig; + if BASE <= r_hat { + break; + } + } else { + break; } - - (q_hat as Word, r_hat) - }; - - while r_hat < BASE - && join_as_double(r_hat as Word, self_third_sig) - < (q_hat as DoubleWord) * (other_second_sig as DoubleWord) - { - q_hat -= 1; - r_hat += other_most_sig as DoubleWord; } - let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat); + let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat as Word); if borrow > self_most_sig { // q_hat was too large, add back one multiple of the modulus let carry = in_place_add(&mut self.digits[j..], &other.digits); @@ -655,6 +708,34 @@ fn test_sub_to_same_size() { let result = crate::arith::mp_nat_to_u128(&x); assert_eq!(result % n, a % n, "{a} % {n} failed sub_to_same_size check"); } + + /* Test that borrow equals self_most_sig at end of sub_to_same_size */ + { + let mut x = MPNat::from_big_endian(&[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02, + 0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x02, 0xc9, 0xd9, 0x12, 0x61, 0x8e, 0xf5, 0x02, 0x2c, + 0xa0, 0x00, 0x00, 0x00, + ]); + let y = MPNat::from_big_endian(&[ + 0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02, 0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x0f, 0x70, 0x00, + 0x00, 0x00, + ]); + x.sub_to_same_size(&y); + } + + /* Additional test for sub_to_same_size q_hat/r_hat adjustment logic */ + { + let mut x = MPNat::from_big_endian(&[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, + 0xff, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + ]); + let y = MPNat::from_big_endian(&[ + 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, + 0x00, 0x00, + ]); + x.sub_to_same_size(&y); + } } #[test] @@ -726,4 +807,4 @@ fn test_mp_nat_be() { }; assert_eq!(hex_output, trimmed) } -} +} \ No newline at end of file diff --git a/build.sh b/build.sh index 4d8c02e3..9421d093 100755 --- a/build.sh +++ b/build.sh @@ -244,13 +244,14 @@ EOF SWITCHES=$2 # build both architectures - cargo build --lib $SWITCHES --release --target=x86_64-apple-darwin +# cargo build --lib $SWITCHES --release --target=x86_64-apple-darwin cargo build --lib $SWITCHES --release --target=aarch64-apple-darwin lipo -create \ -output target/release/$1.dylib \ - -arch x86_64 target/x86_64-apple-darwin/release/$1.dylib \ -arch arm64 target/aarch64-apple-darwin/release/$1.dylib +# -arch x86_64 target/x86_64-apple-darwin/release/$1.dylib \ + } build_secp256r1() {