Skip to content

Commit

Permalink
now have M and F, tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 14, 2024
1 parent 4758e79 commit 5fe7a47
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 182 deletions.
4 changes: 2 additions & 2 deletions examples/logistic.ds
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ dudt_i {
dydt = 0,
dzdt = 0,
}
F_i {
M_i {
dydt,
0,
}
G_i {
F_i {
(r * y) * (1 - (y / k)),
(2 * y) - z,
}
Expand Down
127 changes: 86 additions & 41 deletions src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub struct DiscreteModel<'s> {
time_indep_defns: Vec<Tensor<'s>>,
time_dep_defns: Vec<Tensor<'s>>,
state_dep_defns: Vec<Tensor<'s>>,
dstate_dep_defns: Vec<Tensor<'s>>,
inputs: Vec<Tensor<'s>>,
state: Tensor<'s>,
state_dot: Tensor<'s>,
Expand Down Expand Up @@ -87,6 +88,7 @@ impl<'s> DiscreteModel<'s> {
time_indep_defns: Vec::new(),
time_dep_defns: Vec::new(),
state_dep_defns: Vec::new(),
dstate_dep_defns: Vec::new(),
inputs: Vec::new(),
state: Tensor::new_empty("u"),
state_dot: Tensor::new_empty("u_dot"),
Expand All @@ -95,41 +97,6 @@ impl<'s> DiscreteModel<'s> {
}
}

// residual = F(t, u, u_dot) - G(t, u)
// return a tensor equal to the residual
pub fn residual(&self) -> Tensor<'s> {
let mut residual = self.lhs.clone();
residual.set_name("residual");
let indices = self.lhs.indices().to_vec();
let lhs = Ast {
kind: AstKind::new_indexed_name("F", indices.clone()),
span: None,
};
let rhs = Ast {
kind: AstKind::new_indexed_name("G", indices),
span: None,
};
let name = "residual";
let indices = self.lhs.indices().to_vec();
let layout = self.lhs.layout_ptr().clone();
let elmts = vec![
TensorBlock::new(
None,
Index::from_vec(vec![0]),
indices.clone(),
self.lhs.layout_ptr().clone(),
self.lhs.layout_ptr().clone(),
Ast {
kind: AstKind::new_binop('-', lhs, rhs),
span: None,
},
)
];
Tensor::new(name, elmts, layout, indices)
}



fn build_array(array: &ast::Tensor<'s>, env: &mut Env) -> Option<Tensor<'s>> {
let rank = array.indices().len();
let mut elmts = Vec::new();
Expand Down Expand Up @@ -317,6 +284,14 @@ impl<'s> DiscreteModel<'s> {
if let Some(built) = Self::build_array(tensor, &mut env) {
ret.stop = Some(built);
}
// check that stop is not dependent on dudt
let stop = env.get("stop").unwrap();
if stop.is_dstatedt_dependent() {
env.errs_mut().push(ValidationError::new(
"stop must not be dependent on dudt".to_string(),
tensor_ast.span,
));
}
}
"out" => {
read_out = true;
Expand All @@ -329,13 +304,22 @@ impl<'s> DiscreteModel<'s> {
}
ret.out = built;
}
// check that out is not dependent on dudt
let out = env.get("out").unwrap();
if out.is_dstatedt_dependent() {
env.errs_mut().push(ValidationError::new(
"out must not be dependent on dudt".to_string(),
tensor_ast.span,
));
}
}
_name => {
if let Some(built) = Self::build_array(tensor, &mut env) {
let is_input = model.inputs.iter().any(|name| *name == _name);
if let Some(env_entry) = env.get(built.name()) {
let dependent_on_state = env_entry.is_state_dependent();
let dependent_on_time = env_entry.is_time_dependent();
let dependent_on_dudt = env_entry.is_dstatedt_dependent();
if is_input {
// inputs must be constants
if dependent_on_time || dependent_on_state {
Expand All @@ -347,10 +331,14 @@ impl<'s> DiscreteModel<'s> {
ret.inputs.push(built);
} else if !dependent_on_time {
ret.time_indep_defns.push(built);
} else if dependent_on_time && !dependent_on_state {
} else if dependent_on_time && !dependent_on_state && !dependent_on_dudt {
ret.time_dep_defns.push(built);
} else {
} else if dependent_on_state {
ret.state_dep_defns.push(built);
} else if dependent_on_dudt {
ret.dstate_dep_defns.push(built);
} else {
panic!("all the cases should be covered")
}
}
}
Expand Down Expand Up @@ -635,6 +623,7 @@ impl<'s> DiscreteModel<'s> {
let rhs = Tensor::new_no_layout("F", f_elmts, vec!['i']);
let name = model.name;
let stop = None;
let dstate_dep_defns = Vec::new();
DiscreteModel {
name,
lhs,
Expand All @@ -646,6 +635,7 @@ impl<'s> DiscreteModel<'s> {
time_indep_defns,
time_dep_defns,
state_dep_defns,
dstate_dep_defns,
is_algebraic,
stop,
}
Expand All @@ -664,6 +654,10 @@ impl<'s> DiscreteModel<'s> {
pub fn state_dep_defns(&self) -> &[Tensor] {
self.state_dep_defns.as_ref()
}

pub fn dstate_dep_defns(&self) -> &[Tensor] {
self.dstate_dep_defns.as_ref()
}

pub fn state(&self) -> &Tensor<'s> {
&self.state
Expand Down Expand Up @@ -749,6 +743,54 @@ mod tests {
assert_eq!(discrete.out.elmts()[2].expr().to_string(), "z");
println!("{}", discrete);
}

#[test]
fn tensor_classification() {
let text = "
in = [r, k, ]
r { 1, }
k { 1, }
z { 2 * r }
g { 2 * t }
u_i {
y = 1,
z = 0,
}
u2_i {
2 * y,
2 * z,
}
dudt_i {
dydt = 0,
dzdt = 0,
}
dudt2_i {
2 * dydt,
0,
}
M_i {
dydt,
0,
}
F_i {
(r * y) * (1 - (y / k)),
(2 * y) - z,
}
out_i {
y,
t,
z,
}
";
let model = parse_ds_string(text).unwrap();
let model = DiscreteModel::build("$name", &model).unwrap();
assert_eq!(model.inputs().iter().map(|t| t.name()).collect::<Vec<_>>(), ["r", "k"]);
assert_eq!(model.time_indep_defns().iter().map(|t| t.name()).collect::<Vec<_>>(), ["z"]);
assert_eq!(model.time_dep_defns().iter().map(|t| t.name()).collect::<Vec<_>>(), ["g"]);
assert_eq!(model.state_dep_defns().iter().map(|t| t.name()).collect::<Vec<_>>(), ["u2"]);
assert_eq!(model.dstate_dep_defns().iter().map(|t| t.name()).collect::<Vec<_>>(), ["dudt2"]);
assert_eq!(model.inputs().iter().map(|t| t.name()).collect::<Vec<_>>(), ["r", "k"]);
}

macro_rules! count {
() => (0usize);
Expand Down Expand Up @@ -924,7 +966,7 @@ mod tests {
1,
}
" ["F and u must have the same shape",],
error_f_dep_on_dudt: "
error_dep_on_dudt: "
u_i {
y = 1,
}
Expand All @@ -934,10 +976,13 @@ mod tests {
F_i {
dydt,
}
stop_i {
dydt,
}
out_i {
y,
dydt,
}
" ["G and u must have the same shape",],
" ["F must not be dependent on dudt", "stop must not be dependent on dudt", "out must not be dependent on dudt",],
error_m_dep_on_u: "
u_i {
y = 1,
Expand All @@ -957,7 +1002,7 @@ mod tests {
out_i {
y,
}
" ["G and u must have the same shape",],
" ["M must not be dependent on u",],


);
Expand Down
Loading

0 comments on commit 5fe7a47

Please sign in to comment.