Skip to content

Commit

Permalink
Add coefficient conversion functions
Browse files Browse the repository at this point in the history
- Allow evaluation with floating point coefficients
- Set minimal parsing precision of float to 53 bits
  • Loading branch information
benruijl committed Jun 7, 2024
1 parent 7b82f78 commit 271f855
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 3 deletions.
25 changes: 25 additions & 0 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,21 @@ impl<'a> FromPyObject<'a> for Variable {
}
}

impl<'a> FromPyObject<'a> for Integer {
fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
if let Ok(num) = ob.extract::<i64>() {
Ok(num.into())
} else if let Ok(num) = ob.extract::<&PyLong>() {
let a = format!("{}", num);
Ok(Integer::from_large(
rug::Integer::parse(&a).unwrap().complete(),
))
} else {
Err(exceptions::PyValueError::new_err("Not a valid integer"))
}
}
}

pub struct ConvertibleToExpression(PythonExpression);

impl ConvertibleToExpression {
Expand Down Expand Up @@ -1975,6 +1990,16 @@ impl PythonExpression {
}
}

/// Convert all coefficients to floats, with a given decimal precision.
pub fn to_float(&self, decimal_prec: u32) -> PythonExpression {
self.expr.to_float(decimal_prec).into()
}

/// Convert all floating point coefficients to rationals, with a given maximal denominator.
pub fn float_to_rat(&self, max_denominator: Integer) -> PythonExpression {
self.expr.float_to_rat(&max_denominator).into()
}

/// Create a pattern restriction based on the wildcard length before downcasting.
pub fn req_len(
&self,
Expand Down
122 changes: 122 additions & 0 deletions src/coefficient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ impl From<i32> for Coefficient {
}
}

impl From<f64> for Coefficient {
fn from(value: f64) -> Self {
Coefficient::Float(Float::with_val(53, value))
}
}

impl From<(i64, i64)> for Coefficient {
#[inline]
fn from(r: (i64, i64)) -> Self {
Expand Down Expand Up @@ -863,6 +869,11 @@ impl Atom {
pub fn to_float_into(&self, decimal_prec: u32, out: &mut Atom) {
self.as_view().to_float_into(decimal_prec, out);
}

/// Map all floating point coefficients to rational numbers, using a given maximum denominator.
pub fn float_to_rat(&self, max_denominator: &Integer) -> Atom {
self.as_view().float_to_rat(max_denominator)
}
}

impl<'a> AtomView<'a> {
Expand Down Expand Up @@ -1163,6 +1174,109 @@ impl<'a> AtomView<'a> {
}
}
}

/// Map all floating point coefficients to rational numbers, using a given maximum denominator.
pub fn float_to_rat(&self, max_denominator: &Integer) -> Atom {
let mut a = Atom::new();
self.map_coefficient_into(
|c| match c {
CoefficientView::Float(f) => {
let r = f.to_float().to_rational().unwrap();
Rational::from_large(r)
.truncate_denominator(max_denominator)
.into()
}
_ => c.to_owned(),
},
&mut a,
);
a
}

/// Map all coefficients using a given function.
pub fn map_coefficient<F: Fn(CoefficientView) -> Coefficient + Copy>(&self, f: F) -> Atom {
let mut a = Atom::new();
self.map_coefficient_into(f, &mut a);
a
}

/// Map all coefficients using a given function.
pub fn map_coefficient_into<F: Fn(CoefficientView) -> Coefficient + Copy>(
&self,
f: F,
out: &mut Atom,
) {
Workspace::get_local().with(|ws| self.map_coefficient_impl(f, true, ws, out))
}

fn map_coefficient_impl<F: Fn(CoefficientView) -> Coefficient + Copy>(
&self,
coeff_map: F,
enter_function: bool,
ws: &Workspace,
out: &mut Atom,
) {
match self {
AtomView::Num(n) => {
out.to_num(coeff_map(n.get_coeff_view()));
}
AtomView::Var(_) => out.set_from_view(self),
AtomView::Fun(f) => {
if enter_function {
let mut o = ws.new_atom();
let ff = o.to_fun(f.get_symbol());

let mut na = ws.new_atom();
for a in f.iter() {
a.map_coefficient_impl(coeff_map, enter_function, ws, &mut na);
ff.add_arg(na.as_view());
}

o.as_view().normalize(ws, out);
} else {
out.set_from_view(self);
}
}
AtomView::Pow(p) => {
let (base, exp) = p.get_base_exp();

let mut nb = ws.new_atom();
base.map_coefficient_impl(coeff_map, enter_function, ws, &mut nb);

let mut ne = ws.new_atom();
exp.map_coefficient_impl(coeff_map, enter_function, ws, &mut ne);

let mut o = ws.new_atom();
o.to_pow(nb.as_view(), ne.as_view());

o.as_view().normalize(ws, out);
}
AtomView::Mul(m) => {
let mut o = ws.new_atom();
let mm = o.to_mul();

let mut na = ws.new_atom();
for a in m.iter() {
a.map_coefficient_impl(coeff_map, enter_function, ws, &mut na);
mm.extend(na.as_view());
}

o.as_view().normalize(ws, out);
}
AtomView::Add(a) => {
let mut o = ws.new_atom();
let aa = o.to_add();

let mut na = ws.new_atom();
for a in a.iter() {
a.map_coefficient_impl(coeff_map, enter_function, ws, &mut na);
aa.extend(na.as_view());
}

o.as_view().normalize(ws, out);
}
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1253,4 +1367,12 @@ mod test {
);
assert_eq!(r, "5.0000000000000000000000000000000000000000000000000000000000000e-1*x+6.8164061370918581635917066956651198726148569775622233288512875e-1");
}

#[test]
fn float_to_rat() {
let expr = Atom::parse("1/2 x + 238947/128903718927 + sin(3/4)").unwrap();
let expr = expr.to_float(60);
let expr = expr.float_to_rat(&1000.into());
assert_eq!(expr, Atom::parse("1/2*x+349/512").unwrap());
}
}
5 changes: 3 additions & 2 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ impl<'a> AtomView<'a> {
AtomView::Num(n) => match n.get_coeff_view() {
CoefficientView::Natural(n, d) => coeff_map(&Rational::Natural(n, d)),
CoefficientView::Large(l) => coeff_map(&Rational::Large(l.to_rat())),
CoefficientView::Float(_) => {
unimplemented!("Float not yet supported for evaluation")
CoefficientView::Float(f) => {
// TODO: converting back to rational is slow
coeff_map(&Rational::from_large(f.to_float().to_rational().unwrap()))
}
CoefficientView::FiniteField(_, _) => {
unimplemented!("Finite field not yet supported for evaluation")
Expand Down
2 changes: 1 addition & 1 deletion src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ impl Token {
Ok(f) => {
// derive precision from string length, should be overestimate
out.to_num(Coefficient::Float(
f.complete((n.len() as f64 * LOG2_10).ceil() as u32),
f.complete(((n.len() as f64 * LOG2_10).ceil() as u32).max(53)),
));
}
Err(e) => Err(format!("Error parsing number: {}", e))?,
Expand Down
6 changes: 6 additions & 0 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ class Expression:
transformations can be applied.
"""

def to_float(self, decimal_prec: int) -> Expression:
"""Convert all coefficients to floats, with a given decimal precision."""

def float_to_rat(self, max_denominator: int) -> Expression:
"""Convert all floating point coefficients to rationals, with a given maximal denominator."""

def req_len(self, min_length: int, max_length: int | None) -> PatternRestriction:
"""
Create a pattern restriction based on the wildcard length before downcasting.
Expand Down

0 comments on commit 271f855

Please sign in to comment.