Skip to content

Commit e7c22dc

Browse files
committed
Temp save, will need to remove many parts.
1 parent 81ff1c5 commit e7c22dc

File tree

7 files changed

+158
-11
lines changed

7 files changed

+158
-11
lines changed

debugging/alignment_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges,
2323
"""
2424
TODO: assumes 1-segment recording
2525
"""
26+
# TODO: this is weird, don't return spatial_bin_edges here... amybe assert..
2627
entire_session_hist, temporal_bin_edges, spatial_bin_edges = \
2728
make_2d_motion_histogram(
2829
recording,
@@ -35,6 +36,7 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges,
3536
hist_margin_um=None,
3637
spatial_bin_edges=spatial_bin_edges,
3738
)
39+
3840
entire_session_hist = entire_session_hist[0]
3941

4042
entire_session_hist /= recording.get_duration(segment_index=0)
@@ -290,7 +292,7 @@ def prep_recording(recording, plot=False):
290292
:param recording:
291293
:return:
292294
"""
293-
peaks = detect_peaks(recording, method="by_channel") # "locally_exclusive")
295+
peaks = detect_peaks(recording, method="locally_exclusive")
294296

295297
peak_locations = localize_peaks(recording, peaks,
296298
method="grid_convolution")

debugging/all_recordings.pickle

23.1 MB
Binary file not shown.

debugging/main.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
2+
import numpy as np
3+
import plotting
4+
import alignment_utils
5+
import matplotlib.pyplot as plt
6+
import pickle
7+
scalings = [np.ones(10), np.r_[np.zeros(3), np.ones(7)]]
8+
9+
SAVE = True
10+
11+
if SAVE:
12+
recordings_list, _ = generate_session_displacement_recordings(
13+
non_rigid_gradient=None,
14+
num_units=35,
15+
recording_durations=(100, 100),
16+
recording_shifts=(
17+
(0, 0), (0, 0),
18+
),
19+
recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings},
20+
seed=None,
21+
)
22+
23+
peaks_list = []
24+
peak_locations_list = []
25+
26+
for recording in recordings_list:
27+
peaks, peak_locations = alignment_utils.prep_recording(
28+
recording, plot=False,
29+
)
30+
peaks_list.append(peaks)
31+
peak_locations_list.append(peak_locations)
32+
33+
# something relatively easy, only 15 units
34+
with open('all_recordings.pickle', 'wb') as handle:
35+
pickle.dump((recordings_list, peaks_list, peak_locations_list),
36+
handle, protocol=pickle.HIGHEST_PROTOCOL)
37+
38+
with open('all_recordings.pickle', 'rb') as handle:
39+
recordings_list, peaks_list, peak_locations_list = pickle.load(handle)
40+
41+
bin_um = 2
42+
43+
# TODO: own function
44+
min_y = np.min([np.min(locs["y"]) for locs in peak_locations_list])
45+
max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list])
46+
47+
spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
48+
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges) # TODO: own function
49+
50+
session_histogram_list = []
51+
for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list):
52+
53+
hist, temp, spat = alignment_utils.get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges, log_scale=False)
54+
55+
session_histogram_list.append(
56+
hist
57+
)
58+
# TODO: need to get all outputs and check are same size
59+
plotting.SessionAlignmentWidget(
60+
recordings_list,
61+
peaks_list,
62+
peak_locations_list,
63+
session_histogram_list,
64+
histogram_spatial_bin_centers=spatial_bin_centers,
65+
)
66+
plt.show()

debugging/plotting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
143143
ax_top.set_title(f"Session {i + 1}")
144144
ax_top.set_xlabel(None)
145145

146-
plot = DriftRasterMapWidget(
146+
DriftRasterMapWidget(
147147
dp.peaks_list[i],
148148
dp.corrected_peak_locations_list[i],
149149
recording=dp.recordings_list[i],
@@ -163,7 +163,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
163163

164164
ax = fig.add_subplot(gs[num_rows, :])
165165

166-
plot = SessionAlignmentHistogramWidget(
166+
SessionAlignmentHistogramWidget(
167167
dp.session_histogram_list,
168168
dp.histogram_spatial_bin_centers,
169169
ax=ax,
@@ -242,6 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
242242
if isinstance(linewidths, int):
243243
linewidths = [linewidths] * num_histograms
244244

245+
# TODO: this leads to quite unexpected behaviours, figure something else out.
245246
if spatial_bin_centers is None:
246247
num_bins = dp.session_histogram_list[0].size
247248
spatial_bin_centers = [np.arange(num_bins)] * num_histograms

debugging/session_alignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def estimate_inter_session_displacement(
125125
max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list])
126126

127127
spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
128-
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges)
128+
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges) # TODO: own function
129129

130130
# Estimate an activity histogram per-session
131131
all_session_hists = [] # TODO: probably better as a dict

debugging/test_session_alignment.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646
# handle the case where the passed recordings are not motion correction recordings.
4747

4848

49-
# 1) get all commits and PRs in order. Work on the original PR
50-
# 2) investigate why the expected peaks do not drop when recording_amplitude_scalings (rename) is used
5149
# 3) think about and add new neurons that are introduced when shifted
5250

5351
# 4) add interpolation of the histograms prior to cross correlation

src/spikeinterface/generation/session_displacement_generator.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,27 @@ def generate_session_displacement_recordings(
141141

142142
# Fix generate template kwargs, so they are the same for every created recording.
143143
# Also fix unit firing rates across recordings.
144-
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)
144+
fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)
145145

146146
fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed)
147-
generate_sorting_kwargs["firing_rates"] = fixed_firing_rates
147+
fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs)
148+
fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates
149+
150+
extend = True
151+
if extend:
152+
num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = (
153+
_update_kwargs_for_extended_units(
154+
num_units,
155+
channel_locations,
156+
unit_locations,
157+
generate_unit_locations_kwargs,
158+
generate_templates_kwargs,
159+
generate_sorting_kwargs,
160+
fixed_generate_templates_kwargs,
161+
fixed_generate_sorting_kwargs,
162+
seed,
163+
)
164+
)
148165

149166
# Start looping over parameters, creating recordings shifted
150167
# and scaled as required
@@ -174,7 +191,7 @@ def generate_session_displacement_recordings(
174191
num_units=num_units,
175192
sampling_frequency=sampling_frequency,
176193
durations=[duration],
177-
**generate_sorting_kwargs,
194+
**fixed_generate_sorting_kwargs,
178195
extra_outputs=True,
179196
seed=seed,
180197
)
@@ -195,7 +212,7 @@ def generate_session_displacement_recordings(
195212
unit_locations_moved,
196213
sampling_frequency=sampling_frequency,
197214
seed=seed,
198-
**generate_templates_kwargs,
215+
**fixed_generate_templates_kwargs,
199216
)
200217

201218
if recording_amplitude_scalings is not None:
@@ -210,7 +227,7 @@ def generate_session_displacement_recordings(
210227

211228
# Bring it all together in a `InjectTemplatesRecording` and
212229
# propagate all relevant metadata to the recording.
213-
ms_before = generate_templates_kwargs["ms_before"]
230+
ms_before = fixed_generate_templates_kwargs["ms_before"]
214231
nbefore = int(sampling_frequency * ms_before / 1000.0)
215232

216233
recording = InjectTemplatesRecording(
@@ -389,3 +406,66 @@ def _check_generate_session_displacement_arguments(
389406
"The entry for each recording in `recording_amplitude_scalings` "
390407
"must have the same length as the number of units."
391408
)
409+
410+
411+
def _update_kwargs_for_extended_units(
412+
num_units,
413+
channel_locations,
414+
unit_locations,
415+
generate_unit_locations_kwargs,
416+
generate_templates_kwargs,
417+
generate_sorting_kwargs,
418+
fixed_generate_templates_kwargs,
419+
fixed_generate_sorting_kwargs,
420+
seed,
421+
):
422+
423+
seed_top = seed + 1 if seed is not None else None
424+
seed_bottom = seed - 1 if seed is not None else None
425+
426+
# Set unit locations above and below the probe
427+
channel_locations_extend_top = channel_locations.copy()
428+
channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1])
429+
430+
extend_top_locations = generate_unit_locations(
431+
num_units,
432+
channel_locations_extend_top,
433+
seed=seed_top, # explain
434+
**generate_unit_locations_kwargs,
435+
)
436+
437+
channel_locations_extend_bottom = channel_locations.copy()
438+
channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1])
439+
440+
extend_bottom_locations = generate_unit_locations(
441+
num_units,
442+
channel_locations_extend_bottom,
443+
seed=seed_bottom,
444+
**generate_unit_locations_kwargs,
445+
)
446+
447+
unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations]
448+
449+
# Set firing rates and params for these units
450+
451+
# Do the same here
452+
template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top)
453+
template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom)
454+
455+
for key in fixed_generate_templates_kwargs["unit_params"].keys():
456+
fixed_generate_templates_kwargs["unit_params"][key] = np.r_[
457+
template_kwargs_top["unit_params"][key],
458+
fixed_generate_templates_kwargs["unit_params"][key],
459+
template_kwargs_bottom["unit_params"][key],
460+
]
461+
462+
firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top)
463+
firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom)
464+
465+
fixed_generate_sorting_kwargs["firing_rates"] = np.r_[
466+
firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom
467+
]
468+
469+
num_units *= 3
470+
471+
return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs

0 commit comments

Comments
 (0)