@@ -141,10 +141,27 @@ def generate_session_displacement_recordings(
141141
142142 # Fix generate template kwargs, so they are the same for every created recording.
143143 # Also fix unit firing rates across recordings.
144- generate_templates_kwargs = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed )
144+ fixed_generate_templates_kwargs = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed )
145145
146146 fixed_firing_rates = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed )
147- generate_sorting_kwargs ["firing_rates" ] = fixed_firing_rates
147+ fixed_generate_sorting_kwargs = copy .deepcopy (generate_sorting_kwargs )
148+ fixed_generate_sorting_kwargs ["firing_rates" ] = fixed_firing_rates
149+
150+ extend = True
151+ if extend :
152+ num_units , unit_locations , fixed_generate_templates_kwargs , fixed_generate_sorting_kwargs = (
153+ _update_kwargs_for_extended_units (
154+ num_units ,
155+ channel_locations ,
156+ unit_locations ,
157+ generate_unit_locations_kwargs ,
158+ generate_templates_kwargs ,
159+ generate_sorting_kwargs ,
160+ fixed_generate_templates_kwargs ,
161+ fixed_generate_sorting_kwargs ,
162+ seed ,
163+ )
164+ )
148165
149166 # Start looping over parameters, creating recordings shifted
150167 # and scaled as required
@@ -174,7 +191,7 @@ def generate_session_displacement_recordings(
174191 num_units = num_units ,
175192 sampling_frequency = sampling_frequency ,
176193 durations = [duration ],
177- ** generate_sorting_kwargs ,
194+ ** fixed_generate_sorting_kwargs ,
178195 extra_outputs = True ,
179196 seed = seed ,
180197 )
@@ -195,7 +212,7 @@ def generate_session_displacement_recordings(
195212 unit_locations_moved ,
196213 sampling_frequency = sampling_frequency ,
197214 seed = seed ,
198- ** generate_templates_kwargs ,
215+ ** fixed_generate_templates_kwargs ,
199216 )
200217
201218 if recording_amplitude_scalings is not None :
@@ -210,7 +227,7 @@ def generate_session_displacement_recordings(
210227
211228 # Bring it all together in a `InjectTemplatesRecording` and
212229 # propagate all relevant metadata to the recording.
213- ms_before = generate_templates_kwargs ["ms_before" ]
230+ ms_before = fixed_generate_templates_kwargs ["ms_before" ]
214231 nbefore = int (sampling_frequency * ms_before / 1000.0 )
215232
216233 recording = InjectTemplatesRecording (
@@ -389,3 +406,66 @@ def _check_generate_session_displacement_arguments(
389406 "The entry for each recording in `recording_amplitude_scalings` "
390407 "must have the same length as the number of units."
391408 )
409+
410+
411+ def _update_kwargs_for_extended_units (
412+ num_units ,
413+ channel_locations ,
414+ unit_locations ,
415+ generate_unit_locations_kwargs ,
416+ generate_templates_kwargs ,
417+ generate_sorting_kwargs ,
418+ fixed_generate_templates_kwargs ,
419+ fixed_generate_sorting_kwargs ,
420+ seed ,
421+ ):
422+
423+ seed_top = seed + 1 if seed is not None else None
424+ seed_bottom = seed - 1 if seed is not None else None
425+
426+ # Set unit locations above and below the probe
427+ channel_locations_extend_top = channel_locations .copy ()
428+ channel_locations_extend_top [:, 1 ] -= np .max (channel_locations [:, 1 ])
429+
430+ extend_top_locations = generate_unit_locations (
431+ num_units ,
432+ channel_locations_extend_top ,
433+ seed = seed_top , # explain
434+ ** generate_unit_locations_kwargs ,
435+ )
436+
437+ channel_locations_extend_bottom = channel_locations .copy ()
438+ channel_locations_extend_bottom [:, 1 ] += np .max (channel_locations [:, 1 ])
439+
440+ extend_bottom_locations = generate_unit_locations (
441+ num_units ,
442+ channel_locations_extend_bottom ,
443+ seed = seed_bottom ,
444+ ** generate_unit_locations_kwargs ,
445+ )
446+
447+ unit_locations = np .r_ [extend_bottom_locations , unit_locations , extend_top_locations ]
448+
449+ # Set firing rates and params for these units
450+
451+ # Do the same here
452+ template_kwargs_top = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed_top )
453+ template_kwargs_bottom = fix_generate_templates_kwargs (generate_templates_kwargs , num_units , seed_bottom )
454+
455+ for key in fixed_generate_templates_kwargs ["unit_params" ].keys ():
456+ fixed_generate_templates_kwargs ["unit_params" ][key ] = np .r_ [
457+ template_kwargs_top ["unit_params" ][key ],
458+ fixed_generate_templates_kwargs ["unit_params" ][key ],
459+ template_kwargs_bottom ["unit_params" ][key ],
460+ ]
461+
462+ firing_rates_top = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed_top )
463+ firing_rates_bottom = _ensure_firing_rates (generate_sorting_kwargs ["firing_rates" ], num_units , seed_bottom )
464+
465+ fixed_generate_sorting_kwargs ["firing_rates" ] = np .r_ [
466+ firing_rates_top , fixed_generate_sorting_kwargs ["firing_rates" ], firing_rates_bottom
467+ ]
468+
469+ num_units *= 3
470+
471+ return num_units , unit_locations , fixed_generate_templates_kwargs , fixed_generate_sorting_kwargs
0 commit comments