@@ -22,6 +22,7 @@ def generate_session_displacement_recordings(
2222 recording_shifts = ((0 , 0 ), (0 , 25 ), (0 , 50 )),
2323 non_rigid_gradient = None ,
2424 recording_amplitude_scalings = None ,
25+ shift_units_outside_probe = False ,
2526 sampling_frequency = 30000.0 ,
2627 probe_name = "Neuropixel-128" ,
2728 generate_probe_kwargs = None ,
@@ -87,8 +88,16 @@ def generate_session_displacement_recordings(
8788 "scalings" - a list of numpy arrays, one for each recording, with
8889 each entry an array of length num_units holding the unit scalings.
8990 e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
91+ shift_units_outside_probe : bool
92+ By default (`False`), when units are shifted across sessions, new units are
93+ not introduced into the recording (e.g. the region in which units
94+ have been shifted out of is left at baseline level). In reality,
95+ when the probe shifts new units from outside the original recorded
96+ region are shifted into the recording. When `True`, new units
97+ are shifted into the generated recording.
9098 generate_sorting_kwargs : dict
9199 Only `firing_rates` and `refractory_period_ms` are expected if passed.
100+
92101 All other parameters are used as in from `generate_drifting_recording()`.
93102
94103 Returns
@@ -105,7 +114,6 @@ def generate_session_displacement_recordings(
105114 A list (length num records) of (num_units, num_samples, num_channels)
106115 arrays of templates that have been shifted.
107116
108-
109117 Notes
110118 -----
111119 It is important to consider what unit properties are maintained
@@ -141,12 +149,28 @@ def generate_session_displacement_recordings(
141149
142150 # Fix generate template kwargs, so they are the same for every created recording.
143151 # Also fix unit firing rates across recordings.
144- generate_templates_kwargs = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed )
152+ fixed_generate_templates_kwargs = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed )
145153
146154 fixed_firing_rates = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed )
147- generate_sorting_kwargs ["firing_rates" ] = fixed_firing_rates
155+ fixed_generate_sorting_kwargs = copy .deepcopy (generate_sorting_kwargs )
156+ fixed_generate_sorting_kwargs ["firing_rates" ] = fixed_firing_rates
157+
158+ if shift_units_outside_probe :
159+ num_units , unit_locations , fixed_generate_templates_kwargs , fixed_generate_sorting_kwargs = (
160+ _update_kwargs_for_extended_units (
161+ num_units ,
162+ channel_locations ,
163+ unit_locations ,
164+ generate_unit_locations_kwargs ,
165+ generate_templates_kwargs ,
166+ generate_sorting_kwargs ,
167+ fixed_generate_templates_kwargs ,
168+ fixed_generate_sorting_kwargs ,
169+ seed ,
170+ )
171+ )
148172
149- # Start looping over parameters, creating recordings shifted
173+ # Start looping over parameters, creating recordings shifted
150174 # and scaled as required
151175 extra_outputs_dict = {
152176 "unit_locations" : [],
@@ -174,7 +198,7 @@ def generate_session_displacement_recordings(
174198 num_units = num_units ,
175199 sampling_frequency = sampling_frequency ,
176200 durations = [duration ],
177- ** generate_sorting_kwargs ,
201+ ** fixed_generate_sorting_kwargs ,
178202 extra_outputs = True ,
179203 seed = seed ,
180204 )
@@ -195,7 +219,7 @@ def generate_session_displacement_recordings(
195219 unit_locations_moved ,
196220 sampling_frequency = sampling_frequency ,
197221 seed = seed ,
198- ** generate_templates_kwargs ,
222+ ** fixed_generate_templates_kwargs ,
199223 )
200224
201225 if recording_amplitude_scalings is not None :
@@ -210,7 +234,7 @@ def generate_session_displacement_recordings(
210234
211235 # Bring it all together in a `InjectTemplatesRecording` and
212236 # propagate all relevant metadata to the recording.
213- ms_before = generate_templates_kwargs ["ms_before" ]
237+ ms_before = fixed_generate_templates_kwargs ["ms_before" ]
214238 nbefore = int (sampling_frequency * ms_before / 1000.0 )
215239
216240 recording = InjectTemplatesRecording (
@@ -388,3 +412,89 @@ def _check_generate_session_displacement_arguments(
388412 "The entry for each recording in `recording_amplitude_scalings` "
389413 "must have the same length as the number of units."
390414 )
415+
416+
417+ def _update_kwargs_for_extended_units (
418+ num_units ,
419+ channel_locations ,
420+ unit_locations ,
421+ generate_unit_locations_kwargs ,
422+ generate_templates_kwargs ,
423+ generate_sorting_kwargs ,
424+ fixed_generate_templates_kwargs ,
425+ fixed_generate_sorting_kwargs ,
426+ seed ,
427+ ):
428+ """
429+ In a real world situation, if the probe moves up / down
430+ not only will previously recorded units be shifted, but
431+ new units will be introduced into the recording.
432+
433+ This function extends the default num units, unit locations,
434+ and template / sorting kwargs to extend the unit of units
435+ one probe's height (y dimension) above and below the probe.
436+ Now, when the unit locations are shifted, new units will be
437+ introduced into the recording (from below or above).
438+
439+ It is important that the unit kwargs for the units are kept the
440+ same across runs when seeded (i.e. whether `shift_units_outside_probe`
441+ is `True` or `False`). To acheive this, the fixed unit kwargs
442+ are extended with new units located above and below these fixed
443+ units. The seeds are shifted slightly, so the introduced
444+ units do not duplicate the existing units.
445+
446+ """
447+ seed_top = seed + 1 if seed is not None else None
448+ seed_bottom = seed - 1 if seed is not None else None
449+
450+ # Set unit locations above and below the probe and extend
451+ # the `unit_locations` array.
452+ channel_locations_extend_top = channel_locations .copy ()
453+ channel_locations_extend_top [:, 1 ] -= np .max (channel_locations [:, 1 ])
454+
455+ extend_top_locations = generate_unit_locations (
456+ num_units ,
457+ channel_locations_extend_top ,
458+ seed = seed_top ,
459+ ** generate_unit_locations_kwargs ,
460+ )
461+
462+ channel_locations_extend_bottom = channel_locations .copy ()
463+ channel_locations_extend_bottom [:, 1 ] += np .max (channel_locations [:, 1 ])
464+
465+ extend_bottom_locations = generate_unit_locations (
466+ num_units ,
467+ channel_locations_extend_bottom ,
468+ seed = seed_bottom ,
469+ ** generate_unit_locations_kwargs ,
470+ )
471+
472+ unit_locations = np .r_ [extend_bottom_locations , unit_locations , extend_top_locations ]
473+
474+ # For the new units located above and below the probe, generate a set of
475+ # firing rates and template kwargs.
476+
477+ # Extend the template kwargs
478+ template_kwargs_top = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed_top )
479+ template_kwargs_bottom = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed_bottom )
480+
481+ for key in fixed_generate_templates_kwargs ["unit_params" ].keys ():
482+ fixed_generate_templates_kwargs ["unit_params" ][key ] = np .r_ [
483+ template_kwargs_top ["unit_params" ][key ],
484+ fixed_generate_templates_kwargs ["unit_params" ][key ],
485+ template_kwargs_bottom ["unit_params" ][key ],
486+ ]
487+
488+ # Extend the firing rates
489+ firing_rates_top = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed_top )
490+ firing_rates_bottom = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed_bottom )
491+
492+ fixed_generate_sorting_kwargs ["firing_rates" ] = np .r_ [
493+ firing_rates_top , fixed_generate_sorting_kwargs ["firing_rates" ], firing_rates_bottom
494+ ]
495+
496+ # Update the number of units (3x as a
497+ # new set above and below the existing units)
498+ num_units *= 3
499+
500+ return num_units , unit_locations , fixed_generate_templates_kwargs , fixed_generate_sorting_kwargs
0 commit comments