From fea7a041fdf82c883de2d78674b95be494cd8adb Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Tue, 10 Dec 2024 17:00:42 +0100 Subject: [PATCH] Redesign error propagating floats - Use absolute error so that zeros are supported - Add i() function for floats - Add automatic evaluation of complex i - Move EXP, E, etc from State to Atom --- src/api/python.rs | 16 +- src/atom.rs | 52 ++++++- src/coefficient.rs | 4 +- src/derivative.rs | 34 ++--- src/domains/dual.rs | 7 + src/domains/float.rs | 354 ++++++++++++++++++++++++++++++------------- src/evaluate.rs | 81 +++++----- src/id.rs | 6 +- src/normalize.rs | 38 ++--- src/poly/evaluate.rs | 5 +- src/poly/series.rs | 17 +-- src/printer.rs | 6 +- src/state.rs | 24 +-- src/transformer.rs | 30 ++-- 14 files changed, 438 insertions(+), 236 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 920ff00b..816043a5 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -2656,14 +2656,14 @@ impl PythonExpression { #[classattr] #[pyo3(name = "E")] pub fn e() -> PythonExpression { - Atom::new_var(State::E).into() + Atom::new_var(Atom::E).into() } /// The mathematical constant `π`. #[classattr] #[pyo3(name = "PI")] pub fn pi() -> PythonExpression { - Atom::new_var(State::PI).into() + Atom::new_var(Atom::PI).into() } /// The mathematical constant `i`, where @@ -2671,42 +2671,42 @@ impl PythonExpression { #[classattr] #[pyo3(name = "I")] pub fn i() -> PythonExpression { - Atom::new_var(State::I).into() + Atom::new_var(Atom::I).into() } /// The built-in function that converts a rational polynomial to a coefficient. #[classattr] #[pyo3(name = "COEFF")] pub fn coeff() -> PythonExpression { - Atom::new_var(State::COEFF).into() + Atom::new_var(Atom::COEFF).into() } /// The built-in cosine function. #[classattr] #[pyo3(name = "COS")] pub fn cos() -> PythonExpression { - Atom::new_var(State::COS).into() + Atom::new_var(Atom::COS).into() } /// The built-in sine function. #[classattr] #[pyo3(name = "SIN")] pub fn sin() -> PythonExpression { - Atom::new_var(State::SIN).into() + Atom::new_var(Atom::SIN).into() } /// The built-in exponential function. #[classattr] #[pyo3(name = "EXP")] pub fn exp() -> PythonExpression { - Atom::new_var(State::EXP).into() + Atom::new_var(Atom::EXP).into() } /// The built-in logarithm function. #[classattr] #[pyo3(name = "LOG")] pub fn log() -> PythonExpression { - Atom::new_var(State::LOG).into() + Atom::new_var(Atom::LOG).into() } /// Return all defined symbol names (function names and variables). diff --git a/src/atom.rs b/src/atom.rs index 40cc9542..10bb9ebe 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -7,7 +7,7 @@ use crate::{ coefficient::Coefficient, parser::Token, printer::{AtomPrinter, PrintOptions}, - state::{RecycledAtom, Workspace}, + state::{RecycledAtom, State, Workspace}, transformer::StatsOptions, }; use std::{cmp::Ordering, hash::Hash, ops::DerefMut, str::FromStr}; @@ -511,6 +511,56 @@ pub enum Atom { Zero, } +impl Atom { + /// The built-in function represents a list of function arguments. + pub const ARG: Symbol = State::ARG; + /// The built-in function that converts a rational polynomial to a coefficient. + pub const COEFF: Symbol = State::COEFF; + /// The exponent function. + pub const EXP: Symbol = State::EXP; + /// The logarithm function. + pub const LOG: Symbol = State::LOG; + /// The sine function. + pub const SIN: Symbol = State::SIN; + /// The cosine function. + pub const COS: Symbol = State::COS; + /// The square root function. + pub const SQRT: Symbol = State::SQRT; + /// The built-in function that represents an abstract derivative. + pub const DERIVATIVE: Symbol = State::DERIVATIVE; + /// The constant e, the base of the natural logarithm. + pub const E: Symbol = State::E; + /// The constant i, the imaginary unit. + pub const I: Symbol = State::I; + /// The mathematical constant `π`. + pub const PI: Symbol = State::PI; + + /// Exponentiate the atom. + pub fn exp(&self) -> Atom { + FunctionBuilder::new(Atom::EXP).add_arg(self).finish() + } + + /// Take the logarithm of the atom. + pub fn log(&self) -> Atom { + FunctionBuilder::new(Atom::LOG).add_arg(self).finish() + } + + /// Take the sine the atom. + pub fn sin(&self) -> Atom { + FunctionBuilder::new(Atom::SIN).add_arg(self).finish() + } + + /// Take the cosine the atom. + pub fn cos(&self) -> Atom { + FunctionBuilder::new(Atom::COS).add_arg(self).finish() + } + + /// Take the square root of the atom. + pub fn sqrt(&self) -> Atom { + FunctionBuilder::new(Atom::SQRT).add_arg(self).finish() + } +} + impl Default for Atom { /// Create an atom that represents the number 0. #[inline] diff --git a/src/coefficient.rs b/src/coefficient.rs index efea5884..cfb2a331 100644 --- a/src/coefficient.rs +++ b/src/coefficient.rs @@ -1365,13 +1365,13 @@ impl<'a> AtomView<'a> { let s = v.get_symbol(); match s { - State::PI => { + Atom::PI => { out.to_num(Coefficient::Float(Float::with_val( binary_prec, rug::float::Constant::Pi, ))); } - State::E => { + Atom::E => { out.to_num(Coefficient::Float(Float::with_val(binary_prec, 1).exp())); } _ => { diff --git a/src/derivative.rs b/src/derivative.rs index 9080066e..6d8df1b1 100644 --- a/src/derivative.rs +++ b/src/derivative.rs @@ -9,7 +9,7 @@ use crate::{ combinatorics::CombinationWithReplacementIterator, domains::{atom::AtomField, integer::Integer, rational::Rational}, poly::{series::Series, Variable}, - state::{State, Workspace}, + state::Workspace, }; impl Atom { @@ -83,7 +83,7 @@ impl<'a> AtomView<'a> { // detect if the function to derive is the derivative function itself // if so, derive the last argument of the derivative function and set // a flag to later accumulate previous derivatives - let (to_derive, f, is_der) = if f_orig.get_symbol() == State::DERIVATIVE { + let (to_derive, f, is_der) = if f_orig.get_symbol() == Atom::DERIVATIVE { let to_derive = f_orig.iter().last().unwrap(); ( to_derive, @@ -113,29 +113,29 @@ impl<'a> AtomView<'a> { // derive special functions if f.get_nargs() == 1 - && [State::EXP, State::LOG, State::SIN, State::COS].contains(&f.get_symbol()) + && [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS].contains(&f.get_symbol()) { let mut fn_der = workspace.new_atom(); match f.get_symbol() { - State::EXP => { + Atom::EXP => { fn_der.set_from_view(self); } - State::LOG => { + Atom::LOG => { let mut n = workspace.new_atom(); n.to_num((-1).into()); fn_der.to_pow(f.iter().next().unwrap(), n.as_view()); } - State::SIN => { - let p = fn_der.to_fun(State::COS); + Atom::SIN => { + let p = fn_der.to_fun(Atom::COS); p.add_arg(f.iter().next().unwrap()); } - State::COS => { + Atom::COS => { let mut n = workspace.new_atom(); n.to_num((-1).into()); let mut sin = workspace.new_atom(); - let sin_fun = sin.to_fun(State::SIN); + let sin_fun = sin.to_fun(Atom::SIN); sin_fun.add_arg(f.iter().next().unwrap()); let m = fn_der.to_mul(); @@ -167,7 +167,7 @@ impl<'a> AtomView<'a> { let mut n = workspace.new_atom(); let mut mul = workspace.new_atom(); for (index, arg_der) in args_der { - let p = fn_der.to_fun(State::DERIVATIVE); + let p = fn_der.to_fun(Atom::DERIVATIVE); if is_der { for (i, x_orig) in f_orig.iter().take(f.get_nargs()).enumerate() { @@ -218,7 +218,7 @@ impl<'a> AtomView<'a> { if exp_der_non_zero { // create log(base) let mut log_base = workspace.new_atom(); - let lb = log_base.to_fun(State::LOG); + let lb = log_base.to_fun(Atom::LOG); lb.add_arg(base); if let Atom::Mul(m) = exp_der.deref_mut() { @@ -418,11 +418,11 @@ impl<'a> AtomView<'a> { } match f.get_symbol() { - State::COS => args_series[0].cos(), - State::SIN => args_series[0].sin(), - State::EXP => args_series[0].exp(), - State::LOG => args_series[0].log(), - State::SQRT => args_series[0].rpow((1, 2).into()), + Atom::COS => args_series[0].cos(), + Atom::SIN => args_series[0].sin(), + Atom::EXP => args_series[0].exp(), + Atom::LOG => args_series[0].log(), + Atom::SQRT => args_series[0].rpow((1, 2).into()), _ => { // TODO: also check for log(x)? if args_series @@ -461,7 +461,7 @@ impl<'a> AtomView<'a> { CombinationWithReplacementIterator::new(args_series.len(), i); while let Some(x) = it.next() { - let mut f_der = FunctionBuilder::new(State::DERIVATIVE); + let mut f_der = FunctionBuilder::new(Atom::DERIVATIVE); let mut term = info.one(); for (arg, pow) in x.iter().enumerate() { if *pow > 0 { diff --git a/src/domains/dual.rs b/src/domains/dual.rs index 15491dfe..f023f157 100644 --- a/src/domains/dual.rs +++ b/src/domains/dual.rs @@ -759,6 +759,13 @@ macro_rules! create_hyperdual_from_components { res } + #[inline(always)] + fn i(&self) -> Option { + let mut res = self.zero(); + res.values[0] = self.values[0].i()?; + Some(res) + } + #[inline(always)] fn norm(&self) -> Self { let n = self.values[0].norm(); diff --git a/src/domains/float.rs b/src/domains/float.rs index 5efb8c55..d7fb1e67 100644 --- a/src/domains/float.rs +++ b/src/domains/float.rs @@ -363,6 +363,8 @@ pub trait Real: NumericalFloatLike { fn euler(&self) -> Self; /// The golden ratio, 1.6180339887... fn phi(&self) -> Self; + /// The imaginary unit, if it exists. + fn i(&self) -> Option; fn norm(&self) -> Self; fn sqrt(&self) -> Self; @@ -535,6 +537,11 @@ impl Real for f64 { 1.6180339887498948 } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { f64::abs(*self) @@ -928,6 +935,11 @@ impl Real for F64 { 1.6180339887498948.into() } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { self.0.norm().into() @@ -1821,6 +1833,11 @@ impl Real for Float { (self.one() + self.from_i64(5).sqrt()) / 2 } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { self.0.clone().abs().into() @@ -1973,7 +1990,7 @@ impl Rational { #[derive(Copy, Clone)] pub struct ErrorPropagatingFloat { value: T, - prec: f64, + abs_err: f64, } impl Neg for ErrorPropagatingFloat { @@ -1983,7 +2000,7 @@ impl Neg for ErrorPropagatingFloat { fn neg(self) -> Self::Output { ErrorPropagatingFloat { value: -self.value, - prec: self.prec, + abs_err: self.abs_err, } } } @@ -1993,13 +2010,9 @@ impl Add<&ErrorPropagatingFloat> for ErrorPropagatingFloat #[inline] fn add(self, rhs: &Self) -> Self::Output { - // TODO: handle r = 0 - let r = self.value.clone() + &rhs.value; ErrorPropagatingFloat { - prec: (self.get_num().to_f64().abs() * self.prec - + rhs.get_num().to_f64().abs() * rhs.prec) - / r.clone().to_f64().abs(), - value: r, + abs_err: self.abs_err + rhs.abs_err, + value: self.value + &rhs.value, } } } @@ -2036,9 +2049,20 @@ impl Mul<&ErrorPropagatingFloat> for ErrorPropagatingFloat #[inline] fn mul(self, rhs: &Self) -> Self::Output { - ErrorPropagatingFloat { - value: self.value.clone() * &rhs.value, - prec: self.prec + rhs.prec, + let value = self.value.clone() * &rhs.value; + let r = rhs.value.to_f64().abs(); + let s = self.value.to_f64().abs(); + + if s == 0. && r == 0. { + return ErrorPropagatingFloat { + value, + abs_err: self.abs_err * rhs.abs_err, + }; + } else { + ErrorPropagatingFloat { + value, + abs_err: self.abs_err * r + rhs.abs_err * s, + } } } } @@ -2048,10 +2072,10 @@ impl> Add for ErrorPropa #[inline] fn add(self, rhs: Rational) -> Self::Output { - let v = self.value.to_f64(); - let prec = self.prec * v.abs() / (v + rhs.to_f64()).abs(); - let r = self.value + rhs; - ErrorPropagatingFloat { prec, value: r }.truncate() + ErrorPropagatingFloat { + abs_err: self.abs_err, + value: self.value + rhs, + } } } @@ -2070,9 +2094,10 @@ impl> Mul for ErrorPropa #[inline] fn mul(self, rhs: Rational) -> Self::Output { ErrorPropagatingFloat { + abs_err: self.abs_err * rhs.to_f64().abs(), value: self.value * rhs, - prec: self.prec, } + .truncate() } } @@ -2082,9 +2107,10 @@ impl> Div for ErrorPropa #[inline] fn div(self, rhs: Rational) -> Self::Output { ErrorPropagatingFloat { + abs_err: self.abs_err * rhs.inv().to_f64().abs(), value: self.value.clone() / rhs, - prec: self.prec, } + .truncate() } } @@ -2102,10 +2128,7 @@ impl Div<&ErrorPropagatingFloat> for ErrorPropagatingFloat #[inline] fn div(self, rhs: &Self) -> Self::Output { - ErrorPropagatingFloat { - value: self.value.clone() / &rhs.value, - prec: self.prec + rhs.prec, - } + self * rhs.inv() } } @@ -2178,26 +2201,51 @@ impl DivAssign> for ErrorPropagating } } -impl ErrorPropagatingFloat { +impl ErrorPropagatingFloat { /// Create a new precision tracking float with a number of precise decimal digits `prec`. /// The `prec` must be smaller than the precision of the underlying float. + /// + /// If the value provided is 0, the precision argument is interpreted as an accuracy ( + /// the number of digits of the absolute error). pub fn new(value: T, prec: f64) -> Self { - ErrorPropagatingFloat { - value, - prec: 10f64.pow(-prec), + let r = value.to_f64().abs(); + + if r == 0. { + ErrorPropagatingFloat { + abs_err: 10f64.pow(-prec), + value, + } + } else { + ErrorPropagatingFloat { + abs_err: 10f64.pow(-prec) * r, + value, + } } } - /// Get the number. - #[inline(always)] - pub fn get_num(&self) -> &T { - &self.value + pub fn get_absolute_error(&self) -> f64 { + self.abs_err + } + + pub fn get_relative_error(&self) -> f64 { + self.abs_err / self.value.to_f64().abs() } /// Get the precision in number of decimal digits. #[inline(always)] - pub fn get_precision(&self) -> f64 { - -self.prec.log10() + pub fn get_precision(&self) -> Option { + let r = self.value.to_f64().abs(); + if r == 0. { + return None; + } else { + Some(-(self.abs_err / r).log10()) + } + } + + /// Get the accuracy in number of decimal digits. + #[inline(always)] + pub fn get_accuracy(&self) -> f64 { + -self.abs_err.log10() } /// Truncate the precision to the maximal number of stable decimal digits @@ -2205,36 +2253,56 @@ impl ErrorPropagatingFloat { #[inline(always)] pub fn truncate(mut self) -> Self { if self.value.fixed_precision() { - self.prec = self.prec.max(self.value.get_epsilon()); + self.abs_err = self + .abs_err + .max(self.value.get_epsilon() * self.value.to_f64()); } self } } -impl fmt::Display for ErrorPropagatingFloat { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let p = self.get_precision() as usize; +impl ErrorPropagatingFloat { + pub fn new_with_accuracy(value: T, acc: f64) -> Self { + ErrorPropagatingFloat { + value, + abs_err: 10f64.pow(-acc), + } + } - if p == 0 { - f.write_char('0') + /// Get the number. + #[inline(always)] + pub fn get_num(&self) -> &T { + &self.value + } +} + +impl fmt::Display for ErrorPropagatingFloat { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if let Some(p) = self.get_precision() { + if p < 0. { + return f.write_char('0'); + } else { + f.write_fmt(format_args!("{0:.1$}", self.value, p as usize)) + } } else { - f.write_fmt(format_args!( - "{0:.1$e}", - self.value, - self.get_precision() as usize - )) + f.write_char('0') } } } -impl Debug for ErrorPropagatingFloat { +impl Debug for ErrorPropagatingFloat { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Debug::fmt(&self.value, f)?; - f.write_fmt(format_args!("`{}", self.get_precision())) + + if let Some(p) = self.get_precision() { + f.write_fmt(format_args!("`{:.2}", p)) + } else { + f.write_fmt(format_args!("``{:.2}", -self.abs_err.log10())) + } } } -impl LowerExp for ErrorPropagatingFloat { +impl LowerExp for ErrorPropagatingFloat { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Display::fmt(self, f) } @@ -2265,49 +2333,81 @@ impl NumericalFloatLike for ErrorPropagatingFloat { fn zero(&self) -> Self { ErrorPropagatingFloat { value: self.value.zero(), - prec: 2f64.pow(-(self.value.get_precision() as f64)), + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), } } fn new_zero() -> Self { ErrorPropagatingFloat { value: T::new_zero(), - prec: 2f64.powi(-53), + abs_err: 2f64.powi(-53), } } fn one(&self) -> Self { ErrorPropagatingFloat { value: self.value.one(), - prec: 2f64.pow(-(self.value.get_precision() as f64)), + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), } } fn pow(&self, e: u64) -> Self { + let i = self.to_f64().abs(); + + if i == 0. { + return ErrorPropagatingFloat { + value: self.value.pow(e), + abs_err: self.abs_err.pow(e as f64), + }; + } + + let r = self.value.pow(e); ErrorPropagatingFloat { - value: self.value.pow(e), - prec: self.prec * e as f64, + abs_err: self.abs_err * e as f64 * r.to_f64().abs() / i, + value: r, } } fn inv(&self) -> Self { + let r = self.value.inv(); + let rr = r.to_f64().abs(); ErrorPropagatingFloat { - value: self.value.inv(), - prec: self.prec, + abs_err: self.abs_err * rr * rr, + value: r, } } + /// Convert from a `usize`. fn from_usize(&self, a: usize) -> Self { - ErrorPropagatingFloat { - value: self.value.from_usize(a), - prec: self.prec, + let v = self.value.from_usize(a); + let r = v.to_f64().abs(); + if r == 0. { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), + } + } else { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * r, + } } } + /// Convert from a `i64`. fn from_i64(&self, a: i64) -> Self { - ErrorPropagatingFloat { - value: self.value.from_i64(a), - prec: self.prec, + let v = self.value.from_i64(a); + let r = v.to_f64().abs(); + if r == 0. { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), + } + } else { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * r, + } } } @@ -2318,7 +2418,7 @@ impl NumericalFloatLike for ErrorPropagatingFloat { } fn get_epsilon(&self) -> f64 { - 2.0f64.powi(-(self.get_precision() as i32)) + 2.0f64.powi(-(self.value.get_precision() as i32)) } #[inline(always)] @@ -2327,9 +2427,10 @@ impl NumericalFloatLike for ErrorPropagatingFloat { } fn sample_unit(&self, rng: &mut R) -> Self { + let v = self.value.sample_unit(rng); ErrorPropagatingFloat { - value: self.value.sample_unit(rng), - prec: self.prec, + abs_err: self.abs_err * v.to_f64().abs(), + value: v, } } } @@ -2348,9 +2449,16 @@ impl SingleFloat for ErrorPropagatingFloat { } fn from_rational(&self, rat: &Rational) -> Self { - ErrorPropagatingFloat { - value: self.value.from_rational(rat), - prec: self.prec, + if rat.is_zero() { + ErrorPropagatingFloat { + value: self.value.from_rational(rat), + abs_err: self.abs_err, + } + } else { + ErrorPropagatingFloat { + value: self.value.from_rational(rat), + abs_err: self.abs_err * rat.to_f64(), + } } } } @@ -2372,45 +2480,59 @@ impl RealNumberLike for ErrorPropagatingFloat { impl Real for ErrorPropagatingFloat { fn pi(&self) -> Self { + let v = self.value.pi(); ErrorPropagatingFloat { - value: self.value.pi(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } fn e(&self) -> Self { + let v = self.value.e(); ErrorPropagatingFloat { - value: self.value.e(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } fn euler(&self) -> Self { + let v = self.value.euler(); ErrorPropagatingFloat { - value: self.value.euler(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } fn phi(&self) -> Self { + let v = self.value.phi(); ErrorPropagatingFloat { - value: self.value.phi(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } + #[inline(always)] + fn i(&self) -> Option { + Some(ErrorPropagatingFloat { + value: self.value.i()?, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), + }) + } + fn norm(&self) -> Self { ErrorPropagatingFloat { + abs_err: self.abs_err, value: self.value.norm(), - prec: self.prec, } - .truncate() } fn sqrt(&self) -> Self { + let v = self.value.sqrt(); + let r = v.to_f64().abs(); + ErrorPropagatingFloat { - value: self.value.sqrt(), - prec: self.prec / 2., + abs_err: self.abs_err / (2. * r), + value: v, } .truncate() } @@ -2418,23 +2540,24 @@ impl Real for ErrorPropagatingFloat { fn log(&self) -> Self { let r = self.value.log(); ErrorPropagatingFloat { - prec: self.prec / r.clone().to_f64().abs(), + abs_err: self.abs_err / self.value.to_f64().abs(), value: r, } .truncate() } fn exp(&self) -> Self { + let v = self.value.exp(); ErrorPropagatingFloat { - prec: self.value.to_f64().abs() * self.prec, - value: self.value.exp(), + abs_err: v.to_f64().abs() * self.abs_err, + value: v, } .truncate() } fn sin(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() / self.value.tan().to_f64().abs(), + abs_err: self.abs_err * self.value.to_f64().cos().abs(), value: self.value.sin(), } .truncate() @@ -2442,7 +2565,7 @@ impl Real for ErrorPropagatingFloat { fn cos(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * self.value.tan().to_f64().abs(), + abs_err: self.abs_err * self.value.to_f64().sin().abs(), value: self.value.cos(), } .truncate() @@ -2451,8 +2574,9 @@ impl Real for ErrorPropagatingFloat { fn tan(&self) -> Self { let t = self.value.tan(); let tt = t.to_f64().abs(); + ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * (tt.inv() + tt), + abs_err: self.abs_err * (1. + tt * tt), value: t, } .truncate() @@ -2461,9 +2585,9 @@ impl Real for ErrorPropagatingFloat { fn asin(&self) -> Self { let v = self.value.to_f64(); let t = self.value.asin(); - let tt = (1. - v * v).sqrt() * t.to_f64().abs(); + let tt = (1. - v * v).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2472,9 +2596,9 @@ impl Real for ErrorPropagatingFloat { fn acos(&self) -> Self { let v = self.value.to_f64(); let t = self.value.acos(); - let tt = (1. - v * v).sqrt() * t.to_f64().abs(); + let tt = (1. - v * v).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2485,9 +2609,9 @@ impl Real for ErrorPropagatingFloat { let r = self.clone() / x; let r2 = r.value.to_f64().abs(); - let tt = (1. + r2 * r2) * t.clone().to_f64().abs(); + let tt = 1. + r2 * r2; ErrorPropagatingFloat { - prec: r.prec * r2 / tt, + abs_err: r.abs_err / tt, value: t, } .truncate() @@ -2495,7 +2619,7 @@ impl Real for ErrorPropagatingFloat { fn sinh(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() / self.value.tanh().to_f64().abs(), + abs_err: self.abs_err * self.value.cosh().to_f64().abs(), value: self.value.sinh(), } .truncate() @@ -2503,7 +2627,7 @@ impl Real for ErrorPropagatingFloat { fn cosh(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * self.value.tanh().to_f64().abs(), + abs_err: self.abs_err * self.value.sinh().to_f64().abs(), value: self.value.cosh(), } .truncate() @@ -2513,7 +2637,7 @@ impl Real for ErrorPropagatingFloat { let t = self.value.tanh(); let tt = t.clone().to_f64().abs(); ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * (tt.inv() - tt), + abs_err: self.abs_err * (1. - tt * tt), value: t, } .truncate() @@ -2522,9 +2646,9 @@ impl Real for ErrorPropagatingFloat { fn asinh(&self) -> Self { let v = self.value.to_f64(); let t = self.value.asinh(); - let tt = (1. + v * v).sqrt() * t.to_f64().abs(); + let tt = (1. + v * v).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2533,9 +2657,9 @@ impl Real for ErrorPropagatingFloat { fn acosh(&self) -> Self { let v = self.value.to_f64(); let t = self.value.acosh(); - let tt = (v * v - 1.).sqrt() * t.to_f64().abs(); + let tt = (v * v - 1.).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2544,19 +2668,30 @@ impl Real for ErrorPropagatingFloat { fn atanh(&self) -> Self { let v = self.value.to_f64(); let t = self.value.atanh(); - let tt = (1. - v * v) * t.to_f64().abs(); + let tt = 1. - v * v; ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() } fn powf(&self, e: &Self) -> Self { - let v = self.value.to_f64().abs(); + let i = self.to_f64().abs(); + + if i == 0. { + return ErrorPropagatingFloat { + value: self.value.powf(&e.value), + abs_err: 0., + }; + } + + let r = self.value.powf(&e.value); ErrorPropagatingFloat { - value: self.value.powf(&e.value), - prec: (self.prec + e.prec * v.ln().abs()) * e.value.clone().to_f64().abs(), + abs_err: (self.abs_err * e.value.to_f64() + i * e.abs_err * i.ln().abs()) + * r.to_f64().abs() + / i, + value: r, } .truncate() } @@ -2653,6 +2788,11 @@ macro_rules! simd_impl { 1.6180339887498948.into() } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { (*self).abs() @@ -3603,6 +3743,11 @@ impl Real for Complex { Complex::new(self.re.phi(), self.im.zero()) } + #[inline(always)] + fn i(&self) -> Option { + Some(self.i()) + } + #[inline] fn norm(&self) -> Self { Complex::new(self.norm_squared().sqrt(), self.im.zero()) @@ -3784,7 +3929,7 @@ mod test { + b.powf(&a); assert_eq!(r.value, 17293.219725825093); // error is 14.836811363436391 when the f64 could have theoretically grown in between - assert_eq!(r.get_precision(), 14.836795991431746); + assert_eq!(r.get_precision(), Some(14.836795991431746)); } #[test] @@ -3792,16 +3937,15 @@ mod test { let a = ErrorPropagatingFloat::new(0.0000000123456789, 9.) .exp() .log(); - assert_eq!(a.get_precision(), 8.046104745509947); + assert_eq!(a.get_precision(), Some(8.046104745509947)); } #[test] fn large_cancellation() { let a = ErrorPropagatingFloat::new(Float::with_val(200, 1e-50), 60.); let r = (a.exp() - a.one()) / a; - println!("{}", r.value.prec()); - assert_eq!(format!("{}", r), "1.000000000e0"); - assert_eq!(r.get_precision(), 10.205999132807323); + assert_eq!(format!("{}", r), "1.000000000"); + assert_eq!(r.get_precision(), Some(10.205999132796238)); } #[test] diff --git a/src/evaluate.rs b/src/evaluate.rs index a2687f71..ba176cc2 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -718,11 +718,11 @@ impl ExpressionEvaluator { self.stack[*r] = self.stack[*b].powf(&self.stack[*e]); } Instr::BuiltinFun(r, s, arg) => match s.0 { - State::EXP => self.stack[*r] = self.stack[*arg].exp(), - State::LOG => self.stack[*r] = self.stack[*arg].log(), - State::SIN => self.stack[*r] = self.stack[*arg].sin(), - State::COS => self.stack[*r] = self.stack[*arg].cos(), - State::SQRT => self.stack[*r] = self.stack[*arg].sqrt(), + Atom::EXP => self.stack[*r] = self.stack[*arg].exp(), + Atom::LOG => self.stack[*r] = self.stack[*arg].log(), + Atom::SIN => self.stack[*r] = self.stack[*arg].sin(), + Atom::COS => self.stack[*r] = self.stack[*arg].cos(), + Atom::SQRT => self.stack[*r] = self.stack[*arg].sqrt(), _ => unreachable!(), }, } @@ -1359,23 +1359,23 @@ impl ExpressionEvaluator { *out += format!("\tZ{} = pow({}, {});\n", o, base, exp).as_str(); } Instr::BuiltinFun(o, s, a) => match s.0 { - State::EXP => { + Atom::EXP => { let arg = format!("Z{}", a); *out += format!("\tZ{} = exp({});\n", o, arg).as_str(); } - State::LOG => { + Atom::LOG => { let arg = format!("Z{}", a); *out += format!("\tZ{} = log({});\n", o, arg).as_str(); } - State::SIN => { + Atom::SIN => { let arg = format!("Z{}", a); *out += format!("\tZ{} = sin({});\n", o, arg).as_str(); } - State::COS => { + Atom::COS => { let arg = format!("Z{}", a); *out += format!("\tZ{} = cos({});\n", o, arg).as_str(); } - State::SQRT => { + Atom::SQRT => { let arg = format!("Z{}", a); *out += format!("\tZ{} = sqrt({});\n", o, arg).as_str(); } @@ -1931,19 +1931,19 @@ impl ExpressionEvaluator { let arg = get_input!(*a); match s.0 { - State::EXP => { + Atom::EXP => { *out += format!("\tZ[{}] = exp({});\n", o, arg).as_str(); } - State::LOG => { + Atom::LOG => { *out += format!("\tZ[{}] = log({});\n", o, arg).as_str(); } - State::SIN => { + Atom::SIN => { *out += format!("\tZ[{}] = sin({});\n", o, arg).as_str(); } - State::COS => { + Atom::COS => { *out += format!("\tZ[{}] = cos({});\n", o, arg).as_str(); } - State::SQRT => { + Atom::SQRT => { *out += format!("\tZ[{}] = sqrt({});\n", o, arg).as_str(); } _ => unreachable!(), @@ -2122,19 +2122,19 @@ impl ExpressionEvaluator { let arg = get_input!(*a); match s.0 { - State::EXP => { + Atom::EXP => { *out += format!("\tZ[{}] = exp({});\n", o, arg).as_str(); } - State::LOG => { + Atom::LOG => { *out += format!("\tZ[{}] = log({});\n", o, arg).as_str(); } - State::SIN => { + Atom::SIN => { *out += format!("\tZ[{}] = sin({});\n", o, arg).as_str(); } - State::COS => { + Atom::COS => { *out += format!("\tZ[{}] = cos({});\n", o, arg).as_str(); } - State::SQRT => { + Atom::SQRT => { *out += format!("\tZ[{}] = sqrt({});\n", o, arg).as_str(); } _ => unreachable!(), @@ -3383,11 +3383,11 @@ impl EvalTree { Expression::BuiltinFun(s, a) => { let arg = self.evaluate_impl(a, subexpressions, params, args); match s.0 { - State::EXP => arg.exp(), - State::LOG => arg.log(), - State::SIN => arg.sin(), - State::COS => arg.cos(), - State::SQRT => arg.sqrt(), + Atom::EXP => arg.exp(), + Atom::LOG => arg.log(), + Atom::SIN => arg.sin(), + Atom::COS => arg.cos(), + Atom::SQRT => arg.sqrt(), _ => unreachable!(), } } @@ -3815,31 +3815,31 @@ impl EvalTree { } Expression::ReadArg(s) => args[*s].to_string(), Expression::BuiltinFun(s, a) => match s.0 { - State::EXP => { + Atom::EXP => { let mut r = "exp(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::LOG => { + Atom::LOG => { let mut r = "log(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::SIN => { + Atom::SIN => { let mut r = "sin(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::COS => { + Atom::COS => { let mut r = "cos(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::SQRT => { + Atom::SQRT => { let mut r = "sqrt(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); @@ -3930,7 +3930,7 @@ impl<'a> AtomView<'a> { } AtomView::Fun(f) => { let name = f.get_symbol(); - if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) { + if [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS, Atom::SQRT].contains(&name) { assert!(f.get_nargs() == 1); let arg = f.iter().next().unwrap(); let arg_eval = arg.to_eval_tree_impl(fn_map, params, args, funcs)?; @@ -4065,8 +4065,11 @@ impl<'a> AtomView<'a> { ), }, AtomView::Var(v) => match v.get_symbol() { - State::E => Ok(coeff_map(&1.into()).e()), - State::PI => Ok(coeff_map(&1.into()).pi()), + Atom::E => Ok(coeff_map(&1.into()).e()), + Atom::PI => Ok(coeff_map(&1.into()).pi()), + Atom::I => coeff_map(&1.into()) + .i() + .ok_or_else(|| "Numerical type does not support imaginary unit".to_string()), _ => Err(format!( "Variable {} not in constant map", State::get_name(v.get_symbol()) @@ -4074,17 +4077,17 @@ impl<'a> AtomView<'a> { }, AtomView::Fun(f) => { let name = f.get_symbol(); - if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) { + if [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS, Atom::SQRT].contains(&name) { assert!(f.get_nargs() == 1); let arg = f.iter().next().unwrap(); let arg_eval = arg.evaluate(coeff_map, const_map, function_map, cache)?; return Ok(match f.get_symbol() { - State::EXP => arg_eval.exp(), - State::LOG => arg_eval.log(), - State::SIN => arg_eval.sin(), - State::COS => arg_eval.cos(), - State::SQRT => arg_eval.sqrt(), + Atom::EXP => arg_eval.exp(), + Atom::LOG => arg_eval.log(), + Atom::SIN => arg_eval.sin(), + Atom::COS => arg_eval.cos(), + Atom::SQRT => arg_eval.sqrt(), _ => unreachable!(), }); } diff --git a/src/id.rs b/src/id.rs index c22edbd7..b00795f0 100644 --- a/src/id.rs +++ b/src/id.rs @@ -8,7 +8,7 @@ use crate::{ representation::{InlineVar, ListSlice}, AsAtomView, Atom, AtomType, AtomView, Num, SliceType, Symbol, }, - state::{State, Workspace}, + state::Workspace, transformer::{Transformer, TransformerError}, }; @@ -2109,7 +2109,7 @@ impl<'a> Match<'a> { // to update the coefficient flag } SliceType::Arg => { - let fun = out.to_fun(State::ARG); + let fun = out.to_fun(Atom::ARG); for arg in wargs { fun.add_arg(*arg); } @@ -2124,7 +2124,7 @@ impl<'a> Match<'a> { out.set_from_view(&wargs[0]); } SliceType::Empty => { - let f = out.to_fun(State::ARG); + let f = out.to_fun(Atom::ARG); f.set_normalized(true); } }, diff --git a/src/normalize.rs b/src/normalize.rs index 4961981a..cfb2653c 100644 --- a/src/normalize.rs +++ b/src/normalize.rs @@ -275,7 +275,7 @@ impl<'a> AtomView<'a> { /// Simplify logs in the argument of the exponential function. fn simplify_exp_log(&self, ws: &Workspace, out: &mut Atom) -> bool { if let AtomView::Fun(f) = self { - if f.get_symbol() == State::LOG && f.get_nargs() == 1 { + if f.get_symbol() == Atom::LOG && f.get_nargs() == 1 { out.set_from_view(&f.iter().next().unwrap()); return true; } @@ -333,7 +333,7 @@ impl<'a> AtomView<'a> { if changed { let mut new_exp = ws.new_atom(); // TODO: change to e^() - new_exp.to_fun(State::EXP).add_arg(aa.as_view()); + new_exp.to_fun(Atom::EXP).add_arg(aa.as_view()); m.extend(new_exp.as_view()); @@ -489,7 +489,7 @@ impl Atom { // x * x => x^2 if self.as_view() == other.as_view() { if let AtomView::Var(v) = self.as_view() { - if v.get_symbol() == State::I { + if v.get_symbol() == Atom::I { self.to_num((-1).into()); return true; } @@ -821,7 +821,7 @@ impl<'a> AtomView<'a> { #[inline(always)] fn add_arg(f: &mut Fun, a: AtomView) { if let AtomView::Fun(fa) = a { - if fa.get_symbol() == State::ARG { + if fa.get_symbol() == Atom::ARG { // flatten f(arg(...)) = f(...) for aa in fa.iter() { f.add_arg(aa); @@ -872,16 +872,16 @@ impl<'a> AtomView<'a> { out_f.set_normalized(true); - if [State::COS, State::SIN, State::EXP, State::LOG].contains(&id) + if [Atom::COS, Atom::SIN, Atom::EXP, Atom::LOG].contains(&id) && out_f.to_fun_view().get_nargs() == 1 { let arg = out_f.to_fun_view().iter().next().unwrap(); if let AtomView::Num(n) = arg { - if n.is_zero() && id != State::LOG || n.is_one() && id == State::LOG { - if id == State::COS || id == State::EXP { + if n.is_zero() && id != Atom::LOG || n.is_one() && id == Atom::LOG { + if id == Atom::COS || id == Atom::EXP { out.to_num(Coefficient::one()); return; - } else if id == State::SIN || id == State::LOG { + } else if id == Atom::SIN || id == Atom::LOG { out.to_num(Coefficient::zero()); return; } @@ -889,22 +889,22 @@ impl<'a> AtomView<'a> { if let CoefficientView::Float(f) = n.get_coeff_view() { match id { - State::COS => { + Atom::COS => { let r = f.to_float().cos(); out.to_num(Coefficient::Float(r)); return; } - State::SIN => { + Atom::SIN => { let r = f.to_float().sin(); out.to_num(Coefficient::Float(r)); return; } - State::EXP => { + Atom::EXP => { let r = f.to_float().exp(); out.to_num(Coefficient::Float(r)); return; } - State::LOG => { + Atom::LOG => { let r = f.to_float().log(); out.to_num(Coefficient::Float(r)); return; @@ -915,10 +915,10 @@ impl<'a> AtomView<'a> { } } - if id == State::EXP && out_f.to_fun_view().get_nargs() == 1 { + if id == Atom::EXP && out_f.to_fun_view().get_nargs() == 1 { let arg = out_f.to_fun_view().iter().next().unwrap(); // simplify logs inside exp - if arg.contains_symbol(State::LOG) { + if arg.contains_symbol(Atom::LOG) { let mut buffer = workspace.new_atom(); if arg.simplify_exp_log(workspace, &mut buffer) { out.set_from_view(&buffer.as_view()); @@ -928,7 +928,7 @@ impl<'a> AtomView<'a> { } // try to turn the argument into a number - if id == State::COEFF && out_f.to_fun_view().get_nargs() == 1 { + if id == Atom::COEFF && out_f.to_fun_view().get_nargs() == 1 { let arg = out_f.to_fun_view().iter().next().unwrap(); if let AtomView::Num(_) = arg { let mut buffer = workspace.new_atom(); @@ -1197,7 +1197,7 @@ impl<'a> AtomView<'a> { base_handle.to_num(new_base_num); exp_handle.to_num(new_exp_num); } else if let AtomView::Var(v) = base_handle.as_view() { - if v.get_symbol() == State::I { + if v.get_symbol() == Atom::I { if let CoefficientView::Natural(n, d) = exp_num { let mut new_base = workspace.new_atom(); @@ -1533,7 +1533,7 @@ impl<'a> AtomView<'a> { #[cfg(test)] mod test { - use crate::{atom::Atom, state::State}; + use crate::atom::Atom; #[test] fn pow_apart() { @@ -1564,8 +1564,8 @@ mod test { #[test] fn mul_complex_i() { - let res = Atom::new_var(State::I) * &Atom::new_var(State::E) * &Atom::new_var(State::I); - let refr = -Atom::new_var(State::E); + let res = Atom::new_var(Atom::I) * &Atom::new_var(Atom::E) * &Atom::new_var(Atom::I); + let refr = -Atom::new_var(Atom::E); assert_eq!(res, refr); } diff --git a/src/poly/evaluate.rs b/src/poly/evaluate.rs index dd7f15d4..e9388656 100644 --- a/src/poly/evaluate.rs +++ b/src/poly/evaluate.rs @@ -20,7 +20,6 @@ use crate::{ atom::{Atom, AtomView}, domains::{float::Real, Ring}, evaluate::EvaluationFn, - state::State, }; use super::{polynomial::MultivariatePolynomial, PositiveExponent}; @@ -1678,7 +1677,7 @@ impl<'a> std::fmt::Display for InstructionSetPrinter<'a> { None } } else if let super::Variable::Symbol(i) = x { - if [State::E, State::I, State::PI].contains(i) { + if [Atom::E, Atom::I, Atom::PI].contains(i) { None } else { Some(format!("T {}", x.to_string())) @@ -1855,7 +1854,7 @@ impl ExpressionEvaluator { None } } else if let super::Variable::Symbol(i) = x { - if [State::E, State::I, State::PI].contains(i) { + if [Atom::E, Atom::I, Atom::PI].contains(i) { None } else { Some(x.clone()) diff --git a/src/poly/series.rs b/src/poly/series.rs index f23f17de..b161ca4a 100644 --- a/src/poly/series.rs +++ b/src/poly/series.rs @@ -15,7 +15,6 @@ use crate::{ EuclideanDomain, InternalOrdering, Ring, SelfRing, }, printer::{PrintOptions, PrintState}, - state::State, }; use super::Variable; @@ -923,7 +922,7 @@ impl Series { }; // construct the constant term, log(x) in the argument will be turned into x - let e = FunctionBuilder::new(State::EXP).add_arg(&c).finish(); + let e = FunctionBuilder::new(Atom::EXP).add_arg(&c).finish(); // split the true constant part and the x-dependent part let var = self.variable.to_atom() - &self.expansion_point; @@ -961,7 +960,7 @@ impl Series { .mul_exp_units(-self.shift) - self.one(); - let mut e = self.constant(FunctionBuilder::new(State::LOG).add_arg(&c).finish()); + let mut e = self.constant(FunctionBuilder::new(Atom::LOG).add_arg(&c).finish()); let mut sp = p.clone(); for i in 1..=self.order { let s = sp.clone().div_coeff(&Atom::new_num(i as i64)); @@ -1007,13 +1006,13 @@ impl Series { let p = self.clone().remove_constant(); - let mut e = self.constant(FunctionBuilder::new(State::SIN).add_arg(&c).finish()); + let mut e = self.constant(FunctionBuilder::new(Atom::SIN).add_arg(&c).finish()); let mut sp = p.clone(); for i in 1..=self.order { let mut b = if i % 2 == 1 { - FunctionBuilder::new(State::COS).add_arg(&c).finish() + FunctionBuilder::new(Atom::COS).add_arg(&c).finish() } else { - FunctionBuilder::new(State::SIN).add_arg(&c).finish() + FunctionBuilder::new(Atom::SIN).add_arg(&c).finish() }; if i % 4 >= 2 { @@ -1063,13 +1062,13 @@ impl Series { let p = self.clone().remove_constant(); - let mut e = self.constant(FunctionBuilder::new(State::COS).add_arg(&c).finish()); + let mut e = self.constant(FunctionBuilder::new(Atom::COS).add_arg(&c).finish()); let mut sp = p.clone(); for i in 1..=self.order { let mut b = if i % 2 == 1 { - FunctionBuilder::new(State::SIN).add_arg(&c).finish() + FunctionBuilder::new(Atom::SIN).add_arg(&c).finish() } else { - -FunctionBuilder::new(State::COS).add_arg(&c).finish() + -FunctionBuilder::new(Atom::COS).add_arg(&c).finish() }; if i % 4 < 2 { diff --git a/src/printer.rs b/src/printer.rs index ff4e788e..42c34a00 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -459,9 +459,9 @@ impl<'a> FormattedPrintVar for VarView<'a> { if opts.latex { match id { - State::E => f.write_char('e'), - State::PI => f.write_str("\\pi"), - State::I => f.write_char('i'), + Atom::E => f.write_char('e'), + Atom::PI => f.write_str("\\pi"), + Atom::I => f.write_char('i'), _ => f.write_str(name), } } else if opts.color_builtin_symbols && name.ends_with('_') { diff --git a/src/state.rs b/src/state.rs index 547483b0..0ac508f1 100644 --- a/src/state.rs +++ b/src/state.rs @@ -88,17 +88,17 @@ impl Default for State { } impl State { - pub const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false, false); - pub const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false, false); - pub const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false, false); - pub const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false, false); - pub const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false, false); - pub const COS: Symbol = Symbol::init_fn(5, 0, false, false, false, false); - pub const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false, false); - pub const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false, false); - pub const E: Symbol = Symbol::init_var(8, 0); - pub const I: Symbol = Symbol::init_var(9, 0); - pub const PI: Symbol = Symbol::init_var(10, 0); + pub(crate) const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false, false); + pub(crate) const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false, false); + pub(crate) const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false, false); + pub(crate) const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false, false); + pub(crate) const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false, false); + pub(crate) const COS: Symbol = Symbol::init_fn(5, 0, false, false, false, false); + pub(crate) const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false, false); + pub(crate) const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false, false); + pub(crate) const E: Symbol = Symbol::init_var(8, 0); + pub(crate) const I: Symbol = Symbol::init_var(9, 0); + pub(crate) const PI: Symbol = Symbol::init_var(10, 0); pub const BUILTIN_VAR_LIST: [&'static str; 11] = [ "arg", "coeff", "exp", "log", "sin", "cos", "sqrt", "der", "𝑒", "𝑖", "𝜋", @@ -888,7 +888,7 @@ mod tests { if f.get_nargs() == 1 { let arg = f.iter().next().unwrap(); if let AtomView::Fun(f2) = arg { - if f2.get_symbol() == State::EXP { + if f2.get_symbol() == Atom::EXP { if f2.get_nargs() == 1 { out.set_from_view(&f2.iter().next().unwrap()); return true; diff --git a/src/transformer.rs b/src/transformer.rs index e6c29797..f9ec3204 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -10,7 +10,7 @@ use crate::{ Replacement, }, printer::{AtomPrinter, PrintOptions}, - state::{RecycledAtom, State, Workspace}, + state::{RecycledAtom, Workspace}, }; use ahash::HashMap; use colored::Colorize; @@ -238,7 +238,7 @@ impl FunView<'_> { #[inline(always)] fn add_arg(f: &mut Fun, a: AtomView) { if let AtomView::Fun(fa) = a { - if fa.get_symbol() == State::ARG { + if fa.get_symbol() == Atom::ARG { // flatten f(arg(...)) = f(...) for aa in fa.iter() { f.add_arg(aa); @@ -502,9 +502,9 @@ impl Transformer { } Transformer::ForEach(t) => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut ff = workspace.new_atom(); - let ff = ff.to_fun(State::ARG); + let ff = ff.to_fun(Atom::ARG); let mut a = workspace.new_atom(); for arg in f { @@ -598,7 +598,7 @@ impl Transformer { } Transformer::Product => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut mul_h = workspace.new_atom(); let mul = mul_h.to_mul(); @@ -615,7 +615,7 @@ impl Transformer { } Transformer::Sum => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut add_h = workspace.new_atom(); let add = add_h.to_add(); @@ -632,7 +632,7 @@ impl Transformer { } Transformer::ArgCount(only_for_arg_fun) => { if let AtomView::Fun(f) = cur_input { - if !*only_for_arg_fun || f.get_symbol() == State::ARG { + if !*only_for_arg_fun || f.get_symbol() == Atom::ARG { let n_args = f.get_nargs(); out.to_num((n_args as i64).into()); } else { @@ -654,7 +654,7 @@ impl Transformer { Transformer::Split => match cur_input { AtomView::Mul(m) => { let mut arg_h = workspace.new_atom(); - let arg = arg_h.to_fun(State::ARG); + let arg = arg_h.to_fun(Atom::ARG); for factor in m { arg.add_arg(factor); @@ -664,7 +664,7 @@ impl Transformer { } AtomView::Add(a) => { let mut arg_h = workspace.new_atom(); - let arg = arg_h.to_fun(State::ARG); + let arg = arg_h.to_fun(Atom::ARG); for summand in a { arg.add_arg(summand); @@ -678,7 +678,7 @@ impl Transformer { }, Transformer::Partition(bins, fill_last, repeat) => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let args: Vec<_> = f.iter().collect(); let mut sum_h = workspace.new_atom(); @@ -721,12 +721,12 @@ impl Transformer { } Transformer::Sort => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut args: Vec<_> = f.iter().collect(); args.sort(); let mut fun_h = workspace.new_atom(); - let fun = fun_h.to_fun(State::ARG); + let fun = fun_h.to_fun(Atom::ARG); for arg in args { fun.add_arg(arg); @@ -774,7 +774,7 @@ impl Transformer { } Transformer::Deduplicate => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let args: Vec<_> = f.iter().collect(); let mut args_dedup: Vec<_> = Vec::with_capacity(args.len()); @@ -786,7 +786,7 @@ impl Transformer { } let mut fun_h = workspace.new_atom(); - let fun = fun_h.to_fun(State::ARG); + let fun = fun_h.to_fun(Atom::ARG); for arg in args_dedup { fun.add_arg(arg); @@ -801,7 +801,7 @@ impl Transformer { } Transformer::Permutations(f_name) => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let args: Vec<_> = f.iter().collect(); let mut sum_h = workspace.new_atom();