-
Notifications
You must be signed in to change notification settings - Fork 3
feat: Add methods to Psi to calculate D-optimality #239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,7 @@ use pharmsol::ErrorModels; | |||||||||||||||||||||||||||||||
| use serde::{Deserialize, Serialize}; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| use super::theta::Theta; | ||||||||||||||||||||||||||||||||
| use super::weights::Weights; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| /// [Psi] is a structure that holds the likelihood for each subject (row), for each support point (column) | ||||||||||||||||||||||||||||||||
| #[derive(Debug, Clone, PartialEq)] | ||||||||||||||||||||||||||||||||
|
|
@@ -103,6 +104,71 @@ impl Psi { | |||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Ok(Psi { matrix: mat }) | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| /// Compute the maximum D-optimality value across all support points | ||||||||||||||||||||||||||||||||
| /// | ||||||||||||||||||||||||||||||||
| /// The D-optimality criterion measures convergence of the NPML/NPOD algorithm. | ||||||||||||||||||||||||||||||||
| /// At optimality, this value should be close to 0, meaning no support point | ||||||||||||||||||||||||||||||||
| /// can further improve the likelihood. | ||||||||||||||||||||||||||||||||
| /// | ||||||||||||||||||||||||||||||||
| /// # Interpretation | ||||||||||||||||||||||||||||||||
| /// - **≈ 0**: Solution is optimal | ||||||||||||||||||||||||||||||||
| /// - **> 0**: Not converged; some support points could still improve the objective | ||||||||||||||||||||||||||||||||
| /// - **Larger values**: Further from convergence | ||||||||||||||||||||||||||||||||
| pub fn d_optimality(&self, weights: &Weights) -> Result<f64> { | ||||||||||||||||||||||||||||||||
| let d_values = self.d_optimality_spp(weights)?; | ||||||||||||||||||||||||||||||||
| Ok(d_values.into_iter().fold(f64::NEG_INFINITY, f64::max)) | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| /// Compute D-optimality values for each support point | ||||||||||||||||||||||||||||||||
| /// | ||||||||||||||||||||||||||||||||
| /// Returns the D-value for each support point in the current solution. | ||||||||||||||||||||||||||||||||
| /// At convergence, all values should be close to 0. | ||||||||||||||||||||||||||||||||
| /// | ||||||||||||||||||||||||||||||||
| /// The D-optimality value for support point $j$ is: | ||||||||||||||||||||||||||||||||
| /// $$D(\theta_j) = \sum_{i=1}^{n} \frac{\psi_{ij}}{p_\lambda(y_i)} - n$$ | ||||||||||||||||||||||||||||||||
| pub(crate) fn d_optimality_spp(&self, weights: &Weights) -> Result<Vec<f64>> { | ||||||||||||||||||||||||||||||||
| let psi_mat = self.matrix(); | ||||||||||||||||||||||||||||||||
| let nsub = psi_mat.nrows(); | ||||||||||||||||||||||||||||||||
| let nspp = psi_mat.ncols(); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if nspp != weights.len() { | ||||||||||||||||||||||||||||||||
| bail!( | ||||||||||||||||||||||||||||||||
| "Psi has {} columns but weights has {} elements", | ||||||||||||||||||||||||||||||||
| nspp, | ||||||||||||||||||||||||||||||||
| weights.len() | ||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // Compute pyl = psi * w (weighted probability for each subject) | ||||||||||||||||||||||||||||||||
| let mut pyl = vec![0.0; nsub]; | ||||||||||||||||||||||||||||||||
| for i in 0..nsub { | ||||||||||||||||||||||||||||||||
| for (j, w_j) in weights.iter().enumerate() { | ||||||||||||||||||||||||||||||||
| pyl[i] += psi_mat.get(i, j) * w_j; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
Comment on lines
+144
to
+148
|
||||||||||||||||||||||||||||||||
| let mut pyl = vec![0.0; nsub]; | |
| for i in 0..nsub { | |
| for (j, w_j) in weights.iter().enumerate() { | |
| pyl[i] += psi_mat.get(i, j) * w_j; | |
| } | |
| // Build a column vector (nspp x 1) of weights in column-major order. | |
| let weights_vec: Vec<f64> = weights.iter().copied().collect(); | |
| // Treat the weights as a (nspp x 1) matrix so we can use faer matrix multiplication. | |
| let weights_mat = Mat::from_column_major_slice(nspp, 1, &weights_vec); | |
| // psi_mat has shape (nsub x nspp), so the product has shape (nsub x 1). | |
| let pyl_mat = &psi_mat * &weights_mat; | |
| let mut pyl = Vec::with_capacity(nsub); | |
| for i in 0..nsub { | |
| // Extract the single column of the result matrix. | |
| pyl.push(*pyl_mat.get(i, 0)); |
Copilot
AI
Dec 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Direct floating-point comparison with zero may be unreliable due to floating-point arithmetic precision issues. Consider using a small epsilon threshold instead, such as checking if the absolute value is less than a small tolerance (e.g., 1e-12 or f64::EPSILON).
| // Check for zero probabilities | |
| for (i, &p) in pyl.iter().enumerate() { | |
| if p == 0.0 { | |
| // Check for (effectively) zero probabilities using a small tolerance | |
| let eps = 1e-12_f64; | |
| for (i, &p) in pyl.iter().enumerate() { | |
| if p.abs() < eps { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mathematical notation using dollar signs (e.g.,
$j$,$$D(\theta_j) = ...$$) is LaTeX/MathJax syntax that does not render correctly in rustdoc. Rustdoc does not natively support LaTeX math expressions. Consider using plain text mathematical notation or code formatting instead, such as "D(theta_j) = sum(psi_ij / p_lambda(y_i)) - n" in a code block or plain description.