Skip to content

Commit

Permalink
Redesign error propagating floats
Browse files Browse the repository at this point in the history
- Use absolute error so that zeros are supported
- Add i() function for floats
- Add automatic evaluation of complex i
- Move EXP, E, etc from State to Atom
  • Loading branch information
benruijl committed Dec 10, 2024
1 parent fe31c83 commit fea7a04
Show file tree
Hide file tree
Showing 14 changed files with 438 additions and 236 deletions.
16 changes: 8 additions & 8 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2656,57 +2656,57 @@ impl PythonExpression {
#[classattr]
#[pyo3(name = "E")]
pub fn e() -> PythonExpression {
Atom::new_var(State::E).into()
Atom::new_var(Atom::E).into()
}

/// The mathematical constant `π`.
#[classattr]
#[pyo3(name = "PI")]
pub fn pi() -> PythonExpression {
Atom::new_var(State::PI).into()
Atom::new_var(Atom::PI).into()
}

/// The mathematical constant `i`, where
/// `i^2 = -1`.
#[classattr]
#[pyo3(name = "I")]
pub fn i() -> PythonExpression {
Atom::new_var(State::I).into()
Atom::new_var(Atom::I).into()
}

/// The built-in function that converts a rational polynomial to a coefficient.
#[classattr]
#[pyo3(name = "COEFF")]
pub fn coeff() -> PythonExpression {
Atom::new_var(State::COEFF).into()
Atom::new_var(Atom::COEFF).into()
}

/// The built-in cosine function.
#[classattr]
#[pyo3(name = "COS")]
pub fn cos() -> PythonExpression {
Atom::new_var(State::COS).into()
Atom::new_var(Atom::COS).into()
}

/// The built-in sine function.
#[classattr]
#[pyo3(name = "SIN")]
pub fn sin() -> PythonExpression {
Atom::new_var(State::SIN).into()
Atom::new_var(Atom::SIN).into()
}

/// The built-in exponential function.
#[classattr]
#[pyo3(name = "EXP")]
pub fn exp() -> PythonExpression {
Atom::new_var(State::EXP).into()
Atom::new_var(Atom::EXP).into()
}

/// The built-in logarithm function.
#[classattr]
#[pyo3(name = "LOG")]
pub fn log() -> PythonExpression {
Atom::new_var(State::LOG).into()
Atom::new_var(Atom::LOG).into()
}

/// Return all defined symbol names (function names and variables).
Expand Down
52 changes: 51 additions & 1 deletion src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
coefficient::Coefficient,
parser::Token,
printer::{AtomPrinter, PrintOptions},
state::{RecycledAtom, Workspace},
state::{RecycledAtom, State, Workspace},
transformer::StatsOptions,
};
use std::{cmp::Ordering, hash::Hash, ops::DerefMut, str::FromStr};
Expand Down Expand Up @@ -511,6 +511,56 @@ pub enum Atom {
Zero,
}

impl Atom {
/// The built-in function represents a list of function arguments.
pub const ARG: Symbol = State::ARG;
/// The built-in function that converts a rational polynomial to a coefficient.
pub const COEFF: Symbol = State::COEFF;
/// The exponent function.
pub const EXP: Symbol = State::EXP;
/// The logarithm function.
pub const LOG: Symbol = State::LOG;
/// The sine function.
pub const SIN: Symbol = State::SIN;
/// The cosine function.
pub const COS: Symbol = State::COS;
/// The square root function.
pub const SQRT: Symbol = State::SQRT;
/// The built-in function that represents an abstract derivative.
pub const DERIVATIVE: Symbol = State::DERIVATIVE;
/// The constant e, the base of the natural logarithm.
pub const E: Symbol = State::E;
/// The constant i, the imaginary unit.
pub const I: Symbol = State::I;
/// The mathematical constant `π`.
pub const PI: Symbol = State::PI;

/// Exponentiate the atom.
pub fn exp(&self) -> Atom {
FunctionBuilder::new(Atom::EXP).add_arg(self).finish()
}

/// Take the logarithm of the atom.
pub fn log(&self) -> Atom {
FunctionBuilder::new(Atom::LOG).add_arg(self).finish()
}

/// Take the sine the atom.
pub fn sin(&self) -> Atom {
FunctionBuilder::new(Atom::SIN).add_arg(self).finish()
}

/// Take the cosine the atom.
pub fn cos(&self) -> Atom {
FunctionBuilder::new(Atom::COS).add_arg(self).finish()
}

/// Take the square root of the atom.
pub fn sqrt(&self) -> Atom {
FunctionBuilder::new(Atom::SQRT).add_arg(self).finish()
}
}

impl Default for Atom {
/// Create an atom that represents the number 0.
#[inline]
Expand Down
4 changes: 2 additions & 2 deletions src/coefficient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1365,13 +1365,13 @@ impl<'a> AtomView<'a> {
let s = v.get_symbol();

match s {
State::PI => {
Atom::PI => {
out.to_num(Coefficient::Float(Float::with_val(
binary_prec,
rug::float::Constant::Pi,
)));
}
State::E => {
Atom::E => {
out.to_num(Coefficient::Float(Float::with_val(binary_prec, 1).exp()));
}
_ => {
Expand Down
34 changes: 17 additions & 17 deletions src/derivative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
combinatorics::CombinationWithReplacementIterator,
domains::{atom::AtomField, integer::Integer, rational::Rational},
poly::{series::Series, Variable},
state::{State, Workspace},
state::Workspace,
};

impl Atom {
Expand Down Expand Up @@ -83,7 +83,7 @@ impl<'a> AtomView<'a> {
// detect if the function to derive is the derivative function itself
// if so, derive the last argument of the derivative function and set
// a flag to later accumulate previous derivatives
let (to_derive, f, is_der) = if f_orig.get_symbol() == State::DERIVATIVE {
let (to_derive, f, is_der) = if f_orig.get_symbol() == Atom::DERIVATIVE {
let to_derive = f_orig.iter().last().unwrap();
(
to_derive,
Expand Down Expand Up @@ -113,29 +113,29 @@ impl<'a> AtomView<'a> {

// derive special functions
if f.get_nargs() == 1
&& [State::EXP, State::LOG, State::SIN, State::COS].contains(&f.get_symbol())
&& [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS].contains(&f.get_symbol())
{
let mut fn_der = workspace.new_atom();
match f.get_symbol() {
State::EXP => {
Atom::EXP => {
fn_der.set_from_view(self);
}
State::LOG => {
Atom::LOG => {
let mut n = workspace.new_atom();
n.to_num((-1).into());

fn_der.to_pow(f.iter().next().unwrap(), n.as_view());
}
State::SIN => {
let p = fn_der.to_fun(State::COS);
Atom::SIN => {
let p = fn_der.to_fun(Atom::COS);
p.add_arg(f.iter().next().unwrap());
}
State::COS => {
Atom::COS => {
let mut n = workspace.new_atom();
n.to_num((-1).into());

let mut sin = workspace.new_atom();
let sin_fun = sin.to_fun(State::SIN);
let sin_fun = sin.to_fun(Atom::SIN);
sin_fun.add_arg(f.iter().next().unwrap());

let m = fn_der.to_mul();
Expand Down Expand Up @@ -167,7 +167,7 @@ impl<'a> AtomView<'a> {
let mut n = workspace.new_atom();
let mut mul = workspace.new_atom();
for (index, arg_der) in args_der {
let p = fn_der.to_fun(State::DERIVATIVE);
let p = fn_der.to_fun(Atom::DERIVATIVE);

if is_der {
for (i, x_orig) in f_orig.iter().take(f.get_nargs()).enumerate() {
Expand Down Expand Up @@ -218,7 +218,7 @@ impl<'a> AtomView<'a> {
if exp_der_non_zero {
// create log(base)
let mut log_base = workspace.new_atom();
let lb = log_base.to_fun(State::LOG);
let lb = log_base.to_fun(Atom::LOG);
lb.add_arg(base);

if let Atom::Mul(m) = exp_der.deref_mut() {
Expand Down Expand Up @@ -418,11 +418,11 @@ impl<'a> AtomView<'a> {
}

match f.get_symbol() {
State::COS => args_series[0].cos(),
State::SIN => args_series[0].sin(),
State::EXP => args_series[0].exp(),
State::LOG => args_series[0].log(),
State::SQRT => args_series[0].rpow((1, 2).into()),
Atom::COS => args_series[0].cos(),
Atom::SIN => args_series[0].sin(),
Atom::EXP => args_series[0].exp(),
Atom::LOG => args_series[0].log(),
Atom::SQRT => args_series[0].rpow((1, 2).into()),
_ => {
// TODO: also check for log(x)?
if args_series
Expand Down Expand Up @@ -461,7 +461,7 @@ impl<'a> AtomView<'a> {
CombinationWithReplacementIterator::new(args_series.len(), i);

while let Some(x) = it.next() {
let mut f_der = FunctionBuilder::new(State::DERIVATIVE);
let mut f_der = FunctionBuilder::new(Atom::DERIVATIVE);
let mut term = info.one();
for (arg, pow) in x.iter().enumerate() {
if *pow > 0 {
Expand Down
7 changes: 7 additions & 0 deletions src/domains/dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,13 @@ macro_rules! create_hyperdual_from_components {
res
}

#[inline(always)]
fn i(&self) -> Option<Self> {
let mut res = self.zero();
res.values[0] = self.values[0].i()?;
Some(res)
}

#[inline(always)]
fn norm(&self) -> Self {
let n = self.values[0].norm();
Expand Down
Loading

0 comments on commit fea7a04

Please sign in to comment.