-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathrun_cusadi_function_test.py
77 lines (56 loc) · 2.73 KB
/
run_cusadi_function_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import argparse
import torch
from casadi import *
from src import *
def main(args):
fn_filepath = os.path.join(CUSADI_FUNCTION_DIR, f"{args.fn_name}.casadi")
f = casadi.Function.load(fn_filepath)
print("Evaluating function:", f.name())
print("Function has %d arguments" % f.n_in())
print("Function has %d outputs" % f.n_out())
input_tensors = [torch.rand(args.n_envs, f.nnz_in(i), device='cuda', dtype=torch.double).contiguous()
for i in range(f.n_in())]
fn_cusadi = CusadiFunction(f, args.n_envs)
import time
start = time.perf_counter_ns()
fn_cusadi.evaluate(input_tensors)
torch.cuda.synchronize()
end = time.perf_counter_ns()
print(f"Time taken to evaluate {args.n_envs} environments: {(end-start)/1e9:.6f} seconds")
print("Time eval out: ", fn_cusadi.eval_time)
output_numpy = [numpy.zeros((args.n_envs, f.nnz_out(i))) for i in range(f.n_out())]
for n in range(args.n_envs):
inputs_np = [input_tensors[i][n, :].cpu().numpy() for i in range(f.n_in())]
for i in range(f.n_out()):
output_numpy[i][n, :] = f.call(inputs_np)[i].nonzeros()
print(f"Evaluating with {args.n_envs} environments.")
print(f"Average error for each environment:")
for i in range(f.n_out()):
error_norm = numpy.linalg.norm(fn_cusadi.outputs_sparse[i].cpu().numpy() - output_numpy[i])/args.n_envs
print(f"Output {i} error norm:", error_norm)
def printParserArguments(parser, args):
# Print out all arguments, descriptions, and default values in a formatted manner
print(f"\n{'Argument':<10} {'Description':<80} {'Default':<10} {'Current Value':<10}")
print("=" * 120)
for action in parser._actions:
if action.dest == 'help':
continue
arg_strings = ', '.join(action.option_strings)
description = action.help or 'No description'
default = action.default if action.default is not argparse.SUPPRESS else 'No default'
current_value = getattr(args, action.dest, default)
print(f"{arg_strings:<10} {description:<80} {default:<10} {current_value:<10}")
print()
def setupParser():
parser = argparse.ArgumentParser(description='Script to evaluate Cusadi function and check error')
parser.add_argument('--fn', type=str, dest='fn_name', default='test',
help='Function name in cusadi/casadi_functions, defaults to "test"')
parser.add_argument('--num_envs', type=int, dest='n_envs', default=4000,
help='Number of instances to evaluate in parallel, default to 4000')
return parser
if __name__ == "__main__":
parser = setupParser()
args = parser.parse_args()
printParserArguments(parser, args)
main(args)