Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add parallel g1/g2 msm gnark-crypto impl #217

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 115 additions & 10 deletions gnark/gnark-jni/gnark-eip-2537.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"math/big"
"reflect"
"unsafe"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fp"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
)

const (
Expand Down Expand Up @@ -167,6 +169,54 @@ func eip2537blsG1MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn
return nonMontgomeryMarshalG1(result, javaOutputBuf, errorBuf)
}

//export eip2537blsG1MultiExpParallel
func eip2537blsG1MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int, nbTasks C.int) C.int {
inputLen := int(cInputLen)
errorLen := int(cOutputLen)

// Convert error C pointers to Go slices
errorBuf := castBuffer(javaErrorBuf, errorLen)

if inputLen == 0 {
copy(errorBuf, "invalid input parameters, invalid number of pairs\x00")
return 1
}

if inputLen % (EIP2537PreallocateForG1 + EIP2537PreallocateForScalar) != 0 {
copy(errorBuf, "invalid input parameters, invalid input length for G1 multiplication\x00")
return 1
}

// Convert input C pointers to Go slice
input := castBufferToSlice(unsafe.Pointer(javaInputBuf), inputLen)

var exprCount = inputLen / (EIP2537PreallocateForG1 + EIP2537PreallocateForScalar)

g1Points := make([]bls12381.G1Affine, exprCount)
scalars := make([]fr.Element, exprCount)

for i := 0 ; i < exprCount ; i++ {
_, err := g1AffineDecodeInSubGroupVal(&g1Points[i], input[i*160 : (i*160)+128])
if err != nil {
copy(errorBuf, err.Error())
return 1
}

scalars[i].SetBytes(input[(i*160)+128 : (i+1)*160])
}

var affineResult bls12381.G1Affine
// leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit
_, err := affineResult.MultiExp(g1Points, scalars, ecc.MultiExpConfig{NbTasks: int(nbTasks)})
if err != nil {
copy(errorBuf, err.Error())
return 1
}

// marshal the resulting point and encode directly to the output buffer
return nonMontgomeryMarshalG1(&affineResult, javaOutputBuf, errorBuf)
}

//export eip2537blsG2Add
func eip2537blsG2Add(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int {
inputLen := int(cInputLen)
Expand Down Expand Up @@ -289,6 +339,58 @@ func eip2537blsG2MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn
return nonMontgomeryMarshalG2(result, javaOutputBuf, errorBuf)
}

//export eip2537blsG2MultiExpParallel
func eip2537blsG2MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int, nbTasks C.int) C.int {
inputLen := int(cInputLen)
errorLen := int(cOutputLen)

// Convert error C pointers to Go slices
errorBuf := castBuffer(javaErrorBuf, errorLen)

if inputLen == 0 {
copy(errorBuf, "invalid input parameters, invalid number of pairs\x00")
return 1
}

if inputLen % (EIP2537PreallocateForG2 + EIP2537PreallocateForScalar) != 0 {
copy(errorBuf, "invalid input parameters, invalid input length for G2 multiplication\x00")
return 1
}

// Convert input C pointers to Go slice
input := castBufferToSlice(unsafe.Pointer(javaInputBuf), inputLen)

var exprCount = inputLen / (EIP2537PreallocateForG2 + EIP2537PreallocateForScalar)

g2Points := make([]bls12381.G2Affine, exprCount)
scalars := make([]fr.Element, exprCount)

for i := 0 ; i < exprCount ; i++ {
_, err := g2AffineDecodeInSubGroupVal(&g2Points[i], input[i*288 : (i*288)+256])
if err != nil {
copy(errorBuf, err.Error())
return 1
}

scalars[i].SetBytes(input[(i*288)+256 : (i+1)*288])
}

var affineResult bls12381.G2Affine
// leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit
_, err := affineResult.MultiExp(g2Points, scalars, ecc.MultiExpConfig{NbTasks: int(nbTasks)})
if err != nil {
copy(errorBuf, err.Error())
return 1
}

// marshal the resulting point and encode directly to the output buffer
return nonMontgomeryMarshalG2(&affineResult, javaOutputBuf, errorBuf)
}





//export eip2537blsPairing
func eip2537blsPairing(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int {
inputLen := int(cInputLen)
Expand Down Expand Up @@ -439,25 +541,24 @@ func hasWrongG1Padding(input []byte) bool {
func hasWrongG2Padding(input []byte) bool {
return !isZero(input[:16]) || !isZero(input[64:80] )|| !isZero(input[128:144]) || !isZero(input[192:208])
}


func g1AffineDecodeInSubGroup(input []byte) (*bls12381.G1Affine, error) {
var g1 bls12381.G1Affine
return g1AffineDecodeInSubGroupVal(&g1, input)
}

func g1AffineDecodeInSubGroupVal(g1 *bls12381.G1Affine, input []byte) (*bls12381.G1Affine, error) {
if hasWrongG1Padding(input) {
return nil, ErrMalformedPointPadding
}
var g1x, g1y fp.Element
err := g1x.SetBytesCanonical(input[16:64])
err := g1.X.SetBytesCanonical(input[16:64])
if err != nil {
return nil, err
}
err = g1y.SetBytesCanonical(input[80:128])
err = g1.Y.SetBytesCanonical(input[80:128])
if err != nil {
return nil, err
}

// construct g1affine directly rather than unmarshalling
g1 := &bls12381.G1Affine{X: g1x, Y: g1y}

// do explicit subgroup check
if (!g1.IsInSubGroup()) {
if (!g1.IsOnCurve()) {
Expand Down Expand Up @@ -493,11 +594,15 @@ func g1AffineDecodeOnCurve(input []byte) (*bls12381.G1Affine, error) {
}

func g2AffineDecodeInSubGroup(input []byte) (*bls12381.G2Affine, error) {
var g2 bls12381.G2Affine
return g2AffineDecodeInSubGroupVal(&g2, input)
}

func g2AffineDecodeInSubGroupVal(g2 *bls12381.G2Affine, input []byte) (*bls12381.G2Affine, error) {
if hasWrongG2Padding(input) {
return nil, ErrMalformedPointPadding
}

var g2 bls12381.G2Affine
err := g2.X.A0.SetBytesCanonical(input[16:64])
if err != nil {
return nil, err
Expand All @@ -522,7 +627,7 @@ func g2AffineDecodeInSubGroup(input []byte) (*bls12381.G2Affine, error) {
if (!g2.IsInSubGroup()) {
return nil, ErrSubgroupCheckFailed
}
return &g2, nil;
return g2, nil;
}

func g2AffineDecodeOnCurve(input []byte) (*bls12381.G2Affine, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public class LibGnarkEIP2537 implements Library {
@SuppressWarnings("WeakerAccess")
public static final boolean ENABLED;

// zero implies 'default' degree of parallelism, which is the number of cpu cores available
private static int degreeOfMSMParallelism = 0;

static {
boolean enabled;
try {
Expand Down Expand Up @@ -61,9 +64,10 @@ public static int eip2537_perform_operation(
o_len.setValue(128);
break;
case BLS12_G1MULTIEXP_OPERATION_SHIM_VALUE:
ret = eip2537blsG1MultiExp(i, output, err, i_len,
ret = eip2537blsG1MultiExpParallel(i, output, err, i_len,
EIP2537_PREALLOCATE_FOR_RESULT_BYTES,
EIP2537_PREALLOCATE_FOR_ERROR_BYTES);
EIP2537_PREALLOCATE_FOR_ERROR_BYTES,
degreeOfMSMParallelism);
o_len.setValue(128);
break;
case BLS12_G2ADD_OPERATION_SHIM_VALUE:
Expand All @@ -79,9 +83,10 @@ public static int eip2537_perform_operation(
o_len.setValue(256);
break;
case BLS12_G2MULTIEXP_OPERATION_SHIM_VALUE:
ret = eip2537blsG2MultiExp(i, output, err, i_len,
ret = eip2537blsG2MultiExpParallel(i, output, err, i_len,
EIP2537_PREALLOCATE_FOR_RESULT_BYTES,
EIP2537_PREALLOCATE_FOR_ERROR_BYTES);
EIP2537_PREALLOCATE_FOR_ERROR_BYTES,
degreeOfMSMParallelism);
o_len.setValue(256);
break;
case BLS12_PAIR_OPERATION_SHIM_VALUE:
Expand Down Expand Up @@ -134,6 +139,13 @@ public static native int eip2537blsG1MultiExp(
byte[] error,
int inputSize, int output_len, int err_len);

public static native int eip2537blsG1MultiExpParallel(
byte[] input,
byte[] output,
byte[] error,
int inputSize, int output_len, int err_len,
int nbTasks);

public static native int eip2537blsG2Add(
byte[] input,
byte[] output,
Expand All @@ -152,6 +164,13 @@ public static native int eip2537blsG2MultiExp(
byte[] error,
int inputSize, int output_len, int err_len);

public static native int eip2537blsG2MultiExpParallel(
byte[] input,
byte[] output,
byte[] error,
int inputSize, int output_len, int err_len,
int nbTasks);

public static native int eip2537blsPairing(
byte[] input,
byte[] output,
Expand All @@ -170,4 +189,7 @@ public static native int eip2537blsMapFp2ToG2(
byte[] error,
int inputSize, int output_len, int err_len);

public static void setDegreeOfMSMParallelism(int nbTasks) {
degreeOfMSMParallelism = nbTasks;
}
}
Loading