Skip to content

Commit

Permalink
feat: split tensors into constant and input dependent (#48)
Browse files Browse the repository at this point in the history
* feat: time_indep_defn split to constant_dfns and input_dep_defns

* working through bugs

* fix bugs, turn atomic add back on

* update enzyme

* turn off llvm multithreaded tests for macos
  • Loading branch information
martinjrobins authored Jan 29, 2025
1 parent 4996619 commit 6d5bf33
Show file tree
Hide file tree
Showing 10 changed files with 440 additions and 90 deletions.
90 changes: 72 additions & 18 deletions src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ pub struct DiscreteModel<'s> {
lhs: Option<Tensor<'s>>,
rhs: Tensor<'s>,
out: Option<Tensor<'s>>,
time_indep_defns: Vec<Tensor<'s>>,
constant_defns: Vec<Tensor<'s>>,
input_dep_defns: Vec<Tensor<'s>>,
time_dep_defns: Vec<Tensor<'s>>,
state_dep_defns: Vec<Tensor<'s>>,
dstate_dep_defns: Vec<Tensor<'s>>,
Expand All @@ -54,7 +55,10 @@ impl fmt::Display for DiscreteModel<'_> {
writeln!(f, "{}", input)?;
}
}
for defn in &self.time_indep_defns {
for defn in &self.constant_defns {
writeln!(f, "{}", defn)?;
}
for defn in &self.input_dep_defns {
writeln!(f, "{}", defn)?;
}
for defn in &self.time_dep_defns {
Expand Down Expand Up @@ -90,7 +94,8 @@ impl<'s> DiscreteModel<'s> {
lhs: None,
rhs: Tensor::new_empty("F"),
out: None,
time_indep_defns: Vec::new(),
constant_defns: Vec::new(),
input_dep_defns: Vec::new(),
time_dep_defns: Vec::new(),
state_dep_defns: Vec::new(),
dstate_dep_defns: Vec::new(),
Expand Down Expand Up @@ -258,7 +263,7 @@ impl<'s> DiscreteModel<'s> {
}

pub fn build(name: &'s str, model: &'s ast::DsModel) -> Result<Self, ValidationErrors> {
let mut env = Env::default();
let mut env = Env::new(model.inputs.as_slice());
let mut ret = Self::new(name);
let mut read_state = false;
let mut span_f = None;
Expand Down Expand Up @@ -374,6 +379,7 @@ impl<'s> DiscreteModel<'s> {
let dependent_on_state = env_entry.is_state_dependent();
let dependent_on_time = env_entry.is_time_dependent();
let dependent_on_dudt = env_entry.is_dstatedt_dependent();
let dependent_on_input = env_entry.is_input_dependent();
if is_input {
// inputs must be constants
if dependent_on_time || dependent_on_state {
Expand All @@ -383,12 +389,11 @@ impl<'s> DiscreteModel<'s> {
));
}
ret.inputs.push(built);
} else if !dependent_on_time && !dependent_on_input {
ret.constant_defns.push(built);
} else if !dependent_on_time {
ret.time_indep_defns.push(built);
} else if dependent_on_time
&& !dependent_on_state
&& !dependent_on_dudt
{
ret.input_dep_defns.push(built);
} else if !dependent_on_state && !dependent_on_dudt {
ret.time_dep_defns.push(built);
} else if dependent_on_state {
ret.state_dep_defns.push(built);
Expand Down Expand Up @@ -670,7 +675,7 @@ impl<'s> DiscreteModel<'s> {
.iter()
.map(DiscreteModel::dfn_to_array)
.collect();
let time_indep_defns = const_defns
let constant_defns = const_defns
.iter()
.map(DiscreteModel::dfn_to_array)
.collect();
Expand All @@ -687,7 +692,8 @@ impl<'s> DiscreteModel<'s> {
state,
state_dot: Some(state_dot),
out: Some(out_array),
time_indep_defns,
constant_defns,
input_dep_defns: Vec::new(), // todo: need to implement
time_dep_defns,
state_dep_defns,
dstate_dep_defns,
Expand All @@ -700,9 +706,14 @@ impl<'s> DiscreteModel<'s> {
self.inputs.as_ref()
}

pub fn time_indep_defns(&self) -> &[Tensor] {
self.time_indep_defns.as_ref()
pub fn constant_defns(&self) -> &[Tensor] {
self.constant_defns.as_ref()
}

pub fn input_dep_defns(&self) -> &[Tensor] {
self.input_dep_defns.as_ref()
}

pub fn time_dep_defns(&self) -> &[Tensor] {
self.time_dep_defns.as_ref()
}
Expand Down Expand Up @@ -772,7 +783,8 @@ mod tests {
let model_info = ModelInfo::build("circuit", &models).unwrap();
assert_eq!(model_info.errors.len(), 0);
let discrete = DiscreteModel::from(&model_info);
assert_eq!(discrete.time_indep_defns.len(), 0);
assert_eq!(discrete.input_dep_defns().len(), 0);
assert_eq!(discrete.constant_defns().len(), 0);
assert_eq!(discrete.time_dep_defns.len(), 1);
assert_eq!(discrete.time_dep_defns[0].name(), "inputVoltage");
assert_eq!(discrete.state_dep_defns.len(), 1);
Expand Down Expand Up @@ -849,7 +861,15 @@ mod tests {
);
assert_eq!(
model
.time_indep_defns()
.constant_defns()
.iter()
.map(|t| t.name())
.collect::<Vec<_>>(),
Vec::<&str>::new()
);
assert_eq!(
model
.input_dep_defns()
.iter()
.map(|t| t.name())
.collect::<Vec<_>>(),
Expand Down Expand Up @@ -1163,7 +1183,7 @@ mod tests {
let model = parse_ds_string(model_text.as_str()).unwrap();
match DiscreteModel::build("$name", &model) {
Ok(model) => {
let tensor = model.time_indep_defns.iter().chain(model.time_dep_defns.iter()).find(|t| t.name() == $tensor_name).unwrap();
let tensor = model.constant_defns().iter().chain(model.time_dep_defns.iter()).find(|t| t.name() == $tensor_name).unwrap();
let tensor_string = format!("{}", tensor).chars().filter(|c| !c.is_whitespace()).collect::<String>();
let tensor_check_string = $tensor_string.chars().filter(|c| !c.is_whitespace()).collect::<String>();
assert_eq!(tensor_string, tensor_check_string);
Expand Down Expand Up @@ -1248,6 +1268,40 @@ mod tests {
assert!(model.out().is_none());
}

#[test]
fn test_constants_and_input_dep() {
let text = "
in = [r]
r { 1, }
k { 1, }
r2 { 2 * r }
u_i {
y = k,
}
F_i {
r * y,
}
";
let model = parse_ds_string(text).unwrap();
let model = DiscreteModel::build("$name", &model).unwrap();
assert_eq!(
model
.constant_defns()
.iter()
.map(|t| t.name())
.collect::<Vec<_>>(),
["k"]
);
assert_eq!(
model
.input_dep_defns()
.iter()
.map(|t| t.name())
.collect::<Vec<_>>(),
["r2"]
);
}

#[test]
fn test_sparse_layout() {
let text = "
Expand All @@ -1272,12 +1326,12 @@ mod tests {
let model = parse_ds_string(text).unwrap();
let model = DiscreteModel::build("$name", &model).unwrap();
let r = model
.time_indep_defns()
.constant_defns()
.iter()
.find(|t| t.name() == "r")
.unwrap();
let b = model
.time_indep_defns()
.constant_defns()
.iter()
.find(|t| t.name() == "b")
.unwrap();
Expand Down
23 changes: 18 additions & 5 deletions src/discretise/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct EnvVar {
is_time_dependent: bool,
is_state_dependent: bool,
is_dstatedt_dependent: bool,
is_input_dependent: bool,
is_algebraic: bool,
}

Expand All @@ -34,6 +35,10 @@ impl EnvVar {
self.is_algebraic
}

pub fn is_input_dependent(&self) -> bool {
self.is_input_dependent
}

pub fn layout(&self) -> &Layout {
self.layout.as_ref()
}
Expand All @@ -43,10 +48,11 @@ pub struct Env {
current_span: Option<StringSpan>,
errs: ValidationErrors,
vars: HashMap<String, EnvVar>,
inputs: Vec<String>,
}

impl Default for Env {
fn default() -> Self {
impl Env {
pub fn new(inputs: &[&str]) -> Self {
let mut vars = HashMap::new();
vars.insert(
"t".to_string(),
Expand All @@ -55,18 +61,17 @@ impl Default for Env {
is_time_dependent: true,
is_state_dependent: false,
is_dstatedt_dependent: false,
is_input_dependent: false,
is_algebraic: true,
},
);
Env {
errs: ValidationErrors::default(),
vars,
current_span: None,
inputs: inputs.iter().map(|s| s.to_string()).collect(),
}
}
}

impl Env {
pub fn is_tensor_time_dependent(&self, tensor: &Tensor) -> bool {
if tensor.name() == "u" || tensor.name() == "dudt" {
return true;
Expand All @@ -83,6 +88,12 @@ impl Env {
self.is_tensor_dependent_on(tensor, "u")
}

pub fn is_tensor_input_dependent(&self, tensor: &Tensor) -> bool {
self.inputs
.iter()
.any(|input| self.is_tensor_dependent_on(tensor, input))
}

pub fn is_tensor_dstatedt_dependent(&self, tensor: &Tensor) -> bool {
self.is_tensor_dependent_on(tensor, "dudt")
}
Expand Down Expand Up @@ -112,6 +123,7 @@ impl Env {
is_time_dependent: self.is_tensor_time_dependent(var),
is_state_dependent: self.is_tensor_state_dependent(var),
is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var),
is_input_dependent: self.is_tensor_input_dependent(var),
},
);
}
Expand All @@ -125,6 +137,7 @@ impl Env {
is_time_dependent: self.is_tensor_time_dependent(var),
is_state_dependent: self.is_tensor_state_dependent(var),
is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var),
is_input_dependent: self.is_tensor_input_dependent(var),
},
);
}
Expand Down
Loading

0 comments on commit 6d5bf33

Please sign in to comment.