Skip to content

Commit

Permalink
Add export option for linearized evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 16, 2024
1 parent 28c09c7 commit 13d6bff
Showing 1 changed file with 101 additions and 3 deletions.
104 changes: 101 additions & 3 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ pub enum Expression<T> {

pub struct ExpressionEvaluator<T> {
stack: Vec<T>,
param_count: usize,
reserved_indices: usize,
instructions: Vec<Instr>,
result_indices: Vec<usize>,
Expand Down Expand Up @@ -323,6 +324,101 @@ impl<T> ExpressionEvaluator<T> {
}
}

impl<T: std::fmt::Display> ExpressionEvaluator<T> {
pub fn export_cpp(&self) -> String {
let mut res = String::new();
res += &format!("\ntemplate<typename T>\nvoid eval(T* params, T* out) {{\n");

res += &format!(
"\tT {};\n",
(0..self.stack.len())
.map(|x| format!("Z{}", x))
.collect::<Vec<_>>()
.join(", ")
);

for i in 0..self.param_count {
res += &format!("\tZ{} = params[{}];\n", i, i);
}

for i in self.param_count..self.reserved_indices {
res += &format!("\tZ{} = {};\n", i, self.stack[i]);
}

Self::export_cpp_impl(&self.instructions, &mut res);

for (i, r) in &mut self.result_indices.iter().enumerate() {
res += &format!("\tout[{}] = Z{};\n", i, r);
}

res += "\treturn;\n}\n";

res += "\nextern \"C\" {\n\tvoid eval_double(double* params, double* out) {\n\t\t eval(params, out);\n\t}\n}\n";
res += "\nextern \"C\" {\n\tvoid eval_complex(std::complex<double>* params, std::complex<double>* out) {\n\t\t eval(params, out);\n\t}\n}\n";

let header = "#include <iostream>\n#include <complex>\n#include <cmath>\n\n";

header.to_string() + res.as_str()
}

fn export_cpp_impl(instr: &[Instr], out: &mut String) {
for ins in instr {
match ins {
Instr::Add(o, a) => {
let args = a
.iter()
.map(|x| format!("Z{}", x))
.collect::<Vec<_>>()
.join("+");

*out += format!("\tZ{} = {};\n", o, args).as_str();
}
Instr::Mul(o, a) => {
let args = a
.iter()
.map(|x| format!("Z{}", x))
.collect::<Vec<_>>()
.join("*");

*out += format!("\tZ{} = {};\n", o, args).as_str();
}
Instr::Pow(o, b, e) => {
let base = format!("Z{}", b);
*out += format!("\tZ{} = pow({}, {});\n", o, base, e).as_str();
}
Instr::Powf(o, b, e) => {
let base = format!("Z{}", b);
let exp = format!("Z{}", e);
*out += format!("\tZ{} = pow({}, {});\n", o, base, exp).as_str();
}
Instr::BuiltinFun(o, s, a) => match *s {
State::EXP => {
let arg = format!("Z{}", a);
*out += format!("\tZ{} = exp({});\n", o, arg).as_str();
}
State::LOG => {
let arg = format!("Z{}", a);
*out += format!("\tZ{} = log({});\n", o, arg).as_str();
}
State::SIN => {
let arg = format!("Z{}", a);
*out += format!("\tZ{} = sin({});\n", o, arg).as_str();
}
State::COS => {
let arg = format!("Z{}", a);
*out += format!("\tZ{} = cos({});\n", o, arg).as_str();
}
State::SQRT => {
let arg = format!("Z{}", a);
*out += format!("\tZ{} = sqrt({});\n", o, arg).as_str();
}
_ => unreachable!(),
},
}
}
}
}

#[derive(Debug)]
enum Instr {
Add(usize, Vec<usize>),
Expand Down Expand Up @@ -434,11 +530,11 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {
}

/// Create a linear version of the tree that can be evaluated more efficiently.
pub fn linearize(mut self, param_len: usize) -> ExpressionEvaluator<T> {
let mut stack = vec![T::default(); param_len];
pub fn linearize(mut self, param_count: usize) -> ExpressionEvaluator<T> {
let mut stack = vec![T::default(); param_count];

// strip every constant and move them into the stack after the params
self.strip_constants(&mut stack, param_len);
self.strip_constants(&mut stack, param_count);
let reserved_indices = stack.len();

let mut sub_expr_pos = HashMap::default();
Expand All @@ -460,6 +556,7 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {

let mut e = ExpressionEvaluator {
stack,
param_count,
reserved_indices,
instructions,
result_indices,
Expand Down Expand Up @@ -1478,6 +1575,7 @@ pub struct CompileOptions {
}

impl Default for CompileOptions {
/// Default compile options: `g++ -O3 -ffast-math -funsafe-math-optimizations`.
fn default() -> Self {
CompileOptions {
optimization_level: 3,
Expand Down

0 comments on commit 13d6bff

Please sign in to comment.