Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NTT Evaulation #962

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ let
pytorch = casesSelf.callPackage ./pytorch { };
disp = casesSelf.callPackage ./disp { };
emurt-test = casesSelf.callPackage ./emurt/tests { };
eval = casesSelf.callPackage ./eval { };
}));

# remove non-case attributes in scope
scopeStripped = {
inherit (scope) mlir intrinsic asm perf codegen rvv_bench pytorch disp emurt-test;
inherit (scope) mlir intrinsic asm perf codegen rvv_bench pytorch disp emurt-test eval;
};

# This derivation is for internal CI use only.
Expand Down
50 changes: 50 additions & 0 deletions tests/eval/_ntt/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{ linkerScript
, makeBuilder
, python3
, t1main
}:

let
builder = makeBuilder { casePrefix = "eval"; };
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src: caseArgs: extra_flag:
builder {
caseName = caseName;

src = ./.;

passthru.featuresRequired = { };

buildPhase = ''
runHook preBuild

${python3}/bin/python3 ./gen_header.py ${caseArgs}

$CC -T${linkerScript} \
-DCASE=${caseArgs} \
${extra_flag} \
-I. \
${main_src} ${kernel_src} \
${t1main} \
-o $pname.elf

runHook postBuild
'';

meta.description = "test case 'ntt'";
};

in {
ntt_64 = build_ntt "ntt_64" ./ntt.c ./ntt_main.c "ntt_64" "";
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_main.c "ntt_128" "";
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_main.c "ntt_256" "";
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_main.c "ntt_512" "";
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_main.c "ntt_1024" "";
ntt_4096 = build_ntt "ntt_4096" ./ntt.c ./ntt_main.c "ntt_4096" "";

ntt_mem_64 = build_ntt "ntt_mem_64" ./ntt_mem.c ./ntt_main.c "ntt_64" "-DUSE_SCALAR";
ntt_mem_128 = build_ntt "ntt_mem_128" ./ntt_mem.c ./ntt_main.c "ntt_128" "-DUSE_SCALAR";
ntt_mem_256 = build_ntt "ntt_mem_256" ./ntt_mem.c ./ntt_main.c "ntt_256" "-DUSE_SCALAR";
ntt_mem_512 = build_ntt "ntt_mem_512" ./ntt_mem.c ./ntt_main.c "ntt_512" "-DUSE_SCALAR";
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_main.c "ntt_1024" "-DUSE_SCALAR";
ntt_mem_4096 = build_ntt "ntt_mem_4096" ./ntt_mem.c ./ntt_main.c "ntt_4096" "-DUSE_SCALAR";
}
118 changes: 118 additions & 0 deletions tests/eval/_ntt/gen_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import sys
import json
import random


def genRandomPoly(l, p):
n = 1 << l
a = [random.randrange(p) for _ in range(n)]
return a


def genGoldPoly(l, p, g, poly):
n = 1 << l
poly_out = []
for i in range(n):
tmp = 0
for j in range(n):
tmp += poly[j] * pow(g, i * j, p)
tmp = tmp % p
poly_out.append(tmp)
return poly_out


def genScalarTW(l, p, g):
w = g

twiddle_list = []
for _ in range(l):
twiddle_list.append(w)
w = (w * w) % p

return twiddle_list


def genVectorTW(l, p, g):
n = 1 << l
m = 2
layerIndex = 0

outTW = []
while m <= n:
# print(f"// layer #{layerIndex}")
layerIndex += 1
wPower = 0

for j in range(m // 2):
k = 0
while k < n:
currentW = pow(g, wPower, p)
k += m
outTW.append(currentW)
# print(currentW, end =", ")
wPower += n // m
m *= 2
# print("\n")
return outTW


def main(l, p, g):
# poly_in = genRandomPoly(l, p)
# poly_out = genGoldPoly(l, p, g, poly_in)
scalar_tw = genScalarTW(l, p, g)
vector_tw = genVectorTW(l, p, g)
n = 1 << l
data = {
"l": l,
"n": n,
"p": p,
# "input": poly_in,
# "output": poly_out,
"vector_tw": vector_tw,
"scalar_tw": scalar_tw,
}

json_name = "ntt_" + str(n) + ".json"
with open(json_name, "r") as json_in:
json_data = json.load(json_in)

header_str = "#define macroL " + str(data["l"]) + "\n"
header_str += "#define macroN " + str(data["n"]) + "\n"
header_str += "#define macroP " + str(data["p"]) + "\n"
header_str += (
"#define macroIn " + ",".join(str(e) for e in json_data["input"]) + "\n"
)
header_str += (
"#define macroOut " + ",".join(str(e) for e in json_data["output"]) + "\n"
)
header_str += (
"#define macroScalarTW " + ",".join(str(e) for e in data["scalar_tw"]) + "\n"
)
header_str += (
"#define macroVectorTW " + ",".join(str(e) for e in data["vector_tw"]) + "\n"
)

header_file = "ntt_" + str(n) + ".h"
with open(header_file, "w") as header_out:
header_out.write(header_str)


if __name__ == "__main__":
if len(sys.argv) != 2:
raise Exception("No Enough Input Args")

p = 12289
if sys.argv[1] in ("ntt_64"):
main(6, p, 7311)
elif sys.argv[1] in ("ntt_128"):
main(7, p, 12149)
elif sys.argv[1] in ("ntt_256"):
main(8, p, 8340)
elif sys.argv[1] in ("ntt_512"):
main(9, p, 3400)
elif sys.argv[1] in ("ntt_1024"):
main(10, p, 10302)
elif sys.argv[1] in ("ntt_4096"):
main(12, p, 1331)
else:
raise Exception(f"Unknown Args {sys.argv[1]}")
132 changes: 132 additions & 0 deletions tests/eval/_ntt/ntt.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include <assert.h>
#include <stdio.h>

// array is of length n=2^l, p is a prime number
// roots is of length l, where g = roots[0] satisfies that
// g^(2^l) == 1 mod p and g^(2^(l-1)) == -1 mod p
// roots[i] = g^(2^i) (hence roots[l - 1] = -1)
//
// 32bit * n <= VLEN * 8 => n <= VLEN / 4
void ntt(const int *array, int l, const int *twiddle, int p, int *dst) {
// prepare an array of permutation indices
assert(l <= 16);

int n = 1 << l;

// registers:
// v8-15: array
// v16-24: loaded elements (until vrgather)
// v4-7: permutation index (until vrgather)
// v16-24: coefficients
int vlenb;
asm("csrr %0, vlenb" : "=r"(vlenb));
int elements_in_vreg = vlenb * 2;
assert(elements_in_vreg >= n);

asm("vsetvli zero, %0, e16, m4, tu, mu\n"
"vid.v v4\n"
:
: "r"(n));

// prepare the bit-reversal permutation list
for (int k = 0; 2 * k < l; k++) {
asm("vand.vx v8, v4, %0\n"
"vsub.vv v4, v4, v8\n"
"vsll.vx v8, v8, %1\n" // get the k-th digit and shift left

"vand.vx v12, v4, %2\n"
"vsub.vv v4, v4, v12\n"
"vsrl.vx v12, v12, %1\n" // get the (l-k-1)-th digit and shift right

"vor.vv v4, v4, v8\n"
"vor.vv v4, v4, v12\n"

:
: "r"(1 << k), "r"(l - 1 - 2 * k), "r"(1 << (l - k - 1)));
}

// perform bit-reversal for input coefficients
asm("vsetvli zero, %0, e32, m8, tu, mu\n"
"vle32.v v16, 0(%1)\n"
"vrgatherei16.vv v8, v16, v4\n"
"vse32.v v8, 0(%2)\n"

:
: "r"(n), "r"(array), "r"(dst));

// generate permutation list (0, 2, 4, ..., 1, 3, 5, ...)
asm("vsetvli zero, %0, e16, m4, tu, mu\n"
"vid.v v4\n"
"vsrl.vx v0, v4, %1\n" // (0, 0, 0, 0, ..., 1, 1, 1, 1, ...)
"vand.vx v4, v4, %2\n" // (0, 1, 2, 3, ..., 0, 1, 2, 3, ...)
"vsll.vi v4, v4, 1\n"
"vadd.vv v4, v4, v0\n"

:
: "r"(n), "r"(l-1), "r"((n / 2 - 1)), "r"(n / 2));

#ifdef DEBUG
int tmp1[USERN];// c
int tmp2[USERN];// c
int tmp3[USERN];// c
#endif

for (int k = 0; k < l; k++) {
asm(
// "n" mode
"vsetvli zero, %0, e32, m8, tu, mu\n"
// load coefficients
"vle32.v v16, 0(%4)\n"
// perform permutation for coefficient
"vrgatherei16.vv v8, v16, v4\n"
// save coefficients
"vse32.v v8, 0(%4)\n"

// "n/2" mode
"vsetvli zero, %1, e32, m4, tu, mu\n"
// load twiddle factors
"vle32.v v16, 0(%2)\n"
// load half coefficients
"vle32.v v8, 0(%4)\n"
"vle32.v v12, 0(%5)\n"

#ifdef DEBUG
"vse32.v v8, 0(%6)\n"
"vse32.v v12, 0(%7)\n"
"vse32.v v16, 0(%8)\n"
#endif

// butterfly operation
"vmul.vv v12, v12, v16\n"
"vrem.vx v12, v12, %3\n"
"vadd.vv v16, v8, v12\n" // NOTE: use lazy reduction here
"vsub.vv v20, v8, v12\n"
// save half coefficients
"vse32.v v16, 0(%4)\n"
"vse32.v v20, 0(%5)\n"
:
: /* %0 */ "r"(n),
/* %1 */ "r"(n / 2),
/* %2 */ "r"(twiddle + k * (n / 2)),
/* %3 */ "r"(p),
"r"(dst),
"r"(dst + (n / 2))
#ifdef DEBUG
, "r"(tmp1), "r"(tmp2), "r"(tmp3)
#endif
);
#ifdef DEBUG
for(int k = 0; k < USERN; k++) {
printf("(%x, %x, %x)\n", tmp1[k], tmp2[k], tmp3[k]);
}
#endif
}
// deal with modular
asm("vsetvli zero, %0, e32, m8, tu, mu\n"
"vle32.v v16, 0(%1)\n"
"vrem.vx v8, v16, %2\n"
"vse32.v v8, 0(%1)\n"

:
: "r"(n), "r"(dst), "r"(p));
}
Loading
Loading