77
88from scipy import stats
99
10- # TODO: spike_times -> spike_indexes
10+ # TODO: spike_times -> spike_indices
1111"""
1212Notes
1313-----
1414- not everything is used for current purposes
1515- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
1616"""
1717
18+ ########################################################################################################################
19+ # Get Spike Data
20+ ########################################################################################################################
21+
1822
1923def compute_spike_amplitude_and_depth (
2024 sorter_output : str | Path ,
2125 localised_spikes_only ,
2226 exclude_noise ,
2327 gain : float | None = None ,
24- localised_spikes_channel_cutoff : int = None , # TODO
28+ localised_spikes_channel_cutoff : int = None ,
2529) -> tuple [np .ndarray , ...]:
2630 """
2731 Compute the amplitude and depth of all detected spikes from the kilosort output.
@@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth(
4650
4751 Returns
4852 -------
49- spike_indexes : np.ndarray
50- (num_spikes,) array of spike indexes .
53+ spike_indices : np.ndarray
54+ (num_spikes,) array of spike indices .
5155 spike_amplitudes : np.ndarray
5256 (num_spikes,) array of corresponding spike amplitudes.
5357 spike_depths : np.ndarray
@@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth(
6670 if isinstance (sorter_output , str ):
6771 sorter_output = Path (sorter_output )
6872
69- params = _load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
73+ params = load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
7074
7175 if localised_spikes_only :
7276 localised_templates = []
@@ -81,10 +85,52 @@ def compute_spike_amplitude_and_depth(
8185
8286 localised_template_by_spike = np .isin (params ["spike_templates" ], localised_templates )
8387
84- _strip_spikes (params , localised_template_by_spike )
88+ params ["spike_templates" ] = params ["spike_templates" ][localised_template_by_spike ]
89+ params ["spike_indices" ] = params ["spike_indices" ][localised_template_by_spike ]
90+ params ["spike_clusters" ] = params ["spike_clusters" ][localised_template_by_spike ]
91+ params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][localised_template_by_spike ]
92+ params ["pc_features" ] = params ["pc_features" ][localised_template_by_spike ]
93+
94+ spike_locations , spike_max_sites = _get_locations_from_pc_features (params )
95+
96+ # Amplitude is calculated for each spike as the template amplitude
97+ # multiplied by the `template_scaling_amplitudes`.
98+ template_amplitudes_unscaled , * _ = get_unwhite_template_info (
99+ params ["templates" ],
100+ params ["whitening_matrix_inv" ],
101+ params ["channel_positions" ],
102+ )
103+ spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
104+
105+ if gain is not None :
106+ spike_amplitudes *= gain
85107
108+ compute_template_amplitudes_from_spikes (params ["templates" ], params ["spike_templates" ], spike_amplitudes )
109+
110+ if localised_spikes_only :
111+ # Interpolate the channel ids to location.
112+ # Remove spikes > 5 um from average position
113+ # Above we already removed non-localized templates, but that on its own is insufficient.
114+ # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
115+ # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
116+ # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
117+ # 3) just use depth. Probably go for that. check with others.
118+ spike_depths = spike_locations [:, 1 ]
119+ b = stats .linregress (spike_depths , spike_max_sites ).slope
120+ i = np .abs (spike_max_sites - b * spike_depths ) <= 5
121+
122+ params ["spike_indices" ] = params ["spike_indices" ][i ]
123+ spike_amplitudes = spike_amplitudes [i ]
124+ spike_locations = spike_locations [i , :]
125+ spike_max_sites = spike_max_sites [i ]
126+
127+ return params ["spike_indices" ], spike_amplitudes , spike_locations , spike_max_sites
128+
129+
130+ def _get_locations_from_pc_features (params ):
131+ """ """
86132 # Compute spike depths
87- pc_features = params ["pc_features" ][:, 0 , :] # Do this compute
133+ pc_features = params ["pc_features" ][:, 0 , :]
88134 pc_features [pc_features < 0 ] = 0
89135
90136 # Some spikes do not load at all onto the first PC. To avoid biasing the
@@ -109,58 +155,28 @@ def compute_spike_amplitude_and_depth(
109155 "to extend this code section to handle more components."
110156 )
111157
112- # Get the channel indexes corresponding to the 32 channels from the PC.
158+ # Get the channel indices corresponding to the 32 channels from the PC.
113159 spike_features_indices = params ["pc_features_indices" ][params ["spike_templates" ], :]
114160
115161 # Compute the spike locations as the center of mass of the PC scores
116162 spike_feature_coords = params ["channel_positions" ][spike_features_indices , :]
117- norm_weights = pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ] # TOOD: see why they use square
163+ norm_weights = (
164+ pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ]
165+ ) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
118166 spike_locations = spike_feature_coords * norm_weights [:, :, np .newaxis ]
119167 spike_locations = np .sum (spike_locations , axis = 1 )
120168
121169 # TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122- spike_sites = spike_features_indices [np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )]
170+ spike_max_sites = spike_features_indices [
171+ np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )
172+ ]
123173
124- # Amplitude is calculated for each spike as the template amplitude
125- # multiplied by the `template_scaling_amplitudes`.
126- template_amplitudes_unscaled , * _ = get_unwhite_template_info (
127- params ["templates" ],
128- params ["whitening_matrix_inv" ],
129- params ["channel_positions" ],
130- )
131- spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
132-
133- if gain is not None :
134- spike_amplitudes *= gain
135-
136- if localised_spikes_only :
137- # Interpolate the channel ids to location.
138- # Remove spikes > 5 um from average position
139- # Above we already removed non-localized templates, but that on its own is insufficient.
140- # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
141- # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
142- # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
143- # 3) just use depth. Probably go for that. check with others.
144- spike_depths = spike_locations [:, 1 ]
145- b = stats .linregress (spike_depths , spike_sites ).slope
146- i = np .abs (spike_sites - b * spike_depths ) <= 5 # TODO: need to expose this
174+ return spike_locations , spike_max_sites
147175
148- params ["spike_indexes" ] = params ["spike_indexes" ][i ]
149- spike_amplitudes = spike_amplitudes [i ]
150- spike_locations = spike_locations [i , :]
151176
152- return params ["spike_indexes" ], spike_amplitudes , spike_locations , spike_sites
153-
154-
155- def _strip_spikes_in_place (params , indices ):
156- """ """
157- params ["spike_templates" ] = params ["spike_templates" ][
158- indices
159- ] # TODO: make an function for this. because we do this a lot
160- params ["spike_indexes" ] = params ["spike_indexes" ][indices ]
161- params ["spike_clusters" ] = params ["spike_clusters" ][indices ]
162- params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][indices ]
163- params ["pc_features" ] = params ["pc_features" ][indices ] # TODO: be conciststetn! change indees to indices
177+ ########################################################################################################################
178+ # Get Template Data
179+ ########################################################################################################################
164180
165181
166182def get_unwhite_template_info (
@@ -213,7 +229,7 @@ def get_unwhite_template_info(
213229
214230 template_amplitudes_unscaled = np .max (template_amplitudes_per_channel , axis = 1 )
215231
216- # Zero any small channel amplitudes
232+ # Zero any small channel amplitudes TODO: removed this.
217233 # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218234 # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
219235
@@ -253,9 +269,11 @@ def get_unwhite_template_info(
253269 )
254270
255271
256- def compute_template_amplitudes_from_spikes ():
257- # Take the average of all spike amplitudes to get actual template amplitudes
258- # (since tempScalingAmps are equal mean for all templates)
272+ def compute_template_amplitudes_from_spikes (templates , spike_templates , spike_amplitudes ):
273+ """
274+ Take the average of all spike amplitudes to get actual template amplitudes
275+ (since tempScalingAmps are equal mean for all templates)
276+ """
259277 num_indices = templates .shape [0 ]
260278 sum_per_index = np .zeros (num_indices , dtype = np .float64 )
261279 np .add .at (sum_per_index , spike_templates , spike_amplitudes )
@@ -264,7 +282,12 @@ def compute_template_amplitudes_from_spikes():
264282 return template_amplitudes
265283
266284
267- def _load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
285+ ########################################################################################################################
286+ # Load Parameters from KS Directory
287+ ########################################################################################################################
288+
289+
290+ def load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
268291 """
269292 Loads the output of Kilosort into a `params` dict.
270293
@@ -300,7 +323,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
300323
301324 params = read_python (sorter_output / "params.py" )
302325
303- spike_indexes = np .load (sorter_output / "spike_times.npy" )
326+ spike_indices = np .load (sorter_output / "spike_times.npy" )
304327 spike_templates = np .load (sorter_output / "spike_templates.npy" )
305328
306329 if (clusters_path := sorter_output / "spike_clusters.csv" ).is_dir ():
@@ -328,7 +351,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
328351 noise_cluster_ids = cluster_ids [cluster_groups == 0 ]
329352 not_noise_clusters_by_spike = ~ np .isin (spike_clusters .ravel (), noise_cluster_ids )
330353
331- spike_indexes = spike_indexes [not_noise_clusters_by_spike ]
354+ spike_indices = spike_indices [not_noise_clusters_by_spike ]
332355 spike_templates = spike_templates [not_noise_clusters_by_spike ]
333356 temp_scaling_amplitudes = temp_scaling_amplitudes [not_noise_clusters_by_spike ]
334357
@@ -343,7 +366,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
343366 cluster_groups = 3 * np .ones (cluster_ids .size )
344367
345368 new_params = {
346- "spike_indexes " : spike_indexes .squeeze (),
369+ "spike_indices " : spike_indices .squeeze (),
347370 "spike_templates" : spike_templates .squeeze (),
348371 "spike_clusters" : spike_clusters .squeeze (),
349372 "pc_features" : pc_features ,
0 commit comments