Skip to content

Commit 4417d32

Browse files
ygerpre-commit-ci[bot]samuelgarcia
authored
Speedup template_similarity (#4211)
* Caching the norms * Use sparsity internaly while getting templates in analyzer * WIP * Patch for numba * syntax * Patch for numba * Patch for numba * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleaning * Cleaning * Speedup merging * Message --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Garcia Samuel <[email protected]>
1 parent f9f1dca commit 4417d32

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):
431431
if verbose:
432432
print("### Compute result", key, "###")
433433
benchmark = self.benchmarks[key]
434-
assert benchmark is not None
434+
assert benchmark is not None, f"Benchmkark for key {key} has not been run yet!"
435435
benchmark.compute_result(**result_params)
436436
benchmark.save_result(self.folder / "results" / self.key_to_str(key))
437437

src/spikeinterface/core/template_tools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_in
4444
raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates")
4545
else:
4646
raise ValueError("Input should be Templates or SortingAnalyzer")
47-
4847
return templates_array
4948

5049

src/spikeinterface/postprocessing/template_similarity.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,16 @@ def _merge_extension_data(
110110
n = all_new_unit_ids.size
111111
similarity = np.zeros((n, n), dtype=old_similarity.dtype)
112112

113+
local_mask = ~np.isin(all_new_unit_ids, new_unit_ids)
114+
sub_units_ids = all_new_unit_ids[local_mask]
115+
sub_units_inds = np.flatnonzero(local_mask)
116+
old_units_inds = self.sorting_analyzer.sorting.ids_to_indices(sub_units_ids)
117+
113118
# copy old similarity
114-
for unit_ind1, unit_id1 in enumerate(all_new_unit_ids):
115-
if unit_id1 not in new_unit_ids:
116-
old_ind1 = self.sorting_analyzer.sorting.id_to_index(unit_id1)
117-
for unit_ind2, unit_id2 in enumerate(all_new_unit_ids):
118-
if unit_id2 not in new_unit_ids:
119-
old_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2)
120-
s = self.data["similarity"][old_ind1, old_ind2]
121-
similarity[unit_ind1, unit_ind2] = s
122-
similarity[unit_ind1, unit_ind2] = s
119+
for old_ind1, unit_ind1 in zip(old_units_inds, sub_units_inds):
120+
s = self.data["similarity"][old_ind1, old_units_inds]
121+
similarity[unit_ind1, sub_units_inds] = s
122+
similarity[sub_units_inds, unit_ind1] = s
123123

124124
# insert new similarity both way
125125
for unit_ind, unit_id in enumerate(all_new_unit_ids):
@@ -319,9 +319,14 @@ def _compute_similarity_matrix_numba(
319319
sparsity_mask[i, :], other_sparsity_mask
320320
) # shape (other_num_templates, num_channels)
321321
elif support == "union":
322+
connected_mask = np.logical_and(sparsity_mask[i, :], other_sparsity_mask)
323+
not_connected_mask = np.sum(connected_mask, axis=1) == 0
322324
local_mask = np.logical_or(
323325
sparsity_mask[i, :], other_sparsity_mask
324326
) # shape (other_num_templates, num_channels)
327+
for local_i in np.flatnonzero(not_connected_mask):
328+
local_mask[local_i] = False
329+
325330
elif support == "dense":
326331
local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_)
327332

@@ -386,7 +391,11 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi
386391
if support == "intersection":
387392
mask = np.logical_and(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
388393
elif support == "union":
394+
connected_mask = np.logical_and(sparsity[template_index, :], other_sparsity)
395+
not_connected_mask = np.sum(connected_mask, axis=1) == 0
389396
mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
397+
for i in np.flatnonzero(not_connected_mask):
398+
mask[i] = False
390399
elif support == "dense":
391400
mask = np.ones(other_sparsity.shape, dtype=bool)
392401
return mask

0 commit comments

Comments
 (0)