Skip to content

Commit

Permalink
Merge pull request #1452 from proektlab/sbx-framerate-fix
Browse files Browse the repository at this point in the history
Fix incorrect framerate when loading bidirectional SBX files
  • Loading branch information
pgunn authored Jan 18, 2025
2 parents a7d7fc9 + 7926044 commit 86ed5ab
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 66 deletions.
19 changes: 19 additions & 0 deletions caiman/tests/test_sbx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_load_2d():
assert data_2d.ndim == 3, 'Loaded 2D data has wrong dimensionality'
assert data_2d.shape == SHAPE_2D, 'Loaded 2D data has wrong shape'
assert data_2d.shape == (meta_2d['num_frames'], *meta_2d['frame_size']), 'Shape in metadata does not match loaded data'
assert meta_2d['frame_rate'] == 15.625, 'Frame rate in metadata is incorrect (unidirectional)'
npt.assert_array_equal(data_2d[0, 0, :10], [712, 931, 1048, 825, 1383, 882, 601, 798, 1022, 966], 'Loaded 2D data has wrong values')

data_2d_movie = cm.load(file_2d)
Expand All @@ -37,6 +38,23 @@ def test_load_2d():
npt.assert_array_almost_equal(data_2d_movie, data_2d, err_msg='Movie loaded with cm.load has wrong values')


def test_load_2d_bidi():
file_2d_bidi = os.path.join(TESTDATA_PATH, '2d_sbx_bidi.sbx')
data_2d_bidi = sbx_utils.sbxread(file_2d_bidi)
meta_2d_bidi = sbx_utils.sbx_meta_data(file_2d_bidi)

assert data_2d_bidi.ndim == 3, 'Loaded 2D bidirectional data has wrong dimensionality'
assert data_2d_bidi.shape == SHAPE_2D, 'Loaded 2D bidirectional data has wrong shape'
assert data_2d_bidi.shape == (meta_2d_bidi['num_frames'], *meta_2d_bidi['frame_size']), 'Shape in metadata does not match loaded data'
assert meta_2d_bidi['frame_rate'] == 31.25, 'Frame rate in metadata is incorrect (bidirectional)'
npt.assert_array_equal(data_2d_bidi[0, 0, :10], [2833, 1538, 1741, 1837, 2079, 2038, 1946, 1631, 2260, 2073], 'Loaded 2D bidirectional data has wrong values')

data_2d_bidi_movie = cm.load(file_2d_bidi)
assert data_2d_bidi_movie.ndim == data_2d_bidi.ndim, 'Movie loaded with cm.load has wrong dimensionality'
assert data_2d_bidi_movie.shape == data_2d_bidi.shape, 'Movie loaded with cm.load has wrong shape'
npt.assert_array_almost_equal(data_2d_bidi_movie, data_2d_bidi, err_msg='Movie loaded with cm.load has wrong values')


def test_load_3d():
file_3d = os.path.join(TESTDATA_PATH, '3d_sbx_1.sbx')
data_3d = sbx_utils.sbxread(file_3d)
Expand All @@ -45,6 +63,7 @@ def test_load_3d():
assert data_3d.ndim == 4, 'Loaded 3D data has wrong dimensionality'
assert data_3d.shape == SHAPE_3D, 'Loaded 3D data has wrong shape'
assert data_3d.shape == (meta_3d['num_frames'], *meta_3d['frame_size'], meta_3d['num_planes']), 'Shape in metadata does not match loaded data'
assert meta_3d['frame_rate'] == 15.625 / meta_3d['num_planes'], 'Frame rate in metadata is incorrect (bidirectional 3D)'
npt.assert_array_equal(data_3d[0, 0, :10, 0], [2167, 2525, 1713, 1747, 1887, 1741, 1873, 1244, 1747, 1637], 'Loaded 2D data has wrong values')

data_3d_movie = cm.load(file_3d, is3D=True)
Expand Down
132 changes: 66 additions & 66 deletions caiman/utils/sbx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import os
import scipy
import tifffile
from typing import Iterable, Union, Optional
from typing import Any, Sequence, Union, Optional, cast

DimSubindices = Union[Iterable[int], slice]
FileSubindices = Union[DimSubindices, Iterable[DimSubindices]] # can have inds for just frames or also for y, x, z
ChainSubindices = Union[FileSubindices, Iterable[FileSubindices]] # one to apply to each file, or separate for each file
DimSubindices = Union[Sequence[int], slice]
FileSubindices = Union[DimSubindices, Sequence[DimSubindices]] # can have inds for just frames or also for y, x, z
ChainSubindices = Union[FileSubindices, Sequence[FileSubindices]] # one to apply to each file, or separate for each file

def loadmat_sbx(filename: str) -> dict:
def loadmat_sbx(filename: str) -> dict[str, Any]:
"""
this wrapper should be called instead of directly calling spio.loadmat
Expand All @@ -25,7 +25,7 @@ def loadmat_sbx(filename: str) -> dict:
"""
data_ = scipy.io.loadmat(filename, struct_as_record=False, squeeze_me=True)
_check_keys(data_)
return data_
return data_['info']


def _check_keys(checkdict: dict) -> None:
Expand Down Expand Up @@ -135,7 +135,7 @@ def sbx_chain_to_tif(filenames: list[str], fileout: str, subindices: Optional[Ch
fileout: str
filename to save, including the .tif suffix
subindices: Iterable[int] | slice | Iterable[Iterable[int] | slice | tuple[Iterable[int] | slice, ...]]
subindices: Sequence[int] | slice | Sequence[Sequence[int] | slice | tuple[Sequence[int] | slice, ...]]
see subindices for sbx_to_tif
can specify separate subindices for each file if nested 2 levels deep;
X, Y, and Z sizes must match for all files after indexing.
Expand All @@ -146,21 +146,25 @@ def sbx_chain_to_tif(filenames: list[str], fileout: str, subindices: Optional[Ch
subindices = slice(None)

# Validate aggressively to avoid failing after waiting to copy a lot of data
if isinstance(subindices, slice) or np.isscalar(subindices[0]):
# One set of subindices to repeat for each file
subindices = [(subindices,) for _ in filenames]
if isinstance(subindices, slice) or isinstance(subindices[0], int) or np.isscalar(subindices[0]):
# One set of subindices over time to repeat for each file
_subindices = [(cast(DimSubindices, subindices),) for _ in filenames]

elif isinstance(subindices[0], slice) or np.isscalar(subindices[0][0]):
elif isinstance(subindices[0], slice) or isinstance(subindices[0][0], int) or np.isscalar(subindices[0][0]):
# Interpret this as being an iterable over dimensions to repeat for each file
subindices = [subindices for _ in filenames]
_subindices = [cast(FileSubindices, subindices) for _ in filenames]

elif len(subindices) != len(filenames):
# Must be a separate subindices for each file; must match number of files
raise Exception('Length of subindices does not match length of file list')
raise Exception('Length of subindices does not match length of file list')

else:
_subindices = cast(Sequence[FileSubindices], subindices)
del subindices # ensure _subindices replaces subindices from here

# Get the total size of the file
all_shapes = [sbx_shape(file) for file in filenames]
all_shapes_out = np.stack([_get_output_shape(file, subind)[0] for (file, subind) in zip(filenames, subindices)])
all_shapes_out = np.stack([_get_output_shape(file, subind)[0] for (file, subind) in zip(filenames, _subindices)])

# Check that X, Y, and Z are consistent
for dimname, shapes in zip(('Y', 'X', 'Z'), all_shapes_out.T[1:]):
Expand Down Expand Up @@ -199,14 +203,15 @@ def sbx_chain_to_tif(filenames: list[str], fileout: str, subindices: Optional[Ch
# Now convert each file
tif_memmap = tifffile.memmap(fileout, series=0)
offset = 0
for filename, subind, file_N in zip(filenames, subindices, all_n_frames_out):
_sbxread_helper(filename, subindices=subind, channel=channel, out=tif_memmap[offset:offset+file_N], plane=plane, chunk_size=chunk_size)
for filename, subind, file_N in zip(filenames, _subindices, all_n_frames_out):
this_memmap = cast(np.memmap, tif_memmap[offset:offset+file_N])
_sbxread_helper(filename, subindices=subind, channel=channel, out=this_memmap, plane=plane, chunk_size=chunk_size)
offset += file_N

del tif_memmap # important to make sure file is closed (on Windows)


def sbx_shape(filename: str, info: Optional[dict] = None) -> tuple[int, int, int, int, int]:
def sbx_shape(filename: str, info: Optional[dict[str, Any]] = None) -> tuple[int, int, int, int, int]:
"""
Args:
filename: str
Expand All @@ -223,55 +228,41 @@ def sbx_shape(filename: str, info: Optional[dict] = None) -> tuple[int, int, int

# Load info
if info is None:
info = loadmat_sbx(filename + '.mat')['info']
info = loadmat_sbx(filename + '.mat')

# Image size
if 'sz' not in info:
info['sz'] = np.array([512, 796])

# Scan mode (0 indicates bidirectional)
if 'scanmode' in info and info['scanmode'] == 0:
info['recordsPerBuffer'] *= 2

# Fold lines (multiple subframes per scan) - basically means the frames are smaller and
# there are more of them than is reflected in the info file
if 'fold_lines' in info and info['fold_lines'] > 0:
if info['recordsPerBuffer'] % info['fold_lines'] != 0:
if info['sz'][0] % info['fold_lines'] != 0:
raise Exception('Non-integer folds per frame not supported')
n_folds = round(info['recordsPerBuffer'] / info['fold_lines'])
info['recordsPerBuffer'] = info['fold_lines']

info['sz'][0] = info['fold_lines']
if 'bytesPerBuffer' in info:
n_folds = round(info['sz'][0] / info['fold_lines'])
info['bytesPerBuffer'] /= n_folds
else:
n_folds = 1


# Defining number of channels/size factor
if 'chan' in info:
info['nChan'] = info['chan']['nchan']
factor = 1 # should not be used
elif info['channels'] == 1:
info['nChan'] = 2
else:
if info['channels'] == 1:
info['nChan'] = 2
factor = 1
elif info['channels'] == 2:
info['nChan'] = 1
factor = 2
elif info['channels'] == 3:
info['nChan'] = 1
factor = 2
info['nChan'] = 1

# Determine number of frames in whole file
filesize = os.path.getsize(filename + '.sbx')
if 'scanbox_version' in info:
if info['scanbox_version'] == 2:
info['max_idx'] = filesize / info['recordsPerBuffer'] / info['sz'][1] * factor / 4 - 1
elif info['scanbox_version'] == 3:
if info['scanbox_version'] in [2, 3]:
info['max_idx'] = filesize / np.prod(info['sz']) / info['nChan'] / 2 - 1
else:
raise Exception('Invalid Scanbox version')
else:
info['max_idx'] = filesize / info['bytesPerBuffer'] * factor - 1
info['max_idx'] = filesize / info['bytesPerBuffer'] * (2 // info['nChan']) - 1

n_frames = info['max_idx'] + 1 # Last frame

Expand All @@ -284,7 +275,7 @@ def sbx_shape(filename: str, info: Optional[dict] = None) -> tuple[int, int, int
n_planes = 1
n_frames //= n_planes

x = (int(info['nChan']), int(info['sz'][1]), int(info['recordsPerBuffer']), int(n_planes), int(n_frames))
x = (int(info['nChan']), int(info['sz'][1]), int(info['sz'][0]), int(n_planes), int(n_frames))
return x


Expand All @@ -302,7 +293,7 @@ def sbx_meta_data(filename: str):
if ext == '.sbx':
filename = basename

info = loadmat_sbx(filename + '.mat')['info']
info = loadmat_sbx(filename + '.mat')

meta_data = dict()
n_chan, n_x, n_y, n_planes, n_frames = sbx_shape(filename, info)
Expand Down Expand Up @@ -400,21 +391,22 @@ def _sbxread_helper(filename: str, subindices: FileSubindices = slice(None), cha
filename = basename

# Normalize so subindices is a list over dimensions
if isinstance(subindices, slice) or np.isscalar(subindices[0]):
subindices = [subindices]
if isinstance(subindices, slice) or isinstance(subindices[0], int) or np.isscalar(subindices[0]):
_subindices = [cast(DimSubindices, subindices)]
else:
subindices = list(subindices)
_subindices = list(cast(Sequence[DimSubindices], subindices))
del subindices # ensure _subindices replaces subindices from here

# Load info
info = loadmat_sbx(filename + '.mat')['info']
info = loadmat_sbx(filename + '.mat')

# Get shape (and update info)
data_shape = sbx_shape(filename, info) # (chans, X, Y, Z, frames)
n_chans, n_x, n_y, n_planes, n_frames = data_shape
is3D = n_planes > 1

# Fill in missing dimensions in subindices
subindices += [slice(None) for _ in range(max(0, 3 + is3D - len(subindices)))]
_subindices += [slice(None) for _ in range(max(0, 3 + is3D - len(_subindices)))]

if channel is None:
if n_chans > 1:
Expand All @@ -430,7 +422,7 @@ def _sbxread_helper(filename: str, subindices: FileSubindices = slice(None), cha
if frame_size <= 0:
raise Exception('Invalid scanbox metadata')

save_shape, subindices = _get_output_shape(data_shape, subindices)
save_shape, _subindices = _get_output_shape(data_shape, _subindices)
n_frames_out = save_shape[0]
if plane is not None:
if len(save_shape) < 4:
Expand All @@ -447,16 +439,18 @@ def _sbxread_helper(filename: str, subindices: FileSubindices = slice(None), cha
if not is3D: # squeeze out singleton plane dim
sbx_mmap = sbx_mmap[..., 0]
elif plane is not None: # select plane relative to subindices
sbx_mmap = sbx_mmap[..., subindices[-1][plane]]
subindices = subindices[:-1]
inds = np.ix_(*subindices)
sbx_mmap = sbx_mmap[..., _subindices[-1][plane]]
_subindices = _subindices[:-1]
inds = np.ix_(*_subindices)

out_arr: Optional[np.ndarray] = out # widen type
del out # ensure out_arr replaces out from here
if chunk_size is None:
# load a contiguous block all at once
chunk_size = n_frames_out
elif out is None:
elif out_arr is None:
# Pre-allocate destination when loading in chunks
out = np.empty(save_shape, dtype=np.uint16)
out_arr = np.empty(save_shape, dtype=np.uint16)

n_remaining = n_frames_out
offset = 0
Expand All @@ -468,23 +462,26 @@ def _sbxread_helper(filename: str, subindices: FileSubindices = slice(None), cha
# Note: SBX files store the values strangely, it's necessary to invert each uint16 value to get the correct ones
np.invert(chunk, out=chunk) # avoid copying, may be large

if out is None:
out = chunk # avoid copying when loading all data
if out_arr is None:
out_arr = chunk # avoid copying when loading all data
else:
out[offset:offset+this_chunk_size] = chunk
out_arr[offset:offset+this_chunk_size] = chunk
n_remaining -= this_chunk_size
offset += this_chunk_size

if out_arr is None:
raise RuntimeError('Nothing loaded - no frames selected?')

del sbx_mmap # Important to close file (on Windows)

if isinstance(out, np.memmap):
out.flush()
return out
if isinstance(out_arr, np.memmap):
out_arr.flush()
return out_arr


def _interpret_subindices(subindices: DimSubindices, dim_extent: int) -> tuple[Iterable[int], int]:
def _interpret_subindices(subindices: DimSubindices, dim_extent: int) -> tuple[Sequence[int], int]:
"""
Given the extent of a dimension in the corresponding recording, obtain an iterable over subindices
Given the extent of a dimension in the corresponding recording, obtain a sequence over subindices
and the step size (or 0 if the step size is not uniform).
"""
logger = logging.getLogger("caiman")
Expand All @@ -507,15 +504,18 @@ def _interpret_subindices(subindices: DimSubindices, dim_extent: int) -> tuple[I


def _get_output_shape(filename_or_shape: Union[str, tuple[int, ...]], subindices: FileSubindices
) -> tuple[tuple[int, ...], FileSubindices]:
) -> tuple[tuple[int, ...], tuple[Sequence[int], ...]]:
"""
Helper to determine what shape will be loaded/saved given subindices
Also returns back the subindices with slices transformed to ranges, for convenience
"""
if isinstance(subindices, slice) or np.isscalar(subindices[0]):
subindices = (subindices,)
_subindices = (cast(DimSubindices, subindices),)
else:
_subindices = cast(Sequence[DimSubindices], subindices)
del subindices # ensure _subindices replaces subindices from here

n_inds = len(subindices) # number of dimensions that are indexed
n_inds = len(_subindices) # number of dimensions that are indexed

if isinstance(filename_or_shape, str):
data_shape = sbx_shape(filename_or_shape)
Expand All @@ -529,7 +529,7 @@ def _get_output_shape(filename_or_shape: Union[str, tuple[int, ...]], subindices

shape_out = [n_frames, n_y, n_x, n_planes] if is3D else [n_frames, n_y, n_x]
subinds_out = []
for i, (dim, subind) in enumerate(zip(shape_out, subindices)):
for i, (dim, subind) in enumerate(zip(shape_out, _subindices)):
iterable_elements = _interpret_subindices(subind, dim)[0]
shape_out[i] = len(iterable_elements)
subinds_out.append(iterable_elements)
Expand Down
Binary file added testdata/2d_sbx_bidi.mat
Binary file not shown.
Binary file added testdata/2d_sbx_bidi.sbx
Binary file not shown.

0 comments on commit 86ed5ab

Please sign in to comment.