Skip to content

Commit

Permalink
refactor: Improve mlkem readability
Browse files Browse the repository at this point in the history
  • Loading branch information
lubux committed Sep 24, 2024
1 parent 2e3a702 commit 8be8c23
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions openpgp/mlkem_ecdh/mlkem_ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package mlkem_ecdh

import (
goerrors "errors"
"fmt"
"io"

"github.com/ProtonMail/go-crypto/openpgp/internal/encoding"
Expand All @@ -15,6 +16,11 @@ import (
"github.com/cloudflare/circl/kem"
)

const (
maxSessionKeyLength = 64
kdfContext = "OpenPGPCompositeKDFv1"
)

type PublicKey struct {
AlgId uint8
Curve ecc.ECDHCurve
Expand Down Expand Up @@ -43,9 +49,8 @@ func GenerateKey(rand io.Reader, algId uint8, c ecc.ECDHCurve, k kem.Scheme) (pr
return nil, err
}

kyberSeed := make([]byte, k.SeedSize())

if _, err = rand.Read(kyberSeed); err != nil {
kyberSeed, err := generateRandomSeed(rand, k.SeedSize())
if err != nil {
return nil, err
}

Expand All @@ -56,7 +61,7 @@ func GenerateKey(rand io.Reader, algId uint8, c ecc.ECDHCurve, k kem.Scheme) (pr
// Encrypt implements ML-KEM + ECC encryption as specified in
// https://www.ietf.org/archive/id/draft-ietf-openpgp-pqc-04.html#name-encryption-procedure
func Encrypt(rand io.Reader, pub *PublicKey, msg []byte) (kEphemeral, ecEphemeral, ciphertext []byte, err error) {
if len(msg) > 64 {
if len(msg) > maxSessionKeyLength {
return nil, nil, nil, goerrors.New("mlkem_ecdh: session key too long")
}

Expand All @@ -71,8 +76,7 @@ func Encrypt(rand io.Reader, pub *PublicKey, msg []byte) (kEphemeral, ecEphemera
}

// ML-KEM shared secret derivation
kyberSeed := make([]byte, pub.Mlkem.EncapsulationSeedSize())
_, err = rand.Read(kyberSeed)
kyberSeed, err := generateRandomSeed(rand, pub.Mlkem.EncapsulationSeedSize())
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -82,12 +86,12 @@ func Encrypt(rand io.Reader, pub *PublicKey, msg []byte) (kEphemeral, ecEphemera
return nil, nil, nil, err
}

kek, err := buildKey(pub, ecSS, ecEphemeral, pub.PublicPoint, kSS, kEphemeral, pub.PublicMlkem)
keyEncryptionKey, err := buildKey(pub, ecSS, ecEphemeral, pub.PublicPoint, kSS, kEphemeral, pub.PublicMlkem)
if err != nil {
return nil, nil, nil, err
}

if ciphertext, err = keywrap.Wrap(kek, msg); err != nil {
if ciphertext, err = keywrap.Wrap(keyEncryptionKey, msg); err != nil {
return nil, nil, nil, err
}

Expand Down Expand Up @@ -136,20 +140,20 @@ func buildKey(pub *PublicKey, eccSecretPoint, eccEphemeral, eccPublicKey, mlkemK
// eccData = eccKeyShare || eccCipherText
// mlkemData = mlkemKeyShare || mlkemCipherText
// encData = counter || eccData || mlkemData || fixedInfo
k := sha3.New256()
h.Reset()

// SHA3 never returns error
_, _ = k.Write([]byte{0x00, 0x00, 0x00, 0x01})
_, _ = k.Write(eccKeyShare)
_, _ = k.Write(eccEphemeral)
_, _ = k.Write(eccPublicKey)
_, _ = k.Write(mlkemKeyShare)
_, _ = k.Write(mlkemEphemeral)
_, _ = k.Write(serializedMlkemKey)
_, _ = k.Write([]byte{pub.AlgId})
_, _ = k.Write([]byte("OpenPGPCompositeKDFv1"))

return k.Sum(nil), nil
_, _ = h.Write([]byte{0x00, 0x00, 0x00, 0x01})
_, _ = h.Write(eccKeyShare)
_, _ = h.Write(eccEphemeral)
_, _ = h.Write(eccPublicKey)
_, _ = h.Write(mlkemKeyShare)
_, _ = h.Write(mlkemEphemeral)
_, _ = h.Write(serializedMlkemKey)
_, _ = h.Write([]byte{pub.AlgId})
_, _ = h.Write([]byte(kdfContext))

return h.Sum(nil), nil
}

// Validate checks that the public key corresponds to the private key
Expand Down Expand Up @@ -237,3 +241,11 @@ func DecodeFields(r io.Reader, lenEcc, lenMlkem int, v6 bool) (encryptedMPI1, en

return
}

func generateRandomSeed(rand io.Reader, size int) ([]byte, error) {
randomBytes := make([]byte, size)
if _, err := rand.Read(randomBytes); err != nil {
return nil, fmt.Errorf("failed to generate random bytes: %w", err)
}
return randomBytes, nil
}

0 comments on commit 8be8c23

Please sign in to comment.