Skip to content

Commit

Permalink
Make to_eval_tree fallible
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 15, 2024
1 parent 347d9d5 commit 28c09c7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
3 changes: 2 additions & 1 deletion examples/nested_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ fn main() {
|r| r.clone(),
&fn_map,
&params,
);
)
.unwrap();

// optimize the tree using an occurrence-order Horner scheme
println!("Op original {:?}", tree.count_operations());
Expand Down
90 changes: 46 additions & 44 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ impl<'a> AtomView<'a> {
coeff_map: F,
fn_map: &FunctionMap<'a, T>,
params: &[Atom],
) -> EvalTree<T> {
) -> Result<EvalTree<T>, String> {
Self::to_eval_tree_multiple(std::slice::from_ref(self), coeff_map, fn_map, params)
}

Expand All @@ -1699,20 +1699,20 @@ impl<'a> AtomView<'a> {
coeff_map: F,
fn_map: &FunctionMap<'a, T>,
params: &[Atom],
) -> EvalTree<T> {
) -> Result<EvalTree<T>, String> {
let mut funcs = vec![];
let tree = exprs
.iter()
.map(|t| t.to_eval_tree_impl(coeff_map, fn_map, params, &[], &mut funcs))
.collect();
.collect::<Result<_, _>>()?;

EvalTree {
Ok(EvalTree {
expressions: SplitExpression {
tree,
subexpressions: vec![],
},
functions: funcs,
}
})
}

fn to_eval_tree_impl<T: Clone + Default, F: Fn(&Rational) -> T + Copy>(
Expand All @@ -1722,27 +1722,27 @@ impl<'a> AtomView<'a> {
params: &[Atom],
args: &[Symbol],
funcs: &mut Vec<(String, Vec<Symbol>, SplitExpression<T>)>,
) -> Expression<T> {
) -> Result<Expression<T>, String> {
if let Some(p) = params.iter().position(|a| a.as_view() == *self) {
return Expression::Parameter(p);
return Ok(Expression::Parameter(p));
}

if let Some(c) = fn_map.get(*self) {
return match c {
ConstOrExpr::Const(c) => Expression::Const(c.clone()),
ConstOrExpr::Const(c) => Ok(Expression::Const(c.clone())),
ConstOrExpr::Expr(name, tag_len, args, v) => {
if args.len() != *tag_len {
panic!(
return Err(format!(
"Function {} called with wrong number of arguments: 0 vs {}",
self,
args.len()
);
));
}

if let Some(pos) = funcs.iter().position(|f| f.0 == *name) {
Expression::Eval(pos, vec![])
Ok(Expression::Eval(pos, vec![]))
} else {
let r = v.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs);
let r = v.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)?;
funcs.push((
name.clone(),
args.clone(),
Expand All @@ -1751,7 +1751,7 @@ impl<'a> AtomView<'a> {
subexpressions: vec![],
},
));
Expression::Eval(funcs.len() - 1, vec![])
Ok(Expression::Eval(funcs.len() - 1, vec![]))
}
}
};
Expand All @@ -1760,61 +1760,62 @@ impl<'a> AtomView<'a> {
match self {
AtomView::Num(n) => match n.get_coeff_view() {
CoefficientView::Natural(n, d) => {
Expression::Const(coeff_map(&Rational::Natural(n, d)))
Ok(Expression::Const(coeff_map(&Rational::Natural(n, d))))
}
CoefficientView::Large(l) => {
Expression::Const(coeff_map(&Rational::Large(l.to_rat())))
Ok(Expression::Const(coeff_map(&Rational::Large(l.to_rat()))))
}
CoefficientView::Float(f) => {
// TODO: converting back to rational is slow
Expression::Const(coeff_map(&f.to_float().to_rational()))
Ok(Expression::Const(coeff_map(&f.to_float().to_rational())))
}
CoefficientView::FiniteField(_, _) => {
unimplemented!("Finite field not yet supported for evaluation")
}
CoefficientView::RationalPolynomial(_) => {
unimplemented!(
"Rational polynomial coefficient not yet supported for evaluation"
)
Err("Finite field not yet supported for evaluation".to_string())
}
CoefficientView::RationalPolynomial(_) => Err(
"Rational polynomial coefficient not yet supported for evaluation".to_string(),
),
},
AtomView::Var(v) => {
let name = v.get_symbol();

if let Some(p) = args.iter().position(|s| *s == name) {
return Expression::ReadArg(p);
return Ok(Expression::ReadArg(p));
}

panic!(
Err(format!(
"Variable {} not in constant map",
State::get_name(v.get_symbol())
);
))
}
AtomView::Fun(f) => {
let name = f.get_symbol();
if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) {
assert!(f.get_nargs() == 1);
let arg = f.iter().next().unwrap();
let arg_eval = arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs);
let arg_eval = arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)?;

return Expression::BuiltinFun(f.get_symbol(), Box::new(arg_eval));
return Ok(Expression::BuiltinFun(f.get_symbol(), Box::new(arg_eval)));
}

let symb = InlineVar::new(f.get_symbol());
let Some(fun) = fn_map.get(symb.as_view()) else {
panic!("Undefined function {}", State::get_name(f.get_symbol()));
return Err(format!(
"Undefined function {}",
State::get_name(f.get_symbol())
));
};

match fun {
ConstOrExpr::Const(t) => Expression::Const(t.clone()),
ConstOrExpr::Const(t) => Ok(Expression::Const(t.clone())),
ConstOrExpr::Expr(name, tag_len, arg_spec, e) => {
if f.get_nargs() != arg_spec.len() + *tag_len {
panic!(
return Err(format!(
"Function {} called with wrong number of arguments: {} vs {}",
f.get_symbol(),
f.get_nargs(),
arg_spec.len() + *tag_len
);
));
}

let eval_args = f
Expand All @@ -1823,12 +1824,13 @@ impl<'a> AtomView<'a> {
.map(|arg| {
arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)
})
.collect();
.collect::<Result<_, _>>()?;

if let Some(pos) = funcs.iter().position(|f| f.0 == *name) {
Expression::Eval(pos, eval_args)
Ok(Expression::Eval(pos, eval_args))
} else {
let r = e.to_eval_tree_impl(coeff_map, fn_map, params, arg_spec, funcs);
let r =
e.to_eval_tree_impl(coeff_map, fn_map, params, arg_spec, funcs)?;
funcs.push((
name.clone(),
arg_spec.clone(),
Expand All @@ -1837,49 +1839,49 @@ impl<'a> AtomView<'a> {
subexpressions: vec![],
},
));
Expression::Eval(funcs.len() - 1, eval_args)
Ok(Expression::Eval(funcs.len() - 1, eval_args))
}
}
}
}
AtomView::Pow(p) => {
let (b, e) = p.get_base_exp();
let b_eval = b.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs);
let b_eval = b.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)?;

if let AtomView::Num(n) = e {
if let CoefficientView::Natural(num, den) = n.get_coeff_view() {
if den == 1 {
if num > 1 {
return Expression::Mul(vec![b_eval.clone(); num as usize]);
return Ok(Expression::Mul(vec![b_eval.clone(); num as usize]));
}
return Expression::Pow(Box::new((b_eval, num)));
return Ok(Expression::Pow(Box::new((b_eval, num))));
}
}
}

let e_eval = e.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs);
Expression::Powf(Box::new((b_eval, e_eval)))
let e_eval = e.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)?;
Ok(Expression::Powf(Box::new((b_eval, e_eval))))
}
AtomView::Mul(m) => {
let mut muls = vec![];
for arg in m.iter() {
let a = arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs);
let a = arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)?;
if let Expression::Mul(m) = a {
muls.extend(m);
} else {
muls.push(a);
}
}

Expression::Mul(muls)
Ok(Expression::Mul(muls))
}
AtomView::Add(a) => {
let mut adds = vec![];
for arg in a.iter() {
adds.push(arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs));
adds.push(arg.to_eval_tree_impl(coeff_map, fn_map, params, args, funcs)?);
}

Expression::Add(adds)
Ok(Expression::Add(adds))
}
}
}
Expand Down

0 comments on commit 28c09c7

Please sign in to comment.