Skip to content

Commit 171b68c

Browse files
committed
Further generalize Params trait to enable fusion on higher than binary expressions
1 parent 37ba8d9 commit 171b68c

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

src/lib.rs

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ pub trait Params: Send + Sync + Clone {
6363
/// This can depend on `self` in situations where the number of parameters depends on the data itself, e.g. the number of groups in a hierarchical model.
6464
fn dimension(&self) -> usize;
6565

66-
/// Compute new parameters by mapping the given closure `f` over all coordinate pairs
67-
#[must_use]
68-
fn map<F>(&self, other: &Self, f: F) -> Self
69-
where
70-
F: Fn(f64, f64) -> f64;
66+
/// Access the individual parameter values as an iterator
67+
fn values(&self) -> impl Iterator<Item = &f64>;
68+
69+
/// Collect new parameters from the given iterator
70+
fn collect(iter: impl Iterator<Item = f64>) -> Self;
7171
}
7272

7373
/// Model parameters stored as an array of length `N` considered as an element of the vector space `R^N`
@@ -76,14 +76,13 @@ impl<const N: usize> Params for [f64; N] {
7676
N
7777
}
7878

79-
fn map<F>(&self, other: &Self, f: F) -> Self
80-
where
81-
F: Fn(f64, f64) -> f64,
82-
{
79+
fn values(&self) -> impl Iterator<Item = &f64> {
80+
self.iter()
81+
}
82+
83+
fn collect(iter: impl Iterator<Item = f64>) -> Self {
8384
let mut new = [0.; N];
84-
for i in 0..N {
85-
new[i] = f(self[i], other[i]);
86-
}
85+
iter.enumerate().for_each(|(idx, value)| new[idx] = value);
8786
new
8887
}
8988
}
@@ -94,14 +93,12 @@ impl Params for Vec<f64> {
9493
self.len()
9594
}
9695

97-
fn map<F>(&self, other: &Self, f: F) -> Self
98-
where
99-
F: Fn(f64, f64) -> f64,
100-
{
96+
fn values(&self) -> impl Iterator<Item = &f64> {
10197
self.iter()
102-
.zip(other)
103-
.map(|(self_, other)| f(*self_, *other))
104-
.collect()
98+
}
99+
100+
fn collect(iter: impl Iterator<Item = f64>) -> Self {
101+
iter.collect()
105102
}
106103
}
107104

@@ -111,14 +108,12 @@ impl Params for Box<[f64]> {
111108
self.len()
112109
}
113110

114-
fn map<F>(&self, other: &Self, f: F) -> Self
115-
where
116-
F: Fn(f64, f64) -> f64,
117-
{
111+
fn values(&self) -> impl Iterator<Item = &f64> {
118112
self.iter()
119-
.zip(other.iter())
120-
.map(|(self_, other)| f(*self_, *other))
121-
.collect()
113+
}
114+
115+
fn collect(iter: impl Iterator<Item = f64>) -> Self {
116+
iter.collect()
122117
}
123118
}
124119

@@ -221,9 +216,12 @@ where
221216
fn move_(&mut self, model: &M, other: &Self) -> M::Params {
222217
let z = ((M::SCALE - 1.) * gen_unit(&mut self.rng) + 1.).powi(2) / M::SCALE;
223218

224-
let mut new_state = self
225-
.state
226-
.map(&other.state, |self_, other| other - z * (other - self_));
219+
let mut new_state = M::Params::collect(
220+
self.state
221+
.values()
222+
.zip(other.state.values())
223+
.map(|(self_, other)| other - z * (other - self_)),
224+
);
227225

228226
let new_log_prob = model.log_prob(&new_state);
229227

0 commit comments

Comments
 (0)