11from __future__ import annotations
22
3+ import copy
4+
35import numpy as np
46import json
57from pathlib import Path
@@ -45,6 +47,7 @@ def correct_inter_session_displacement(
4547 from spikeinterface .sortingcomponents .motion_estimation import estimate_motion
4648 from spikeinterface .sortingcomponents .motion_interpolation import InterpolateMotionRecording
4749 from spikeinterface .core .node_pipeline import ExtractDenseWaveforms , run_node_pipeline
50+ from spikeinterface .sortingcomponents .motion_utils import Motion
4851
4952 # TODO: do not accept multi-segment recordings.
5053 # TODO: check all recordings have the same probe dimensions!
@@ -128,6 +131,7 @@ def correct_inter_session_displacement(
128131 spatial_bin_edges = None ,
129132 )
130133 else :
134+ assert NotImplementedError
131135 motion_histogram = make_3d_motion_histograms (
132136 recording ,
133137 peaks ,
@@ -141,6 +145,9 @@ def correct_inter_session_displacement(
141145 spatial_bin_edges = None ,
142146 )
143147 motion_histogram_list .append (motion_histogram [0 ].squeeze ())
148+ # store bin edges
149+ temporal_bin_edges = motion_histogram [1 ]
150+ spatial_bin_edges = motion_histogram [2 ]
144151
145152 # Do some checks on temporal and spatial bin edges that they are all the same?
146153 # TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
@@ -180,8 +187,42 @@ def correct_inter_session_displacement(
180187 # TODO: do multi-session optimisation
181188
182189 # Handle drift
190+ interpolate_motion_kwargs = {}
183191
184192 # TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object!
185193 # Will need the y-axis bins for this
186- motion = Motion ([motion_array ], [temporal_bins ], non_rigid_window_centers , direction = direction )
187- recording_corrected = InterpolateMotionRecording (recording , motion , ** interpolate_motion_kwargs )
194+ all_recording_corrected = []
195+ all_motion_info = []
196+ for i , recording in enumerate (recordings_list ):
197+
198+ # TODO: direct copy, use 'get_window' from motion machinery
199+ bin_centers = spatial_bin_edges [:- 1 ] + bin_um / 2.0
200+ n = bin_centers .size
201+ non_rigid_windows = [np .ones (n , dtype = "float64" )]
202+ middle = (spatial_bin_edges [0 ] + spatial_bin_edges [- 1 ]) / 2.0
203+ non_rigid_window_centers = np .array ([middle ])
204+
205+ motion_array = shifts [i ] # TODO: this is the rigid case!
206+ temporal_bins = 0.5 * (temporal_bin_edges [1 :] + temporal_bin_edges [:- 1 ])
207+ motion = Motion (
208+ [np .atleast_2d (motion_array )], [temporal_bins ], non_rigid_window_centers , direction = "y"
209+ ) # will be same for all except for shifts
210+ all_motion_info .append (motion ) # not certain on this
211+
212+ if isinstance (recording , InterpolateMotionRecording ):
213+ raise NotImplementedError
214+ recording_corrected = copy .deepcopy (recording )
215+ # TODO: add interpolation to the existing one.
216+ # Not if inter-session motion correction already exists, but further
217+ # up the preprocessing chain, it will NOT be added and interpolation
218+ # will occur twice. Throw a warning here!
219+ else :
220+ recording_corrected = InterpolateMotionRecording (recording , motion , ** interpolate_motion_kwargs )
221+ all_recording_corrected .append (recording_corrected )
222+
223+ displacement_info = {
224+ "all_motion_info" : all_motion_info ,
225+ "all_motion_histograms" : motion_histogram_list , # TODO: naming
226+ "all_shifts" : shifts ,
227+ }
228+ return all_recording_corrected , displacement_info # TODO: output more stuff later e.g. the Motion object
0 commit comments