diff --git a/bnd/cli.py b/bnd/cli.py index 7f591d6..48573b1 100644 --- a/bnd/cli.py +++ b/bnd/cli.py @@ -45,6 +45,12 @@ def to_pyal( "--custom-map/--default-map", help="Run conversion with a custom map (-c) or the not (-C)", ), + validate: bool = typer.Option( + False, + "-v/-V", + "--validate/--no-validate", + help="Validate spike times alignment (-v) or skip validation (-V)", + ), ) -> None: """ Convert session data into a pyaldata dataframe and saves it as a .mat @@ -57,8 +63,9 @@ def to_pyal( \b Basic usage: - `bnd to-pyal M037_2024_01_01_10_00 # Kilosorts data and converts to pyaldata - `bnd to-pyal M037_2024_01_01_10_00 -c # Uses custom mapping + `bnd to-pyal M037_2024_01_01_10_00` # Kilosorts data and converts to pyaldata + `bnd to-pyal M037_2024_01_01_10_00 -c` # Uses custom mapping + `bnd to-pyal M037_2024_01_01_10_00 -v` # Validates spike alignment """ _check_processing_dependencies() from .pipeline.pyaldata import run_pyaldata_conversion @@ -71,7 +78,7 @@ def to_pyal( _check_session_directory(session_path) # Run pipeline - run_pyaldata_conversion(session_path, kilosort_flag, custom_map) + run_pyaldata_conversion(session_path, kilosort_flag, custom_map, validate) return @@ -91,6 +98,12 @@ def to_nwb( "--custom-map/--default-map", help="Run conversion with a custom map (-c) or the not (-C)", ), + validate: bool = typer.Option( + False, + "-v/-V", + "--validate/--no-validate", + help="Validate spike times alignment after conversion (-v) or skip validation (-V)", + ), ) -> None: """ Convert session data into a nwb file and saves it as a .nwb @@ -102,6 +115,7 @@ def to_nwb( Basic usage: `bnd to-nwb M037_2024_01_01_10_00` `bnd to-nwb M037_2024_01_01_10_00 -c` # Use custom channel mapping + `bnd to-nwb M037_2024_01_01_10_00 -v` # Validate spike alignment after conversion """ # TODO: Add channel map argument: no-map, default-map, custom-map # _check_processing_dependencies() @@ -115,6 +129,27 @@ def to_nwb( # Run pipeline run_nwb_conversion(session_path, kilosort_flag, custom_map) + + # Validate spike times after conversion if requested + if validate: + from .pipeline.spike_validation import check_negative_spike_times + nwb_path = session_path / f"{session_name}.nwb" + + if nwb_path.exists(): + has_negative, min_spike_time = check_negative_spike_times(nwb_path) + + if has_negative: + print(f"\n[green]✓ Validation passed: Found negative spike times (min = {min_spike_time:.3f}s)[/green]") + print(f"Alignment between pycontrol and ephys data looks correct.") + elif min_spike_time is None: + print(f"\n[yellow]⚠ No spike data found for validation[/yellow]") + else: + print(f"\n[red]✗ Validation failed: No negative spike times found (min = {min_spike_time:.3f}s)[/red]") + print(f"This suggests alignment issues between pycontrol and ephys data.") + print(f"Consider checking your raw data and re-running conversion.") + else: + print(f"\n[red]✗ NWB file not found for validation[/red]") + return @@ -142,9 +177,152 @@ def ksort(session_name: str = typer.Argument(help="Session name to kilosort")) - return +@app.command() +def validate_spikes( + session_name: str = typer.Argument(..., help="Session name to validate"), + delete_if_invalid: bool = typer.Option( + True, + "-d/-D", + "--delete/--no-delete", + help="Delete files if no negative spikes found (-d) or keep them (-D)", + ), + batch: bool = typer.Option( + False, + "-b/-B", + "--batch/--single", + help="Validate all sessions for this animal (-b) or single session (-B)", + ), +) -> None: + """ + Validate spike times alignment in NWB files. + + Checks for negative spike times which indicate correct pycontrol-ephys alignment. + If no negative spikes are found, optionally deletes NWB and pyaldata files. + + \b + Basic usage: + `bnd validate-spikes M037_2024_01_01_10_00` # Validate single session + `bnd validate-spikes M037_2024_01_01_10_00 -D` # Check only, don't delete + `bnd validate-spikes M037 -b` # Validate all M037 sessions + """ + _check_processing_dependencies() + from .pipeline.spike_validation import validate_and_clean_session, batch_validate_sessions + + config = _load_config() + + if batch: + # Extract animal name from session + animal_name = session_name.split("_")[0] + # Get animal directory and list sessions + animal_path = config.LOCAL_PATH / "raw" / animal_name + + if not animal_path.exists(): + print(f"[red]Animal directory not found: {animal_path}[/red]") + return + + try: + # Get all sessions for this animal + _, session_names = list_session_datetime(animal_path) + session_paths = [config.get_local_session_path(name) for name in session_names] + + if not session_paths: + print(f"[red]No sessions found for animal {animal_name}[/red]") + return + + print(f"[bold]Validating {len(session_paths)} sessions for {animal_name}[/bold]") + batch_validate_sessions(session_paths, delete_if_no_negative=delete_if_invalid) + + except Exception as e: + print(f"[red]Error getting sessions for {animal_name}: {str(e)}[/red]") + return + else: + # Single session validation + session_path = config.get_local_session_path(session_name) + _check_session_directory(session_path) + + needs_reconversion = validate_and_clean_session( + session_path, + delete_if_no_negative=delete_if_invalid + ) + + if needs_reconversion and delete_if_invalid: + print(f"\n[yellow]Session {session_name} needs re-conversion.[/yellow]") + print("Run: `bnd to-pyal {session_name}` to re-convert") + + return + + # ================================== Data Transfer ======================================== +@app.command() +def replace_processed( + session_name: str = typer.Argument(..., help="Session name to replace processed files for"), + auto_confirm: bool = typer.Option( + False, + "-y/-Y", + "--yes/--no-yes", + help="Auto-confirm replacement (-y) or prompt for confirmation (-Y)", + ), +) -> None: + """ + Replace processed files (.nwb and .mat) on the server. + + This command specifically replaces NWB and PyAlData files that may have been + corrected due to alignment issues. It will overwrite existing files on RDS. + + \b + Basic usage: + `bnd replace-processed M037_2024_01_01_10_00` # Replace with confirmation + `bnd replace-processed M037_2024_01_01_10_00 -y` # Auto-confirm + """ + _check_processing_dependencies() + from .data_transfer import replace_processed_files + + config = _load_config() + session_path = config.get_local_session_path(session_name) + + # Check session directory + _check_session_directory(session_path) + + # Find processed files to replace + processed_files = [] + + # NWB files + nwb_files = list(session_path.glob("*.nwb")) + processed_files.extend(nwb_files) + + # PyAlData files (could be partitioned) + mat_files = list(session_path.glob("*_pyaldata*.mat")) + + processed_files.extend(mat_files) + + if not processed_files: + print(f"[yellow]No processed files (.nwb or .mat) found in {session_name}[/yellow]") + return + + # Show what will be replaced + print(f"[bold]Files to replace on RDS:[/bold]") + for file in processed_files: + print(f" - {file.name}") + + # Confirm replacement + if not auto_confirm: + response = input(f"\nReplace these {len(processed_files)} files on RDS? (y/n): ").strip().lower() + if "y" not in response: + print("[yellow]Operation cancelled.[/yellow]") + return + + # Replace files + try: + replace_processed_files(session_name, processed_files) + print(f"[green]✓ Successfully replaced processed files for {session_name}[/green]") + except Exception as e: + print(f"[red]✗ Error replacing files: {str(e)}[/red]") + + return + + @app.command() def up( session_or_animal_name: str = typer.Argument( diff --git a/bnd/data_transfer.py b/bnd/data_transfer.py index 0c8e72c..53dd76e 100644 --- a/bnd/data_transfer.py +++ b/bnd/data_transfer.py @@ -151,3 +151,59 @@ def download_animal(animal_name: str, file_extension: str, max_size_MB: float = _,session_list = list_session_datetime(remote_animal_path) for session_name in session_list: download_session(session_name, file_extension, max_size_MB, do_video) + + +def _replace_file(local_file: Path, remote_file: Path): + """ + Replace a file on the remote server, overwriting if it exists. + + Parameters + ---------- + local_file: Path + local path of the file to upload + remote_file: Path + remote path of the file to upload (will be overwritten if exists) + """ + + # Ensure the destination directory exists + remote_file.parent.mkdir(parents=True, exist_ok=True) + + try: + shutil.copy2(local_file, remote_file) + except PermissionError: + shutil.copyfile(local_file, remote_file) + + logger.info(f'Replaced "{local_file.name}" on server') + + +def replace_processed_files(session_name: str, processed_files: list[Path]) -> None: + """ + Replace specific processed files (.nwb and .mat) on the server. + + This function will overwrite existing files, unlike the regular upload_session + which refuses to overwrite existing files. + + Parameters + ---------- + session_name: str + Name of the session + processed_files: list[Path] + List of local processed files to replace on server + """ + config = _load_config() + + uploaded_count = 0 + + for local_file in processed_files: + if not local_file.is_file(): + logger.warning(f"Skipping {local_file.name} - not a file") + continue + + # Convert to remote path + remote_file = config.convert_to_remote(local_file) + + # Replace the file (will overwrite if exists) + _replace_file(local_file=local_file, remote_file=remote_file) + uploaded_count += 1 + + logger.info(f"Replacement complete. Replaced {uploaded_count} processed files for {session_name}.") diff --git a/bnd/pipeline/pyaldata.py b/bnd/pipeline/pyaldata.py index b0ee052..53036ed 100644 --- a/bnd/pipeline/pyaldata.py +++ b/bnd/pipeline/pyaldata.py @@ -889,7 +889,10 @@ def save(self): def run_pyaldata_conversion( - session_path: Path, kilosort_flag: bool, custom_map: bool + session_path: Path, + kilosort_flag: bool, + custom_map: bool, + validate_spikes: bool = False ) -> None: """ Main pyaldata conversion routine. Creates pyaldata file for a specific session. It will @@ -900,6 +903,10 @@ def run_pyaldata_conversion( session_path : Path kilosort_flag : bool Whether to run kilosort or not. Defaults to True + custom_map : bool + Whether to use custom channel mapping + validate_spikes : bool + Whether to validate spike times for alignment issues Returns ------- @@ -912,6 +919,18 @@ def run_pyaldata_conversion( if isinstance(session_path, str): session_path = Path(session_path) + # Validate spike times if requested + if validate_spikes: + from .spike_validation import validate_and_clean_session + needs_reconversion = validate_and_clean_session( + session_path, + delete_if_no_negative=True + ) + if needs_reconversion: + logger.info("Re-running conversion due to alignment issues...") + # Files have been deleted, now re-run the conversion + run_nwb_conversion(session_path, kilosort_flag, custom_map) + # Get nwb file nwbfile_path = config.get_subdirectories_from_pattern(session_path, "*.nwb") if not nwbfile_path: diff --git a/bnd/pipeline/spike_validation.py b/bnd/pipeline/spike_validation.py new file mode 100644 index 0000000..4c3bb3e --- /dev/null +++ b/bnd/pipeline/spike_validation.py @@ -0,0 +1,251 @@ +""" +Module for validating spike times in NWB files to check alignment issues +""" + +import os +from pathlib import Path +from typing import Tuple, Optional + +import numpy as np +from pynwb import NWBHDF5IO +from rich import print + +from ..logger import set_logging + +logger = set_logging(__name__) + + +def check_negative_spike_times(nwb_path: Path) -> Tuple[bool, Optional[float]]: + """ + Check if there are negative spike times in an NWB file. + + Parameters + ---------- + nwb_path : Path + Path to the NWB file to check + + Returns + ------- + Tuple[bool, Optional[float]] + (has_negative_spikes, min_spike_time) + - has_negative_spikes: True if negative spike times exist + - min_spike_time: The minimum spike time found (None if no spikes) + """ + try: + with NWBHDF5IO(nwb_path, mode="r") as io: + nwbfile = io.read() + + # Check if ecephys processing module exists + if not hasattr(nwbfile, "processing"): + logger.warning(f"No processing module in {nwb_path.name}") + return False, None + + if "ecephys" not in nwbfile.processing: + logger.warning(f"No ecephys data in {nwb_path.name}") + return False, None + + # Check all probe units for negative spike times + ecephys = nwbfile.processing["ecephys"].data_interfaces + min_spike_time = float('inf') + has_spikes = False + + for probe_name, probe_units in ecephys.items(): + if hasattr(probe_units, 'spike_times'): + spike_times = probe_units.spike_times[:] + if len(spike_times) > 0: + has_spikes = True + probe_min = np.min(spike_times) + min_spike_time = min(min_spike_time, probe_min) + + if probe_min < 0: + logger.info( + f"Found negative spike times in {probe_name}: " + f"min = {probe_min:.3f}s" + ) + + if not has_spikes: + logger.warning(f"No spike times found in {nwb_path.name}") + return False, None + + has_negative = min_spike_time < 0 + return has_negative, min_spike_time + + except Exception as e: + logger.error(f"Error reading NWB file {nwb_path.name}: {str(e)}") + raise + + +def validate_and_clean_session( + session_path: Path, + force_rerun: bool = False, + delete_if_no_negative: bool = True +) -> bool: + """ + Validate spike times in a session and optionally clean/re-run if needed. + + Parameters + ---------- + session_path : Path + Path to the session directory + force_rerun : bool + Force re-run of conversion even if negative spikes exist + delete_if_no_negative : bool + Delete NWB and pyaldata files if no negative spikes found + + Returns + ------- + bool + True if session needs re-conversion, False otherwise + """ + # Check for NWB file + nwb_path = session_path / f"{session_path.name}.nwb" + if not nwb_path.exists(): + logger.warning(f"No NWB file found for session {session_path.name}") + return True # Needs conversion + + # Check for negative spike times + has_negative, min_spike_time = check_negative_spike_times(nwb_path) + + if has_negative: + logger.info( + f"✓ Session {session_path.name} has negative spike times " + f"(min = {min_spike_time:.3f}s). Alignment looks correct." + ) + return False # No re-conversion needed + + elif min_spike_time is None: + logger.warning( + f"Session {session_path.name} has no spike data to validate" + ) + return False # Can't validate, don't delete + + else: + logger.warning( + f"⚠ Session {session_path.name} has NO negative spike times " + f"(min = {min_spike_time:.3f}s). This suggests alignment issues." + ) + + if delete_if_no_negative or force_rerun: + print(f"\n[red]⚠ IMPORTANT: Re-conversion requires raw ephys data[/red]") + print(f"The alignment issue was in NWB conversion, which needs access to") + print(f"raw recording files (.meta files) to select the correct recording.") + print(f"") + print(f"[yellow]Before proceeding, ensure you have downloaded raw data:[/yellow]") + print(f" bnd dl {session_path.name}") + print(f"") + + # Check if raw ephys data exists locally + has_raw_ephys = any(session_path.rglob("*.meta")) or any(session_path.rglob("*_g?")) + + if not has_raw_ephys: + print(f"[red]✗ Raw ephys data not found locally[/red]") + print(f"Please download first: [bold]bnd dl {session_path.name}[/bold]") + print(f"Then re-run: [bold]bnd to-pyal {session_path.name} -v -K[/bold]") + return False + else: + print(f"[green]✓ Raw ephys data found locally[/green]") + + # Delete existing files + files_to_delete = [] + + # NWB file + if nwb_path.exists(): + files_to_delete.append(nwb_path) + + # PyAlData files (could be partitioned) + mat_files = list(session_path.glob("*_pyaldata*.mat")) + files_to_delete.extend(mat_files) + + if files_to_delete: + print(f"\n[yellow]Files to be deleted:[/yellow]") + for file in files_to_delete: + print(f" - {file.name}") + + # Confirm deletion + response = input( + "\nDelete these files and re-run conversion? (y/n): " + ).strip().lower() + + if "y" in response: + for file in files_to_delete: + os.remove(file) + logger.info(f"Deleted {file.name}") + print(f"\n[green]Files deleted. Re-conversion will proceed...[/green]") + return True # Needs re-conversion + else: + logger.info("Files not deleted. Keeping current data.") + print(f"\n[yellow]To manually re-convert later:[/yellow]") + print(f" bnd to-pyal {session_path.name} -K") + return False + + return False + + +def batch_validate_sessions( + sessions: list[Path], + delete_if_no_negative: bool = True, + summary_only: bool = False +) -> dict: + """ + Validate multiple sessions and provide a summary. + + Parameters + ---------- + sessions : list[Path] + List of session paths to validate + delete_if_no_negative : bool + Delete files if no negative spikes found + summary_only : bool + Only show summary without making changes + + Returns + ------- + dict + Summary of validation results + """ + results = { + "valid": [], + "invalid": [], + "no_data": [], + "missing_nwb": [] + } + + for session_path in sessions: + nwb_path = session_path / f"{session_path.name}.nwb" + + if not nwb_path.exists(): + results["missing_nwb"].append(session_path.name) + continue + + try: + has_negative, min_spike_time = check_negative_spike_times(nwb_path) + + if min_spike_time is None: + results["no_data"].append(session_path.name) + elif has_negative: + results["valid"].append((session_path.name, min_spike_time)) + else: + results["invalid"].append((session_path.name, min_spike_time)) + + if not summary_only and delete_if_no_negative: + validate_and_clean_session( + session_path, + delete_if_no_negative=True + ) + + except Exception as e: + logger.error(f"Error validating {session_path.name}: {str(e)}") + + # Print summary + print("\n[bold]Validation Summary:[/bold]") + print(f"✓ Valid (with negative spikes): {len(results['valid'])}") + print(f"✗ Invalid (no negative spikes): {len(results['invalid'])}") + print(f"○ No spike data: {len(results['no_data'])}") + print(f"□ Missing NWB: {len(results['missing_nwb'])}") + + if results["invalid"] and not summary_only: + print(f"\n[yellow]Sessions needing re-conversion:[/yellow]") + for session_name, min_time in results["invalid"]: + print(f" - {session_name} (min spike time: {min_time:.3f}s)") + + return results \ No newline at end of file