-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-Authored-By: Lucas-Wye <[email protected]>
- Loading branch information
Showing
13 changed files
with
1,281 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
{ linkerScript | ||
, makeBuilder | ||
, t1main | ||
}: | ||
|
||
let | ||
builder = makeBuilder { casePrefix = "eval"; }; | ||
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src: | ||
builder { | ||
caseName = caseName; | ||
|
||
src = ./.; | ||
|
||
passthru.featuresRequired = { }; | ||
|
||
buildPhase = '' | ||
runHook preBuild | ||
$CC -T${linkerScript} \ | ||
${main_src} ${kernel_src} \ | ||
${t1main} \ | ||
-o $pname.elf | ||
runHook postBuild | ||
''; | ||
|
||
meta.description = "test case 'ntt'"; | ||
}; | ||
|
||
in { | ||
ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_64_main.c; | ||
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_128_main.c; | ||
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_256_main.c; | ||
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_512_main.c; | ||
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_1024_main.c; | ||
ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_4096_main.c; | ||
|
||
ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_64_main.c; | ||
ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_128_main.c; | ||
ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_256_main.c; | ||
ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_512_main.c; | ||
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_1024_main.c; | ||
ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_4096_main.c; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import random | ||
|
||
def main(): | ||
vlen = 4096 | ||
l = 12 | ||
n = 1 << l | ||
# assert n <= vlen // 4 | ||
p = 12289 # p is prime and n | p - 1 | ||
g = 11 # primitive root of p | ||
assert (p - 1) % n == 0 | ||
w = (g ** ((p - 1) // n)) % p # now w^n == 1 mod p by Fermat's little theorem | ||
print(w) | ||
|
||
twindle_list = [] | ||
for _ in range(l): | ||
twindle_list.append(w) | ||
w = (w * w) % p | ||
print(twindle_list) | ||
|
||
a = [random.randrange(p) for _ in range(n)] | ||
print(a) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
def gen_tw_for_vector_ntt(l, w_one, prime_p): | ||
n = pow(2, l) | ||
w_power_list = [] | ||
m = 2 | ||
while m <= n: | ||
w_power = 0 | ||
w = 1 | ||
w_power_dict = {} | ||
for j in range(m // 2): | ||
k = 0 | ||
while k < n: | ||
i_u = k + j | ||
i_t = k + j + m //2 | ||
k += m | ||
w_power_dict[i_u] = (i_t, w_power) | ||
w_power += n//m | ||
m = 2 * m | ||
w_power_list.append(w_power_dict) | ||
|
||
# print(w_power_list) | ||
perm_each = { } | ||
for i in range(n//2): | ||
perm_each[i] = i | ||
perm_each[i+n//2] = i + n//2 | ||
# print("(coe 0, 1), w_power, (permu 0, 1)\n") | ||
print(f"\nfor ntt {n}") | ||
layer_index = 0 | ||
for w_power_dict in w_power_list: | ||
print(f"// layer #{layer_index}") | ||
layer_index += 1 | ||
|
||
# sort_keys = sorted(w_power_dict.keys()) | ||
sort_keys = w_power_dict.keys() | ||
index = 0 | ||
for w_key in sort_keys: | ||
# print(f"({w_key}, {w_power_dict[w_key][0]}), {w_power_dict[w_key][1]}, ", end = "") | ||
# print(f"({perm_each[w_key]}, {perm_each[w_power_dict[w_key][0]]})") | ||
current_w = pow(w_one, w_power_dict[w_key][1], prime_p) | ||
print(current_w, end = ", ") | ||
perm_each[w_key] = index | ||
perm_each[w_power_dict[w_key][0]] = index + n//2 | ||
index += 1 | ||
|
||
print("\n") | ||
|
||
if __name__ == '__main__': | ||
gen_tw_for_vector_ntt(6, 7311, 12289) | ||
gen_tw_for_vector_ntt(7, 12149, 12289) | ||
gen_tw_for_vector_ntt(8, 8340, 12289) | ||
gen_tw_for_vector_ntt(9, 3400, 12289) | ||
gen_tw_for_vector_ntt(10, 10302, 12289) | ||
gen_tw_for_vector_ntt(12, 1331, 12289) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#include <assert.h> | ||
#include <stdio.h> | ||
|
||
// #define USERN 32 | ||
// #define DEBUG | ||
|
||
// array is of length n=2^l, p is a prime number | ||
// roots is of length l, where g = roots[0] satisfies that | ||
// g^(2^l) == 1 mod p and g^(2^(l-1)) == -1 mod p | ||
// roots[i] = g^(2^i) (hence roots[l - 1] = -1) | ||
// | ||
// 32bit * n <= VLEN * 8 => n <= VLEN / 4 | ||
void ntt(const int *array, int l, const int *twiddle, int p, int *dst) { | ||
// prepare an array of permutation indices | ||
assert(l <= 16); | ||
|
||
int n = 1 << l; | ||
|
||
// registers: | ||
// v8-15: array | ||
// v16-24: loaded elements (until vrgather) | ||
// v4-7: permutation index (until vrgather) | ||
// v16-24: coefficients | ||
int vlenb; | ||
asm("csrr %0, vlenb" : "=r"(vlenb)); | ||
int elements_in_vreg = vlenb * 2; | ||
assert(elements_in_vreg >= n); | ||
|
||
asm("vsetvli zero, %0, e16, m4, tu, mu\n" | ||
"vid.v v4\n" | ||
: | ||
: "r"(n)); | ||
|
||
// prepare the bit-reversal permutation list | ||
for (int k = 0; 2 * k < l; k++) { | ||
asm("vand.vx v8, v4, %0\n" | ||
"vsub.vv v4, v4, v8\n" | ||
"vsll.vx v8, v8, %1\n" // get the k-th digit and shift left | ||
|
||
"vand.vx v12, v4, %2\n" | ||
"vsub.vv v4, v4, v12\n" | ||
"vsrl.vx v12, v12, %1\n" // get the (l-k-1)-th digit and shift right | ||
|
||
"vor.vv v4, v4, v8\n" | ||
"vor.vv v4, v4, v12\n" | ||
|
||
: | ||
: "r"(1 << k), "r"(l - 1 - 2 * k), "r"(1 << (l - k - 1))); | ||
} | ||
|
||
// perform bit-reversal for input coefficients | ||
asm("vsetvli zero, %0, e32, m8, tu, mu\n" | ||
"vle32.v v16, 0(%1)\n" | ||
"vrgatherei16.vv v8, v16, v4\n" | ||
"vse32.v v8, 0(%2)\n" | ||
|
||
: | ||
: "r"(n), "r"(array), "r"(dst)); | ||
|
||
// generate permutation list (0, 2, 4, ..., 1, 3, 5, ...) | ||
asm("vsetvli zero, %0, e16, m4, tu, mu\n" | ||
"vid.v v4\n" | ||
"vsrl.vx v0, v4, %1\n" // (0, 0, 0, 0, ..., 1, 1, 1, 1, ...) | ||
"vand.vx v4, v4, %2\n" // (0, 1, 2, 3, ..., 0, 1, 2, 3, ...) | ||
"vsll.vi v4, v4, 1\n" | ||
"vadd.vv v4, v4, v0\n" | ||
|
||
: | ||
: "r"(n), "r"(l-1), "r"((n / 2 - 1)), "r"(n / 2)); | ||
|
||
#ifdef DEBUG | ||
int tmp1[USERN];// c | ||
int tmp2[USERN];// c | ||
int tmp3[USERN];// c | ||
#endif | ||
|
||
for (int k = 0; k < l; k++) { | ||
asm( | ||
// "n" mode | ||
"vsetvli zero, %0, e32, m8, tu, mu\n" | ||
// load coefficients | ||
"vle32.v v16, 0(%4)\n" | ||
// perform permutation for coefficient | ||
"vrgatherei16.vv v8, v16, v4\n" | ||
// save coefficients | ||
"vse32.v v8, 0(%4)\n" | ||
|
||
// "n/2" mode | ||
"vsetvli zero, %1, e32, m4, tu, mu\n" | ||
// load twiddle factors | ||
"vle32.v v16, 0(%2)\n" | ||
// load half coefficients | ||
"vle32.v v8, 0(%4)\n" | ||
"vle32.v v12, 0(%5)\n" | ||
|
||
#ifdef DEBUG | ||
"vse32.v v8, 0(%6)\n"// c | ||
"vse32.v v12, 0(%7)\n"// c | ||
"vse32.v v16, 0(%8)\n"// c | ||
#endif | ||
|
||
// butterfly operation | ||
"vmul.vv v12, v12, v16\n" | ||
"vrem.vx v12, v12, %3\n" | ||
"vadd.vv v16, v8, v12\n" // TODO: will it overflow? | ||
"vsub.vv v20, v8, v12\n" | ||
// save half coefficients | ||
"vse32.v v16, 0(%4)\n" | ||
"vse32.v v20, 0(%5)\n" | ||
: | ||
: /* %0 */ "r"(n), | ||
/* %1 */ "r"(n / 2), | ||
/* %2 */ "r"(twiddle + k * (n / 2)), | ||
/* %3 */ "r"(p), | ||
"r"(dst), | ||
"r"(dst + (n / 2)) | ||
#ifdef DEBUG | ||
, "r"(tmp1), "r"(tmp2), "r"(tmp3) | ||
#endif | ||
); | ||
#ifdef DEBUG | ||
for(int k = 0; k < USERN; k++) { | ||
printf("(%x, %x, %x)\n", tmp1[k], tmp2[k], tmp3[k]); | ||
} | ||
#endif | ||
} | ||
// deal with modular | ||
asm("vsetvli zero, %0, e32, m8, tu, mu\n" | ||
"vle32.v v16, 0(%1)\n" | ||
"vrem.vx v8, v16, %2\n" | ||
"vse32.v v8, 0(%1)\n" | ||
|
||
: | ||
: "r"(n), "r"(dst), "r"(p)); | ||
} |
Oops, something went wrong.