Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions bnd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,30 +149,29 @@ def _load_config() -> Config:
return Config()


def find_file(
main_path: str | Path, extension: tuple[str | Path] = (".txt",)
) -> list[Path]:
def find_file(path: str | Path, extension: tuple[str] = ('.raw.kwd',)) -> list[Path]:
"""
This function finds all the file types specified by 'extension' in the 'main_path' directory
and all its subdirectories and their sub-subdirectories etc.,
and returns a list of all file paths
'extension' is a list of desired file extensions: ['.dat','.prm']
Recursively finds files with the specified extensions within the given path.
`path` (str or Path): The directory in which to search for files.
`extension`: A tuple of file extensions e.g., ('.dat', '.prm').
"""
if isinstance(main_path, str):
path = Path(main_path)
else:
path = main_path
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"Path does not exist: {p}")

# Convert extension to list if it is a string.
if isinstance(extension, str):
extension = extension.split() # turning extension into a list with a single element

return [
Path(walking[0] / goodfile)
for walking in path.walk()
for goodfile in walking[2]
for ext in extension
if goodfile.endswith(ext)
]
extension = extension.split()

# Normalize extensions to ensure they start with a dot.
normalized_exts = [ext if ext.startswith('.') else '.' + ext for ext in extension]

found_files = []
for ext in normalized_exts:
for file in p.rglob(f"*{ext}"):
if file.is_file():
found_files.append(file)
return found_files


def list_dirs(main_path: str | Path) -> list[str]:
Expand Down
72 changes: 53 additions & 19 deletions bnd/pipeline/kilosort.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,64 @@
from pathlib import Path
from configparser import ConfigParser
import os

import torch
from kilosort import run_kilosort
from kilosort.utils import PROBE_DIR, download_probes

from bnd import set_logging
from bnd.config import Config, _load_config
from ..config import find_file

logger = set_logging(__name__)


def read_metadata(filepath: Path) -> dict:
"""Parse a section-less INI file (eg NPx metadata file) and return a dictionary of key-value pairs."""
with open(filepath, 'r') as f:
content = f.read()
# Inject a dummy section header
content_with_section = '[dummy_section]\n' + content

config = ConfigParser()
config.optionxform = str # disables lowercasing
config.read_string(content_with_section)

return dict(config.items('dummy_section'))


def add_entry_to_metadata(filepath: Path, tag: str, value: str) -> None:
"""
Add or update a tag=value entry in the NPx metadata.
"""
with open(filepath, 'a') as f: # append mode
f.write(f"{tag}={value}\n")

def _read_probe_type(meta_file_path: str) -> str:
with open(meta_file_path, "r") as meta_file:
for line in meta_file:
if "imDatPrb_type" in line:
_, value = line.strip().split("=")
break

if int(value) == 0:
probe_type = (
"neuropixPhase3B1_kilosortChanMap.mat" # Neuropixels Phase3B1 (staggered)
)
elif int(value) == 21:
probe_type = "NP2_kilosortChanMap.mat"
else:
raise ValueError(
"Probe type not recogised. It appears to be different from Npx 1.0 or 2.0"
)
meta = read_metadata(meta_file_path)
probe_type_val = meta["imDatPrb_type"]
if int(probe_type_val) == 0:
probe_type = (
"neuropixPhase3B1_kilosortChanMap.mat" # Neuropixels Phase3B1 (staggered)
)
elif int(probe_type_val) == 21:
probe_type = "NP2_kilosortChanMap.mat"
else:
raise ValueError(
"Probe type not recogised. It appears to be different from Npx 1.0 or 2.0"
)
return probe_type

def _fix_session_metadata(meta_file_path: Path) -> None:
""" to inject `fileSizeBytes` and `fileTimeSecs` if they are missing"""
meta = read_metadata(meta_file_path)
if "fileSizeBytes" not in meta:
datafile_path = find_file(meta_file_path.parent, 'ap.bin')[0]
data_size = os.path.getsize(datafile_path)
add_entry_to_metadata(meta_file_path, "fileSizeBytes", str(data_size))
data_duration = data_size / int(meta['nSavedChans']) / 2 / int(meta["imSampRate"])
add_entry_to_metadata(meta_file_path, "fileTimeSecs", str(data_duration))
logger.warning(f"Metadata missing values: Injected fileSizeBytes: {data_size} and fileTimeSecs: {data_duration}")

def run_kilosort_on_stream(
config: Config,
Expand All @@ -55,9 +85,10 @@ def run_kilosort_on_stream(
-------

"""
meta_file_path = config.get_subdirectories_from_pattern(probe_folder_path, "*ap.meta")[0]

sorter_params = {
"n_chan_bin": 385,
"n_chan_bin": int(read_metadata(meta_file_path)["nSavedChans"]),
}

ksort_output_path = (
Expand All @@ -76,9 +107,12 @@ def run_kilosort_on_stream(
# Sometimes the gateway can throw an error so just double check.
download_probes()


# Check if the metadata file is complete
# when SpikeGLX crashes, metadata misses some values.
_fix_session_metadata(meta_file_path)
# Find out which probe type we have
meta_file_path = config.get_subdirectories_from_pattern(probe_folder_path, "*ap.meta")
probe_name = _read_probe_type(str(meta_file_path[0]))
probe_name = _read_probe_type(meta_file_path)

_ = run_kilosort(
settings=sorter_params,
Expand Down
Loading