Skip to content

Commit

Permalink
Expand term streamer API
Browse files Browse the repository at this point in the history
- Add matches function to Python API
  • Loading branch information
benruijl committed May 27, 2024
1 parent f923031 commit df53b42
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 16 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
- name: Generate code coverage
run: cargo llvm-cov --workspace --codecov --output-path codecov.json
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
token: ${{ secrets.CODECOV_TOKEN }}
files: codecov.json
fail_ci_if_error: true
84 changes: 80 additions & 4 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2791,7 +2791,7 @@ impl PythonExpression {
}

/// Return an iterator over the pattern `self` matching to `lhs`.
/// Restrictions on pattern can be supplied through `cond`.
/// Restrictions on the pattern can be supplied through `cond`.
///
/// Examples
/// --------
Expand Down Expand Up @@ -2831,6 +2831,39 @@ impl PythonExpression {
))
}

/// Test whether the pattern is found in the expression.
/// Restrictions on the pattern can be supplied through `cond`.
///
/// Examples
/// --------
///
/// >>> f = Expression.symbol('f')
/// >>> if f(1).matches(f(2)):
/// >>> print('match')
pub fn matches(
&self,
lhs: ConvertibleToPattern,
cond: Option<PythonPatternRestriction>,
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
) -> PyResult<bool> {
let pat = lhs.to_pattern()?.expr;
let conditions = cond
.map(|r| r.condition.clone())
.unwrap_or(Condition::default());
let settings = MatchSettings {
level_range: level_range.unwrap_or((0, None)),
level_is_tree_depth: level_is_tree_depth.unwrap_or(false),
..MatchSettings::default()
};

Ok(
PatternAtomTreeIterator::new(&pat, self.expr.as_view(), &conditions, &settings)
.next()
.is_some(),
)
}

/// Return an iterator over the replacement of the pattern `self` on `lhs` by `rhs`.
/// Restrictions on pattern can be supplied through `cond`.
///
Expand Down Expand Up @@ -3497,16 +3530,27 @@ impl PythonTermStreamer {
self.stream += &mut rhs.stream;
}

/// Get the total number of bytes of the stream.
pub fn get_byte_size(&self) -> usize {
self.stream.get_byte_size()
}

/// Add an expression to the term streamer.
/// Return true iff the stream fits in memory.
pub fn fits_in_memory(&self) -> bool {
self.stream.fits_in_memory()
}

/// Get the number of terms in the stream.
pub fn get_num_terms(&self) -> usize {
self.stream.get_num_terms()
}

/// Add an expression to the term stream.
pub fn push(&mut self, expr: PythonExpression) {
self.stream.push(expr.expr.clone());
}

/// Sort and fuse all terms in the streamer.
/// Sort and fuse all terms in the stream.
pub fn normalize(&mut self) {
self.stream.normalize();
}
Expand All @@ -3516,7 +3560,7 @@ impl PythonTermStreamer {
self.stream.to_expression().into()
}

/// Map the transformations to every term in the streamer.
/// Map the transformations to every term in the stream.
pub fn map(&mut self, op: PythonPattern, py: Python) -> PyResult<Self> {
let t = match &op.expr {
Pattern::Transformer(t) => {
Expand Down Expand Up @@ -3553,6 +3597,38 @@ impl PythonTermStreamer {
})
.map(|x| PythonTermStreamer { stream: x })
}

/// Map the transformations to every term in the stream using a single thread.
pub fn map_single_thread(&mut self, op: PythonPattern) -> PyResult<Self> {
let t = match &op.expr {
Pattern::Transformer(t) => {
if t.0.is_some() {
return Err(exceptions::PyValueError::new_err(
"Transformer is bound to expression. Use Transformer() instead."
.to_string(),
));
}
&t.1
}
_ => {
return Err(exceptions::PyValueError::new_err(
"Operation must of a transformer".to_string(),
));
}
};

// map every term in the expression
let s = self.stream.map_single_thread(|x| {
let mut out = Atom::default();
Workspace::get_local().with(|ws| {
Transformer::execute(x.as_view(), &t, ws, &mut out)
.unwrap_or_else(|e| panic!("Transformer failed during execution: {:?}", e));
});
out
});

Ok(PythonTermStreamer { stream: s })
}
}

self_cell!(
Expand Down
5 changes: 2 additions & 3 deletions src/poly/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1143,9 +1143,8 @@ impl InstructionList {
if pairs > 0 {
a.retain(|x| x != &idx1);

// add back removed indices when the count is odd
if count > 2 * pairs {
a.extend(std::iter::repeat(idx1).take(count - 2 * pairs));
if count % 2 == 1 {
a.push(idx1);
}

a.extend(std::iter::repeat(insert_index).take(pairs));
Expand Down
21 changes: 19 additions & 2 deletions src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl<'a, R: ReadableNamedStream> Iterator for TermInputStream<'a, R> {
pub struct TermStreamer<W: WriteableNamedStream> {
mem_buf: Vec<Atom>,
mem_size: usize,
num_terms: usize,
total_size: usize,
file_buf: Vec<W>,
config: TermStreamerConfig,
Expand Down Expand Up @@ -178,6 +179,7 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
Self {
mem_buf: vec![],
mem_size: 0,
num_terms: 0,
total_size: 0,
file_buf: vec![],
filename,
Expand All @@ -196,6 +198,7 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
Self {
mem_buf: vec![],
mem_size: 0,
num_terms: 0,
total_size: 0,
file_buf: vec![],
filename: self.filename.clone(),
Expand All @@ -205,6 +208,16 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
}
}

/// Returns true iff the stream fits in memory.
pub fn fits_in_memory(&self) -> bool {
self.file_buf.is_empty()
}

/// Get the number of terms in the stream.
pub fn get_num_terms(&self) -> usize {
self.num_terms
}

/// Add terms to the buffer.
pub fn push(&mut self, a: Atom) {
if let AtomView::Add(aa) = a.as_view() {
Expand All @@ -220,6 +233,7 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
let size = a.as_view().get_byte_size();
self.mem_buf.push(a);
self.mem_size += size;
self.num_terms += 1;
self.total_size += size;

if self.mem_size >= self.config.max_mem_bytes {
Expand Down Expand Up @@ -248,6 +262,7 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
.par_sort_by(|a, b| a.as_view().cmp_terms(&b.as_view()));

let mut out = Vec::with_capacity(self.mem_buf.len());
let old_size = self.mem_buf.len();
let mut new_size = 0;

if !self.mem_buf.is_empty() {
Expand Down Expand Up @@ -286,6 +301,8 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
}

self.mem_buf = out;
self.num_terms += self.mem_buf.len();
self.num_terms -= old_size;
self.total_size += new_size;
self.total_size -= self.mem_size;
self.mem_size = new_size;
Expand Down Expand Up @@ -442,9 +459,9 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
out_wrap.into_inner().unwrap()
}

/// Map every term in the stream using the function `f` using a single core. The resulting terms
/// Map every term in the stream using the function `f` using a single thread. The resulting terms
/// are a stream as well, which is returned by this function.
pub fn map_single_core(mut self, f: impl Fn(Atom) -> Atom) -> Self {
pub fn map_single_thread(&mut self, f: impl Fn(Atom) -> Atom) -> Self {
let mut new_out = self.next_generation();

let reader = self.reader();
Expand Down
38 changes: 33 additions & 5 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ class Expression:
) -> MatchIterator:
"""
Return an iterator over the pattern `self` matching to `lhs`.
Restrictions on pattern can be supplied through `cond`.
Restrictions on the pattern can be supplied through `cond`.
The `level_range` specifies the `[min,max]` level at which the pattern is allowed to match.
The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree,
Expand All @@ -845,6 +845,25 @@ class Expression:
>>> print(map[0],'=', map[1])
"""

def matches(
self,
lhs: Transformer | Expression | int,
cond: Optional[PatternRestriction] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
) -> bool:
"""
Test whether the pattern is found in the expression.
Restrictions on the pattern can be supplied through `cond`.
Examples
--------
>>> f = Expression.symbol('f')
>>> if f(1).matches(f(2)):
>>> print('match')
"""

def replace(
self,
lhs: Transformer | Expression | int,
Expand Down Expand Up @@ -1540,19 +1559,28 @@ class TermStreamer:
"""Add another term streamer to this one."""

def get_byte_size(self) -> int:
"""Get the byte size of the term streamer."""
"""Get the byte size of the term stream."""

def get_num_terms(self) -> int:
"""Get the number of terms in the stream."""

def fits_in_memory(self) -> bool:
"""Check if the term stream fits in memory."""

def push(self, expr: Expression) -> None:
"""Push an expresssion to the term streamer."""
"""Push an expresssion to the term stream."""

def normalize(self) -> None:
"""Sort and fuse all terms in the streamer."""
"""Sort and fuse all terms in the stream."""

def to_expression(self) -> Expression:
"""Convert the term stream into an expression. This may exceed the available memory."""

def map(self, f: Transformer) -> TermStreamer:
"""Apply a transformer to all terms in the streamer."""
"""Apply a transformer to all terms in the stream."""

def map_single_thread(self, f: Transformer) -> TermStreamer:
"""Apply a transformer to all terms in the stream using a single thread."""


class MatchIterator:
Expand Down

0 comments on commit df53b42

Please sign in to comment.