Skip to content

Commit

Permalink
[evel] use python to generate header
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas-Wye committed Feb 7, 2025
1 parent 804da11 commit 54cb9c3
Show file tree
Hide file tree
Showing 19 changed files with 274 additions and 1,045 deletions.
34 changes: 20 additions & 14 deletions tests/eval/_ntt/default.nix
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
{ linkerScript
, makeBuilder
, python3
, t1main
}:

let
builder = makeBuilder { casePrefix = "eval"; };
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src:
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src: caseArgs: extra_flag:
builder {
caseName = caseName;

Expand All @@ -16,7 +17,12 @@ let
buildPhase = ''
runHook preBuild
${python3}/bin/python3 ./gen_header.py ${caseArgs}
$CC -T${linkerScript} \
-DCASE=${caseArgs} \
${extra_flag} \
-I. \
${main_src} ${kernel_src} \
${t1main} \
-o $pname.elf
Expand All @@ -28,17 +34,17 @@ let
};

in {
ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_64_main.c;
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_128_main.c;
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_256_main.c;
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_512_main.c;
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_1024_main.c;
ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_4096_main.c;

ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_64_main.c;
ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_128_main.c;
ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_256_main.c;
ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_512_main.c;
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_1024_main.c;
ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_4096_main.c;
ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_main.c "ntt_64" "";
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_main.c "ntt_128" "";
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_main.c "ntt_256" "";
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_main.c "ntt_512" "";
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_main.c "ntt_1024" "";
ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_main.c "ntt_4096" "";

ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_main.c "ntt_64" "-DUSE_SCALAR";
ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_main.c "ntt_128" "-DUSE_SCALAR";
ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_main.c "ntt_256" "-DUSE_SCALAR";
ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_main.c "ntt_512" "-DUSE_SCALAR";
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_main.c "ntt_1024" "-DUSE_SCALAR";
ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_main.c "ntt_4096" "-DUSE_SCALAR";
}
24 changes: 0 additions & 24 deletions tests/eval/_ntt/gen_data.py

This file was deleted.

118 changes: 118 additions & 0 deletions tests/eval/_ntt/gen_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import sys
import json
import random


def genRandomPoly(l, p):
n = 1 << l
a = [random.randrange(p) for _ in range(n)]
return a


def genGoldPoly(l, p, g, poly):
n = 1 << l
poly_out = []
for i in range(n):
tmp = 0
for j in range(n):
tmp += poly[j] * pow(g, i * j, p)
tmp = tmp % p
poly_out.append(tmp)
return poly_out


def genScalarTW(l, p, g):
w = g

twiddle_list = []
for _ in range(l):
twiddle_list.append(w)
w = (w * w) % p

return twiddle_list


def genVectorTW(l, p, g):
n = 1 << l
m = 2
layerIndex = 0

outTW = []
while m <= n:
# print(f"// layer #{layerIndex}")
layerIndex += 1
wPower = 0

for j in range(m // 2):
k = 0
while k < n:
currentW = pow(g, wPower, p)
k += m
outTW.append(currentW)
# print(currentW, end =", ")
wPower += n // m
m *= 2
# print("\n")
return outTW


def main(l, p, g):
# poly_in = genRandomPoly(l, p)
# poly_out = genGoldPoly(l, p, g, poly_in)
scalar_tw = genScalarTW(l, p, g)
vector_tw = genVectorTW(l, p, g)
n = 1 << l
data = {
"l": l,
"n": n,
"p": p,
# "input": poly_in,
# "output": poly_out,
"vector_tw": vector_tw,
"scalar_tw": scalar_tw,
}

json_name = "ntt_" + str(n) + ".json"
with open(json_name, "r") as json_in:
json_data = json.load(json_in)

header_str = "#define macroL " + str(data["l"]) + "\n"
header_str += "#define macroN " + str(data["n"]) + "\n"
header_str += "#define macroP " + str(data["p"]) + "\n"
header_str += (
"#define macroIn " + ",".join(str(e) for e in json_data["input"]) + "\n"
)
header_str += (
"#define macroOut " + ",".join(str(e) for e in json_data["output"]) + "\n"
)
header_str += (
"#define macroScalarTW " + ",".join(str(e) for e in data["scalar_tw"]) + "\n"
)
header_str += (
"#define macroVectorTW " + ",".join(str(e) for e in data["vector_tw"]) + "\n"
)

header_file = "ntt_" + str(n) + ".h"
with open(header_file, "w") as header_out:
header_out.write(header_str)


if __name__ == "__main__":
if len(sys.argv) != 2:
raise Exception("No Enough Input Args")

p = 12289
if sys.argv[1] in ("ntt_64"):
main(6, p, 7311)
elif sys.argv[1] in ("ntt_128"):
main(7, p, 12149)
elif sys.argv[1] in ("ntt_256"):
main(8, p, 8340)
elif sys.argv[1] in ("ntt_512"):
main(9, p, 3400)
elif sys.argv[1] in ("ntt_1024"):
main(10, p, 10302)
elif sys.argv[1] in ("ntt_4096"):
main(12, p, 1331)
else:
raise Exception(f"Unknown Args {sys.argv[1]}")
53 changes: 0 additions & 53 deletions tests/eval/_ntt/gen_vector_ntt_tw.py

This file was deleted.

11 changes: 4 additions & 7 deletions tests/eval/_ntt/ntt.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
#include <assert.h>
#include <stdio.h>

// #define USERN 32
// #define DEBUG

// array is of length n=2^l, p is a prime number
// roots is of length l, where g = roots[0] satisfies that
// g^(2^l) == 1 mod p and g^(2^(l-1)) == -1 mod p
Expand Down Expand Up @@ -94,15 +91,15 @@ void ntt(const int *array, int l, const int *twiddle, int p, int *dst) {
"vle32.v v12, 0(%5)\n"

#ifdef DEBUG
"vse32.v v8, 0(%6)\n"// c
"vse32.v v12, 0(%7)\n"// c
"vse32.v v16, 0(%8)\n"// c
"vse32.v v8, 0(%6)\n"
"vse32.v v12, 0(%7)\n"
"vse32.v v16, 0(%8)\n"
#endif

// butterfly operation
"vmul.vv v12, v12, v16\n"
"vrem.vx v12, v12, %3\n"
"vadd.vv v16, v8, v12\n" // TODO: will it overflow?
"vadd.vv v16, v8, v12\n" // NOTE: use lazy reduction here
"vsub.vv v20, v8, v12\n"
// save half coefficients
"vse32.v v16, 0(%4)\n"
Expand Down
Loading

0 comments on commit 54cb9c3

Please sign in to comment.