diff --git a/mesmerize_core/caiman_extensions/mcorr.py b/mesmerize_core/caiman_extensions/mcorr.py index 9cd5fb7..0699d98 100644 --- a/mesmerize_core/caiman_extensions/mcorr.py +++ b/mesmerize_core/caiman_extensions/mcorr.py @@ -5,7 +5,6 @@ from caiman import load_memmap from ._utils import validate -from typing import * @pd.api.extensions.register_series_accessor("mcorr") @@ -92,9 +91,7 @@ def get_output(self, mode: str = "r") -> np.ndarray: return mc_movie @validate("mcorr") - def get_shifts( - self, pw_rigid: bool = False - ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + def get_shifts(self, pw_rigid) -> list[np.ndarray]: """ Gets file path to shifts array (.npy file) for item, processes shifts array into a list of x and y shifts based on whether rigid or nonrigid @@ -107,26 +104,16 @@ def get_shifts( False = Rigid Returns: -------- - List of Processed X and Y shifts arrays + List of Processed X and Y [and Z] shifts arrays + - For rigid correction, each element is a vector of length n_frames + - For pw_rigid correction, each element is an n_frames x n_patches matrix """ path = self._series.paths.resolve(self._series["outputs"]["shifts"]) shifts = np.load(str(path)) if pw_rigid: - n_pts = shifts.shape[1] - n_lines = shifts.shape[2] - xs = [np.linspace(0, n_pts, n_pts)] - ys = [] - - for i in range(shifts.shape[0]): - for j in range(n_lines): - ys.append(shifts[i, :, j]) + shifts_by_dim = list(shifts) # dims-length list of n_frames x n_patches matrices else: - n_pts = shifts.shape[0] - n_lines = shifts.shape[1] - xs = [np.linspace(0, n_pts, n_pts)] - ys = [] - - for i in range(n_lines): - ys.append(shifts[:, i]) - return xs, ys + shifts_by_dim = list(shifts.T) # dims-length list of n_frames-length vectors + + return shifts_by_dim diff --git a/tests/test_core.py b/tests/test_core.py index 163189e..878a43a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -44,7 +44,7 @@ def _download_ground_truths(): print(f"Downloading ground truths") - url = f"https://zenodo.org/record/6828096/files/ground_truths.zip" + url = f"https://zenodo.org/record/13732996/files/ground_truths.zip" # basically from https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests/37573701 response = requests.get(url, stream=True) @@ -252,6 +252,15 @@ def test_mcorr(): ) ) + # test to check shifts output path + assert ( + batch_dir.joinpath(df.iloc[-1]["outputs"]["shifts"]) + == df.paths.resolve(df.iloc[-1]["outputs"]["shifts"]) + == batch_dir.joinpath( + str(df.iloc[-1]["uuid"]), f'{df.iloc[-1]["uuid"]}_shifts.npy' + ) + ) + # test to check mean-projection output path assert ( batch_dir.joinpath(df.iloc[-1]["outputs"]["mean-projection-path"]) @@ -303,6 +312,15 @@ def test_mcorr(): ) numpy.testing.assert_array_equal(mcorr_output, mcorr_output_actual) + + # test to check mcorr get_shifts() + mcorr_shifts = df.iloc[-1].mcorr.get_shifts(pw_rigid=test_params[algo]["main"]["pw_rigid"]) + mcorr_shifts_actual = numpy.load( + ground_truths_dir.joinpath("mcorr", "mcorr_shifts.npy") + ) + numpy.testing.assert_array_equal(mcorr_shifts, mcorr_shifts_actual) + + # test to check caiman get_input_movie_path() assert df.iloc[-1].caiman.get_input_movie_path() == get_full_raw_data_path( df.iloc[0]["input_movie_path"]