Skip to content

Commit

Permalink
Add a linearization transformer
Browse files Browse the repository at this point in the history
- Can extract symbols as well
  • Loading branch information
benruijl committed Jul 13, 2024
1 parent a52797a commit 9ed92f0
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 2 deletions.
35 changes: 35 additions & 0 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,41 @@ impl PythonPattern {
return append_transformer!(self, Transformer::ArgCount(only_for_arg_fun));
}

/// Create a transformer that linearizes a function, optionally extracting `symbols`
/// as well.
///
/// Examples
/// --------
/// >>> from symbolica import Expression, Transformer
/// >>> x, y, z, w, f, x__ = Expression.symbols('x', 'y', 'z', 'w', 'f', 'x__')
/// >>> e = f(x+y, 4*z*w+3).replace_all(f(x__), f(x__).transform().linearize([z]))
/// >>> print(e)
///
/// yields `f(x,3)+f(y,3)+4*z*f(x,w)+4*z*f(y,w)`.
pub fn linearize(&self, symbols: Option<Vec<PythonExpression>>) -> PyResult<PythonPattern> {
let mut c_symbols = vec![];
if let Some(symbols) = symbols {
for s in symbols {
if let AtomView::Var(v) = s.expr.as_view() {
c_symbols.push(v.get_symbol());
} else {
return Err(exceptions::PyValueError::new_err(
"Can only linearize in variables",
));
}
}
}

return append_transformer!(
self,
Transformer::Linearize(if c_symbols.is_empty() {
None
} else {
Some(c_symbols)
})
);
}

/// Create a transformer that sorts a list of arguments.
///
/// Examples
Expand Down
186 changes: 184 additions & 2 deletions src/transformer.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::time::Instant;

use crate::{
atom::{Atom, AtomView, Symbol},
atom::{representation::FunView, Atom, AtomView, Fun, Symbol},
coefficient::{Coefficient, CoefficientView},
combinatorics::{partitions, unique_permutations},
domains::rational::Rational,
id::{Condition, MatchSettings, Pattern, Replacement, WildcardAndRestriction},
printer::{AtomPrinter, PrintOptions},
state::{State, Workspace},
state::{RecycledAtom, State, Workspace},
};
use ahash::HashMap;
use colored::Colorize;
Expand Down Expand Up @@ -92,6 +92,8 @@ pub enum Transformer {
/// of arguments of `arg()` is returned and 1 is returned otherwise.
/// If the argument is `false`, 0 is returned for non-functions.
ArgCount(bool),
/// Linearize a function, optionally extracting `symbols` as well.
Linearize(Option<Vec<Symbol>>),
/// Map the rhs with a user-specified function.
Map(Box<dyn Map>),
/// Apply a transformation to each argument of the `arg()` function.
Expand Down Expand Up @@ -123,6 +125,7 @@ impl std::fmt::Debug for Transformer {
Transformer::Product => f.debug_tuple("Product").finish(),
Transformer::Sum => f.debug_tuple("Sum").finish(),
Transformer::ArgCount(p) => f.debug_tuple("ArgCount").field(p).finish(),
Transformer::Linearize(s) => f.debug_tuple("Linearize").field(s).finish(),
Transformer::Map(_) => f.debug_tuple("Map").finish(),
Transformer::ForEach(t) => f.debug_tuple("ForEach").field(t).finish(),
Transformer::Split => f.debug_tuple("Split").finish(),
Expand Down Expand Up @@ -150,6 +153,159 @@ impl std::fmt::Debug for Transformer {
}
}

impl FunView<'_> {
/// Linearize a function, optionally extracting `symbols` as well.
pub fn linearize(&self, symbols: Option<&[Symbol]>) -> Atom {
let mut out = Atom::new();
Workspace::get_local().with(|ws| {
self.linearize_impl(symbols, ws, &mut out);
});
out
}

fn linearize_impl(&self, symbols: Option<&[Symbol]>, workspace: &Workspace, out: &mut Atom) {
/// Add an argument `a` to `f` and flatten nested `arg`s.
#[inline(always)]
fn add_arg(f: &mut Fun, a: AtomView) {
if let AtomView::Fun(fa) = a {
if fa.get_symbol() == State::ARG {
// flatten f(arg(...)) = f(...)
for aa in fa.iter() {
f.add_arg(aa);
}

return;
}
}

f.add_arg(a);
}

/// Take Cartesian product of arguments
#[inline(always)]
fn cartesian_product<'b>(
workspace: &Workspace,
list: &[Vec<AtomView<'b>>],
fun_name: Symbol,
cur: &mut Vec<AtomView<'b>>,
acc: &mut Vec<RecycledAtom>,
) {
if list.is_empty() {
let mut h = workspace.new_atom();
let f = h.to_fun(fun_name);
for a in cur.iter() {
add_arg(f, *a);
}
acc.push(h);
return;
}

for a in &list[0] {
cur.push(*a);
cartesian_product(workspace, &list[1..], fun_name, cur, acc);
cur.pop();
}
}

if self.iter().any(|a| matches!(a, AtomView::Add(_))) {
let mut arg_buf = Vec::with_capacity(self.get_nargs());

for a in self.iter() {
let mut vec = vec![];
if let AtomView::Add(aa) = a {
for a in aa.iter() {
vec.push(a);
}
} else {
vec.push(a);
}
arg_buf.push(vec);
}

let mut acc = Vec::new();
cartesian_product(
workspace,
&arg_buf,
self.get_symbol(),
&mut vec![],
&mut acc,
);

let mut add_h = workspace.new_atom();
let add = add_h.to_add();

let mut h = workspace.new_atom();
for a in acc {
a.as_view().normalize(workspace, &mut h);

if let AtomView::Fun(ff) = h.as_view() {
let mut h2 = workspace.new_atom();
ff.linearize_impl(symbols, workspace, &mut h2);
add.extend(h2.as_view());
} else {
add.extend(h.as_view());
}
}

add_h.as_view().normalize(workspace, out);
return;
}

// linearize products
if self.iter().any(|a| {
symbols.is_some()
|| if let AtomView::Mul(m) = a {
m.has_coefficient()
} else {
false
}
}) {
let mut new_term = workspace.new_atom();
let t = new_term.to_mul();
let mut new_fun = workspace.new_atom();
let nf = new_fun.to_fun(self.get_symbol());

let mut coeff = workspace.new_atom();
let c = coeff.to_mul();
for a in self.iter() {
if let AtomView::Mul(m) = a {
if m.has_coefficient() || symbols.is_some() {
let mut stripped = workspace.new_atom();
let mul = stripped.to_mul();

for a in m {
if let AtomView::Num(_) = a {
c.extend(a);
} else if let AtomView::Var(v) = a {
let s = v.get_symbol();
if symbols.map(|x| x.contains(&s)).unwrap_or(false) {
c.extend(a);
} else {
mul.extend(a);
}
} else {
mul.extend(a);
}
}

nf.add_arg(stripped.as_view());
} else {
nf.add_arg(a);
}
} else {
nf.add_arg(a);
}
}

t.extend(new_fun.as_view());
t.extend(coeff.as_view());
t.as_view().normalize(workspace, out);
} else {
out.set_from_view(&self.as_view());
}
}
}

impl Transformer {
/// Create a new partition transformer that must exactly fit the input.
pub fn new_partition_exact(partitions: Vec<(Symbol, usize)>) -> Transformer {
Expand Down Expand Up @@ -294,6 +450,13 @@ impl Transformer {
out.to_num(Coefficient::zero());
}
}
Transformer::Linearize(symbols) => {
if let AtomView::Fun(f) = input {
f.linearize_impl(symbols.as_ref().map(|x| x.as_slice()), workspace, out);
} else {
out.set_from_view(&input);
}
}
Transformer::Split => match input {
AtomView::Mul(m) => {
let mut arg_h = workspace.new_atom();
Expand Down Expand Up @@ -679,4 +842,23 @@ mod test {
let r = Atom::parse("arg(0,0,0,0)").unwrap();
assert_eq!(out, r);
}

#[test]
fn linearize() {
let p = Atom::parse("f1(v1+v2,4*v3*v4+3)").unwrap();

let mut out = Atom::new();
Workspace::get_local().with(|ws| {
Transformer::execute(
p.as_view(),
&[Transformer::Linearize(Some(vec![State::get_symbol("v3")]))],
ws,
&mut out,
)
.unwrap()
});

let r = Atom::parse("f1(v1,3)+f1(v2,3)+4*v3*f1(v1,v4)+4*v3*f1(v2,v4)").unwrap();
assert_eq!(out, r);
}
}
15 changes: 15 additions & 0 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,21 @@ class Transformer:
>>> print(e)
"""

def linearize(self, symbols: Optional[Sequence[Expression]]) -> Transformer:
"""Create a transformer that linearizes a function, optionally extracting `symbols`
as well.
Examples
--------
>>> from symbolica import Expression, Transformer
>>> x, y, z, w, f, x__ = Expression.symbols('x', 'y', 'z', 'w', 'f', 'x__')
>>> e = f(x+y, 4*z*w+3).replace_all(f(x__), f(x__).transform().linearize([z]))
>>> print(e)
yields `f(x,3)+f(y,3)+4*z*f(x,w)+4*z*f(y,w)`.
"""

def partitions(
self,
bins: Sequence[Tuple[Transformer | Expression, int]],
Expand Down

0 comments on commit 9ed92f0

Please sign in to comment.