Skip to content

Commit 488a120

Browse files
authored
refactor: Refactor algorithm trait (#196)
* refactor: Updating the algorithm trait To be more idiomatic Rust * Refactoring * Documentation * Update mod.rs * Clean up code * Remove logs from trait * Refactor Status and StopReason A little more idiomatic and a little less idiotic * Fix tests And by fix I mean remove half of them
1 parent 30d7197 commit 488a120

File tree

17 files changed

+250
-324
lines changed

17 files changed

+250
-324
lines changed

examples/iov/main.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ fn main() -> Result<()> {
5252
let data = data::read_pmetrics("examples/iov/test.csv").unwrap();
5353
let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap();
5454
algorithm.initialize().unwrap();
55-
while !algorithm.next_cycle().unwrap() {}
56-
let result = algorithm.into_npresult();
55+
let result = algorithm.fit().unwrap();
5756
result.write_outputs().unwrap();
5857

5958
Ok(())

examples/meta/main.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ fn main() {
6464
let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap();
6565
// let result = algorithm.fit().unwrap();
6666
algorithm.initialize().unwrap();
67-
while !algorithm.next_cycle().unwrap() {}
68-
let result = algorithm.into_npresult();
67+
let result = algorithm.fit().unwrap();
6968
result.write_outputs().unwrap();
7069
}

examples/new_iov/main.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ fn main() {
5555
let data = data::read_pmetrics("examples/new_iov/data.csv").unwrap();
5656
let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap();
5757
algorithm.initialize().unwrap();
58-
while !algorithm.next_cycle().unwrap() {}
59-
let result = algorithm.into_npresult();
58+
let result = algorithm.fit().unwrap();
6059
result.write_outputs().unwrap();
6160
}

examples/theophylline/main.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ fn main() {
5353
let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap();
5454
// let result = algorithm.fit().unwrap();
5555
algorithm.initialize().unwrap();
56-
while !algorithm.next_cycle().unwrap() {}
57-
let result = algorithm.into_npresult();
56+
let result = algorithm.fit().unwrap();
5857
result.write_outputs().unwrap();
5958
}

examples/vanco_sde/main.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ fn main() {
7878

7979
let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap();
8080
algorithm.initialize().unwrap();
81-
while !algorithm.next_cycle().unwrap() {}
82-
let result = algorithm.into_npresult();
81+
let result = algorithm.fit().unwrap();
8382
result.write_outputs().unwrap();
8483
}

src/algorithms/mod.rs

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
7575
.collect::<Vec<_>>();
7676

7777
if !indices.is_empty() {
78-
let subject: Vec<&Subject> = self.get_data().subjects();
78+
let subject: Vec<&Subject> = self.data().subjects();
7979
let zero_probability_subjects: Vec<&String> =
8080
indices.iter().map(|&i| subject[i].id()).collect();
8181

@@ -89,7 +89,7 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
8989
for index in &indices {
9090
tracing::debug!("Subject with zero probability: {}", subject[*index].id());
9191

92-
let error_model = self.get_settings().errormodels().clone();
92+
let error_model = self.settings().errormodels().clone();
9393

9494
// Simulate all support points in parallel
9595
let spp_results: Vec<_> = self
@@ -207,54 +207,103 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
207207

208208
Ok(())
209209
}
210-
fn get_settings(&self) -> &Settings;
210+
211+
fn settings(&self) -> &Settings;
212+
/// Get the equation used in the algorithm
211213
fn equation(&self) -> &E;
212-
fn get_data(&self) -> &Data;
214+
/// Get the data used in the algorithm
215+
fn data(&self) -> &Data;
213216
fn get_prior(&self) -> Theta;
214-
fn inc_cycle(&mut self) -> usize;
215-
fn get_cycle(&self) -> usize;
217+
/// Increment the cycle counter and return the new value
218+
fn increment_cycle(&mut self) -> usize;
219+
/// Get the current cycle number
220+
fn cycle(&self) -> usize;
221+
/// Set the current [Theta]
216222
fn set_theta(&mut self, theta: Theta);
223+
/// Get the current [Theta]
217224
fn theta(&self) -> &Theta;
225+
/// Get the current [Psi]
218226
fn psi(&self) -> &Psi;
227+
/// Get the current likelihood
219228
fn likelihood(&self) -> f64;
229+
/// Get the current negative two log-likelihood
220230
fn n2ll(&self) -> f64 {
221231
-2.0 * self.likelihood()
222232
}
233+
/// Get the current [Status] of the algorithm
223234
fn status(&self) -> &Status;
235+
/// Set the current [Status] of the algorithm
224236
fn set_status(&mut self, status: Status);
225-
fn convergence_evaluation(&mut self);
226-
fn converged(&self) -> bool;
237+
/// Evaluate convergence criteria and update status
238+
fn evaluation(&mut self) -> Result<Status>;
239+
240+
/// Create and log a cycle state with the current algorithm state
241+
fn log_cycle_state(&mut self);
242+
243+
/// Initialize the algorithm, setting up initial [Theta] and [Status]
227244
fn initialize(&mut self) -> Result<()> {
228245
// If a stop file exists in the current directory, remove it
229246
if Path::new("stop").exists() {
230247
tracing::info!("Removing existing stop file prior to run");
231248
fs::remove_file("stop").context("Unable to remove previous stop file")?;
232249
}
233-
self.set_status(Status::InProgress);
250+
self.set_status(Status::Continue);
234251
self.set_theta(self.get_prior());
235252
Ok(())
236253
}
237-
fn evaluation(&mut self) -> Result<()>;
254+
fn estimation(&mut self) -> Result<()>;
255+
/// Performs condensation of [Theta] and updates [Psi]
256+
///
257+
/// This step reduces the number of support points in [Theta] based on the current weights,
258+
/// and updates the [Psi] matrix accordingly to reflect the new set of support points.
259+
/// It is typically performed after the estimation step in each cycle of the algorithm.
238260
fn condensation(&mut self) -> Result<()>;
261+
262+
/// Performs optimizations on the current [ErrorModels] and updates [Psi] accordingly
263+
///
264+
/// This step refines the error model parameters to better fit the data,
265+
/// and subsequently updates the [Psi] matrix to reflect these changes.
239266
fn optimizations(&mut self) -> Result<()>;
240-
fn logs(&self);
267+
268+
/// Performs expansion of [Theta]
269+
///
270+
/// This step increases the number of support points in [Theta] based on the current distribution,
271+
/// allowing for exploration of the parameter space.
241272
fn expansion(&mut self) -> Result<()>;
242-
fn next_cycle(&mut self) -> Result<bool> {
243-
if self.inc_cycle() > 1 {
273+
274+
/// Proceed to the next cycle of the algorithm
275+
///
276+
/// This method increments the cycle counter, performs expansion if necessary,
277+
/// and then runs the estimation, condensation, optimization, logging, and evaluation steps
278+
/// in sequence. It returns the current [Status] of the algorithm after completing these steps.
279+
fn next_cycle(&mut self) -> Result<Status> {
280+
let cycle = self.increment_cycle();
281+
282+
if cycle > 1 {
244283
self.expansion()?;
245284
}
246-
let span = tracing::info_span!("", "{}", format!("Cycle {}", self.get_cycle()));
285+
286+
let span = tracing::info_span!("", "{}", format!("Cycle {}", self.cycle()));
247287
let _enter = span.enter();
248-
self.evaluation()?;
288+
self.estimation()?;
249289
self.condensation()?;
250290
self.optimizations()?;
251-
self.logs();
252-
self.convergence_evaluation();
253-
Ok(self.converged())
291+
self.evaluation()
254292
}
293+
294+
/// Fit the model until convergence or stopping criteria are met
295+
///
296+
/// This method runs the full fitting process, starting with initialization,
297+
/// followed by iterative cycles of estimation, condensation, optimization, and evaluation
298+
/// until the algorithm converges or meets a stopping criteria.
255299
fn fit(&mut self) -> Result<NPResult<E>> {
256300
self.initialize().unwrap();
257-
while !self.next_cycle()? {}
301+
loop {
302+
match self.next_cycle()? {
303+
Status::Continue => continue,
304+
Status::Stop(_) => break,
305+
}
306+
}
258307
Ok(self.into_npresult())
259308
}
260309

@@ -274,32 +323,27 @@ pub fn dispatch_algorithm<E: Equation + Send + 'static>(
274323
}
275324
}
276325

277-
/// Represents the status of the algorithm
326+
/// Represents the status/result of the algorithm
278327
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
279328
pub enum Status {
280-
/// Algorithm is starting up
281-
Starting,
282-
/// Algorithm has converged to a solution
283-
Converged,
284-
/// Algorithm stopped due to reaching maximum cycles
285-
MaxCycles,
286-
/// Algorithm is currently running
287-
InProgress,
288-
/// Algorithm was manually stopped by user
289-
ManualStop,
290-
/// Other status with custom message
291-
Other(String),
329+
Continue,
330+
Stop(StopReason),
292331
}
293332

294333
impl std::fmt::Display for Status {
295334
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296335
match self {
297-
Status::Starting => write!(f, "Starting"),
298-
Status::Converged => write!(f, "Converged"),
299-
Status::MaxCycles => write!(f, "Maximum cycles reached"),
300-
Status::InProgress => write!(f, "In progress"),
301-
Status::ManualStop => write!(f, "Manual stop requested"),
302-
Status::Other(msg) => write!(f, "{}", msg),
336+
Status::Continue => write!(f, "Continue"),
337+
Status::Stop(s) => write!(f, "Stop: {:?}", s),
303338
}
304339
}
305340
}
341+
342+
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
343+
344+
pub enum StopReason {
345+
Converged,
346+
MaxCycles,
347+
Stopped,
348+
Completed,
349+
}

0 commit comments

Comments
 (0)