Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [
] }
faer = "0.23.1"
faer-ext = { version = "0.7.1", features = ["nalgebra", "ndarray"] }
pharmsol = "=0.21.0"
pharmsol = "=0.22.0"
rand = "0.9.0"
anyhow = "1.0.100"
rayon = "1.10.0"
Expand Down
44 changes: 35 additions & 9 deletions src/algorithms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fs;
use std::path::Path;

use crate::routines::math::logsumexp_rows;
use crate::routines::output::NPResult;
use crate::routines::settings::Settings;
use crate::structs::psi::Psi;
Expand Down Expand Up @@ -36,35 +37,60 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
// Count problematic values in psi
let mut nan_count = 0;
let mut inf_count = 0;
let is_log_space = self.psi().is_log_space();

let psi = self.psi().matrix().as_ref().into_ndarray();
// First coerce all NaN and infinite in psi to 0.0
// First coerce all NaN and infinite in psi to 0.0 (or NEG_INFINITY for log-space)
for i in 0..psi.nrows() {
for j in 0..self.psi().matrix().ncols() {
let val = psi.get((i, j)).unwrap();
if val.is_nan() {
nan_count += 1;
// *val = 0.0;
} else if val.is_infinite() {
inf_count += 1;
// *val = 0.0;
// In log-space, NEG_INFINITY is valid (represents zero probability)
// Only count positive infinity as problematic
if !is_log_space || val.is_sign_positive() {
inf_count += 1;
}
}
}
}

if nan_count + inf_count > 0 {
tracing::warn!(
"Psi matrix contains {} NaN, {} Infinite values of {} total values",
"Psi matrix contains {} NaN, {} problematic Infinite values of {} total values",
nan_count,
inf_count,
psi.ncols() * psi.nrows()
);
}

let (_, col) = psi.dim();
let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
let plam = psi.dot(&ecol);
let w = 1. / &plam;
// Calculate row sums: for regular space: sum; for log-space: logsumexp
let plam: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = if is_log_space {
// For log-space, use logsumexp for each row
Array::from_vec(logsumexp_rows(psi.nrows(), psi.ncols(), |i, j| psi[(i, j)]))
} else {
// For regular space, sum each row
let (_, col) = psi.dim();
let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
psi.dot(&ecol)
};

// Check for subjects with zero probability
// In log-space: -inf means zero probability
// In regular space: 0 means zero probability
let w: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = if is_log_space {
// For log-space, check if logsumexp result is -inf
Array::from_shape_fn(plam.len(), |i| {
if plam[i].is_infinite() && plam[i].is_sign_negative() {
f64::INFINITY // Will be flagged as problematic
} else {
1.0 // Valid
}
})
} else {
1. / &plam
};

// Get the index of each element in `w` that is NaN or infinite
let indices: Vec<usize> = w
Expand Down
80 changes: 43 additions & 37 deletions src/algorithms/npag.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::algorithms::{Status, StopReason};
use crate::prelude::algorithms::Algorithms;

pub use crate::routines::estimation::ipm::burke;
pub use crate::routines::estimation::ipm::{burke, burke_ipm, burke_log};
pub use crate::routines::estimation::qr;
use crate::routines::math::logsumexp;
use crate::routines::settings::Settings;

use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
use crate::structs::psi::{calculate_psi, Psi};
use crate::structs::psi::{calculate_psi_dispatch, Psi};
use crate::structs::theta::Theta;
use crate::structs::weights::Weights;

Expand Down Expand Up @@ -160,8 +161,24 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
self.eps /= 2.;
if self.eps <= THETA_E {
let pyl = psi * w.weights();
self.f1 = pyl.iter().map(|x| x.ln()).sum();
// Compute f1 = sum(log(pyl)) where pyl = psi * w
self.f1 = if self.psi.is_log_space() {
// For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w)))
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
(0..psi.nrows())
.map(|i| {
Comment on lines +167 to +169
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any weight in w is zero or negative, .ln() will produce NEG_INFINITY or NaN, potentially causing numerical issues. While weights from the IPM should be positive, consider adding validation or documenting this assumption.

Suggested change
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
(0..psi.nrows())
.map(|i| {
if w.weights().iter().any(|&x| x <= 0.0) {
bail!("All weights must be positive before taking logarithm, found zero or negative weight.");
}
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
(0..psi.nrows())

Copilot uses AI. Check for mistakes.
let combined: Vec<f64> = (0..psi.ncols())
.map(|j| *psi.get(i, j) + log_w[j])
.collect();
logsumexp(&combined)
})
.sum()
} else {
// For regular space: f1 = sum(log(psi * w))
let pyl = psi * w.weights();
pyl.iter().map(|x| x.ln()).sum()
};

if (self.f1 - self.f0).abs() <= THETA_F {
tracing::info!("The model converged after {} cycles", self.cycle,);
self.set_status(Status::Stop(StopReason::Converged));
Expand Down Expand Up @@ -197,31 +214,29 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
}

fn estimation(&mut self) -> Result<()> {
self.psi = calculate_psi(
let use_log_space = self.settings.advanced().log_space;

self.psi = calculate_psi_dispatch(
&self.equation,
&self.data,
&self.theta,
&self.error_models,
self.cycle == 1 && self.settings.config().progress,
self.cycle != 1,
use_log_space,
)?;

if let Err(err) = self.validate_psi() {
bail!(err);
}

(self.lambda, _) = match burke(&self.psi) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
bail!("Error in IPM during estimation: {:?}", err);
}
};
(self.lambda, _) = burke_ipm(&self.psi)
.map_err(|err| anyhow::anyhow!("Error in IPM during estimation: {:?}", err))?;
Ok(())
}

fn condensation(&mut self) -> Result<()> {
// Filter out the support points with lambda < max(lambda)/1000

let max_lambda = self
.lambda
.iter()
Expand Down Expand Up @@ -273,20 +288,16 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
self.psi.filter_column_indices(keep.as_slice());

self.validate_psi()?;
(self.lambda, self.objf) = match burke(&self.psi) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
return Err(anyhow::anyhow!(
"Error in IPM during condensation: {:?}",
err
));
}
};

(self.lambda, self.objf) = burke_ipm(&self.psi)
.map_err(|err| anyhow::anyhow!("Error in IPM during condensation: {:?}", err))?;
self.w = self.lambda.clone();
Ok(())
}

fn optimizations(&mut self) -> Result<()> {
let use_log_space = self.settings.advanced().log_space;

self.error_models
.clone()
.iter_mut()
Expand All @@ -298,8 +309,6 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
}
})
.try_for_each(|(outeq, em)| -> Result<()> {
// OPTIMIZATION

let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);

Expand All @@ -309,35 +318,32 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
let mut error_model_down = self.error_models.clone();
error_model_down.set_factor(outeq, gamma_down)?;

let psi_up = calculate_psi(
let psi_up = calculate_psi_dispatch(
&self.equation,
&self.data,
&self.theta,
&error_model_up,
false,
true,
use_log_space,
)?;
let psi_down = calculate_psi(

let psi_down = calculate_psi_dispatch(
&self.equation,
&self.data,
&self.theta,
&error_model_down,
false,
true,
use_log_space,
)?;

let (lambda_up, objf_up) = match burke(&psi_up) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
bail!("Error in IPM during optim: {:?}", err);
}
};
let (lambda_down, objf_down) = match burke(&psi_down) {
Ok((lambda, objf)) => (lambda, objf),
Err(err) => {
bail!("Error in IPM during optim: {:?}", err);
}
};
let (lambda_up, objf_up) = burke_ipm(&psi_up)
.map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?;

let (lambda_down, objf_down) = burke_ipm(&psi_down)
.map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?;

if objf_up > self.objf {
self.error_models.set_factor(outeq, gamma_up)?;
self.objf = objf_up;
Expand Down
Loading
Loading