diff --git a/bnd/config.py b/bnd/config.py index 5a2da76..9746cf3 100644 --- a/bnd/config.py +++ b/bnd/config.py @@ -149,30 +149,29 @@ def _load_config() -> Config: return Config() -def find_file( - main_path: str | Path, extension: tuple[str | Path] = (".txt",) -) -> list[Path]: +def find_file(path: str | Path, extension: tuple[str] = ('.raw.kwd',)) -> list[Path]: """ - This function finds all the file types specified by 'extension' in the 'main_path' directory - and all its subdirectories and their sub-subdirectories etc., - and returns a list of all file paths - 'extension' is a list of desired file extensions: ['.dat','.prm'] + Recursively finds files with the specified extensions within the given path. + `path` (str or Path): The directory in which to search for files. + `extension`: A tuple of file extensions e.g., ('.dat', '.prm'). """ - if isinstance(main_path, str): - path = Path(main_path) - else: - path = main_path + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f"Path does not exist: {p}") + # Convert extension to list if it is a string. if isinstance(extension, str): - extension = extension.split() # turning extension into a list with a single element - - return [ - Path(walking[0] / goodfile) - for walking in path.walk() - for goodfile in walking[2] - for ext in extension - if goodfile.endswith(ext) - ] + extension = extension.split() + + # Normalize extensions to ensure they start with a dot. + normalized_exts = [ext if ext.startswith('.') else '.' + ext for ext in extension] + + found_files = [] + for ext in normalized_exts: + for file in p.rglob(f"*{ext}"): + if file.is_file(): + found_files.append(file) + return found_files def list_dirs(main_path: str | Path) -> list[str]: diff --git a/bnd/pipeline/kilosort.py b/bnd/pipeline/kilosort.py index f765093..38aeb58 100644 --- a/bnd/pipeline/kilosort.py +++ b/bnd/pipeline/kilosort.py @@ -1,4 +1,6 @@ from pathlib import Path +from configparser import ConfigParser +import os import torch from kilosort import run_kilosort @@ -6,29 +8,57 @@ from bnd import set_logging from bnd.config import Config, _load_config +from ..config import find_file logger = set_logging(__name__) +def read_metadata(filepath: Path) -> dict: + """Parse a section-less INI file (eg NPx metadata file) and return a dictionary of key-value pairs.""" + with open(filepath, 'r') as f: + content = f.read() + # Inject a dummy section header + content_with_section = '[dummy_section]\n' + content + + config = ConfigParser() + config.optionxform = str # disables lowercasing + config.read_string(content_with_section) + + return dict(config.items('dummy_section')) + + +def add_entry_to_metadata(filepath: Path, tag: str, value: str) -> None: + """ + Add or update a tag=value entry in the NPx metadata. + """ + with open(filepath, 'a') as f: # append mode + f.write(f"{tag}={value}\n") + def _read_probe_type(meta_file_path: str) -> str: - with open(meta_file_path, "r") as meta_file: - for line in meta_file: - if "imDatPrb_type" in line: - _, value = line.strip().split("=") - break - - if int(value) == 0: - probe_type = ( - "neuropixPhase3B1_kilosortChanMap.mat" # Neuropixels Phase3B1 (staggered) - ) - elif int(value) == 21: - probe_type = "NP2_kilosortChanMap.mat" - else: - raise ValueError( - "Probe type not recogised. It appears to be different from Npx 1.0 or 2.0" - ) + meta = read_metadata(meta_file_path) + probe_type_val = meta["imDatPrb_type"] + if int(probe_type_val) == 0: + probe_type = ( + "neuropixPhase3B1_kilosortChanMap.mat" # Neuropixels Phase3B1 (staggered) + ) + elif int(probe_type_val) == 21: + probe_type = "NP2_kilosortChanMap.mat" + else: + raise ValueError( + "Probe type not recogised. It appears to be different from Npx 1.0 or 2.0" + ) return probe_type +def _fix_session_metadata(meta_file_path: Path) -> None: + """ to inject `fileSizeBytes` and `fileTimeSecs` if they are missing""" + meta = read_metadata(meta_file_path) + if "fileSizeBytes" not in meta: + datafile_path = find_file(meta_file_path.parent, 'ap.bin')[0] + data_size = os.path.getsize(datafile_path) + add_entry_to_metadata(meta_file_path, "fileSizeBytes", str(data_size)) + data_duration = data_size / int(meta['nSavedChans']) / 2 / int(meta["imSampRate"]) + add_entry_to_metadata(meta_file_path, "fileTimeSecs", str(data_duration)) + logger.warning(f"Metadata missing values: Injected fileSizeBytes: {data_size} and fileTimeSecs: {data_duration}") def run_kilosort_on_stream( config: Config, @@ -55,9 +85,10 @@ def run_kilosort_on_stream( ------- """ + meta_file_path = config.get_subdirectories_from_pattern(probe_folder_path, "*ap.meta")[0] sorter_params = { - "n_chan_bin": 385, + "n_chan_bin": int(read_metadata(meta_file_path)["nSavedChans"]), } ksort_output_path = ( @@ -76,9 +107,12 @@ def run_kilosort_on_stream( # Sometimes the gateway can throw an error so just double check. download_probes() + + # Check if the metadata file is complete + # when SpikeGLX crashes, metadata misses some values. + _fix_session_metadata(meta_file_path) # Find out which probe type we have - meta_file_path = config.get_subdirectories_from_pattern(probe_folder_path, "*ap.meta") - probe_name = _read_probe_type(str(meta_file_path[0])) + probe_name = _read_probe_type(meta_file_path) _ = run_kilosort( settings=sorter_params,