Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [
] }
faer = "0.23.1"
faer-ext = { version = "0.7.1", features = ["nalgebra", "ndarray"] }
pharmsol = "=0.21.0"
pharmsol = "=0.22.0"
rand = "0.9.0"
anyhow = "1.0.100"
rayon = "1.10.0"
Expand Down
47 changes: 38 additions & 9 deletions src/algorithms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fs;
use std::path::Path;

use crate::routines::math::logsumexp_rows;
use crate::routines::output::NPResult;
use crate::routines::settings::Settings;
use crate::structs::psi::Psi;
Expand Down Expand Up @@ -36,35 +37,63 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
// Count problematic values in psi
let mut nan_count = 0;
let mut inf_count = 0;
let is_log_space = match self.psi().space() {
crate::structs::psi::Space::Linear => false,
crate::structs::psi::Space::Log => true,
};

let psi = self.psi().matrix().as_ref().into_ndarray();
// First coerce all NaN and infinite in psi to 0.0
// First coerce all NaN and infinite in psi to 0.0 (or NEG_INFINITY for log-space)
for i in 0..psi.nrows() {
for j in 0..self.psi().matrix().ncols() {
let val = psi.get((i, j)).unwrap();
if val.is_nan() {
nan_count += 1;
// *val = 0.0;
} else if val.is_infinite() {
inf_count += 1;
// *val = 0.0;
// In log-space, NEG_INFINITY is valid (represents zero probability)
// Only count positive infinity as problematic
if !is_log_space || val.is_sign_positive() {
inf_count += 1;
}
}
}
}

if nan_count + inf_count > 0 {
tracing::warn!(
"Psi matrix contains {} NaN, {} Infinite values of {} total values",
"Psi matrix contains {} NaN, {} problematic Infinite values of {} total values",
nan_count,
inf_count,
psi.ncols() * psi.nrows()
);
}

let (_, col) = psi.dim();
let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
let plam = psi.dot(&ecol);
let w = 1. / &plam;
// Calculate row sums: for regular space: sum; for log-space: logsumexp
let plam: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = if is_log_space {
// For log-space, use logsumexp for each row
Array::from_vec(logsumexp_rows(psi.nrows(), psi.ncols(), |i, j| psi[(i, j)]))
} else {
// For regular space, sum each row
let (_, col) = psi.dim();
let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
psi.dot(&ecol)
};

// Check for subjects with zero probability
// In log-space: -inf means zero probability
// In regular space: 0 means zero probability
let w: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = if is_log_space {
// For log-space, check if logsumexp result is -inf
Array::from_shape_fn(plam.len(), |i| {
if plam[i].is_infinite() && plam[i].is_sign_negative() {
f64::INFINITY // Will be flagged as problematic
} else {
1.0 // Valid
}
})
} else {
1. / &plam
};

// Get the index of each element in `w` that is NaN or infinite
let indices: Vec<usize> = w
Expand Down
68 changes: 35 additions & 33 deletions src/algorithms/npag.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::algorithms::{Status, StopReason};
use crate::prelude::algorithms::Algorithms;

pub use crate::routines::estimation::ipm::burke;
pub use crate::routines::estimation::ipm::{burke, burke_ipm, burke_log};
pub use crate::routines::estimation::qr;
use crate::routines::math::logsumexp;
use crate::routines::settings::Settings;

use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
Expand Down Expand Up @@ -160,8 +161,24 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
self.eps /= 2.;
if self.eps <= THETA_E {
let pyl = psi * w.weights();
self.f1 = pyl.iter().map(|x| x.ln()).sum();
// Compute f1 = sum(log(pyl)) where pyl = psi * w
self.f1 = if self.psi.space() == crate::structs::psi::Space::Log {
// For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w)))
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
(0..psi.nrows())
.map(|i| {
Comment on lines +167 to +169
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any weight in w is zero or negative, .ln() will produce NEG_INFINITY or NaN, potentially causing numerical issues. While weights from the IPM should be positive, consider adding validation or documenting this assumption.

Suggested change
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
(0..psi.nrows())
.map(|i| {
if w.weights().iter().any(|&x| x <= 0.0) {
bail!("All weights must be positive before taking logarithm, found zero or negative weight.");
}
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
(0..psi.nrows())

Copilot uses AI. Check for mistakes.
let combined: Vec<f64> = (0..psi.ncols())
.map(|j| *psi.get(i, j) + log_w[j])
.collect();
logsumexp(&combined)
})
.sum()
} else {
// For regular space: f1 = sum(log(psi * w))
let pyl = psi * w.weights();
pyl.iter().map(|x| x.ln()).sum()
};

if (self.f1 - self.f0).abs() <= THETA_F {
tracing::info!("The model converged after {} cycles", self.cycle,);
self.set_status(Status::Stop(StopReason::Converged));
Expand Down Expand Up @@ -204,24 +221,20 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
&self.error_models,
self.cycle == 1 && self.settings.config().progress,
self.cycle != 1,
self.settings.advanced().space,
)?;

if let Err(err) = self.validate_psi() {
bail!(err);
}

(self.lambda, _) = match burke(&self.psi) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
bail!("Error in IPM during estimation: {:?}", err);
}
};
(self.lambda, _) = burke_ipm(&self.psi)
.map_err(|err| anyhow::anyhow!("Error in IPM during estimation: {:?}", err))?;
Ok(())
}

fn condensation(&mut self) -> Result<()> {
// Filter out the support points with lambda < max(lambda)/1000

let max_lambda = self
.lambda
.iter()
Expand Down Expand Up @@ -273,15 +286,9 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
self.psi.filter_column_indices(keep.as_slice());

self.validate_psi()?;
(self.lambda, self.objf) = match burke(&self.psi) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
return Err(anyhow::anyhow!(
"Error in IPM during condensation: {:?}",
err
));
}
};

(self.lambda, self.objf) = burke_ipm(&self.psi)
.map_err(|err| anyhow::anyhow!("Error in IPM during condensation: {:?}", err))?;
self.w = self.lambda.clone();
Ok(())
}
Expand All @@ -298,8 +305,6 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
}
})
.try_for_each(|(outeq, em)| -> Result<()> {
// OPTIMIZATION

let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);

Expand All @@ -316,28 +321,25 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
&error_model_up,
false,
true,
self.settings.advanced().space,
)?;

let psi_down = calculate_psi(
&self.equation,
&self.data,
&self.theta,
&error_model_down,
false,
true,
self.settings.advanced().space,
)?;

let (lambda_up, objf_up) = match burke(&psi_up) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
bail!("Error in IPM during optim: {:?}", err);
}
};
let (lambda_down, objf_down) = match burke(&psi_down) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
bail!("Error in IPM during optim: {:?}", err);
}
};
let (lambda_up, objf_up) = burke_ipm(&psi_up)
.map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?;

let (lambda_down, objf_down) = burke_ipm(&psi_down)
.map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?;

if objf_up > self.objf {
self.error_models.set_factor(outeq, gamma_up)?;
self.objf = objf_up;
Expand Down
Loading