Skip to content

Commit 3115445

Browse files
committed
log-likelihood
1 parent d4dadb3 commit 3115445

File tree

15 files changed

+1003
-187
lines changed

15 files changed

+1003
-187
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [
2929
] }
3030
faer = "0.23.1"
3131
faer-ext = { version = "0.7.1", features = ["nalgebra", "ndarray"] }
32-
pharmsol = "=0.21.0"
32+
pharmsol = "=0.22.0"
3333
rand = "0.9.0"
3434
anyhow = "1.0.100"
3535
rayon = "1.10.0"

src/algorithms/mod.rs

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::fs;
22
use std::path::Path;
33

4+
use crate::routines::math::logsumexp_rows;
45
use crate::routines::output::NPResult;
56
use crate::routines::settings::Settings;
67
use crate::structs::psi::Psi;
@@ -36,35 +37,60 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
3637
// Count problematic values in psi
3738
let mut nan_count = 0;
3839
let mut inf_count = 0;
40+
let is_log_space = self.psi().is_log_space();
3941

4042
let psi = self.psi().matrix().as_ref().into_ndarray();
41-
// First coerce all NaN and infinite in psi to 0.0
43+
// First coerce all NaN and infinite in psi to 0.0 (or NEG_INFINITY for log-space)
4244
for i in 0..psi.nrows() {
4345
for j in 0..self.psi().matrix().ncols() {
4446
let val = psi.get((i, j)).unwrap();
4547
if val.is_nan() {
4648
nan_count += 1;
47-
// *val = 0.0;
4849
} else if val.is_infinite() {
49-
inf_count += 1;
50-
// *val = 0.0;
50+
// In log-space, NEG_INFINITY is valid (represents zero probability)
51+
// Only count positive infinity as problematic
52+
if !is_log_space || val.is_sign_positive() {
53+
inf_count += 1;
54+
}
5155
}
5256
}
5357
}
5458

5559
if nan_count + inf_count > 0 {
5660
tracing::warn!(
57-
"Psi matrix contains {} NaN, {} Infinite values of {} total values",
61+
"Psi matrix contains {} NaN, {} problematic Infinite values of {} total values",
5862
nan_count,
5963
inf_count,
6064
psi.ncols() * psi.nrows()
6165
);
6266
}
6367

64-
let (_, col) = psi.dim();
65-
let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
66-
let plam = psi.dot(&ecol);
67-
let w = 1. / &plam;
68+
// Calculate row sums: for regular space: sum; for log-space: logsumexp
69+
let plam: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = if is_log_space {
70+
// For log-space, use logsumexp for each row
71+
Array::from_vec(logsumexp_rows(psi.nrows(), psi.ncols(), |i, j| psi[(i, j)]))
72+
} else {
73+
// For regular space, sum each row
74+
let (_, col) = psi.dim();
75+
let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
76+
psi.dot(&ecol)
77+
};
78+
79+
// Check for subjects with zero probability
80+
// In log-space: -inf means zero probability
81+
// In regular space: 0 means zero probability
82+
let w: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = if is_log_space {
83+
// For log-space, check if logsumexp result is -inf
84+
Array::from_shape_fn(plam.len(), |i| {
85+
if plam[i].is_infinite() && plam[i].is_sign_negative() {
86+
f64::INFINITY // Will be flagged as problematic
87+
} else {
88+
1.0 // Valid
89+
}
90+
})
91+
} else {
92+
1. / &plam
93+
};
6894

6995
// Get the index of each element in `w` that is NaN or infinite
7096
let indices: Vec<usize> = w

src/algorithms/npag.rs

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use crate::algorithms::{Status, StopReason};
22
use crate::prelude::algorithms::Algorithms;
33

4-
pub use crate::routines::estimation::ipm::burke;
4+
pub use crate::routines::estimation::ipm::{burke, burke_ipm, burke_log};
55
pub use crate::routines::estimation::qr;
6+
use crate::routines::math::logsumexp;
67
use crate::routines::settings::Settings;
78

89
use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
9-
use crate::structs::psi::{calculate_psi, Psi};
10+
use crate::structs::psi::{calculate_psi_dispatch, Psi};
1011
use crate::structs::theta::Theta;
1112
use crate::structs::weights::Weights;
1213

@@ -160,8 +161,24 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
160161
if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
161162
self.eps /= 2.;
162163
if self.eps <= THETA_E {
163-
let pyl = psi * w.weights();
164-
self.f1 = pyl.iter().map(|x| x.ln()).sum();
164+
// Compute f1 = sum(log(pyl)) where pyl = psi * w
165+
self.f1 = if self.psi.is_log_space() {
166+
// For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w)))
167+
let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect();
168+
(0..psi.nrows())
169+
.map(|i| {
170+
let combined: Vec<f64> = (0..psi.ncols())
171+
.map(|j| *psi.get(i, j) + log_w[j])
172+
.collect();
173+
logsumexp(&combined)
174+
})
175+
.sum()
176+
} else {
177+
// For regular space: f1 = sum(log(psi * w))
178+
let pyl = psi * w.weights();
179+
pyl.iter().map(|x| x.ln()).sum()
180+
};
181+
165182
if (self.f1 - self.f0).abs() <= THETA_F {
166183
tracing::info!("The model converged after {} cycles", self.cycle,);
167184
self.set_status(Status::Stop(StopReason::Converged));
@@ -197,31 +214,29 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
197214
}
198215

199216
fn estimation(&mut self) -> Result<()> {
200-
self.psi = calculate_psi(
217+
let use_log_space = self.settings.advanced().log_space;
218+
219+
self.psi = calculate_psi_dispatch(
201220
&self.equation,
202221
&self.data,
203222
&self.theta,
204223
&self.error_models,
205224
self.cycle == 1 && self.settings.config().progress,
206225
self.cycle != 1,
226+
use_log_space,
207227
)?;
208228

209229
if let Err(err) = self.validate_psi() {
210230
bail!(err);
211231
}
212232

213-
(self.lambda, _) = match burke(&self.psi) {
214-
Ok((lambda, objf)) => (lambda, objf),
215-
Err(err) => {
216-
bail!("Error in IPM during estimation: {:?}", err);
217-
}
218-
};
233+
(self.lambda, _) = burke_ipm(&self.psi)
234+
.map_err(|err| anyhow::anyhow!("Error in IPM during estimation: {:?}", err))?;
219235
Ok(())
220236
}
221237

222238
fn condensation(&mut self) -> Result<()> {
223239
// Filter out the support points with lambda < max(lambda)/1000
224-
225240
let max_lambda = self
226241
.lambda
227242
.iter()
@@ -273,20 +288,16 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
273288
self.psi.filter_column_indices(keep.as_slice());
274289

275290
self.validate_psi()?;
276-
(self.lambda, self.objf) = match burke(&self.psi) {
277-
Ok((lambda, objf)) => (lambda, objf),
278-
Err(err) => {
279-
return Err(anyhow::anyhow!(
280-
"Error in IPM during condensation: {:?}",
281-
err
282-
));
283-
}
284-
};
291+
292+
(self.lambda, self.objf) = burke_ipm(&self.psi)
293+
.map_err(|err| anyhow::anyhow!("Error in IPM during condensation: {:?}", err))?;
285294
self.w = self.lambda.clone();
286295
Ok(())
287296
}
288297

289298
fn optimizations(&mut self) -> Result<()> {
299+
let use_log_space = self.settings.advanced().log_space;
300+
290301
self.error_models
291302
.clone()
292303
.iter_mut()
@@ -298,8 +309,6 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
298309
}
299310
})
300311
.try_for_each(|(outeq, em)| -> Result<()> {
301-
// OPTIMIZATION
302-
303312
let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
304313
let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
305314

@@ -309,35 +318,32 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
309318
let mut error_model_down = self.error_models.clone();
310319
error_model_down.set_factor(outeq, gamma_down)?;
311320

312-
let psi_up = calculate_psi(
321+
let psi_up = calculate_psi_dispatch(
313322
&self.equation,
314323
&self.data,
315324
&self.theta,
316325
&error_model_up,
317326
false,
318327
true,
328+
use_log_space,
319329
)?;
320-
let psi_down = calculate_psi(
330+
331+
let psi_down = calculate_psi_dispatch(
321332
&self.equation,
322333
&self.data,
323334
&self.theta,
324335
&error_model_down,
325336
false,
326337
true,
338+
use_log_space,
327339
)?;
328340

329-
let (lambda_up, objf_up) = match burke(&psi_up) {
330-
Ok((lambda, objf)) => (lambda, objf),
331-
Err(err) => {
332-
bail!("Error in IPM during optim: {:?}", err);
333-
}
334-
};
335-
let (lambda_down, objf_down) = match burke(&psi_down) {
336-
Ok((lambda, objf)) => (lambda, objf),
337-
Err(err) => {
338-
bail!("Error in IPM during optim: {:?}", err);
339-
}
340-
};
341+
let (lambda_up, objf_up) = burke_ipm(&psi_up)
342+
.map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?;
343+
344+
let (lambda_down, objf_down) = burke_ipm(&psi_down)
345+
.map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?;
346+
341347
if objf_up > self.objf {
342348
self.error_models.set_factor(outeq, gamma_up)?;
343349
self.objf = objf_up;

0 commit comments

Comments
 (0)