Skip to content

Commit

Permalink
Make get_shifts return value more useful; add test for shifts (#317)
Browse files Browse the repository at this point in the history
* Make get_shifts return value more useful; add test for shifts

* Fix type annotation for get_shifts
  • Loading branch information
ethanbb authored Sep 9, 2024
1 parent 5471878 commit 5d7c9b3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
29 changes: 8 additions & 21 deletions mesmerize_core/caiman_extensions/mcorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from caiman import load_memmap

from ._utils import validate
from typing import *


@pd.api.extensions.register_series_accessor("mcorr")
Expand Down Expand Up @@ -92,9 +91,7 @@ def get_output(self, mode: str = "r") -> np.ndarray:
return mc_movie

@validate("mcorr")
def get_shifts(
self, pw_rigid: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
def get_shifts(self, pw_rigid) -> list[np.ndarray]:
"""
Gets file path to shifts array (.npy file) for item, processes shifts array
into a list of x and y shifts based on whether rigid or nonrigid
Expand All @@ -107,26 +104,16 @@ def get_shifts(
False = Rigid
Returns:
--------
List of Processed X and Y shifts arrays
List of Processed X and Y [and Z] shifts arrays
- For rigid correction, each element is a vector of length n_frames
- For pw_rigid correction, each element is an n_frames x n_patches matrix
"""
path = self._series.paths.resolve(self._series["outputs"]["shifts"])
shifts = np.load(str(path))

if pw_rigid:
n_pts = shifts.shape[1]
n_lines = shifts.shape[2]
xs = [np.linspace(0, n_pts, n_pts)]
ys = []

for i in range(shifts.shape[0]):
for j in range(n_lines):
ys.append(shifts[i, :, j])
shifts_by_dim = list(shifts) # dims-length list of n_frames x n_patches matrices
else:
n_pts = shifts.shape[0]
n_lines = shifts.shape[1]
xs = [np.linspace(0, n_pts, n_pts)]
ys = []

for i in range(n_lines):
ys.append(shifts[:, i])
return xs, ys
shifts_by_dim = list(shifts.T) # dims-length list of n_frames-length vectors

return shifts_by_dim
20 changes: 19 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

def _download_ground_truths():
print(f"Downloading ground truths")
url = f"https://zenodo.org/record/6828096/files/ground_truths.zip"
url = f"https://zenodo.org/record/13732996/files/ground_truths.zip"

# basically from https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests/37573701
response = requests.get(url, stream=True)
Expand Down Expand Up @@ -252,6 +252,15 @@ def test_mcorr():
)
)

# test to check shifts output path
assert (
batch_dir.joinpath(df.iloc[-1]["outputs"]["shifts"])
== df.paths.resolve(df.iloc[-1]["outputs"]["shifts"])
== batch_dir.joinpath(
str(df.iloc[-1]["uuid"]), f'{df.iloc[-1]["uuid"]}_shifts.npy'
)
)

# test to check mean-projection output path
assert (
batch_dir.joinpath(df.iloc[-1]["outputs"]["mean-projection-path"])
Expand Down Expand Up @@ -303,6 +312,15 @@ def test_mcorr():
)
numpy.testing.assert_array_equal(mcorr_output, mcorr_output_actual)


# test to check mcorr get_shifts()
mcorr_shifts = df.iloc[-1].mcorr.get_shifts(pw_rigid=test_params[algo]["main"]["pw_rigid"])
mcorr_shifts_actual = numpy.load(
ground_truths_dir.joinpath("mcorr", "mcorr_shifts.npy")
)
numpy.testing.assert_array_equal(mcorr_shifts, mcorr_shifts_actual)


# test to check caiman get_input_movie_path()
assert df.iloc[-1].caiman.get_input_movie_path() == get_full_raw_data_path(
df.iloc[0]["input_movie_path"]
Expand Down

0 comments on commit 5d7c9b3

Please sign in to comment.