diff --git a/benches/performance.rs b/benches/performance.rs index 62ac8cf..cb0f73a 100644 --- a/benches/performance.rs +++ b/benches/performance.rs @@ -38,8 +38,9 @@ fn readme(n: usize) { }, (2, 1), ); + let spp = support_point!["ka" => 0.3, "ke" => 0.5, "tlag" => 0.1, "v" => 70.0]; for _ in 0..n { - let op = ode.estimate_predictions(&subject, &vec![0.3, 0.5, 0.1, 70.0]); + let op = ode.estimate_predictions(&subject, &spp); assert_eq!( op.flat_predictions(), vec![ diff --git a/examples/bke.rs b/examples/bke.rs index 5741876..ff3a70c 100644 --- a/examples/bke.rs +++ b/examples/bke.rs @@ -19,7 +19,7 @@ fn main() { |_p| fa! {}, |_p, _t, _cov, _x| {}, |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); + let v = p.get("v").unwrap(); y[0] = x[0] / v; }, (1, 1), @@ -28,39 +28,31 @@ fn main() { let ode = equation::ODE::new( |x, p, _t, dx, rateiv, _cov| { // fetch_cov!(cov, t, wt); - fetch_params!(p, ke, _v); + let ke = p.get("ke").unwrap(); dx[0] = -ke * x[0] + rateiv[0]; }, |_p| lag! {}, |_p| fa! {}, |_p, _t, _cov, _x| {}, |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); + let v = p.get("v").unwrap(); y[0] = x[0] / v; }, (1, 1), ); - let em = ErrorModel::new((0.0, 0.05, 0.0, 0.0), 0.0, &ErrorType::Add); - let ll = an.estimate_likelihood( - &subject, - &vec![1.02282724609375, 194.51904296875], - &em, - false, - ); - let op = an.estimate_predictions(&subject, &vec![1.02282724609375, 194.51904296875]); + let em = ErrorModel::new((0.0, 0.05, 0.0, 0.0), 0.0, &ErrorType::Additive); + // let spp = SupportPoint::new(vec![("ke", 1.02282724609375), ("v", 194.51904296875)]); + let spp = support_point!("ke" => 1.02282724609375, "v" => 194.51904296875); + let ll = an.estimate_likelihood(&subject, &spp, &em, false); + let op = an.estimate_predictions(&subject, &spp); println!( "Analytical: \n-2ll:{:#?}\n{:#?}", -2.0 * ll, op.flat_predictions() ); - let ll = ode.estimate_likelihood( - &subject, - &vec![1.02282724609375, 194.51904296875], - &em, - false, - ); - let op = ode.estimate_predictions(&subject, &vec![1.02282724609375, 194.51904296875]); + let ll = ode.estimate_likelihood(&subject, &spp, &em, false); + let op = ode.estimate_predictions(&subject, &spp); println!("ODE: \n-2ll:{:#?}\n{:#?}", -2.0 * ll, op.flat_predictions()); } diff --git a/examples/error.rs b/examples/error.rs index d8dd473..f1b3e9e 100644 --- a/examples/error.rs +++ b/examples/error.rs @@ -26,7 +26,7 @@ fn main() { |_p, _t, _cov, _x| {}, |x, p, t, cov, y| { fetch_cov!(cov, t, WT); - fetch_params!(p, CL0, V0, Vp0, Q0); + let V0 = p.get("v0").unwrap(); let V = V0 / (WT / 85.0); y[0] = x[0] / V; }, @@ -39,6 +39,9 @@ fn main() { .repeat(120, 0.1) .build(); - let op = ode.estimate_predictions(&subject, &vec![0.1, 0.1, 0.1, 0.1, 70.0]); + //spp = vec![0.1, 0.1, 0.1, 0.1, 70.0] + let spp = support_point!("cl0" => 0.1, "v0" => 0.1, "vp0" => 0.1, "q0" => 0.1, "v0" => 70.0); + + let op = ode.estimate_predictions(&subject, &spp); dbg!(op); } diff --git a/examples/gendata.rs b/examples/gendata.rs index 3c55bb1..e32f084 100644 --- a/examples/gendata.rs +++ b/examples/gendata.rs @@ -23,8 +23,7 @@ fn main() { // user defined dx[0] = -ke * x[0]; }, - |p, d| { - fetch_params!(p, _ke0); + |_p, d| { d[1] = 0.1; }, |_p| lag! {}, @@ -33,8 +32,7 @@ fn main() { fetch_params!(p, ke0); x[1] = ke0; }, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke0); + |x, _p, _t, _cov, y| { y[0] = x[0] / 50.0; }, (2, 1), @@ -62,6 +60,7 @@ fn main() { let mut data = vec![]; for (i, spp) in support_points.iter().enumerate() { + let spp = SupportPoint::from_vec(spp.clone(), vec!["ke", "ske", "v"]); let trajectories = sde.estimate_predictions(&subject, &spp); let trajectory = trajectories.row(0); // dbg!(&trajectory); diff --git a/examples/nsim.rs b/examples/nsim.rs index 6d3e188..3b1dd2d 100644 --- a/examples/nsim.rs +++ b/examples/nsim.rs @@ -1,5 +1,6 @@ fn main() { use pharmsol::*; + let subject = Subject::builder("id1") .bolus(0.0, 100.0, 0) .repeat(2, 0.5) @@ -15,7 +16,7 @@ fn main() { let ode = equation::ODE::new( |x, p, t, dx, _rateiv, cov| { fetch_cov!(cov, t, wt, age); - fetch_params!(p, ka, ke, _tlag, _v); + fetch_params!(p, ka, ke); // Secondary Eqs let ke = ke * wt.powf(0.75) * (age / 25.0).powf(0.5); @@ -25,13 +26,13 @@ fn main() { dx[1] = ka * x[0] - ke * x[1]; }, |p| { - fetch_params!(p, _ka, _ke, tlag, _v); + fetch_params!(p, tlag); lag! {0=>tlag} }, |_p| fa! {}, |_p, _t, _cov, _x| {}, |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); + fetch_params!(p, v); y[0] = x[1] / v; }, (2, 1), @@ -61,6 +62,13 @@ fn main() { // (2, 1), // ); - let op = ode.estimate_predictions(&subject, &vec![0.3, 0.5, 0.1, 70.0]); + let spp = support_point!( + "ka" => 0.3, + "ke" => 0.5, + "tlag" => 0.1, + "v" => 70.0 + ); + + let op = ode.estimate_predictions(&subject, &spp); println!("{op:#?}"); } diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index 05ebc13..82f18dc 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -17,24 +17,26 @@ fn main() { let ode = equation::ODE::new( |x, p, _t, dx, _rateiv, _cov| { // fetch_cov!(cov, t,); - fetch_params!(p, ka, ke, _tlag, _v); + fetch_params!(p, ka, ke); //Struct dx[0] = -ka * x[0]; dx[1] = ka * x[0] - ke * x[1]; }, |p| { - fetch_params!(p, _ka, _ke, tlag, _v); + fetch_params!(p, tlag); lag! {0=>tlag} }, |_p| fa! {}, |_p, _t, _cov, _x| {}, |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); + fetch_params!(p, v); y[0] = x[1] / v; }, (2, 1), ); - let op = ode.estimate_predictions(&subject, &vec![0.3, 0.5, 0.1, 70.0]); + let spp = support_point!["ka" => 0.3, "ke" => 0.5, "tlag" => 0.1, "v" => 70.0]; + + let op = ode.estimate_predictions(&subject, &spp); println!("{:#?}", op.flat_predictions()); } diff --git a/examples/pf.rs b/examples/pf.rs index bb3f477..5b9adbd 100644 --- a/examples/pf.rs +++ b/examples/pf.rs @@ -12,8 +12,9 @@ fn main() { let sde = equation::SDE::new( |x, p, _t, dx, _rateiv, _cov| { + fetch_params!(p, ke0); dx[0] = -x[0] * x[1]; // ke *x[0] - dx[1] = -x[1] + p[0]; // mean reverting + dx[1] = -x[1] + ke0; // mean reverting }, |_p, d| { d[0] = 1.0; @@ -29,10 +30,12 @@ fn main() { 10000, ); - let em = ErrorModel::new((0.5, 0.0, 0.0, 0.0), 0.0, &error_model::ErrorType::Add); // sigma = 0.5 + let em = ErrorModel::new((0.5, 0.0, 0.0, 0.0), 0.0, &error_model::ErrorType::Additive); // sigma = 0.5 - let ll = sde.estimate_likelihood(&subject, &vec![1.0], &em, false); + let spp = support_point!["ke0" => 1.0]; - dbg!(sde.estimate_likelihood(&subject, &vec![1.0], &em, false)); + let ll = sde.estimate_likelihood(&subject, &spp, &em, false); + + dbg!(sde.estimate_likelihood(&subject, &spp, &em, false)); println!("{ll:#?}"); } diff --git a/examples/sim_theta.rs b/examples/sim_theta.rs new file mode 100644 index 0000000..62a0acf --- /dev/null +++ b/examples/sim_theta.rs @@ -0,0 +1,61 @@ +use csv::ReaderBuilder; +use pharmsol::*; +use prelude::{data::read_pmetrics, simulator::Prediction}; + +fn main() { + // path to theta + let path = "../PMcore/outputs/theta.csv"; + // read theta into an Array2 + let mut rdr = ReaderBuilder::new().from_path(path).unwrap(); + + let mut spps = vec![]; + for result in rdr.records() { + let record = result.unwrap(); + let mut row = vec![]; + for field in record.iter() { + row.push(field.parse::().unwrap()); + } + spps.push(row); + } + + let sde = equation::SDE::new( + |x, p, _t, dx, _rateiv, _cov| { + // automatically defined + fetch_params!(p, ke0); + // let ke0 = 1.2; + dx[1] = -x[1] + ke0; + let ke = x[1]; + // user defined + dx[0] = -ke * x[0]; + }, + |p, d| { + fetch_params!(p, _ke0); + d[1] = 0.1; + }, + |_p| lag! {}, + |_p| fa! {}, + |p, _t, _cov, x| { + fetch_params!(p, ke0); + x[1] = ke0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke0); + y[0] = x[0] / 50.0; + }, + (2, 1), + 1, + ); + + let data = read_pmetrics("../PMcore/examples/iov/test.csv").unwrap(); + + for (i, spp) in spps.iter().enumerate() { + for (j, subject) in data.get_subjects().iter().enumerate() { + let spp = SupportPoint::from_vec(spp.clone(), vec!["ke0"]); + let trajectories: ndarray::Array2 = + sde.estimate_predictions(&subject, &spp); + let trajectory = trajectories.row(0); + println!("{}, {}", i, j); + dbg!(trajectory); + } + } +} diff --git a/src/data/covariate.rs b/src/data/covariate.rs index 576dc90..4b200bc 100644 --- a/src/data/covariate.rs +++ b/src/data/covariate.rs @@ -1,14 +1,14 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::{collections::HashMap, fmt}; -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub enum InterpolationMethod { Linear { slope: f64, intercept: f64 }, CarryForward { value: f64 }, } /// A [CovariateSegment] is a segment of the piece-wise interpolation of a [Covariate] -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct CovariateSegment { from: f64, to: f64, @@ -46,7 +46,7 @@ impl CovariateSegment { } /// A [Covariate] is a collection of [CovariateSegment]s, which allows for interpolation of covariate values -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct Covariate { name: String, segments: Vec, @@ -113,7 +113,7 @@ impl fmt::Display for Covariate { } /// [Covariates] is a collection of [Covariate] -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct Covariates { // Mapping from covariate name to its segments // FIXME: this hashmap have a key, covariate also has a name field, we are never diff --git a/src/data/error_model.rs b/src/data/error_model.rs index 8ad7e58..71a97dd 100644 --- a/src/data/error_model.rs +++ b/src/data/error_model.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; + use crate::simulator::likelihood::Prediction; #[derive(Debug, Clone)] @@ -16,10 +18,10 @@ impl<'a> ErrorModel<'a> { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize, Copy)] pub enum ErrorType { - Add, - Prop, + Additive, + Proportional, } #[allow(clippy::extra_unused_lifetimes)] impl<'a> ErrorModel<'_> { @@ -34,8 +36,8 @@ impl<'a> ErrorModel<'_> { + c3 * prediction.observation().powi(3); let res = match self.e_type { - ErrorType::Add => (alpha.powi(2) + self.gl.powi(2)).sqrt(), - ErrorType::Prop => self.gl * alpha, + ErrorType::Additive => (alpha.powi(2) + self.gl.powi(2)).sqrt(), + ErrorType::Proportional => self.gl * alpha, }; if res.is_nan() || res < 0.0 { diff --git a/src/data/event.rs b/src/data/event.rs index 33ddaa6..7ae3489 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -1,9 +1,9 @@ use std::fmt; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; /// An Event can be a Bolus, Infusion, or Observation -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub enum Event { Bolus(Bolus), Infusion(Infusion), @@ -29,7 +29,7 @@ impl Event { /// An instantaenous input of drug /// /// An amount of drug is added to a compartment at a specific time -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Bolus { time: f64, amount: f64, @@ -62,7 +62,7 @@ impl Bolus { } /// A continuous dose of drug -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Infusion { time: f64, amount: f64, @@ -99,7 +99,7 @@ impl Infusion { } /// An observation of drug concentration or covariates -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Observation { time: f64, value: f64, diff --git a/src/data/structs.rs b/src/data/structs.rs index fb088e2..12331b5 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -1,13 +1,13 @@ use crate::data::*; use csv::WriterBuilder; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::{collections::HashMap, fmt}; /// [Data] is a collection of [Subject]s, which are collections of [Occasion]s, which are collections of [Event]s /// /// This is the main data structure used to store the data, and is used to pass data to the model /// [Data] implements the [DataTrait], which provides methods to access the data -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Data { subjects: Vec, } @@ -238,7 +238,7 @@ impl Data { } /// [Subject] is a collection of blocks for one individual -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, Serialize)] pub struct Subject { id: String, occasions: Vec, @@ -267,7 +267,7 @@ impl Subject { } /// An [Occasion] is a collection of events, for a given [Subject], that are from a specific occasion -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, Serialize)] pub struct Occasion { events: Vec, covariates: Covariates, diff --git a/src/lib.rs b/src/lib.rs index a8a9d73..8e7d02e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,9 @@ pub use crate::data::InterpolationMethod::*; pub use crate::data::*; pub use crate::equation::*; pub use crate::simulator::equation::{self, ODE}; +pub use crate::simulator::types::*; pub use nalgebra::dmatrix; +pub use std::collections::BTreeMap; pub use std::collections::HashMap; pub mod prelude { @@ -25,6 +27,7 @@ pub mod prelude { equation, equation::Equation, likelihood::{psi, PopulationPredictions, Prediction, SubjectPredictions}, + types::SupportPoint, }; } pub mod models { @@ -41,17 +44,17 @@ pub mod prelude { //traits pub use crate::simulator::fitting::{EstimateTheta, OptimalSupportPoint}; - #[macro_export] - macro_rules! fetch_params { - ($p:expr, $($name:ident),*) => { - let p = $p; - let mut idx = 0; - $( - let $name = p[idx]; - idx += 1; - )* - }; - } + // #[macro_export] + // macro_rules! fetch_params { + // ($p:expr, $($name:ident),*) => { + // let p = $p; + // let mut idx = 0; + // $( + // let $name = p[idx]; + // idx += 1; + // )* + // }; + // } #[macro_export] macro_rules! fetch_cov { ($cov:expr, $t:expr, $($name:ident),*) => { diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index f19bf6b..4b4ecc5 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -1,6 +1,7 @@ use crate::{ data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject, }; +use anyhow::Context; use cached::proc_macro::cached; use cached::UnboundCache; use nalgebra::{DVector, Matrix2, Vector2}; @@ -54,13 +55,13 @@ impl EquationPriv for Analytical { // } #[inline(always)] - fn get_lag(&self, spp: &[f64]) -> Option> { - Some((self.lag)(&V::from_vec(spp.to_owned()))) + fn get_lag(&self, spp: &SupportPoint) -> Option> { + Some((self.lag)(spp)) } #[inline(always)] - fn get_fa(&self, spp: &[f64]) -> Option> { - Some((self.fa)(&V::from_vec(spp.to_owned()))) + fn get_fa(&self, spp: &SupportPoint) -> Option> { + Some((self.fa)(spp)) } #[inline(always)] @@ -76,7 +77,7 @@ impl EquationPriv for Analytical { fn solve( &self, x: &mut Self::S, - support_point: &Vec, + support_point: &SupportPoint, covariates: &Covariates, infusions: &Vec, ti: f64, @@ -85,7 +86,7 @@ impl EquationPriv for Analytical { if ti == tf { return; } - let mut support_point = V::from_vec(support_point.to_owned()); + let mut support_point = support_point.clone(); let mut rateiv = V::from_vec(vec![0.0, 0.0, 0.0]); //TODO: This should be pre-calculated for infusion in infusions { @@ -99,7 +100,7 @@ impl EquationPriv for Analytical { #[inline(always)] fn process_observation( &self, - support_point: &Vec, + support_point: &SupportPoint, observation: &Observation, error_model: Option<&ErrorModel>, _time: f64, @@ -110,13 +111,7 @@ impl EquationPriv for Analytical { ) { let mut y = V::zeros(self.get_nouteqs()); let out = &self.out; - (out)( - x, - &V::from_vec(support_point.clone()), - observation.time(), - covariates, - &mut y, - ); + (out)(x, support_point, observation.time(), covariates, &mut y); let pred = y[observation.outeq()]; let pred = observation.to_obs_pred(pred, x.as_slice().to_vec()); if let Some(error_model) = error_model { @@ -125,11 +120,16 @@ impl EquationPriv for Analytical { output.add_prediction(pred); } #[inline(always)] - fn initial_state(&self, spp: &Vec, covariates: &Covariates, occasion_index: usize) -> V { + fn initial_state( + &self, + spp: &SupportPoint, + covariates: &Covariates, + occasion_index: usize, + ) -> V { let init = &self.init; let mut x = V::zeros(self.get_nstates()); if occasion_index == 0 { - (init)(&V::from_vec(spp.to_vec()), 0.0, covariates, &mut x); + (init)(spp, 0.0, covariates, &mut x); } x } @@ -139,26 +139,26 @@ impl Equation for Analytical { fn estimate_likelihood( &self, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: &ErrorModel, cache: bool, ) -> f64 { _estimate_likelihood(self, subject, support_point, error_model, cache) } } -fn spphash(spp: &[f64]) -> u64 { - spp.iter().fold(0, |acc, x| acc + x.to_bits()) -} +// fn spphash(spp: &[f64]) -> u64 { +// spp.iter().fold(0, |acc, x| acc + x.to_bits()) +// } #[inline(always)] #[cached( ty = "UnboundCache", create = "{ UnboundCache::with_capacity(100_000) }", - convert = r#"{ format!("{}{}", subject.id(), spphash(support_point)) }"# + convert = r#"{ format!("{}{}", subject.id(), support_point.hash()) }"# )] fn _subject_predictions( ode: &Analytical, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, ) -> SubjectPredictions { ode.simulate_subject(subject, support_point, None).0 } @@ -166,7 +166,7 @@ fn _subject_predictions( fn _estimate_likelihood( ode: &Analytical, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: &ErrorModel, cache: bool, ) -> f64 { @@ -187,9 +187,12 @@ fn _estimate_likelihood( /// - covariates are not used /// -pub fn one_compartment(x: &V, p: &V, t: T, rateiv: V, _cov: &Covariates) -> V { +pub fn one_compartment(x: &V, p: &SupportPoint, t: T, rateiv: V, _cov: &Covariates) -> V { let mut xout = x.clone(); - let ke = p[0]; + let ke = p + .get("ke") + .context("ke not found in support point") + .unwrap(); xout[0] = x[0] * (-ke * t).exp() + rateiv[0] / ke * (1.0 - (-ke * t).exp()); // dbg!(t, &rateiv, x, &xout); @@ -205,10 +208,24 @@ pub fn one_compartment(x: &V, p: &V, t: T, rateiv: V, _cov: &Covariates) -> V { /// - covariates are not used /// -pub fn one_compartment_with_absorption(x: &V, p: &V, t: T, rateiv: V, _cov: &Covariates) -> V { +pub fn one_compartment_with_absorption( + x: &V, + p: &SupportPoint, + t: T, + rateiv: V, + _cov: &Covariates, +) -> V { let mut xout = x.clone(); - let ka = p[0]; - let ke = p[1]; + // let ka = p[0]; + // let ke = p[1]; + let ka = p + .get("ka") + .context("ka not found in support point") + .unwrap(); + let ke = p + .get("ke") + .context("ke not found in support point") + .unwrap(); xout[0] = x[0] * (-ka * t).exp(); diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index a2c716c..68893af 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -10,7 +10,7 @@ pub use sde::*; use crate::{error_model::ErrorModel, Covariates, Event, Infusion, Observation, Subject}; -use super::likelihood::Prediction; +use super::{likelihood::Prediction, SupportPoint}; pub trait State { fn add_bolus(&mut self, input: usize, amount: f64); @@ -32,14 +32,14 @@ pub trait EquationTypes { pub(crate) trait EquationPriv: EquationTypes { // fn get_init(&self) -> &Init; // fn get_out(&self) -> &Out; - fn get_lag(&self, spp: &[f64]) -> Option>; - fn get_fa(&self, spp: &[f64]) -> Option>; + fn get_lag(&self, spp: &SupportPoint) -> Option>; + fn get_fa(&self, spp: &SupportPoint) -> Option>; fn get_nstates(&self) -> usize; fn get_nouteqs(&self) -> usize; fn solve( &self, state: &mut Self::S, - support_point: &Vec, + support_point: &SupportPoint, covariates: &Covariates, infusions: &Vec, start_time: f64, @@ -54,7 +54,7 @@ pub(crate) trait EquationPriv: EquationTypes { } fn process_observation( &self, - support_point: &Vec, + support_point: &SupportPoint, observation: &Observation, error_model: Option<&ErrorModel>, time: f64, @@ -66,14 +66,14 @@ pub(crate) trait EquationPriv: EquationTypes { fn initial_state( &self, - support_point: &Vec, + support_point: &SupportPoint, covariates: &Covariates, occasion_index: usize, ) -> Self::S; fn simulate_event( &self, - support_point: &Vec, + support_point: &SupportPoint, event: &Event, next_event: Option<&Event>, error_model: Option<&ErrorModel>, @@ -125,19 +125,19 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { fn estimate_likelihood( &self, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: &ErrorModel, cache: bool, ) -> f64; - fn estimate_predictions(&self, subject: &Subject, support_point: &Vec) -> Self::P { + fn estimate_predictions(&self, subject: &Subject, support_point: &SupportPoint) -> Self::P { self.simulate_subject(subject, support_point, None).0 } fn simulate_subject( &self, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: Option<&ErrorModel>, ) -> (Self::P, Option) { let lag = self.get_lag(support_point); diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index 4266a2a..3ac4e9f 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -10,16 +10,17 @@ use diffsol::{ use std::{cell::RefCell, rc::Rc}; use crate::data::{Covariates, Infusion}; +use crate::simulator::SupportPoint; pub(crate) struct PMClosure where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V, M::V, &Covariates), + F: Fn(&M::V, &SupportPoint, M::T, &mut M::V, M::V, &Covariates), { func: F, nstates: usize, nout: usize, nparams: usize, - p: Rc, + p: Rc, sparsity: Option, statistics: RefCell, covariates: Covariates, @@ -29,13 +30,13 @@ where impl PMClosure where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V, M::V, &Covariates), + F: Fn(&M::V, &SupportPoint, M::T, &mut M::V, M::V, &Covariates), { pub(crate) fn new( func: F, nstates: usize, nout: usize, - p: Rc, + p: Rc, covariates: Covariates, infusions: Vec, ) -> Self { @@ -57,7 +58,7 @@ where impl Op for PMClosure where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V, M::V, &Covariates), + F: Fn(&M::V, &SupportPoint, M::T, &mut M::V, M::V, &Covariates), { type V = M::V; type T = M::T; @@ -82,7 +83,7 @@ where impl NonLinearOp for PMClosure where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V, M::V, &Covariates), + F: Fn(&M::V, &SupportPoint, M::T, &mut M::V, M::V, &Covariates), { fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) { let mut rateiv = Self::V::zeros(self.nstates); @@ -102,7 +103,7 @@ where impl NonLinearOpJacobian for PMClosure where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V, M::V, &Covariates), + F: Fn(&M::V, &SupportPoint, M::T, &mut M::V, M::V, &Covariates), { fn jac_mul_inplace(&self, _x: &M::V, t: M::T, v: &M::V, y: &mut M::V) { let rateiv = Self::V::zeros(self.nstates); diff --git a/src/simulator/equation/ode/diffsol_traits.rs b/src/simulator/equation/ode/diffsol_traits.rs index 029e3f1..56089ec 100644 --- a/src/simulator/equation/ode/diffsol_traits.rs +++ b/src/simulator/equation/ode/diffsol_traits.rs @@ -1,5 +1,8 @@ use super::closure::PMClosure; -use crate::data::{Covariates, Infusion}; +use crate::{ + data::{Covariates, Infusion}, + simulator::SupportPoint, +}; use diffsol::{ error::DiffsolError, matrix::Matrix, @@ -14,7 +17,7 @@ use std::rc::Rc; pub(crate) fn build_pm_ode( rhs: F, init: I, - p: M::V, + p: SupportPoint, t0: f64, h0: f64, rtol: f64, @@ -27,19 +30,23 @@ pub(crate) fn build_pm_ode( > where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V, M::V, &Covariates), + F: Fn(&M::V, &SupportPoint, M::T, &mut M::V, M::V, &Covariates), I: Fn(&M::V, M::T) -> M::V, { let p = Rc::new(p); let t0 = M::T::from(t0); - let y0 = (init)(&p, t0); + let v = Rc::new(M::V::from_vec( + p.to_vec().into_iter().map(M::T::from).collect(), + )); + let y0 = (init)(&v, t0); let nstates = y0.len(); let rhs = PMClosure::new(rhs, nstates, nstates, p.clone(), cov, infusions); // let mass = Rc::new(UnitCallable::new(nstates)); let rhs = Rc::new(rhs); - let init = ConstantClosure::new(init, p.clone()); + + let init = ConstantClosure::new(init, v.clone()); let init = Rc::new(init); - let eqn = OdeSolverEquations::new(rhs, None, None, init, None, p); + let eqn = OdeSolverEquations::new(rhs, None, None, init, None, v); // let atol = M::V::from_element(nstates, M::T::from(atol)); OdeBuilder::new() .atol(vec![atol]) diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index ca6feef..e7a0058 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -7,7 +7,9 @@ use crate::{ data::{Covariates, Infusion}, error_model::ErrorModel, prelude::simulator::SubjectPredictions, - simulator::{likelihood::ToPrediction, DiffEq, Fa, Init, Lag, Neqs, Out, M, T, V}, + simulator::{ + likelihood::ToPrediction, DiffEq, Fa, Init, Lag, Neqs, Out, SupportPoint, M, T, V, + }, Observation, Subject, }; use cached::proc_macro::cached; @@ -52,35 +54,35 @@ impl State for V { } } -// Hash the support points by converting them to bits and summing them -// The wrapping_add is used to avoid overflow, and prevent panics -fn spphash(spp: &[f64]) -> u64 { - let mut hasher = std::hash::DefaultHasher::new(); - spp.iter().for_each(|&value| { - // Normalize negative zero to zero, e.g. -0.0 -> 0.0 - let normalized_value = if value == 0.0 && value.is_sign_negative() { - 0.0 - } else { - value - }; - // Convert the value to bits and hash it - let bits = normalized_value.to_bits(); - std::hash::Hash::hash(&bits, &mut hasher); - }); - - std::hash::Hasher::finish(&hasher) -} +// // Hash the support points by converting them to bits and summing them +// // The wrapping_add is used to avoid overflow, and prevent panics +// fn spphash(spp: &[f64]) -> u64 { +// let mut hasher = std::hash::DefaultHasher::new(); +// spp.iter().for_each(|&value| { +// // Normalize negative zero to zero, e.g. -0.0 -> 0.0 +// let normalized_value = if value == 0.0 && value.is_sign_negative() { +// 0.0 +// } else { +// value +// }; +// // Convert the value to bits and hash it +// let bits = normalized_value.to_bits(); +// std::hash::Hash::hash(&bits, &mut hasher); +// }); + +// std::hash::Hasher::finish(&hasher) +// } #[inline(always)] #[cached( ty = "UnboundCache", create = "{ UnboundCache::with_capacity(100_000) }", - convert = r#"{ format!("{}{}", subject.id(), spphash(support_point)) }"# + convert = r#"{ format!("{}{}", subject.id(), support_point.hash()) }"# )] fn _subject_predictions( ode: &ODE, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, ) -> SubjectPredictions { ode.simulate_subject(subject, support_point, None).0 } @@ -88,7 +90,7 @@ fn _subject_predictions( fn _estimate_likelihood( ode: &ODE, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: &ErrorModel, cache: bool, ) -> f64 { @@ -107,13 +109,13 @@ impl EquationTypes for ODE { impl EquationPriv for ODE { #[inline(always)] - fn get_lag(&self, spp: &[f64]) -> Option> { - Some((self.lag)(&V::from_vec(spp.to_owned()))) + fn get_lag(&self, spp: &SupportPoint) -> Option> { + Some((self.lag)(spp)) } #[inline(always)] - fn get_fa(&self, spp: &[f64]) -> Option> { - Some((self.fa)(&V::from_vec(spp.to_owned()))) + fn get_fa(&self, spp: &SupportPoint) -> Option> { + Some((self.fa)(spp)) } #[inline(always)] @@ -129,7 +131,7 @@ impl EquationPriv for ODE { fn solve( &self, state: &mut Self::S, - support_point: &Vec, + support_point: &SupportPoint, covariates: &Covariates, infusions: &Vec, start_time: f64, @@ -142,7 +144,7 @@ impl EquationPriv for ODE { let problem = build_pm_ode::( self.diffeq, |_p: &V, _t: T| state.clone(), - V::from_vec(support_point.to_vec()), + support_point.clone(), start_time, 1e-3, RTOL, @@ -163,7 +165,7 @@ impl EquationPriv for ODE { #[inline(always)] fn process_observation( &self, - support_point: &Vec, + support_point: &SupportPoint, observation: &Observation, error_model: Option<&ErrorModel>, _time: f64, @@ -174,13 +176,7 @@ impl EquationPriv for ODE { ) { let mut y = V::zeros(self.get_nouteqs()); let out = &self.out; - (out)( - x, - &V::from_vec(support_point.clone()), - observation.time(), - covariates, - &mut y, - ); + (out)(x, support_point, observation.time(), covariates, &mut y); let pred = y[observation.outeq()]; let pred = observation.to_obs_pred(pred, x.as_slice().to_vec()); if let Some(error_model) = error_model { @@ -190,11 +186,16 @@ impl EquationPriv for ODE { } #[inline(always)] - fn initial_state(&self, spp: &Vec, covariates: &Covariates, occasion_index: usize) -> V { + fn initial_state( + &self, + spp: &SupportPoint, + covariates: &Covariates, + occasion_index: usize, + ) -> V { let init = &self.init; let mut x = V::zeros(self.get_nstates()); if occasion_index == 0 { - (init)(&V::from_vec(spp.to_vec()), 0.0, covariates, &mut x); + (init)(spp, 0.0, covariates, &mut x); } x } @@ -204,7 +205,7 @@ impl Equation for ODE { fn estimate_likelihood( &self, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: &ErrorModel, cache: bool, ) -> f64 { @@ -212,22 +213,22 @@ impl Equation for ODE { } } -// Test spphash -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_spphash() { - let spp1 = vec![1.0, 2.0, 3.0]; - let spp2 = vec![1.0, 2.0, 3.0]; - let spp3 = vec![3.0, 2.0, 1.0]; - let spp4 = vec![1.0, 2.0, 3.000001]; - // Equal values should have the same hash - assert_eq!(spphash(&spp1), spphash(&spp2)); - // Mirrored values should have different hashes - assert_ne!(spphash(&spp1), spphash(&spp3)); - // Very close values should have different hashes - // Note: Due to f64 precision this will fail for values that are very close, e.g. 3.0 and 3.0000000000000001 - assert_ne!(spphash(&spp1), spphash(&spp4)); - } -} +// // Test spphash +// #[cfg(test)] +// mod tests { +// use super::*; +// #[test] +// fn test_spphash() { +// let spp1 = vec![1.0, 2.0, 3.0]; +// let spp2 = vec![1.0, 2.0, 3.0]; +// let spp3 = vec![3.0, 2.0, 1.0]; +// let spp4 = vec![1.0, 2.0, 3.000001]; +// // Equal values should have the same hash +// assert_eq!(spphash(&spp1), spphash(&spp2)); +// // Mirrored values should have different hashes +// assert_ne!(spphash(&spp1), spphash(&spp3)); +// // Very close values should have different hashes +// // Note: Due to f64 precision this will fail for values that are very close, e.g. 3.0 and 3.0000000000000001 +// assert_ne!(spphash(&spp1), spphash(&spp4)); +// } +// } diff --git a/src/simulator/equation/sde/em.rs b/src/simulator/equation/sde/em.rs index dbbf946..8dc986e 100644 --- a/src/simulator/equation/sde/em.rs +++ b/src/simulator/equation/sde/em.rs @@ -1,6 +1,6 @@ use crate::{ data::Covariates, - simulator::{Diffusion, Drift}, + simulator::{Diffusion, Drift, SupportPoint}, }; use nalgebra::DVector; use rand::prelude::*; @@ -10,7 +10,7 @@ use rand_distr::{Distribution, Normal}; pub struct EM { drift: Drift, diffusion: Diffusion, - params: DVector, + params: SupportPoint, state: DVector, cov: Covariates, rateiv: DVector, @@ -21,7 +21,7 @@ impl EM { pub fn new( drift: Drift, diffusion: Diffusion, - params: DVector, + params: SupportPoint, initial_state: DVector, cov: Covariates, rateiv: DVector, diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index 61328ce..7c9bc6f 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -15,7 +15,9 @@ use crate::{ data::{Covariates, Infusion}, error_model::ErrorModel, prelude::simulator::Prediction, - simulator::{likelihood::ToPrediction, Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V}, + simulator::{ + likelihood::ToPrediction, Diffusion, Drift, Fa, Init, Lag, Neqs, Out, SupportPoint, V, + }, Subject, }; @@ -26,7 +28,7 @@ pub(crate) fn simulate_sde_event( drift: &Drift, difussion: &Diffusion, x: V, - support_point: &[f64], + support_point: &SupportPoint, cov: &Covariates, infusions: &[Infusion], ti: f64, @@ -46,7 +48,7 @@ pub(crate) fn simulate_sde_event( let mut sde = em::EM::new( *drift, *difussion, - DVector::from_column_slice(support_point), + support_point.clone(), x, cov.clone(), rateiv, @@ -127,13 +129,13 @@ impl EquationPriv for SDE { // } #[inline(always)] - fn get_lag(&self, spp: &[f64]) -> Option> { - Some((self.lag)(&V::from_vec(spp.to_owned()))) + fn get_lag(&self, spp: &SupportPoint) -> Option> { + Some((self.lag)(spp)) } #[inline(always)] - fn get_fa(&self, spp: &[f64]) -> Option> { - Some((self.fa)(&V::from_vec(spp.to_owned()))) + fn get_fa(&self, spp: &SupportPoint) -> Option> { + Some((self.fa)(spp)) } #[inline(always)] @@ -149,7 +151,7 @@ impl EquationPriv for SDE { fn solve( &self, state: &mut Self::S, - support_point: &Vec, + support_point: &SupportPoint, covariates: &Covariates, infusions: &Vec, ti: f64, @@ -178,7 +180,7 @@ impl EquationPriv for SDE { #[inline(always)] fn process_observation( &self, - support_point: &Vec, + support_point: &SupportPoint, observation: &crate::Observation, error_model: Option<&ErrorModel>, _time: f64, @@ -190,13 +192,7 @@ impl EquationPriv for SDE { let mut pred = vec![Prediction::default(); self.nparticles]; pred.par_iter_mut().enumerate().for_each(|(i, p)| { let mut y = V::zeros(self.get_nouteqs()); - (self.out)( - &x[i], - &V::from_vec(support_point.clone()), - observation.time(), - covariates, - &mut y, - ); + (self.out)(&x[i], support_point, observation.time(), covariates, &mut y); *p = observation.to_obs_pred(y[observation.outeq()], x[i].as_slice().to_vec()); }); let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone()).unwrap(); @@ -220,7 +216,7 @@ impl EquationPriv for SDE { #[inline(always)] fn initial_state( &self, - support_point: &Vec, + support_point: &SupportPoint, covariates: &Covariates, occasion_index: usize, ) -> Self::S { @@ -228,12 +224,7 @@ impl EquationPriv for SDE { for _ in 0..self.nparticles { let mut state = DVector::zeros(self.get_nstates()); if occasion_index == 0 { - (self.init)( - &V::from_vec(support_point.to_vec()), - 0.0, - covariates, - &mut state, - ); + (self.init)(support_point, 0.0, covariates, &mut state); } x.push(state); } @@ -245,7 +236,7 @@ impl Equation for SDE { fn estimate_likelihood( &self, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: &ErrorModel, cache: bool, ) -> f64 { @@ -257,20 +248,20 @@ impl Equation for SDE { } } -fn spphash(spp: &[f64]) -> u64 { - spp.iter().fold(0, |acc, x| acc + x.to_bits()) -} +// fn spphash(spp: &[f64]) -> u64 { +// spp.iter().fold(0, |acc, x| acc + x.to_bits()) +// } #[inline(always)] #[cached( ty = "UnboundCache", create = "{ UnboundCache::with_capacity(100_000) }", - convert = r#"{ format!("{}{}{}", subject.id(), spphash(support_point), error_model.gl()) }"# + convert = r#"{ format!("{}{}{}", subject.id(), support_point.hash(), error_model.gl()) }"# )] fn _estimate_likelihood( sde: &SDE, subject: &Subject, - support_point: &Vec, + support_point: &SupportPoint, error_model: &ErrorModel, ) -> f64 { let ypred = sde.simulate_subject(subject, support_point, Some(error_model)); diff --git a/src/simulator/fitting.rs b/src/simulator/fitting.rs index 0769d36..60fd4b5 100644 --- a/src/simulator/fitting.rs +++ b/src/simulator/fitting.rs @@ -9,9 +9,12 @@ use crate::{ Equation, Predictions, }; +use super::SupportPoint; + struct SppOptimizer<'a, E: Equation> { equation: &'a E, subject: &'a Subject, + parameters: Vec, } impl<'a, E: Equation> CostFunction for SppOptimizer<'a, E> { @@ -19,20 +22,26 @@ impl<'a, E: Equation> CostFunction for SppOptimizer<'a, E> { type Output = f64; fn cost(&self, point: &Self::Param) -> Result { + let support_point = SupportPoint::from_vec(point.to_vec(), self.parameters.clone()); let (concentrations, _) = self.equation - .simulate_subject(self.subject, point.to_vec().as_ref(), None); + .simulate_subject(self.subject, &support_point, None); Ok(concentrations.squared_error()) //TODO: Change this to use the D function } } impl<'a, E: Equation> SppOptimizer<'a, E> { - fn new(equation: &'a E, subject: &'a Subject) -> Self { - Self { equation, subject } + fn new(equation: &'a E, subject: &'a Subject, parameters: Vec) -> Self { + Self { + equation, + subject, + parameters, + } } - fn optimize(self, point: &Array1) -> Array1 { - let simplex = create_initial_simplex(point); + fn optimize(self, point: &SupportPoint) -> Array1 { + let point = Array1::from_vec(point.to_vec()); + let simplex = create_initial_simplex(&point); let solver = NelderMead::new(simplex) .with_sd_tolerance(1e-2) .expect("Error setting up the solver"); @@ -71,21 +80,21 @@ fn create_initial_simplex(initial_point: &Array1) -> Vec> { } pub trait OptimalSupportPoint { - fn optimal_support_point(&self, equation: &E, point: &Array1) -> Array1; + fn optimal_support_point(&self, equation: &E, point: &SupportPoint) -> Array1; } impl OptimalSupportPoint for Subject { - fn optimal_support_point(&self, equation: &E, point: &Array1) -> Array1 { - SppOptimizer::new(equation, self).optimize(point) + fn optimal_support_point(&self, equation: &E, point: &SupportPoint) -> Array1 { + SppOptimizer::new(equation, self, point.parameters()).optimize(point) } } pub trait EstimateTheta { - fn estimate_theta(&self, equation: &E, point: &Array1) -> Array2; + fn estimate_theta(&self, equation: &E, point: &SupportPoint) -> Array2; } impl EstimateTheta for Data { - fn estimate_theta(&self, equation: &E, point: &Array1) -> Array2 { + fn estimate_theta(&self, equation: &E, point: &SupportPoint) -> Array2 { let n_sub = self.len(); let mut theta = Array2::zeros((n_sub, point.len())); diff --git a/src/simulator/likelihood.rs b/src/simulator/likelihood.rs index 51abd65..b03d0c4 100644 --- a/src/simulator/likelihood.rs +++ b/src/simulator/likelihood.rs @@ -7,6 +7,8 @@ use indicatif::{ProgressBar, ProgressStyle}; use ndarray::{Array2, Axis, ShapeBuilder}; use rayon::prelude::*; +use super::Theta; + const FRAC_1_SQRT_2PI: f64 = std::f64::consts::FRAC_2_SQRT_PI * std::f64::consts::FRAC_1_SQRT_2 / 2.0; @@ -95,12 +97,12 @@ impl From> for PopulationPredictions { pub fn psi( equation: &impl Equation, subjects: &Data, - support_points: &Array2, + support_points: &Theta, error_model: &ErrorModel, progress: bool, cache: bool, ) -> Array2 { - let mut pred: Array2 = Array2::default((subjects.len(), support_points.nrows()).f()); + let mut pred: Array2 = Array2::default((subjects.len(), support_points.len()).f()); let subjects = subjects.get_subjects(); let pb = match progress { true => { @@ -126,12 +128,9 @@ pub fn psi( .enumerate() .for_each(|(j, mut element)| { let subject = subjects.get(i).unwrap(); - let likelihood = equation.estimate_likelihood( - subject, - support_points.row(j).to_vec().as_ref(), - error_model, - cache, - ); + let support_point = support_points.get(j); + let likelihood = + equation.estimate_likelihood(subject, &support_point, error_model, cache); element.fill(likelihood); if let Some(pb_ref) = pb.as_ref() { pb_ref.inc(1); diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 62f7427..a10f8c8 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -1,11 +1,14 @@ pub mod equation; pub mod fitting; pub(crate) mod likelihood; +pub mod types; + use crate::{ data::{Covariates, Infusion}, error_model::ErrorModel, simulator::likelihood::{SubjectPredictions, ToPrediction}, }; +use types::*; use std::collections::HashMap; @@ -30,7 +33,7 @@ type M = nalgebra::DMatrix; /// dx[0] = -ka * x[0]; /// dx[1] = ka * x[0] - ke * x[1]; /// }; -pub type DiffEq = fn(&V, &V, T, &mut V, V, &Covariates); +pub type DiffEq = fn(&V, &SupportPoint, T, &mut V, V, &Covariates); /// This closure represents an Analytical solution of the model, see [analytical] module for examples. /// Params: @@ -40,7 +43,7 @@ pub type DiffEq = fn(&V, &V, T, &mut V, V, &Covariates); /// - rateiv: A vector of infusion rates at time t /// - cov: A reference to the covariates at time t; Use the [fetch_cov!] macro to extract the covariates /// TODO: Remove covariates. They are not used in the analytical solution -pub type AnalyticalEq = fn(&V, &V, T, V, &Covariates) -> V; +pub type AnalyticalEq = fn(&V, &SupportPoint, T, V, &Covariates) -> V; /// This closure represents the drift term of the model: /// Params: @@ -70,7 +73,7 @@ pub type Drift = DiffEq; /// - p: The parameters of the model; Use the [fetch_params!] macro to extract the parameters /// - d: A mutable reference to the diffusion term for each state variable /// (This vector should have the same length as the x, and dx vectors on the drift closure) -pub type Diffusion = fn(&V, &mut V); +pub type Diffusion = fn(&SupportPoint, &mut V); /// This closure represents the initial state of the system: /// Params: /// - p: The parameters of the model; Use the [fetch_params!] macro to extract the parameters @@ -86,7 +89,7 @@ pub type Diffusion = fn(&V, &mut V); /// x[0] = 500.0; /// x[1] = 0.0; /// }; -pub type Init = fn(&V, T, &Covariates, &mut V); +pub type Init = fn(&SupportPoint, T, &Covariates, &mut V); /// This closure represents the output equation of the model: /// Params: @@ -102,7 +105,7 @@ pub type Init = fn(&V, T, &Covariates, &mut V); /// fetch_params!(p, ka, ke, v); /// y[0] = x[1] / v; /// }; -pub type Out = fn(&V, &V, T, &Covariates, &mut V); +pub type Out = fn(&V, &SupportPoint, T, &Covariates, &mut V); /// This closure represents the secondary equation of the model, secondary equations are used to update /// the parameter values based on the covariates. @@ -119,7 +122,7 @@ pub type Out = fn(&V, &V, T, &Covariates, &mut V); /// ka = ka * wt; /// }; /// ``` -pub type SecEq = fn(&mut V, T, &Covariates); +pub type SecEq = fn(&mut SupportPoint, T, &Covariates); /// This closure represents the lag time of the model, the lag term delays the only the boluses going into /// an specific comparment. @@ -138,7 +141,7 @@ pub type SecEq = fn(&mut V, T, &Covariates); /// ``` /// This will lag the bolus going into the first compartment by tlag and the bolus going into the /// second compartment by 0.3 -pub type Lag = fn(&V) -> HashMap; +pub type Lag = fn(&SupportPoint) -> HashMap; /// This closure represents the fraction absorbed (also called bioavailability or protein binding) /// of the model, the fa term is used to adjust the amount of drug that is absorbed into the system. @@ -157,7 +160,7 @@ pub type Lag = fn(&V) -> HashMap; /// ``` /// This will adjust the amount of drug absorbed into the first compartment by fa and the amount of drug /// absorbed into the second compartment by 0.3 -pub type Fa = fn(&V) -> HashMap; +pub type Fa = fn(&SupportPoint) -> HashMap; /// The number of states and output equations of the model /// The first element is the number of states and the second element is the number of output equations diff --git a/src/simulator/types.rs b/src/simulator/types.rs new file mode 100644 index 0000000..0c5b358 --- /dev/null +++ b/src/simulator/types.rs @@ -0,0 +1,107 @@ +use std::collections::hash_map::DefaultHasher; +use std::collections::BTreeMap; +use std::hash::{Hash, Hasher}; + +use ndarray::Array2; + +#[derive(Clone, Debug)] +pub struct Theta { + matrix: Array2, + parameters: Vec, +} + +impl Theta { + pub fn new(matrix: Array2, parameters: Vec) -> Self { + Self { matrix, parameters } + } + + pub fn len(&self) -> usize { + self.matrix.nrows() + } + + pub fn get(&self, i: usize) -> SupportPoint { + let mut map = BTreeMap::new(); + for (j, parameter) in self.parameters.iter().enumerate() { + map.insert(parameter.clone(), self.matrix[[i, j]]); + } + SupportPoint::new(map) + } +} + +#[derive(Clone, Default, Debug)] +pub struct SupportPoint { + map: BTreeMap, +} +#[macro_export] +macro_rules! support_point { + ($($key:expr => $value:expr),* $(,)?) => {{ + let mut map = BTreeMap::new(); + $( + map.insert($key.to_string(), $value); + )* + SupportPoint::new(map) + }}; +} +#[macro_export] +macro_rules! fetch_params { + ($p:expr, $($param:ident),*) => { + $( + let $param = $p.get(stringify!($param).to_lowercase().as_str()).unwrap(); + )* + }; +} + +impl SupportPoint { + /// Create a new, empty support point + pub fn new(map: BTreeMap) -> Self { + Self { map } + } + + pub fn from_vec(vec: Vec, parameters: Vec>) -> Self { + assert!(vec.len() == parameters.len()); + let parameters: Vec = parameters.into_iter().map(|s| s.into()).collect(); + let mut map = BTreeMap::new(); + for (i, parameter) in parameters.iter().enumerate() { + map.insert(parameter.clone(), vec[i]); + } + Self { map } + } + + /// Get the value of a parameter in the support point + pub fn get(&self, key: &str) -> Option { + self.map.get(key).copied() + } + + /// Insert a new parameter into the support point + pub fn insert(&mut self, key: String, value: f64) { + self.map.insert(key, value); + } + + pub fn hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + + self.map.iter().for_each(|(_, &value)| { + let bits = value.to_bits(); + bits.hash(&mut hasher); + }); + + hasher.finish() + } + + pub fn get_mut(&mut self, key: &str) -> Option<&mut f64> { + self.map.get_mut(key) + } + + pub fn len(&self) -> usize { + self.map.len() + } + + // THIS DOES NOT GUARANTEE ORDER, MIGHT FAIL! + pub fn to_vec(&self) -> Vec { + self.map.values().copied().collect() + } + + pub fn parameters(&self) -> Vec { + self.map.keys().cloned().collect() + } +}