Skip to content

Commit

Permalink
Re-add error handling for neuron creation
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Jan 20, 2024
1 parent 19261d7 commit 30b094f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 31 deletions.
7 changes: 4 additions & 3 deletions nblast-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl NblastArena {
#[wasm_bindgen(js_name = "addPoints")]
pub fn add_points(&mut self, flat_points: &[f64]) -> JsResult<usize> {
let points = flat_to_array3(flat_points);
let neuron = Neuron::new(points, self.k);
let neuron = Neuron::new(points, self.k).map_err(JsError::new)?;
Ok(self.arena.add_neuron(neuron))
}

Expand All @@ -106,7 +106,8 @@ impl NblastArena {
})
.collect();
let points = flat_to_array3(flat_points);
let neuron = Neuron::new_with_tangents_alphas(points, tangents_alphas);
let neuron =
Neuron::new_with_tangents_alphas(points, tangents_alphas).map_err(JsError::new)?;
Ok(self.arena.add_neuron(neuron))
}

Expand Down Expand Up @@ -165,7 +166,7 @@ impl NblastArena {
#[wasm_bindgen(js_name = "makeFlatTangentsAlphas")]
pub fn make_flat_tangents_alphas(flat_points: &[f64], k: usize) -> JsResult<Float64Array> {
let points = flat_to_array3(flat_points);
let neuron = Neuron::new(points, k);
let neuron = Neuron::new(points, k).map_err(JsError::new)?;
let out = Float64Array::new_with_length(neuron.len() as u32);
for (idx, val) in neuron
.tangents()
Expand Down
16 changes: 12 additions & 4 deletions nblast-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ impl ArenaWrapper {
.map(|r| [r[0], r[1], r[2]])
.collect(),
self.k,
);
)
.map_err(|e| PyErr::new::<PyValueError, _>(e))?;
Ok(self.arena.add_neuron(neuron))
}

Expand Down Expand Up @@ -112,7 +113,8 @@ impl ArenaWrapper {
.map(|r| [r[0], r[1], r[2]])
.collect(),
tangents_alphas,
);
)
.map_err(|e| PyErr::new::<PyValueError, _>(e))?;
Ok(self.arena.add_neuron(neuron))
}

Expand Down Expand Up @@ -355,13 +357,19 @@ fn make_neurons_many(
pool.install(|| {
points_list
.into_par_iter()
.map(|ps| Neuron::new(ps.into_iter().map(|p| [p[0], p[1], p[2]]).collect(), k))
.map(|ps| {
Neuron::new(ps.into_iter().map(|p| [p[0], p[1], p[2]]).collect(), k)
.expect("Invalid neuron")
})
.collect()
})
} else {
points_list
.into_iter()
.map(|ps| Neuron::new(ps.into_iter().map(|p| [p[0], p[1], p[2]]).collect(), k))
.map(|ps| {
Neuron::new(ps.into_iter().map(|p| [p[0], p[1], p[2]]).collect(), k)
.expect("invalid neuron")
})
.collect()
}
}
Expand Down
12 changes: 6 additions & 6 deletions nblast-rs/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,16 @@ fn bench_query_nabo(b: &mut Bencher) {

fn bench_query_kiddo(b: &mut Bencher) {
let score_fn = get_score_fn();
let query = KiddoTangentsAlphas::new(read_points(NAMES[0]), N_NEIGHBORS);
let target = KiddoTangentsAlphas::new(read_points(NAMES[1]), N_NEIGHBORS);
let query = KiddoTangentsAlphas::new(read_points(NAMES[0]), N_NEIGHBORS).unwrap();
let target = KiddoTangentsAlphas::new(read_points(NAMES[1]), N_NEIGHBORS).unwrap();

b.iter(|| query.query(&target, false, &score_fn))
}

fn bench_query_exact_kiddo(b: &mut Bencher) {
let score_fn = get_score_fn();
let query = ExactKiddoTangentsAlphas::new(read_points(NAMES[0]), N_NEIGHBORS);
let target = ExactKiddoTangentsAlphas::new(read_points(NAMES[1]), N_NEIGHBORS);
let query = ExactKiddoTangentsAlphas::new(read_points(NAMES[0]), N_NEIGHBORS).unwrap();
let target = ExactKiddoTangentsAlphas::new(read_points(NAMES[1]), N_NEIGHBORS).unwrap();

b.iter(|| query.query(&target, false, &score_fn))
}
Expand Down Expand Up @@ -408,7 +408,7 @@ fn bench_all_to_all_serial_kiddo(b: &mut Bencher) {
let mut idxs = Vec::new();
for name in NAMES.iter() {
let points = read_points(name);
idxs.push(arena.add_neuron(KiddoTangentsAlphas::new(points, N_NEIGHBORS)));
idxs.push(arena.add_neuron(KiddoTangentsAlphas::new(points, N_NEIGHBORS).unwrap()));
}

b.iter(|| arena.queries_targets(&idxs, &idxs, false, &None, None));
Expand All @@ -419,7 +419,7 @@ fn bench_all_to_all_serial_exact_kiddo(b: &mut Bencher) {
let mut idxs = Vec::new();
for name in NAMES.iter() {
let points = read_points(name);
idxs.push(arena.add_neuron(ExactKiddoTangentsAlphas::new(points, N_NEIGHBORS)));
idxs.push(arena.add_neuron(ExactKiddoTangentsAlphas::new(points, N_NEIGHBORS).unwrap()));
}

b.iter(|| arena.queries_targets(&idxs, &idxs, false, &None, None));
Expand Down
12 changes: 7 additions & 5 deletions nblast-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@
//! // Add some neurons built from points and a neighborhood size,
//! // returning their indices in the arena
//! let idx1 = arena.add_neuron(
//! Neuron::new(random_points(6, &mut rng), 5))
//! Neuron::new(random_points(6, &mut rng), 5).expect("cannot construct neuron")
//! );
//! let idx2 = arena.add_neuron(
//! Neuron::new(random_points(8, &mut rng), 5))
//! Neuron::new(random_points(8, &mut rng), 5).expect("cannot construct neuron")
//! );
//!
//! // get a raw score (not normalized by self-hit, no symmetry)
Expand Down Expand Up @@ -966,7 +966,7 @@ mod test {
#[test]
fn test_neuron() {
let (points, exp_tan, _exp_alpha) = tangent_data();
let tgt = Neuron::new(points, N_NEIGHBORS);
let tgt = Neuron::new(points, N_NEIGHBORS).unwrap();
assert!(equivalent_tangents(&tgt.tangents()[0], &exp_tan));
// tested from the python side
// assert_close(tgt.alphas()[0], exp_alpha);
Expand Down Expand Up @@ -1089,8 +1089,10 @@ mod test {
RangeTable::new_from_bins(vec![dist_thresholds, dot_thresholds], cells).unwrap(),
);

let query = Neuron::new(make_points(&[0., 0., 0.], &[1., 0., 0.], 10), N_NEIGHBORS);
let target = Neuron::new(make_points(&[0.5, 0., 0.], &[1.1, 0., 0.], 10), N_NEIGHBORS);
let query =
Neuron::new(make_points(&[0., 0., 0.], &[1., 0., 0.], 10), N_NEIGHBORS).unwrap();
let target =
Neuron::new(make_points(&[0.5, 0., 0.], &[1.1, 0., 0.], 10), N_NEIGHBORS).unwrap();

let mut arena = NblastArena::new(score_calc, false);
let q_idx = arena.add_neuron(query);
Expand Down
29 changes: 16 additions & 13 deletions nblast-rs/src/neurons/kiddo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ pub struct KiddoTangentsAlphas {
impl KiddoTangentsAlphas {
/// Calculate tangents from constructed R*-tree.
/// `k` is the number of points to calculate each tangent with.
pub fn new(points: Vec<Point3>, k: usize) -> Self {
pub fn new(points: Vec<Point3>, k: usize) -> Result<Self, &'static str> {
let tree: KdTree = points.as_slice().into();
if points.len() < k {
return Err("Not enough points to calculate neighborhood");
}
let points_tangents_alphas = points
.iter()
.map(|p| {
Expand All @@ -32,22 +35,25 @@ impl KiddoTangentsAlphas {
})
.collect();

Self {
Ok(Self {
tree,
points_tangents_alphas,
}
})
}

/// Use pre-calculated tangents.
pub fn new_with_tangents_alphas(
points: Vec<Point3>,
tangents_alphas: Vec<TangentAlpha>,
) -> Self {
) -> Result<Self, &'static str> {
if points.len() != tangents_alphas.len() {
return Err("Mismatch in points and tangents_alphas length");
}
let tree: KdTree = points.as_slice().into();
Self {
Ok(Self {
tree,
points_tangents_alphas: points.into_iter().zip(tangents_alphas).collect(),
}
})
}

fn nearest_match_dist_dot_inner(
Expand Down Expand Up @@ -159,19 +165,16 @@ pub struct ExactKiddoTangentsAlphas(KiddoTangentsAlphas);
impl ExactKiddoTangentsAlphas {
/// Calculate tangents from constructed R*-tree.
/// `k` is the number of points to calculate each tangent with.
pub fn new(points: Vec<Point3>, k: usize) -> Self {
Self(KiddoTangentsAlphas::new(points, k))
pub fn new(points: Vec<Point3>, k: usize) -> Result<Self, &'static str> {
KiddoTangentsAlphas::new(points, k).map(Self)
}

/// Use pre-calculated tangents.
pub fn new_with_tangents_alphas(
points: Vec<Point3>,
tangents_alphas: Vec<TangentAlpha>,
) -> Self {
Self(KiddoTangentsAlphas::new_with_tangents_alphas(
points,
tangents_alphas,
))
) -> Result<Self, &'static str> {
KiddoTangentsAlphas::new_with_tangents_alphas(points, tangents_alphas).map(Self)
}
}

Expand Down

0 comments on commit 30b094f

Please sign in to comment.