Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions src/dilithium_py/utilities/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
def reduce_mod_pm(x, n):
def reduce_mod_pm(r, n):
"""
Takes an integer 0 < x < n and represents
it as an integer in the range
Takes an integer 0 < r < n and computes the value
r' = r mod^± n, defined to be integer in the range

r = x % n

for n odd:
-(n-1)/2 < r <= (n-1)/2
for n even:
- n / 2 <= r <= n / 2
-(n / 2) < r' <= (n / 2)
for n odd:
-(n - 1) / 2 <= r' <= (n - 1) / 2
"""
x = x % n
if x > (n >> 1):
x -= n
return x
r = r % n
if r > (n >> 1):
r -= n
return r


def decompose(r, a, q):
Expand Down
12 changes: 10 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@

class TestUtils(unittest.TestCase):
def test_reduce_mod_pm_even(self):
"""
For odd modulus the reduced value should be bounded by
-(n-1)/2 <= x <= (n-1)/2
"""
for _ in range(100):
modulus = 2 * randint(0, 100)
for i in range(modulus):
x = reduce_mod_pm(i, modulus)
self.assertTrue(x <= modulus // 2)
self.assertTrue(x > -modulus // 2)
self.assertTrue(-modulus // 2 < x)

def test_reduce_mod_pm_odd(self):
"""
For even modulus the reduced value should be bounded by
-n/2 < x <= n/2
"""
for _ in range(100):
modulus = 2 * randint(0, 100) + 1
for i in range(modulus):
x = reduce_mod_pm(i, modulus)
self.assertTrue(x <= (modulus - 1) // 2)
self.assertTrue(x >= -(modulus - 1) // 2)
self.assertTrue(-(modulus - 1) // 2 <= x)