Skip to content

Commit

Permalink
bug: turn off atomic add if not threaded (#49)
Browse files Browse the repository at this point in the history
* bug: turn off atomic add if not threaded

* remove unneeded test
  • Loading branch information
martinjrobins authored Jan 29, 2025
1 parent 0ea8e01 commit 6d55ab3
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 60 deletions.
59 changes: 0 additions & 59 deletions src/execution/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1702,65 +1702,6 @@ mod tests {
big_state_tridiag2: "b_ij { (0..100, 0..100): p + 2.0, (0..99, 1..100): 2.0, (1..100, 0..99): 1.0, (0, 99): 1.0, (99, 0): 2.0 } r_i { b_ij * u_j }" expect "r" vec![6.; 100]; vec![7.; 100]; vec![700.] ; vec![1.; 100]; vec![100.],
}

#[cfg(feature = "llvm")]
#[test]
fn test_bad_big_state_expr() {
let full_text = "
in = [p]
p { 1 }
u_i {
(0:50): x = p,
(50:100): y = p,
}
r_i { x_i }
F_i { u_i }";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("test_bad_big_state_expr", &model).unwrap();
let compiler = Compiler::<crate::LlvmModule>::from_discrete_model(
&discrete_model,
CompilerMode::MultiThreaded(None),
)
.unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![1.];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice());
let mut u0 = vec![0.; 100];
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
let mut res = vec![0.; 100];
compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice());
let mut ddata = compiler.get_new_data();
let dr = compiler
.get_tensor_data_mut("r", ddata.as_mut_slice())
.unwrap();
let mut dinputs = vec![0.; 1];
dr.fill(1.);
let mut dres = vec![0.; 100];
let mut du0 = vec![0.; 100];
compiler.rhs_rgrad(
0.,
u0.as_slice(),
du0.as_mut_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
res.as_slice(),
dres.as_mut_slice(),
);
assert_relative_eq!(du0[0..50], vec![1.; 50].as_slice());
compiler.set_u0_rgrad(
u0.as_mut_slice(),
du0.as_mut_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
);
compiler.set_inputs_rgrad(
inputs.as_slice(),
dinputs.as_mut_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
);
assert_relative_eq!(dinputs.as_slice(), vec![50.].as_slice());
}

#[test]
fn test_repeated_grad_cranelift() {
test_repeated_grad_common::<CraneliftModule>();
Expand Down
2 changes: 1 addition & 1 deletion src/execution/llvm/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3303,7 +3303,7 @@ impl<'ctx> CodeGen<'ctx> {
args_uncacheable.as_mut_ptr(),
args_uncacheable.len(),
std::ptr::null_mut(),
1,
if self.threaded { 1 } else { 0 },
)
};
if self.threaded {
Expand Down

0 comments on commit 6d55ab3

Please sign in to comment.