diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index c466a0daa..2e54720d9 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -34,6 +34,7 @@ pub struct NPAG { ranges: Vec<(f64, f64)>, psi: Psi, theta: Theta, + theta_old: Option, // Store previous theta for CHECKBIG calculation lambda: Weights, w: Weights, eps: f64, @@ -57,6 +58,7 @@ impl Algorithms for NPAG { ranges: settings.parameters().ranges(), psi: Psi::new(), theta: Theta::new(), + theta_old: None, // Initialize as None (no previous theta yet) lambda: Weights::default(), w: Weights::default(), eps: 0.2, @@ -162,11 +164,46 @@ impl Algorithms for NPAG { if self.eps <= THETA_E { let pyl = psi * w.weights(); self.f1 = pyl.iter().map(|x| x.ln()).sum(); - if (self.f1 - self.f0).abs() <= THETA_F { - tracing::info!("The model converged after {} cycles", self.cycle,); + + // Calculate CHECKBIG if we have a previous theta + let checkbig = if let Some(ref old_theta) = self.theta_old { + Some(self.theta.max_relative_difference(&old_theta)?) + } else { + None + }; + + let f1_f0_diff = (self.f1 - self.f0).abs(); + + // Log convergence metrics for diagnostics + match checkbig { + Some(cb) => tracing::debug!( + "f1-f0={:.6e} (threshold={:.6e}), CHECKBIG={:.6e} (threshold={:.6e})", + f1_f0_diff, + THETA_F, + cb, + THETA_E + ), + None => tracing::debug!( + "f1-f0={:.6e} (threshold={:.6e}), CHECKBIG=N/A (no previous theta)", + f1_f0_diff, + THETA_F + ), + } + + // Standard likelihood convergence check + if f1_f0_diff <= THETA_F { + tracing::info!("The model converged according to the LIKELIHOOD criteria",); self.set_status(Status::Stop(StopReason::Converged)); self.log_cycle_state(); return Ok(self.status().clone()); + } else if let Some(cb) = checkbig { + // Additional CHECKBIG convergence check + if cb <= THETA_E { + tracing::info!("The model converged according to the CHECKBIG criteria",); + self.set_status(Status::Stop(StopReason::Converged)); + self.log_cycle_state(); + return Ok(self.status().clone()); + } } else { self.f0 = self.f1; self.eps = 0.2; @@ -174,6 +211,9 @@ impl Algorithms for NPAG { } } + // Save current theta for next cycle's CHECKBIG calculation + self.theta_old = Some(self.theta.clone()); + // Stop if we have reached maximum number of cycles if self.cycle >= self.settings.config().cycles { tracing::warn!("Maximum number of cycles reached"); diff --git a/src/structs/theta.rs b/src/structs/theta.rs index b03a873b6..88d59e674 100644 --- a/src/structs/theta.rs +++ b/src/structs/theta.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; use anyhow::{bail, Result}; -use faer::Mat; +use faer::{ColRef, Mat}; use serde::{Deserialize, Serialize}; use crate::prelude::Parameters; @@ -201,6 +201,41 @@ impl Theta { Theta::from_parts(mat, parameters) } + + /// Compute the maximum relative difference in medians across parameters between two Thetas + /// + /// This is useful for assessing convergence between iterations + /// # Errors + /// Returns an error if the number of parameters (columns) do not match between the two Thetas + pub fn max_relative_difference(&self, other: &Theta) -> Result { + if self.matrix.ncols() != other.matrix.ncols() { + bail!("Number of parameters (columns) do not match between Thetas"); + } + + fn median_col(col: ColRef) -> f64 { + let mut vals: Vec<&f64> = col.iter().collect(); + vals.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mid = vals.len() / 2; + if vals.len() % 2 == 0 { + (vals[mid - 1] + vals[mid]) / 2.0 + } else { + *vals[mid] + } + } + + let mut max_rel_diff = 0.0; + for i in 0..self.matrix.ncols() { + let current_median = median_col(self.matrix.col(i)); + let other_median = median_col(other.matrix.col(i)); + + let denom = current_median.abs().max(other_median.abs()).max(1e-8); // Avoid division by zero + let rel_diff = ((current_median - other_median).abs()) / denom; + if rel_diff > max_rel_diff { + max_rel_diff = rel_diff; + } + } + Ok(max_rel_diff) + } } impl Debug for Theta { @@ -379,4 +414,55 @@ mod tests { assert_eq!(theta.matrix(), &new_matrix); } + + #[test] + fn test_max_relative_difference() { + let matrix1 = mat![[2.0, 4.0], [6.0, 8.0]]; + let matrix2 = mat![[2.0, 4.0], [8.0, 8.0]]; + let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0); + let theta1 = Theta::from_parts(matrix1, parameters.clone()).unwrap(); + let theta2 = Theta::from_parts(matrix2, parameters).unwrap(); + let max_rel_diff = theta1.max_relative_difference(&theta2).unwrap(); + println!("Max relative difference: {}", max_rel_diff); + assert!((max_rel_diff - 0.2).abs() < 1e-6); + } + + #[test] + fn test_max_relative_difference_same_theta() { + let matrix1 = mat![[1.0, 2.0], [3.0, 4.0]]; + let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0); + let theta1 = Theta::from_parts(matrix1, parameters.clone()).unwrap(); + let theta2 = theta1.clone(); + let max_rel_diff = theta1.max_relative_difference(&theta2).unwrap(); + println!("Max relative difference: {}", max_rel_diff); + assert!((max_rel_diff - 0.0).abs() < 1e-6); + } + + #[test] + fn test_max_relative_difference_shape_error() { + let matrix1 = mat![[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]]; + let matrix2 = mat![[2.0, 4.0], [8.0, 8.0]]; + let parameters1 = Parameters::new() + .add("A", 0.0, 10.0) + .add("B", 0.0, 10.0) + .add("C", 0.0, 10.0); + let parameters2 = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0); + let theta1 = Theta::from_parts(matrix1, parameters1).unwrap(); + let theta2 = Theta::from_parts(matrix2, parameters2).unwrap(); + let result = theta1.max_relative_difference(&theta2); + assert!(result.is_err()); + } + + #[test] + fn test_max_relative_difference_odd_length() { + let matrix1 = mat![[1.0, 2.0], [3.0, 6.0], [5.0, 10.0]]; + let matrix2 = mat![[1.0, 2.0], [4.0, 6.0], [5.0, 10.0]]; + let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0); + let theta1 = Theta::from_parts(matrix1, parameters.clone()).unwrap(); + let theta2 = Theta::from_parts(matrix2, parameters).unwrap(); + let max_rel_diff = theta1.max_relative_difference(&theta2).unwrap(); + println!("Max relative difference (odd length): {}", max_rel_diff); + + assert!((max_rel_diff - 0.25).abs() < 1e-6); + } } diff --git a/src/structs/weights.rs b/src/structs/weights.rs index e84974443..1f515a94c 100644 --- a/src/structs/weights.rs +++ b/src/structs/weights.rs @@ -58,14 +58,25 @@ impl Weights { self.weights.nrows() } + /// Check if there are no weights. + pub fn is_empty(&self) -> bool { + self.weights.nrows() == 0 + } + /// Get a vector representation of the weights. pub fn to_vec(&self) -> Vec { self.weights.iter().cloned().collect() } + /// Get an iterator over the weights. pub fn iter(&self) -> impl Iterator + '_ { self.weights.iter().cloned() } + + /// Get a mutable iterator over the weights. + pub fn iter_mut(&mut self) -> impl Iterator + '_ { + self.weights.iter_mut() + } } impl Serialize for Weights {