Skip to content

Commit

Permalink
Merge pull request #54 from GiacomoPope/test_polynomial_generic
Browse files Browse the repository at this point in the history
test generic polynomials
  • Loading branch information
GiacomoPope authored Jul 23, 2024
2 parents edfba35 + 7a66d63 commit 4426a08
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/kyber_py/polynomials/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __call__(self, coefficients, is_ntt=False):
class PolynomialKyber(Polynomial):
def __init__(self, parent, coefficients):
self.parent = parent
self.coeffs = self.parse_coefficients(coefficients)
self.coeffs = self._parse_coefficients(coefficients)

def encode(self, d):
"""
Expand Down Expand Up @@ -188,7 +188,7 @@ def from_ntt(self):
class PolynomialKyberNTT(PolynomialKyber):
def __init__(self, parent, coefficients):
self.parent = parent
self.coeffs = self.parse_coefficients(coefficients)
self.coeffs = self._parse_coefficients(coefficients)

def to_ntt(self):
raise TypeError(f"Polynomial is of type: {type(self)}")
Expand Down
22 changes: 12 additions & 10 deletions src/kyber_py/polynomials/polynomials_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __repr__(self):
class Polynomial:
def __init__(self, parent, coefficients):
self.parent = parent
self.coeffs = self.parse_coefficients(coefficients)
self.coeffs = self._parse_coefficients(coefficients)

def is_zero(self):
"""
Expand All @@ -50,7 +50,7 @@ def is_constant(self):
"""
return all(c == 0 for c in self.coeffs[1:])

def parse_coefficients(self, coefficients):
def _parse_coefficients(self, coefficients):
"""
Helper function which right pads with zeros
to allow polynomial construction as
Expand All @@ -72,19 +72,19 @@ def reduce_coefficients(self):
self.coeffs = [c % self.parent.q for c in self.coeffs]
return self

def add_mod_q(self, x, y):
def _add_mod_q(self, x, y):
"""
add two coefficients modulo q
"""
return (x + y) % self.parent.q

def sub_mod_q(self, x, y):
def _sub_mod_q(self, x, y):
"""
sub two coefficients modulo q
"""
return (x - y) % self.parent.q

def schoolbook_multiplication(self, other):
def _schoolbook_multiplication(self, other):
"""
Naive implementation of polynomial multiplication
suitible for all R_q = F_1[X]/(X^n + 1)
Expand All @@ -111,11 +111,12 @@ def __neg__(self):
def _add_(self, other):
if isinstance(other, type(self)):
new_coeffs = [
self.add_mod_q(x, y) for x, y in zip(self.coeffs, other.coeffs)
self._add_mod_q(x, y)
for x, y in zip(self.coeffs, other.coeffs)
]
elif isinstance(other, int):
new_coeffs = self.coeffs.copy()
new_coeffs[0] = self.add_mod_q(new_coeffs[0], other)
new_coeffs[0] = self._add_mod_q(new_coeffs[0], other)
else:
raise NotImplementedError(
"Polynomials can only be added to each other"
Expand All @@ -136,11 +137,12 @@ def __iadd__(self, other):
def _sub_(self, other):
if isinstance(other, type(self)):
new_coeffs = [
self.sub_mod_q(x, y) for x, y in zip(self.coeffs, other.coeffs)
self._sub_mod_q(x, y)
for x, y in zip(self.coeffs, other.coeffs)
]
elif isinstance(other, int):
new_coeffs = self.coeffs.copy()
new_coeffs[0] = self.sub_mod_q(new_coeffs[0], other)
new_coeffs[0] = self._sub_mod_q(new_coeffs[0], other)
else:
raise NotImplementedError(
"Polynomials can only be subtracted from each other"
Expand All @@ -160,7 +162,7 @@ def __isub__(self, other):

def __mul__(self, other):
if isinstance(other, type(self)):
new_coeffs = self.schoolbook_multiplication(other)
new_coeffs = self._schoolbook_multiplication(other)
elif isinstance(other, int):
new_coeffs = [(c * other) % self.parent.q for c in self.coeffs]
else:
Expand Down
90 changes: 90 additions & 0 deletions tests/test_polynomial_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import unittest
from random import randint
from kyber_py.polynomials.polynomials_generic import PolynomialRing


class TestPolynomialRing(unittest.TestCase):
R = PolynomialRing(11, 5)

def test_gen(self):
self.assertTrue(self.R.gen() == self.R([0, 1]))

def test_random_element(self):
for _ in range(100):
f = self.R.random_element()
self.assertEqual(type(f), self.R.element)
self.assertEqual(len(f.coeffs), self.R.n)
self.assertTrue(all([c < self.R.q for c in f.coeffs]))


class TestPolynomial(unittest.TestCase):
R = PolynomialRing(11, 5)

def test_is_zero(self):
self.assertTrue(self.R(0).is_zero())
self.assertFalse(self.R(1).is_zero())

def test_is_constant(self):
self.assertTrue(self.R(0).is_constant())
self.assertTrue(self.R(1).is_constant())
self.assertFalse(self.R.gen().is_constant())

def test_reduce_coefficents(self):
for _ in range(100):
# Create non-canonical coefficients
coeffs = [
randint(-2 * self.R.q, 3 * self.R.q) for _ in range(self.R.n)
]
f = self.R(coeffs).reduce_coefficients()
self.assertTrue(all([c < self.R.q for c in f.coeffs]))

def test_add_polynomials(self):
zero = self.R(0)
for _ in range(100):
f1 = self.R.random_element()
f2 = self.R.random_element()
f3 = self.R.random_element()

self.assertEqual(f1 + zero, f1)
self.assertEqual(f1 + f2, f2 + f1)
self.assertEqual(f1 + (f2 + f3), (f1 + f2) + f3)

def test_sub_polynomials(self):
zero = self.R(0)
for _ in range(100):
f1 = self.R.random_element()
f2 = self.R.random_element()
f3 = self.R.random_element()

self.assertEqual(f1 - zero, f1)
self.assertEqual(f3 - f3, zero)
self.assertEqual(f1 - f2, -(f2 - f1))
self.assertEqual(f1 - (f2 - f3), (f1 - f2) + f3)

def test_mul_polynomials(self):
zero = self.R(0)
one = self.R(1)
for _ in range(100):
f1 = self.R.random_element()
f2 = self.R.random_element()
f3 = self.R.random_element()

self.assertEqual(f1 * zero, zero)
self.assertEqual(f1 * one, f1)
self.assertEqual(f1 * f2, f2 * f1)
self.assertEqual(f1 * (f2 * f3), (f1 * f2) * f3)

def test_pow_polynomials(self):
one = self.R(1)
for _ in range(100):
f1 = self.R.random_element()

self.assertEqual(one, f1**0)
self.assertEqual(f1, f1**1)
self.assertEqual(f1 * f1, f1**2)
self.assertEqual(f1 * f1 * f1, f1**3)
self.assertRaises(ValueError, lambda: f1 ** (-1))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4426a08

Please sign in to comment.