Skip to content

Commit

Permalink
Add together and apart for expressions
Browse files Browse the repository at this point in the history
- Improvements to apart function
- Fix Hermite reduction
- Handle constants better in rational integration algorithm
- Fix printing of Mersennse prime field element
  • Loading branch information
benruijl committed Apr 15, 2024
1 parent 26b7100 commit d06075c
Show file tree
Hide file tree
Showing 8 changed files with 497 additions and 79 deletions.
118 changes: 113 additions & 5 deletions src/api/cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,16 @@ unsafe extern "C" fn drop(symbolica: *mut Symbolica) {
mod test {
use std::ffi::{c_char, CStr};

use super::{drop, init};
use crate::domains::finite_field::Mersenne64;

use super::{drop, init, set_options};

#[test]
fn simplify() {
let symbolica = unsafe { init() };

unsafe { set_options(symbolica, true, false) };

unsafe { super::set_vars(symbolica, b"d,y\0".as_ptr() as *const c_char) };

let input = "-(4096-4096*y^2)/(-3072+1024*d)*(1536-512*d)-(-8192+8192*y^2)/(2)*((-6+d)/2)-(-8192+8192*y^2)/(-2)*((-13+3*d)/2)-(-8192+8192*y^2)/(-4)*(-8+2*d)\0";
Expand All @@ -436,11 +440,15 @@ mod test {

assert_eq!(result, "[32768-32768*y^2-8192*d+8192*d*y^2]");

unsafe { set_options(symbolica, true, true) };

let result = unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, 0, false) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(result, "32768-32768*y^2-8192*d+8192*d*y^2");

unsafe { set_options(symbolica, false, false) };

let result =
unsafe { super::simplify_factorized(symbolica, input.as_ptr() as *const i8, 0, true) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();
Expand All @@ -461,32 +469,132 @@ mod test {

unsafe { super::set_vars(symbolica, b"d,y\0".as_ptr() as *const c_char) };

let prime = 4293491017;

let input = "-(4096-4096*y^2)/(-3072+1024*d)*(1536-512*d)-(-8192+8192*y^2)/(2)*((-6+d)/2)-(-8192+8192*y^2)/(-2)*((-13+3*d)/2)-(-8192+8192*y^2)/(-4)*(-8+2*d)\0";
let result =
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, 4293491017, true) };
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, prime, true) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(result, "[32768+4293458249*y^2+4293482825*d+8192*d*y^2]");

let result =
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, 4293491017, false) };
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, prime, false) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(result, "32768+4293458249*y^2+4293482825*d+8192*d*y^2");

let result = unsafe {
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, 4293491017, true)
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, prime, true)
};
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(result, "[32768+4293458249*y^2+4293482825*d+8192*d*y^2]");

let result = unsafe {
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, 4293491017, false)
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, prime, false)
};
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

unsafe { drop(symbolica) };
assert_eq!(result, "32768+4293458249*y^2+4293482825*d+8192*d*y^2");
}

#[test]
fn simplify_mersenne() {
let symbolica = unsafe { init() };

unsafe { super::set_vars(symbolica, b"d,y\0".as_ptr() as *const c_char) };

let prime = Mersenne64::PRIME;

let input = "-(4096-4096*y^2)/(-3072+1024*d)*(1536-512*d)-(-8192+8192*y^2)/(2)*((-6+d)/2)-(-8192+8192*y^2)/(-2)*((-13+3*d)/2)-(-8192+8192*y^2)/(-4)*(-8+2*d)\0";
let result =
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, prime, true) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(
result,
"[32768+2305843009213661183*y^2+2305843009213685759*d+8192*d*y^2]"
);

let result =
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, prime, false) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(
result,
"32768+2305843009213661183*y^2+2305843009213685759*d+8192*d*y^2"
);

let result = unsafe {
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, prime, true)
};
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(
result,
"[32768+2305843009213661183*y^2+2305843009213685759*d+8192*d*y^2]"
);

let result = unsafe {
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, prime, false)
};
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

unsafe { drop(symbolica) };
assert_eq!(
result,
"32768+2305843009213661183*y^2+2305843009213685759*d+8192*d*y^2"
);
}

#[test]
fn simplify_u64_prime() {
let symbolica = unsafe { init() };

unsafe { super::set_vars(symbolica, b"d,y\0".as_ptr() as *const c_char) };

let prime = 18446744073709551163;

let input = "-(4096-4096*y^2)/(-3072+1024*d)*(1536-512*d)-(-8192+8192*y^2)/(2)*((-6+d)/2)-(-8192+8192*y^2)/(-2)*((-13+3*d)/2)-(-8192+8192*y^2)/(-4)*(-8+2*d)\0";
let result =
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, prime, true) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(
result,
"[32768+18446744073709518395*y^2+18446744073709542971*d+8192*d*y^2]"
);

let result =
unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, prime, false) };
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(
result,
"32768+18446744073709518395*y^2+18446744073709542971*d+8192*d*y^2"
);

let result = unsafe {
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, prime, true)
};
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

assert_eq!(
result,
"[32768+18446744073709518395*y^2+18446744073709542971*d+8192*d*y^2]"
);

let result = unsafe {
super::simplify_factorized(symbolica, input.as_ptr() as *const i8, prime, false)
};
let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned();

unsafe { drop(symbolica) };
assert_eq!(
result,
"32768+18446744073709518395*y^2+18446744073709542971*d+8192*d*y^2"
);
}
}
55 changes: 55 additions & 0 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,61 @@ impl PythonExpression {
Ok(PythonExpression { expr: Arc::new(b) })
}

/// Compute the partial fraction decomposition in `x`.
///
/// Examples
/// --------
///
/// >>> from symbolica import Expression
/// >>> x = Expression.var('x')
/// >>> p = Expression.parse('1/((x+y)*(x^2+x*y+1)(x+1))')
/// >>> print(p.apart(x))
pub fn apart(&self, x: PythonExpression) -> PyResult<PythonExpression> {
let poly = self.expr.to_rational_polynomial::<_, _, u32>(&Q, &Z, None);
let x = poly
.get_variables()
.iter()
.position(|v| match (v, x.expr.as_view()) {
(Variable::Symbol(y), AtomView::Var(vv)) => *y == vv.get_symbol(),
(Variable::Function(_, f) | Variable::Other(f), a) => f.as_view() == a,
_ => false,
})
.ok_or(exceptions::PyValueError::new_err(format!(
"Variable {} not found in polynomial",
x.__str__()?
)))?;

let fs = poly.apart(x);

let mut rn = Atom::new();
Workspace::get_local().with(|ws| {
let mut res = ws.new_atom();
let a = res.to_add();
for f in fs {
a.extend(f.to_expression().as_view());
}

res.as_view().normalize(ws, &mut rn);
});

Ok(PythonExpression { expr: Arc::new(rn) })
}

/// Write the expression over a common denominator.
///
/// Examples
/// --------
///
/// >>> from symbolica import Expression
/// >>> p = Expression.parse('v1^2/2+v1^3/v4*v2+v3/(1+v4)')
/// >>> print(p.together())
pub fn together(&self) -> PyResult<PythonExpression> {
let poly = self.expr.to_rational_polynomial::<_, _, u32>(&Q, &Z, None);
Ok(PythonExpression {
expr: Arc::new(poly.to_expression()),
})
}

/// Convert the expression to a polynomial, optionally, with the variables and the ordering specified in `vars`.
/// All non-polynomial elements will be converted to new independent variables.
pub fn to_polynomial(&self, vars: Option<Vec<PythonExpression>>) -> PyResult<PythonPolynomial> {
Expand Down
78 changes: 78 additions & 0 deletions src/collect.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use ahash::HashMap;

use crate::{
domains::{integer::Z, rational::Q},
representations::{Add, AsAtomView, Atom, AtomView, Symbol},
state::Workspace,
};
Expand Down Expand Up @@ -51,6 +52,16 @@ impl Atom {
pub fn coefficient<'a, T: AsAtomView<'a>>(&self, x: T) -> Atom {
Workspace::get_local().with(|ws| self.as_view().coefficient_with_ws(x.as_atom_view(), ws))
}

/// Write the expression over a common denominator.
pub fn together(&self) -> Atom {
self.as_view().together()
}

/// Write the expression as a sum of terms with minimal denominators.
pub fn apart(&self, x: Symbol) -> Atom {
self.as_view().apart(x)
}
}

impl<'a> AtomView<'a> {
Expand Down Expand Up @@ -341,6 +352,49 @@ impl<'a> AtomView<'a> {
}
}
}

/// Write the expression over a common denominator.
pub fn together(&self) -> Atom {
let mut out = Atom::new();
self.together_into(&mut out);
out
}

/// Write the expression over a common denominator.
pub fn together_into(&self, out: &mut Atom) {
self.to_rational_polynomial::<_, _, u32>(&Q, &Z, None)
.to_expression_into(out);
}

/// Write the expression as a sum of terms with minimal denominators.
pub fn apart(&self, x: Symbol) -> Atom {
let mut out = Atom::new();

Workspace::get_local().with(|ws| {
self.apart_with_ws_into(x, ws, &mut out);
});

out
}

/// Write the expression as a sum of terms with minimal denominators.
pub fn apart_with_ws_into(&self, x: Symbol, ws: &Workspace, out: &mut Atom) {
let poly = self.to_rational_polynomial::<_, _, u32>(&Q, &Z, None);
if let Some(v) = poly.get_variables().iter().position(|v| v == &x.into()) {
let mut a = ws.new_atom();
let add = a.to_add();

let mut a = ws.new_atom();
for x in poly.apart(v) {
x.to_expression_into(&mut a);
add.extend(a.as_view());
}

add.as_view().normalize(ws, out);
} else {
out.set_from_view(self);
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -414,4 +468,28 @@ mod test {

assert_eq!(out, ref_out);
}

#[test]
fn together() {
let input = Atom::parse("v1^2/2+v1^3/v4*v2+v3/(1+v4)").unwrap();
let out = input.together();

let ref_out =
Atom::parse("(2*v4+2*v4^2)^-1*(2*v3*v4+v1^2*v4+v1^2*v4^2+2*v1^3*v2+2*v1^3*v2*v4)")
.unwrap();

assert_eq!(out, ref_out);
}

#[test]
fn apart() {
let input =
Atom::parse("(2*v4+2*v4^2)^-1*(2*v3*v4+v1^2*v4+v1^2*v4^2+2*v1^3*v2+2*v1^3*v2*v4)")
.unwrap();
let out = input.apart(State::get_symbol("v4"));

let ref_out = Atom::parse("1/2*v1^2+v3*(v4+1)^-1+v1^3*v2*v4^-1").unwrap();

assert_eq!(out, ref_out);
}
}
4 changes: 2 additions & 2 deletions src/domains/finite_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,13 +685,13 @@ impl Mersenne64 {

impl std::fmt::Debug for Mersenne64 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", Self::PRIME))
std::fmt::Debug::fmt(&self.0, f)
}
}

impl Display for Mersenne64 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", Self::PRIME))
self.0.fmt(f)
}
}

Expand Down
Loading

0 comments on commit d06075c

Please sign in to comment.