Skip to content

Commit

Permalink
use hashToCurve and remove unsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
klkvr committed Oct 1, 2024
1 parent a1288ae commit 3482bc3
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 120 deletions.
145 changes: 29 additions & 116 deletions examples/rust-multisig/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use alloy::{primitives::U256, providers::ProviderBuilder, sol, sol_types::SolValue};
use blst::{
blst_bendian_from_fp, blst_fp, blst_fp2, blst_fp_from_bendian, blst_keygen, blst_p1,
blst_p1_affine, blst_p1_to_affine, blst_p2, blst_p2_add_or_double, blst_p2_affine,
blst_p2_from_affine, blst_p2_to_affine, blst_scalar, blst_sign_pk_in_g1, blst_sk_to_pk_in_g1,
};
use blst::min_pk::{AggregateSignature, SecretKey, Signature};
use rand::RngCore;
use BLS::G2Point;

Expand All @@ -14,112 +10,41 @@ sol! {
"../../out/BLSMultisig.sol/BLSMultisig.json"
}

impl From<BLS::Fp> for blst_fp {
fn from(value: BLS::Fp) -> Self {
let data = value.abi_encode();
impl From<[u8; 96]> for BLS::G1Point {
fn from(value: [u8; 96]) -> Self {
let mut data = [0u8; 128];
data[16..64].copy_from_slice(&value[0..48]);
data[80..128].copy_from_slice(&value[48..96]);

let mut val = blst_fp::default();
unsafe { blst_fp_from_bendian(&mut val, data[16..].as_ptr()) };

val
}
}

impl From<blst_fp> for BLS::Fp {
fn from(value: blst_fp) -> Self {
let mut data = [0u8; 48];
unsafe { blst_bendian_from_fp(data.as_mut_ptr(), &value) };

Self {
a: U256::from_be_slice(&data[..16]),
b: U256::from_be_slice(&data[16..]),
}
}
}

impl From<BLS::Fp2> for blst_fp2 {
fn from(value: BLS::Fp2) -> Self {
Self {
fp: [value.c0.into(), value.c1.into()],
}
}
}

impl From<blst_fp2> for BLS::Fp2 {
fn from(value: blst_fp2) -> Self {
Self {
c0: value.fp[0].into(),
c1: value.fp[1].into(),
}
}
}

impl From<BLS::G2Point> for blst_p2 {
fn from(value: BLS::G2Point) -> Self {
let b_aff = blst_p2_affine {
x: value.x.into(),
y: value.y.into(),
};

let mut b = blst_p2::default();
unsafe { blst_p2_from_affine(&mut b, &b_aff) };

b
}
}

impl From<blst_p2> for BLS::G2Point {
fn from(value: blst_p2) -> Self {
let mut affine = blst_p2_affine::default();
unsafe { blst_p2_to_affine(&mut affine, &value) };

BLS::G2Point {
x: affine.x.into(),
y: affine.y.into(),
}
BLS::G1Point::abi_decode(&data, false).unwrap()
}
}

impl From<blst_p1> for BLS::G1Point {
fn from(value: blst_p1) -> Self {
let mut affine = blst_p1_affine::default();
unsafe { blst_p1_to_affine(&mut affine, &value) };
impl From<[u8; 192]> for BLS::G2Point {
fn from(value: [u8; 192]) -> Self {
let mut data = [0u8; 256];
data[16..64].copy_from_slice(&value[48..96]);
data[80..128].copy_from_slice(&value[0..48]);
data[144..192].copy_from_slice(&value[144..192]);
data[208..256].copy_from_slice(&value[96..144]);

BLS::G1Point {
x: affine.x.into(),
y: affine.y.into(),
}
BLS::G2Point::abi_decode(&data, false).unwrap()
}
}

/// Generates `num` BLS keys and returns them as a tuple of secret keys and public keys, sorted by public key.
fn generate_keys(num: usize) -> (Vec<blst_scalar>, Vec<BLS::G1Point>) {
fn generate_keys(num: usize) -> (Vec<SecretKey>, Vec<BLS::G1Point>) {
let mut rng = rand::thread_rng();
let mut keys = Vec::with_capacity(num);

for _ in 0..num {
let mut ikm = [0u8; 32];
rng.fill_bytes(&mut ikm);

let key_info: &[u8] = &[];

// secret key
let mut sk = blst_scalar::default();
unsafe {
blst_keygen(
&mut sk,
ikm.as_ptr(),
ikm.len(),
key_info.as_ptr(),
key_info.len(),
)
};

// public key
let mut pk = blst_p1::default();
unsafe { blst_sk_to_pk_in_g1(&mut pk, &sk) }

keys.push((sk, BLS::G1Point::from(pk)));
let sk = SecretKey::key_gen(&ikm, &[]).unwrap();
let pk: BLS::G1Point = sk.sk_to_pk().serialize().into();

keys.push((sk, pk));
}

keys.sort_by(|(_, pk1), (_, pk2)| pk1.cmp(pk2));
Expand All @@ -128,24 +53,20 @@ fn generate_keys(num: usize) -> (Vec<blst_scalar>, Vec<BLS::G1Point>) {
}

/// Signs a message with the provided keys and returns the aggregated signature.
fn sign_message(keys: &[blst_scalar], message: blst_p2) -> G2Point {
let mut signatures = Vec::new();
fn sign_message(keys: &[SecretKey], msg: &[u8]) -> G2Point {
let mut sigs = Vec::new();

// create individual signatures
for key in keys {
let mut sig = blst_p2::default();
unsafe { blst_sign_pk_in_g1(&mut sig, &message, key) };

signatures.push(sig);
let sig = key.sign(msg, b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_", &[]);
sigs.push(sig);
}

// aggregate signatures by adding them
let mut agg_sig = signatures.swap_remove(0);
for sig in signatures {
unsafe { blst_p2_add_or_double(&mut agg_sig, &agg_sig, &sig) };
}
let agg_sig = Signature::from_aggregate(
&AggregateSignature::aggregate(sigs.iter().collect::<Vec<_>>().as_slice(), false).unwrap(),
);

agg_sig.into()
agg_sig.serialize().into()
}

#[tokio::main]
Expand All @@ -160,15 +81,7 @@ pub async fn main() {

let operation = BLSMultisig::Operation::default();

let point: blst_p2 = multisig
.getOperationPoint(operation.clone())
.call()
.await
.unwrap()
._0
.into();

let signature = sign_message(&keys, point);
let signature = sign_message(&keys, &operation.abi_encode());

let receipt = multisig
.verifyAndExecute(BLSMultisig::SignedOperation {
Expand Down
8 changes: 4 additions & 4 deletions src/BLSMultisig.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ pragma solidity ^0.8.23;

import {BLS} from "./sign/BLS.sol";

/// @notice BLS-powered multisignature wallet, demonstrating the use of
/// @notice BLS-powered multisignature wallet, demonstrating the use of
/// aggregated BLS signatures for verification
/// @dev This is for demonstration purposes only, do not use in production. This contract does
/// @dev This is for demonstration purposes only, do not use in production. This contract does
/// not include protection from rogue public-key attacks. You
contract BLSMultisig {
/// @notice Public keys of signers. This may contain a pre-aggregated
/// @notice Public keys of signers. This may contain a pre-aggregated
/// public keys for common sets of signers as well.
mapping(bytes32 => bool) public signers;

Expand Down Expand Up @@ -52,7 +52,7 @@ contract BLSMultisig {

/// @notice Maps an operation to a point on G2 which needs to be signed.
function getOperationPoint(Operation memory op) public view returns (BLS.G2Point memory) {
return BLS.MapFp2ToG2(BLS.Fp2(BLS.Fp(0, 0), BLS.Fp(0, uint256(keccak256(abi.encode(op))))));
return BLS.hashToCurveG2(abi.encode(op));
}

/// @notice Accepts an operation signed by a subset of the signers and executes it
Expand Down
133 changes: 133 additions & 0 deletions src/sign/BLS.sol
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,137 @@ library BLS {
require(success, "MAP_FP2_TO_G2 failed");
return abi.decode(output, (G2Point));
}

/// @notice Computes a point in G2 from a message
/// @dev Uses the eip-2537 precompiles
/// @param message Arbitrarylength byte string to be hashed
/// @return A point in G2
function hashToCurveG2(bytes memory message) internal view returns (G2Point memory) {
// 1. u = hash_to_field(msg, 2)
Fp2[2] memory u = hashToFieldFp2(message, bytes("BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_"));
// 2. Q0 = map_to_curve(u[0])
G2Point memory q0 = MapFp2ToG2(u[0]);
// 3. Q1 = map_to_curve(u[1])
G2Point memory q1 = MapFp2ToG2(u[1]);
// 4. R = Q0 + Q1
return G2Add(q0, q1);
}

/// @notice Computes a field point from a message
/// @dev Follows https://datatracker.ietf.org/doc/html/rfc9380#section-5.2
/// @param message Arbitrarylength byte string to be hashed
/// @param dst The domain separation tag
/// @return Two field points
function hashToFieldFp2(bytes memory message, bytes memory dst) private view returns (Fp2[2] memory) {
// 1. len_in_bytes = count * m * L
// so always 2 * 2 * 64 = 256
uint16 lenInBytes = 256;
// 2. uniform_bytes = expand_message(msg, DST, len_in_bytes)
bytes32[] memory pseudoRandomBytes = expandMsgXmd(message, dst, lenInBytes);
Fp2[2] memory u;
// No loop here saves 800 gas hardcoding offset an additional 300
// 3. for i in (0, ..., count - 1):
// 4. for j in (0, ..., m - 1):
// 5. elm_offset = L * (j + i * m)
// 6. tv = substr(uniform_bytes, elm_offset, HTF_L)
// uint8 HTF_L = 64;
// bytes memory tv = new bytes(64);
// 7. e_j = OS2IP(tv) mod p
// 8. u_i = (e_0, ..., e_(m - 1))
// tv = bytes.concat(pseudo_random_bytes[0], pseudo_random_bytes[1]);
u[0].c0 = _modfield(pseudoRandomBytes[0], pseudoRandomBytes[1]);
u[0].c1 = _modfield(pseudoRandomBytes[2], pseudoRandomBytes[3]);
u[1].c0 = _modfield(pseudoRandomBytes[4], pseudoRandomBytes[5]);
u[1].c1 = _modfield(pseudoRandomBytes[6], pseudoRandomBytes[7]);
// 9. return (u_0, ..., u_(count - 1))
return u;
}

/// @notice Computes a field point from a message
/// @dev Follows https://datatracker.ietf.org/doc/html/rfc9380#section-5.3
/// @dev bytes32[] because len_in_bytes is always a multiple of 32 in our case even 128
/// @param message Arbitrarylength byte string to be hashed
/// @param dst The domain separation tag of at most 255 bytes
/// @param lenInBytes The length of the requested output in bytes
/// @return A field point
function expandMsgXmd(bytes memory message, bytes memory dst, uint16 lenInBytes)
private
pure
returns (bytes32[] memory)
{
// 1. ell = ceil(len_in_bytes / b_in_bytes)
// b_in_bytes seems to be 32 for sha256
// ceil the division
uint256 ell = (lenInBytes - 1) / 32 + 1;

// 2. ABORT if ell > 255 or len_in_bytes > 65535 or len(DST) > 255
require(ell <= 255, "len_in_bytes too large for sha256");
// Not really needed because of parameter type
// require(lenInBytes <= 65535, "len_in_bytes too large");
// no length normalizing via hashing
require(dst.length <= 255, "dst too long");

bytes memory dstPrime = bytes.concat(dst, bytes1(uint8(dst.length)));

// 4. Z_pad = I2OSP(0, s_in_bytes)
// this should be sha256 blocksize so 64 bytes
bytes memory zPad = new bytes(64);

// 5. l_i_b_str = I2OSP(len_in_bytes, 2)
// length in byte string?
bytes2 libStr = bytes2(lenInBytes);

// 6. msg_prime = Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime
bytes memory msgPrime = bytes.concat(zPad, message, libStr, hex"00", dstPrime);

// 7. b_0 = H(msg_prime)
bytes32 b_0 = sha256(msgPrime);

bytes32[] memory b = new bytes32[](ell);

// 8. b_1 = H(b_0 || I2OSP(1, 1) || DST_prime)
b[0] = sha256(bytes.concat(b_0, hex"01", dstPrime));

// 9. for i in (2, ..., ell):
for (uint8 i = 2; i <= ell; i++) {
// 10. b_i = H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime)
bytes memory tmp = abi.encodePacked(b_0 ^ b[i - 2], i, dstPrime);
b[i - 1] = sha256(tmp);
}
// 11. uniform_bytes = b_1 || ... || b_ell
// 12. return substr(uniform_bytes, 0, len_in_bytes)
// Here we don't need the uniform_bytes because b is already properly formed
return b;
}

// passing two bytes32 instead of bytes memory saves approx 700 gas per call
// Computes the mod against the bls12-381 field modulus
function _modfield(bytes32 _b1, bytes32 _b2) private view returns (Fp memory r) {
(bool success, bytes memory output) = address(0x5).staticcall(
abi.encode(
// arg[0] = base.length
0x40,
// arg[1] = exp.length
0x20,
// arg[2] = mod.length
0x40,
// arg[3] = base.bits @ + 0x60
// places the first 32 bytes of _b1 and the last 32 bytes of _b2
_b1,
_b2,
// arg[4] = exp
// exponent always 1
1,
// arg[5] = mod
// this field_modulus as hex 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787
// we add the 0 prefix so that the result will be exactly 64 bytes
// saves 300 gas per call instead of sending it along every time
// places the first 32 bytes and the last 32 bytes of the field modulus
0x000000000000000000000000000000001a0111ea397fe69a4b1ba7b6434bacd7, // arg[5] = mod
0x64774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab //
)
);
require(success, "MODEXP failed");
return abi.decode(output, (Fp));
}
}

0 comments on commit 3482bc3

Please sign in to comment.