-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolynomial.py
65 lines (53 loc) · 2.11 KB
/
polynomial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Polynomial evaluation and Lagrange interpolation
#
# unoptimized
from dumb25519 import Scalar, Point, ScalarVector, PointVector
import dumb25519
# list of powers of x: [x ** 0, x ** 1, x ** 2, ..., x ** degree]
# * x: Scalar
# * degree: int
def powers(x: Scalar, degree: int) -> ScalarVector:
powers_x = ScalarVector()
for i in range(degree + 1):
powers_x.append(x ** i)
return powers_x
# polynomial evaluation poly_eval(x)
# * coeff: ScalarVector of coefficients
def poly_eval(x: Scalar, coeff: ScalarVector) -> Scalar:
degree = len(coeff) - 1
return powers(x, degree) ** coeff
# polynomial multiplication
# * poly_a: ScalarVector of polynomial 'a'
# * poly_b: ScalarVector of polynomial 'b'
def poly_mul(poly_a: ScalarVector, poly_b: ScalarVector) -> ScalarVector:
prod = [Scalar(0) for i in range(len(poly_a) + len(poly_b) - 1)]
for i in range(len(poly_a)):
for j in range(len(poly_b)):
prod[i + j] += poly_a[i] * poly_b[j]
return ScalarVector(prod)
# Lagrange interpolation
# * coords: list of coordinates (in Scalar)
def lagrange(coords: list) -> ScalarVector:
poly = ScalarVector([Scalar(0) for i in range(len(coords))])
for i in range(len(coords)):
basis = ScalarVector([Scalar(1)])
for j in range(len(coords)):
if j == i:
continue
basis = poly_mul(basis, ScalarVector([-coords[j][0], Scalar(1)]))
basis *= (coords[i][0] - coords[j][0]).invert()
poly += basis * coords[i][1]
return poly
if __name__ == '__main__':
my_points = [(Scalar(-1), dumb25519.random_scalar()),
(Scalar(0), dumb25519.random_scalar()),
(Scalar(1), dumb25519.random_scalar())]
my_coeffs = lagrange(my_points)
# test
passed = True
for i in my_points:
passed &= (poly_eval(i[0], my_coeffs) == i[1])
if passed:
print('The implementation of Lagrange interpolation works!')
else:
print('There\'s a problem in the implementation of Lagrange interpolation.')