Skip to content

Commit

Permalink
Merge pull request #58 from GiacomoPope/docstrings_and_cleanup
Browse files Browse the repository at this point in the history
Some clean up of the Kyber code inspired by the cleaner ML-KEM code as well as small docstring changes
  • Loading branch information
GiacomoPope authored Jul 23, 2024
2 parents 86ac860 + b5009c0 commit 77d6e48
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 109 deletions.
228 changes: 137 additions & 91 deletions src/kyber_py/kyber/kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@

class Kyber:
def __init__(self, parameter_set, seed=None):
"""
Initialise Kyber with specified lattice parameters.
:param dict params: the lattice parameters
:param bytes seed: the optional seed for a DRBG, must be unique and
unpredictable
"""
self.k = parameter_set["k"]
self.eta_1 = parameter_set["eta_1"]
self.eta_2 = parameter_set["eta_2"]
Expand All @@ -24,10 +31,18 @@ def __init__(self, parameter_set, seed=None):

def set_drbg_seed(self, seed):
"""
Setting the seed switches the entropy source
from os.urandom to AES256 CTR DRBG
Change entropy source to a DRBG and seed it with provided value.
Setting the seed switches the entropy source from :func:`os.urandom()`
to an AES256 CTR DRBG.
Used for both deterministic versions of Kyber as well as testing
alignment with the KAT vectors
Note: currently requires pycryptodome for AES impl.
Note:
currently requires pycryptodome for AES impl.
:param bytes seed: random bytes to seed the DRBG with
"""
try:
from ..drbg.aes256_ctr_drbg import AES256_CTR_DRBG
Expand All @@ -44,7 +59,10 @@ def reseed_drbg(self, seed):
"""
Reseeds the DRBG, errors if a DRBG is not set.
Note: currently requires pycryptodome for AES impl.
Note:
currently requires pycryptodome for AES impl.
:param bytes seed: random bytes to use as a new seed of the DRBG
"""
if self._drbg is None:
raise Warning(
Expand All @@ -59,14 +77,14 @@ def _xof(bytes32, i, j):
XOF: B^* x B x B -> B*
NOTE:
We use hashlib's `shake_128` implementation, which does not support an
easy XOF interface, so we take the "easy" option and request a fixed
number of 840 bytes (5 invocations of Keccak), rather than creating a
byte stream.
If your code crashes because of too few bytes, you can get dinner at:
Casa de Chá da Boa Nova
https://cryptojedi.org/papers/terminate-20230516.pdf
We use hashlib's ``shake_128`` implementation, which does not support
an easy XOF interface, so we take the "easy" option and request a
fixed number of 840 bytes (5 invocations of Keccak), rather than
creating a byte stream.
If your code crashes because of too few bytes, you can get dinner at:
Casa de Chá da Boa Nova
https://cryptojedi.org/papers/terminate-20230516.pdf
"""
input_bytes = bytes32 + i + j
if len(input_bytes) != 34:
Expand Down Expand Up @@ -109,114 +127,122 @@ def _kdf(input_bytes, length):
"""
return shake_256(input_bytes).digest(length)

def _generate_error_vector(self, sigma, eta, N, is_ntt=False):
def _generate_error_vector(self, sigma, eta, N):
"""
Helper function which generates a element in the
module from the Centered Binomial Distribution.
"""
elements = [0 for _ in range(self.k)]
for i in range(self.k):
input_bytes = self._prf(sigma, bytes([N]), 64 * eta)
elements[i] = self.R.cbd(input_bytes, eta, is_ntt=is_ntt)
elements[i] = self.R.cbd(input_bytes, eta)
N += 1
v = self.M.vector(elements)
return v, N

def _generate_matrix_from_seed(self, rho, transpose=False, is_ntt=False):
def _generate_polynomial(self, sigma, eta, N):
"""
Helper function which generates a element in the
polynomial ring from the Centered Binomial Distribution.
"""
prf_output = self._prf(sigma, bytes([N]), 64 * eta)
p = self.R.cbd(prf_output, eta)
return p, N + 1

def _generate_matrix_from_seed(self, rho, transpose=False):
"""
Helper function which generates a element of size
k x k from a seed `rho`.
Helper function which generates a matrix of size k x k from a seed `rho`
whose coefficients are polynomials in the NTT domain
When `transpose` is set to True, the matrix A is
built as the transpose.
When `transpose` is set to True, the matrix A is built as the transpose.
"""
A_data = [[0 for _ in range(self.k)] for _ in range(self.k)]
for i in range(self.k):
for j in range(self.k):
input_bytes = self._xof(rho, bytes([j]), bytes([i]))
A_data[i][j] = self.R.parse(input_bytes, is_ntt=is_ntt)
A_data[i][j] = self.R.parse(input_bytes, is_ntt=True)
A_hat = self.M(A_data, transpose=transpose)
return A_hat

def _cpapke_keygen(self):
"""
Generate a public key and private key.
Algorithm 4 (Key Generation)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Input:
None
Output:
Secret Key (12*k*n) / 8 bytes
Public Key (12*k*n) / 8 + 32 bytes
:return: Tuple with public key and private key.
:rtype: tuple(bytes, bytes)
"""
# Generate random value, hash and split
d = self.random_bytes(32)
rho, sigma = self._g(d)

# Generate the matrix A ∈ R^kxk
A_hat = self._generate_matrix_from_seed(rho)

# Set counter for PRF
N = 0

# Generate the matrix A ∈ R^kxk
A = self._generate_matrix_from_seed(rho, is_ntt=True)

# Generate the error vector s ∈ R^k
s, N = self._generate_error_vector(sigma, self.eta_1, N)
s = s.to_ntt()
s_hat = s.to_ntt()

# Generate the error vector e ∈ R^k
e, N = self._generate_error_vector(sigma, self.eta_1, N)
e = e.to_ntt()
e_hat = e.to_ntt()

# Construct the public key
t = (A @ s) + e
t_hat = (A_hat @ s_hat) + e_hat

# Reduce vectors mod^+ q
t.reduce_coefficients()
s.reduce_coefficients()
t_hat.reduce_coefficients()
s_hat.reduce_coefficients()

# Encode elements to bytes and return
pk = t.encode(12) + rho
sk = s.encode(12)
pk = t_hat.encode(12) + rho
sk = s_hat.encode(12)
return pk, sk

def _cpapke_enc(self, pk, m, coins):
"""
Algorithm 5 (Encryption)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Input:
pk: public key
m: message ∈ B^32
coins: random coins ∈ B^32
Output:
c: ciphertext
:param bytes pk: byte-encoded public key
:param bytes m: a 32-byte message
:param bytes coins: a 32-byte random value
:return: the ciphertext c
:rtype: bytes
"""
N = 0
rho = pk[-32:]
# Unpack the public key
t_hat_bytes, rho = pk[:-32], pk[-32:]

# Decode t vector from public key
t = self.M.decode_vector(pk, self.k, 12, is_ntt=True)
# Decode t_hat vector from public key
t_hat = self.M.decode_vector(t_hat_bytes, self.k, 12, is_ntt=True)

# Encode message as polynomial
m_poly = self.R.decode(m, 1).decompress(1)

# Generate the matrix A^T ∈ R^(kxk)
At = self._generate_matrix_from_seed(rho, transpose=True, is_ntt=True)
A_hat_T = self._generate_matrix_from_seed(rho, transpose=True)

# Set counter for PRF
N = 0

# Generate the error vector r ∈ R^k
r, N = self._generate_error_vector(coins, self.eta_1, N)
r = r.to_ntt()
r_hat = r.to_ntt()

# Generate the error vector e1 ∈ R^k
e1, N = self._generate_error_vector(coins, self.eta_2, N)

# Generate the error polynomial e2 ∈ R
input_bytes = self._prf(coins, bytes([N]), 64 * self.eta_2)
e2 = self.R.cbd(input_bytes, self.eta_2)
e2, N = self._generate_polynomial(coins, self.eta_2, N)

# Module/Polynomial arithmetic
u = (At @ r).from_ntt() + e1
v = t.dot(r).from_ntt()
u = (A_hat_T @ r_hat).from_ntt() + e1
v = t_hat.dot(r_hat).from_ntt()
v = v + e2 + m_poly

# Ciphertext to bytes
Expand All @@ -230,42 +256,41 @@ def _cpapke_dec(self, sk, c):
Algorithm 6 (Decryption)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Input:
sk: public key
c: message ∈ B^32
Output:
m: message ∈ B^32
:param bytes sk: byte-encoded secret key
:param bytes c: a 32-byte ciphertext
:return: the message m
:rtype: bytes
"""
# Split ciphertext to vectors
index = self.du * self.k * self.R.n // 8
c2 = c[index:]
c1, c2 = c[:index], c[index:]

# Recover the vector u and convert to NTT form
u = self.M.decode_vector(c, self.k, self.du).decompress(self.du)
u = u.to_ntt()
u = self.M.decode_vector(c1, self.k, self.du).decompress(self.du)
u_hat = u.to_ntt()

# Recover the polynomial v
v = self.R.decode(c2, self.dv).decompress(self.dv)

# s_transpose (already in NTT form)
s = self.M.decode_vector(sk, self.k, 12, is_ntt=True)
s_hat = self.M.decode_vector(sk, self.k, 12, is_ntt=True)

# Recover message as polynomial
m = s.dot(u).from_ntt()
m = (s_hat.dot(u_hat)).from_ntt()
m = v - m

# Return message as bytes
return m.compress(1).encode(1)

def keygen(self):
"""
Generate a public public key and private secret key.
Algorithm 7 (CCA KEM KeyGen)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Output:
pk: Public key
sk: Secret key
:return: Tuple with public key and secret key.
:rtype: tuple(bytes, bytes)
"""
# Note, although the paper gens z then
# pk, sk, the implementation does it this
Expand All @@ -280,18 +305,19 @@ def keygen(self):

def encaps(self, pk, key_length=32):
"""
Generate a random key, encapsulate it, return both it and ciphertext.
Algorithm 8 (CCA KEM Encapsulation)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Input:
pk: Public Key
Output:
K: Shared key
c: Ciphertext
NOTE:
We switch the order of the output (c, K) as (K, c) to align encaps
output with FIPS 203.
output with FIPS 203-ipd.
:param bytes pk: byte-encoded public key
:param int key_length: length of secret key, default value 32
:return: a random key and an public of it
:rtype: tuple(bytes, bytes)
"""
# Compute random message
m = self.random_bytes(32)
Expand All @@ -300,39 +326,59 @@ def encaps(self, pk, key_length=32):
m_hash = self._h(m)

# Compute key K and challenge c
Kbar, r = self._g(m_hash + self._h(pk))
K_bar, r = self._g(m_hash + self._h(pk))

# Perform the underlying pke encryption
c = self._cpapke_enc(pk, m_hash, r)
K = self._kdf(Kbar + self._h(c), key_length)

# Derive a key from the ciphertext
K = self._kdf(K_bar + self._h(c), key_length)

return K, c

def _unpack_secret_key(self, sk):
"""
Extract values from byte encoded secret key:
sk = _sk || pk || H(pk) || z
"""
index = 12 * self.k * self.R.n // 8

sk_pke = sk[:index]
pk_pke = sk[index:-64]
pk_hash = sk[-64:-32]
z = sk[-32:]

return sk_pke, pk_pke, pk_hash, z

def decaps(self, c, sk, key_length=32):
"""
Decapsulate a key from a ciphertext using a secret key.
Algorithm 9 (CCA KEM Decapsulation)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Input:
c: ciphertext
sk: Secret Key
Output:
K: Shared key
:param bytes c: ciphertext with an encapsulated key
:param bytes sk: secret key
:param int key_length: length of secret key, default value 32
:return: shared key
:rtype: bytes
"""
# Extract values from `sk`
# sk = _sk || pk || H(pk) || z
index = 12 * self.k * self.R.n // 8
_sk = sk[:index]
pk = sk[index:-64]
hpk = sk[-64:-32]
z = sk[-32:]
sk_pke, pk_pke, pk_hash, z = self._unpack_secret_key(sk)

# Decrypt the ciphertext
_m = self._cpapke_dec(_sk, c)
m = self._cpapke_dec(sk_pke, c)

# Decapsulation
_Kbar, _r = self._g(_m + hpk)
_c = self._cpapke_enc(pk, _m, _r)
K_bar, r = self._g(m + pk_hash)
c_prime = self._cpapke_enc(pk_pke, m, r)

# if decapsulation was successful return K
key = self._kdf(_Kbar + self._h(c), key_length)
key = self._kdf(K_bar + self._h(c), key_length)
garbage = self._kdf(z + self._h(c), key_length)

return select_bytes(garbage, key, c == _c)
# If c != c_prime, return garbage instead of the key
# WARNING: for proper implementations, it is absolutely
# vital that the selection between the key and garbage is
# performed in constant time
return select_bytes(garbage, key, c == c_prime)
Loading

0 comments on commit 77d6e48

Please sign in to comment.