|
| 1 | +## Modular inverse |
| 2 | + |
| 3 | +import std/[options, math] |
| 4 | + |
| 5 | + |
| 6 | +func euclidHalfIteration(inA, inB: Positive): tuple[gcd: Natural, coeff: int] = |
| 7 | + var (a, b) = (inA.Natural, inB.Natural) |
| 8 | + var (x0, x1) = (1, 0) |
| 9 | + while b != 0: |
| 10 | + (x0, x1) = (x1, x0 - (a div b) * x1) |
| 11 | + (a, b) = (b, math.floorMod(a, b)) |
| 12 | + |
| 13 | + (a, x0) |
| 14 | + |
| 15 | + |
| 16 | +func modularInverse*(inA: int, modulus: Positive): Option[Positive] = |
| 17 | + ## For a given integer `a` and a natural number `modulus` it |
| 18 | + ## computes the inverse of `a` modulo `modulus`, i.e. |
| 19 | + ## it finds an integer `0 < inv < modulus` such that |
| 20 | + ## `(a * inv) mod modulus == 1`. |
| 21 | + let a = math.floorMod(inA, modulus) |
| 22 | + if a == 0: |
| 23 | + return none(Positive) |
| 24 | + let (gcd, x) = euclidHalfIteration(a, modulus) |
| 25 | + if gcd == 1: |
| 26 | + return some(math.floorMod(x, modulus).Positive) |
| 27 | + none(Positive) |
| 28 | + |
| 29 | + |
| 30 | +when isMainModule: |
| 31 | + import std/[unittest, sequtils, strformat, random] |
| 32 | + suite "modularInverse": |
| 33 | + const testCases = [ |
| 34 | + (3, 7, 5), |
| 35 | + (-1, 5, 4), # Inverse of a negative |
| 36 | + (-7, 5, 2), # Inverse of a negative lower than modulus |
| 37 | + (-7, 4, 1), # Inverse of a negative with non-prime modulus |
| 38 | + (4, 5, 4), |
| 39 | + (9, 5, 4), |
| 40 | + (5, 21, 17), |
| 41 | + (2, 21, 11), |
| 42 | + (4, 21, 16), |
| 43 | + (55, 372, 115), |
| 44 | + (1, 100, 1), |
| 45 | + ].mapIt: |
| 46 | + let tc = (id: fmt"a={it[0]}, modulus={it[1]}", a: it[0], modulus: it[1], |
| 47 | + inv: it[2]) |
| 48 | + assert 0 < tc.inv |
| 49 | + assert tc.inv < tc.modulus |
| 50 | + assert math.floorMod(tc.a * tc.inv, tc.modulus) == 1 |
| 51 | + tc |
| 52 | + |
| 53 | + for tc in testCases: |
| 54 | + test tc.id: |
| 55 | + checkpoint("returns expected result") |
| 56 | + check modularInverse(tc.a, tc.modulus).get() == tc.inv |
| 57 | + |
| 58 | + test "No inverse when modulus is 1": |
| 59 | + check modularInverse(0, 1).is_none() |
| 60 | + check modularInverse(1, 1).is_none() |
| 61 | + check modularInverse(-1, 1).is_none() |
| 62 | + |
| 63 | + test "No inverse when inputs are not co-prime": |
| 64 | + check modularInverse(2, 4).is_none() |
| 65 | + check modularInverse(-5, 25).is_none() |
| 66 | + check modularInverse(0, 17).is_none() |
| 67 | + check modularInverse(17, 17).is_none() |
| 68 | + |
| 69 | + randomize() |
| 70 | + const randomTestSize = 1000 |
| 71 | + for testNum in 0..randomTestSize: |
| 72 | + let a = rand(-10000000..10000000) |
| 73 | + let modulus = rand(1..1000000) |
| 74 | + test fmt"(random test) a={a}, modulus={modulus}": |
| 75 | + let inv = modularInverse(a, modulus) |
| 76 | + if inv.isSome(): |
| 77 | + check 0 < inv.get() |
| 78 | + check inv.get() < modulus |
| 79 | + check math.floorMod(a * inv.get(), modulus) == 1 |
| 80 | + else: |
| 81 | + check math.gcd(a, modulus) != 1 |
0 commit comments