From b5009c0f6bfad4599a9d7233d907dc062f2ecfd5 Mon Sep 17 00:00:00 2001 From: Giacomo Pope Date: Tue, 23 Jul 2024 16:41:54 +0100 Subject: [PATCH] update some docstrings and clean up kyber code --- src/kyber_py/kyber/kyber.py | 228 ++++++++++++++++++++-------------- src/kyber_py/ml_kem/ml_kem.py | 55 +++++--- 2 files changed, 174 insertions(+), 109 deletions(-) diff --git a/src/kyber_py/kyber/kyber.py b/src/kyber_py/kyber/kyber.py index c23fb02..1a7ebec 100644 --- a/src/kyber_py/kyber/kyber.py +++ b/src/kyber_py/kyber/kyber.py @@ -6,6 +6,13 @@ class Kyber: def __init__(self, parameter_set, seed=None): + """ + Initialise Kyber with specified lattice parameters. + + :param dict params: the lattice parameters + :param bytes seed: the optional seed for a DRBG, must be unique and + unpredictable + """ self.k = parameter_set["k"] self.eta_1 = parameter_set["eta_1"] self.eta_2 = parameter_set["eta_2"] @@ -24,10 +31,18 @@ def __init__(self, parameter_set, seed=None): def set_drbg_seed(self, seed): """ - Setting the seed switches the entropy source - from os.urandom to AES256 CTR DRBG + Change entropy source to a DRBG and seed it with provided value. + + Setting the seed switches the entropy source from :func:`os.urandom()` + to an AES256 CTR DRBG. + + Used for both deterministic versions of Kyber as well as testing + alignment with the KAT vectors - Note: currently requires pycryptodome for AES impl. + Note: + currently requires pycryptodome for AES impl. + + :param bytes seed: random bytes to seed the DRBG with """ try: from ..drbg.aes256_ctr_drbg import AES256_CTR_DRBG @@ -44,7 +59,10 @@ def reseed_drbg(self, seed): """ Reseeds the DRBG, errors if a DRBG is not set. - Note: currently requires pycryptodome for AES impl. + Note: + currently requires pycryptodome for AES impl. + + :param bytes seed: random bytes to use as a new seed of the DRBG """ if self._drbg is None: raise Warning( @@ -59,14 +77,14 @@ def _xof(bytes32, i, j): XOF: B^* x B x B -> B* NOTE: - We use hashlib's `shake_128` implementation, which does not support an - easy XOF interface, so we take the "easy" option and request a fixed - number of 840 bytes (5 invocations of Keccak), rather than creating a - byte stream. - - If your code crashes because of too few bytes, you can get dinner at: - Casa de Chá da Boa Nova - https://cryptojedi.org/papers/terminate-20230516.pdf + We use hashlib's ``shake_128`` implementation, which does not support + an easy XOF interface, so we take the "easy" option and request a + fixed number of 840 bytes (5 invocations of Keccak), rather than + creating a byte stream. + + If your code crashes because of too few bytes, you can get dinner at: + Casa de Chá da Boa Nova + https://cryptojedi.org/papers/terminate-20230516.pdf """ input_bytes = bytes32 + i + j if len(input_bytes) != 34: @@ -109,7 +127,7 @@ def _kdf(input_bytes, length): """ return shake_256(input_bytes).digest(length) - def _generate_error_vector(self, sigma, eta, N, is_ntt=False): + def _generate_error_vector(self, sigma, eta, N): """ Helper function which generates a element in the module from the Centered Binomial Distribution. @@ -117,66 +135,73 @@ def _generate_error_vector(self, sigma, eta, N, is_ntt=False): elements = [0 for _ in range(self.k)] for i in range(self.k): input_bytes = self._prf(sigma, bytes([N]), 64 * eta) - elements[i] = self.R.cbd(input_bytes, eta, is_ntt=is_ntt) + elements[i] = self.R.cbd(input_bytes, eta) N += 1 v = self.M.vector(elements) return v, N - def _generate_matrix_from_seed(self, rho, transpose=False, is_ntt=False): + def _generate_polynomial(self, sigma, eta, N): + """ + Helper function which generates a element in the + polynomial ring from the Centered Binomial Distribution. + """ + prf_output = self._prf(sigma, bytes([N]), 64 * eta) + p = self.R.cbd(prf_output, eta) + return p, N + 1 + + def _generate_matrix_from_seed(self, rho, transpose=False): """ - Helper function which generates a element of size - k x k from a seed `rho`. + Helper function which generates a matrix of size k x k from a seed `rho` + whose coefficients are polynomials in the NTT domain - When `transpose` is set to True, the matrix A is - built as the transpose. + When `transpose` is set to True, the matrix A is built as the transpose. """ A_data = [[0 for _ in range(self.k)] for _ in range(self.k)] for i in range(self.k): for j in range(self.k): input_bytes = self._xof(rho, bytes([j]), bytes([i])) - A_data[i][j] = self.R.parse(input_bytes, is_ntt=is_ntt) + A_data[i][j] = self.R.parse(input_bytes, is_ntt=True) A_hat = self.M(A_data, transpose=transpose) return A_hat def _cpapke_keygen(self): """ + Generate a public key and private key. + Algorithm 4 (Key Generation) https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf - Input: - None - Output: - Secret Key (12*k*n) / 8 bytes - Public Key (12*k*n) / 8 + 32 bytes + :return: Tuple with public key and private key. + :rtype: tuple(bytes, bytes) """ # Generate random value, hash and split d = self.random_bytes(32) rho, sigma = self._g(d) + # Generate the matrix A ∈ R^kxk + A_hat = self._generate_matrix_from_seed(rho) + # Set counter for PRF N = 0 - # Generate the matrix A ∈ R^kxk - A = self._generate_matrix_from_seed(rho, is_ntt=True) - # Generate the error vector s ∈ R^k s, N = self._generate_error_vector(sigma, self.eta_1, N) - s = s.to_ntt() + s_hat = s.to_ntt() # Generate the error vector e ∈ R^k e, N = self._generate_error_vector(sigma, self.eta_1, N) - e = e.to_ntt() + e_hat = e.to_ntt() # Construct the public key - t = (A @ s) + e + t_hat = (A_hat @ s_hat) + e_hat # Reduce vectors mod^+ q - t.reduce_coefficients() - s.reduce_coefficients() + t_hat.reduce_coefficients() + s_hat.reduce_coefficients() # Encode elements to bytes and return - pk = t.encode(12) + rho - sk = s.encode(12) + pk = t_hat.encode(12) + rho + sk = s_hat.encode(12) return pk, sk def _cpapke_enc(self, pk, m, coins): @@ -184,39 +209,40 @@ def _cpapke_enc(self, pk, m, coins): Algorithm 5 (Encryption) https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf - Input: - pk: public key - m: message ∈ B^32 - coins: random coins ∈ B^32 - Output: - c: ciphertext + :param bytes pk: byte-encoded public key + :param bytes m: a 32-byte message + :param bytes coins: a 32-byte random value + :return: the ciphertext c + :rtype: bytes """ - N = 0 - rho = pk[-32:] + # Unpack the public key + t_hat_bytes, rho = pk[:-32], pk[-32:] - # Decode t vector from public key - t = self.M.decode_vector(pk, self.k, 12, is_ntt=True) + # Decode t_hat vector from public key + t_hat = self.M.decode_vector(t_hat_bytes, self.k, 12, is_ntt=True) # Encode message as polynomial m_poly = self.R.decode(m, 1).decompress(1) # Generate the matrix A^T ∈ R^(kxk) - At = self._generate_matrix_from_seed(rho, transpose=True, is_ntt=True) + A_hat_T = self._generate_matrix_from_seed(rho, transpose=True) + + # Set counter for PRF + N = 0 # Generate the error vector r ∈ R^k r, N = self._generate_error_vector(coins, self.eta_1, N) - r = r.to_ntt() + r_hat = r.to_ntt() # Generate the error vector e1 ∈ R^k e1, N = self._generate_error_vector(coins, self.eta_2, N) # Generate the error polynomial e2 ∈ R - input_bytes = self._prf(coins, bytes([N]), 64 * self.eta_2) - e2 = self.R.cbd(input_bytes, self.eta_2) + e2, N = self._generate_polynomial(coins, self.eta_2, N) # Module/Polynomial arithmetic - u = (At @ r).from_ntt() + e1 - v = t.dot(r).from_ntt() + u = (A_hat_T @ r_hat).from_ntt() + e1 + v = t_hat.dot(r_hat).from_ntt() v = v + e2 + m_poly # Ciphertext to bytes @@ -230,28 +256,27 @@ def _cpapke_dec(self, sk, c): Algorithm 6 (Decryption) https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf - Input: - sk: public key - c: message ∈ B^32 - Output: - m: message ∈ B^32 + :param bytes sk: byte-encoded secret key + :param bytes c: a 32-byte ciphertext + :return: the message m + :rtype: bytes """ # Split ciphertext to vectors index = self.du * self.k * self.R.n // 8 - c2 = c[index:] + c1, c2 = c[:index], c[index:] # Recover the vector u and convert to NTT form - u = self.M.decode_vector(c, self.k, self.du).decompress(self.du) - u = u.to_ntt() + u = self.M.decode_vector(c1, self.k, self.du).decompress(self.du) + u_hat = u.to_ntt() # Recover the polynomial v v = self.R.decode(c2, self.dv).decompress(self.dv) # s_transpose (already in NTT form) - s = self.M.decode_vector(sk, self.k, 12, is_ntt=True) + s_hat = self.M.decode_vector(sk, self.k, 12, is_ntt=True) # Recover message as polynomial - m = s.dot(u).from_ntt() + m = (s_hat.dot(u_hat)).from_ntt() m = v - m # Return message as bytes @@ -259,13 +284,13 @@ def _cpapke_dec(self, sk, c): def keygen(self): """ + Generate a public public key and private secret key. + Algorithm 7 (CCA KEM KeyGen) https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf - Output: - pk: Public key - sk: Secret key - + :return: Tuple with public key and secret key. + :rtype: tuple(bytes, bytes) """ # Note, although the paper gens z then # pk, sk, the implementation does it this @@ -280,18 +305,19 @@ def keygen(self): def encaps(self, pk, key_length=32): """ + Generate a random key, encapsulate it, return both it and ciphertext. + Algorithm 8 (CCA KEM Encapsulation) https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf - Input: - pk: Public Key - Output: - K: Shared key - c: Ciphertext - NOTE: We switch the order of the output (c, K) as (K, c) to align encaps - output with FIPS 203. + output with FIPS 203-ipd. + + :param bytes pk: byte-encoded public key + :param int key_length: length of secret key, default value 32 + :return: a random key and an public of it + :rtype: tuple(bytes, bytes) """ # Compute random message m = self.random_bytes(32) @@ -300,39 +326,59 @@ def encaps(self, pk, key_length=32): m_hash = self._h(m) # Compute key K and challenge c - Kbar, r = self._g(m_hash + self._h(pk)) + K_bar, r = self._g(m_hash + self._h(pk)) + + # Perform the underlying pke encryption c = self._cpapke_enc(pk, m_hash, r) - K = self._kdf(Kbar + self._h(c), key_length) + + # Derive a key from the ciphertext + K = self._kdf(K_bar + self._h(c), key_length) + return K, c + def _unpack_secret_key(self, sk): + """ + Extract values from byte encoded secret key: + + sk = _sk || pk || H(pk) || z + """ + index = 12 * self.k * self.R.n // 8 + + sk_pke = sk[:index] + pk_pke = sk[index:-64] + pk_hash = sk[-64:-32] + z = sk[-32:] + + return sk_pke, pk_pke, pk_hash, z + def decaps(self, c, sk, key_length=32): """ + Decapsulate a key from a ciphertext using a secret key. + Algorithm 9 (CCA KEM Decapsulation) https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf - Input: - c: ciphertext - sk: Secret Key - Output: - K: Shared key + :param bytes c: ciphertext with an encapsulated key + :param bytes sk: secret key + :param int key_length: length of secret key, default value 32 + :return: shared key + :rtype: bytes """ - # Extract values from `sk` - # sk = _sk || pk || H(pk) || z - index = 12 * self.k * self.R.n // 8 - _sk = sk[:index] - pk = sk[index:-64] - hpk = sk[-64:-32] - z = sk[-32:] + sk_pke, pk_pke, pk_hash, z = self._unpack_secret_key(sk) # Decrypt the ciphertext - _m = self._cpapke_dec(_sk, c) + m = self._cpapke_dec(sk_pke, c) # Decapsulation - _Kbar, _r = self._g(_m + hpk) - _c = self._cpapke_enc(pk, _m, _r) + K_bar, r = self._g(m + pk_hash) + c_prime = self._cpapke_enc(pk_pke, m, r) # if decapsulation was successful return K - key = self._kdf(_Kbar + self._h(c), key_length) + key = self._kdf(K_bar + self._h(c), key_length) garbage = self._kdf(z + self._h(c), key_length) - return select_bytes(garbage, key, c == _c) + # If c != c_prime, return garbage instead of the key + # WARNING: for proper implementations, it is absolutely + # vital that the selection between the key and garbage is + # performed in constant time + return select_bytes(garbage, key, c == c_prime) diff --git a/src/kyber_py/ml_kem/ml_kem.py b/src/kyber_py/ml_kem/ml_kem.py index e2ea2a2..28c334f 100644 --- a/src/kyber_py/ml_kem/ml_kem.py +++ b/src/kyber_py/ml_kem/ml_kem.py @@ -38,10 +38,11 @@ def set_drbg_seed(self, seed): """ Change entropy source to a DRBG and seed it with provided value. - Setting the seed switches the entropy source - from :func:`os.urandom()` to an AES256 CTR DRBG. + Setting the seed switches the entropy source from :func:`os.urandom()` + to an AES256 CTR DRBG. - Not recommended, exists mostly for testing against official KATs. + Used for both deterministic versions of ML-KEM as well as testing + alignment with the KAT vectors Note: currently requires pycryptodome for AES impl. @@ -97,10 +98,12 @@ def _xof(bytes32, i, j): ) return shake_128(input_bytes).digest(840) - # Pseudorandom function described between lines - # 726 - 731 @staticmethod def _prf(eta, s, b): + """ + Pseudorandom function described between lines 726 - 731 of in FIPS + 203-ipd + """ input_bytes = s + b if len(input_bytes) != 33: raise ValueError( @@ -109,7 +112,7 @@ def _prf(eta, s, b): return shake_256(input_bytes).digest(eta * 64) # Three hash functions described between lines - # 741 - 750 + # 741 - 750 in FIPS 203-ipd @staticmethod def _H(s): return sha3_256(s).digest() @@ -123,7 +126,14 @@ def _G(s): h = sha3_512(s).digest() return h[:32], h[32:] - def _generate_matrix(self, rho, transpose=False): + def _generate_matrix_from_seed(self, rho, transpose=False): + """ + Helper function which generates a element of size + k x k from a seed `rho`. + + When `transpose` is set to True, the matrix A is + built as the transpose. + """ A_data = [[0 for _ in range(self.k)] for _ in range(self.k)] for i in range(self.k): for j in range(self.k): @@ -132,7 +142,11 @@ def _generate_matrix(self, rho, transpose=False): A_hat = self.M(A_data, transpose=transpose) return A_hat - def _generate_vector(self, sigma, eta, N): + def _generate_error_vector(self, sigma, eta, N): + """ + Helper function which generates a element in the + module from the Centered Binomial Distribution. + """ elements = [0 for _ in range(self.k)] for i in range(self.k): prf_output = self._prf(eta, sigma, bytes([N])) @@ -155,15 +169,17 @@ def _pke_keygen(self): rho, sigma = self._G(d) # Generate A_hat from seed rho - A_hat = self._generate_matrix(rho) + A_hat = self._generate_matrix_from_seed(rho) + # Set counter for PRF N = 0 - s, N = self._generate_vector(sigma, self.eta_1, N) - e, N = self._generate_vector(sigma, self.eta_1, N) - # TODO: we could convert to ntt form as we create the data - # and skip this call to compute a new Matrix objects + # Generate the error vector s ∈ R^k + s, N = self._generate_error_vector(sigma, self.eta_1, N) s_hat = s.to_ntt() + + # Generate the error vector e ∈ R^k + e, N = self._generate_error_vector(sigma, self.eta_1, N) e_hat = e.to_ntt() # Compute public value (in NTT form) @@ -198,16 +214,16 @@ def _pke_encrypt(self, ek_pke, m, r): ), "Modulus check failed, t_hat does not encode correctly" # Generate A_hat^T from seed rho - A_hat = self._generate_matrix(rho, transpose=True) + A_hat_T = self._generate_matrix_from_seed(rho, transpose=True) N = 0 - r_vec, N = self._generate_vector(r, self.eta_1, N) - e1, N = self._generate_vector(r, self.eta_2, N) + r_vec, N = self._generate_error_vector(r, self.eta_1, N) + e1, N = self._generate_error_vector(r, self.eta_2, N) e2, N = self._generate_polynomial(r, self.eta_2, N) r_hat = r_vec.to_ntt() - u = (A_hat @ r_hat).from_ntt() + e1 + u = (A_hat_T @ r_hat).from_ntt() + e1 mu = self.R.decode(m, 1).decompress(1) v = t_hat.dot(r_hat).from_ntt() + e2 + mu @@ -237,7 +253,7 @@ def _pke_decrypt(self, dk_pke, c): def keygen(self): """ - Generate a pair or encapsulation key and decapsulation keys. + Generate a public encapsulation key and private decapsulation key. Algorithm 15 in FIPS 203-ipd @@ -317,4 +333,7 @@ def decaps(self, c, dk): c_prime = self._pke_encrypt(ek_pke, m_prime, r_prime) # If c != c_prime, return K_bar as garbage + # WARNING: for proper implementations, it is absolutely + # vital that the selection between the key and garbage is + # performed in constant time return select_bytes(K_bar, K_prime, c == c_prime)