diff --git a/src/lib.rs b/src/lib.rs index 4159a63e..9bdb1791 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,9 @@ pub mod routines; // Structures pub mod structs; +// MMopt +pub mod mmopt; + // Re-export commonly used items pub use anyhow::Result; pub use std::collections::HashMap; @@ -42,6 +45,9 @@ pub mod prelude { pub use crate::routines::settings::*; pub use crate::structs::*; + + pub use crate::mmopt::*; + pub mod simulator { pub use pharmsol::prelude::simulator::*; } diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs new file mode 100644 index 00000000..2b070479 --- /dev/null +++ b/src/mmopt/mod.rs @@ -0,0 +1,204 @@ +use anyhow::Result; +use faer::Mat; +use pharmsol::{ + prelude::simulator::SubjectPredictions, Data, Equation, ErrorModel, Predictions, Subject, +}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use serde_json::error; +use std::fmt::Error; + +use crate::structs::theta::Theta; + +pub struct PredictionsContainer { + pub matrix: Mat, + pub times: Vec, + pub probs: Vec, +} + +impl PredictionsContainer { + fn matrix(&self) -> &Mat { + &self.matrix + } + + fn nsub(&self) -> usize { + self.matrix.ncols() + } + fn nout(&self) -> usize { + self.matrix.nrows() + } +} + +struct CostMatrix { + matrix: Option>, + auc: f64, + cmax: f64, + cmin: f64, +} + +impl CostMatrix { + pub fn new(auc: f64, cmax: f64, cmin: f64) -> Self { + !unimplemented!() + } +} + +/// The results of a multiple-model optimization +/// +/// +#[derive(Debug)] +pub struct MmoptResult { + // Optimal sample times + pub times: Vec, + // Bayes risk + pub risk: f64, +} + +pub fn mmopt( + theta: &Theta, + subject: &Subject, + equation: impl Equation, + errormodel: ErrorModel, + nsamp: usize, +) -> Result { + // Check that subject contains only one Occasion + if subject.occasions().len() != 1 { + return Err(anyhow::anyhow!("Subject must contain only one Occasion")); + } + + // Generate predictions + let predictions = theta + .matrix() + .row_iter() + .map(|theta_row| { + let support_point: Vec = theta_row.iter().cloned().collect(); + let predictions = equation + .estimate_predictions(&subject, &support_point) + .get_predictions(); + predictions + }) + .collect::>(); + + // Times vector + let times = predictions[0].iter().map(|p| p.time()).collect::>(); + + // Generate prediction matrix + let pred_matrix = Mat::from_fn(predictions[0].len(), theta.nspp(), |i, j| { + predictions[j][i].prediction().to_owned() + }); + + // Generate sample candidate indices + let candidate_indices = generate_combinations(times.len(), nsamp); + + let (best_combo, min_risk) = candidate_indices + .par_iter() + .map(|combo| { + let mut risk = 0.0; + // Compare the i-th and the j-th subject predictions + for i in 0..theta.nspp() { + for j in 0..theta.nspp() { + if i != j { + let i_obs: Vec = pred_matrix + .col(i) + .iter() + .enumerate() + .filter_map(|(k, &x)| if combo.contains(&k) { Some(x) } else { None }) + .collect(); + + let j_obs: Vec = pred_matrix + .col(j) + .iter() + .enumerate() + .filter_map(|(k, &x)| if combo.contains(&k) { Some(x) } else { None }) + .collect(); + + let i_var: Vec = + i_obs.iter().map(|&x| errormodel.variance(x)).collect(); + let j_var: Vec = + j_obs.iter().map(|&x| errorpoly.variance(x)).collect(); + + let sum_k_ijn: f64 = i_obs + .iter() + .zip(j_obs.iter()) + .zip(i_var.iter()) + .zip(j_var.iter()) + .map(|(((y_i, y_j), i_var), j_var)| { + let denominator = i_var + j_var; + let term1 = (y_i - y_j).powi(2) / (4.0 * denominator); + let term2 = 0.5 * ((i_var + j_var) / 2.0).ln(); + let term3 = -0.25 * (i_var * j_var).ln(); + term1 + term2 + term3 + }) + .collect::>() + .iter() + .sum::(); + + let prob_i = predictions.probs[i]; + let prob_j = predictions.probs[j]; + let cost = cost_matrix.matrix[(i, j)]; + let risk_component = prob_i * prob_j * (-sum_k_ijn).exp() * cost; + risk += risk_component; + } + } + } + + (combo.clone(), risk) + }) + .min_by(|(_, risk_a), (_, risk_b)| risk_a.partial_cmp(risk_b).unwrap()) + .unwrap(); + + let times = best_combo.iter().map(|&i| times[i]).collect::>(); + let res = MmoptResult { + times: times, + risk: min_risk, + }; + + Ok(res) +} + +fn generate_combinations(m: usize, n: usize) -> Vec> { + fn backtrack( + m: usize, + n: usize, + start: usize, + current: &mut Vec, + results: &mut Vec>, + ) { + if current.len() == n { + results.push(current.clone()); + return; + } + + for i in start..m { + current.push(i); + backtrack(m, n, i + 1, current, results); + current.pop(); + } + } + + let mut results = Vec::new(); + let mut current = Vec::new(); + backtrack(m, n, 0, &mut current, &mut results); + results +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_combinations() { + let m = 5; + let n = 3; + let combinations = generate_combinations(m, n); + assert_eq!(combinations.len(), 10); + assert_eq!(combinations[0], vec![0, 1, 2]); + assert_eq!(combinations[1], vec![0, 1, 3]); + assert_eq!(combinations[2], vec![0, 1, 4]); + assert_eq!(combinations[3], vec![0, 2, 3]); + assert_eq!(combinations[4], vec![0, 2, 4]); + assert_eq!(combinations[5], vec![0, 3, 4]); + assert_eq!(combinations[6], vec![1, 2, 3]); + assert_eq!(combinations[7], vec![1, 2, 4]); + assert_eq!(combinations[8], vec![1, 3, 4]); + assert_eq!(combinations[9], vec![2, 3, 4]); + } +}