diff --git a/src/dilithium_py/utilities/utils.py b/src/dilithium_py/utilities/utils.py index 28b384b..5486c85 100644 --- a/src/dilithium_py/utilities/utils.py +++ b/src/dilithium_py/utilities/utils.py @@ -1,19 +1,17 @@ -def reduce_mod_pm(x, n): +def reduce_mod_pm(r, n): """ - Takes an integer 0 < x < n and represents - it as an integer in the range + Takes an integer 0 < r < n and computes the value + r' = r mod^± n, defined to be integer in the range - r = x % n - - for n odd: - -(n-1)/2 < r <= (n-1)/2 for n even: - - n / 2 <= r <= n / 2 + -(n / 2) < r' <= (n / 2) + for n odd: + -(n - 1) / 2 <= r' <= (n - 1) / 2 """ - x = x % n - if x > (n >> 1): - x -= n - return x + r = r % n + if r > (n >> 1): + r -= n + return r def decompose(r, a, q): diff --git a/tests/test_utils.py b/tests/test_utils.py index ef596d3..0c5d71a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,17 +5,25 @@ class TestUtils(unittest.TestCase): def test_reduce_mod_pm_even(self): + """ + For odd modulus the reduced value should be bounded by + -(n-1)/2 <= x <= (n-1)/2 + """ for _ in range(100): modulus = 2 * randint(0, 100) for i in range(modulus): x = reduce_mod_pm(i, modulus) self.assertTrue(x <= modulus // 2) - self.assertTrue(x > -modulus // 2) + self.assertTrue(-modulus // 2 < x) def test_reduce_mod_pm_odd(self): + """ + For even modulus the reduced value should be bounded by + -n/2 < x <= n/2 + """ for _ in range(100): modulus = 2 * randint(0, 100) + 1 for i in range(modulus): x = reduce_mod_pm(i, modulus) self.assertTrue(x <= (modulus - 1) // 2) - self.assertTrue(x >= -(modulus - 1) // 2) + self.assertTrue(-(modulus - 1) // 2 <= x)