Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 181 additions & 3 deletions bnd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def to_pyal(
"--custom-map/--default-map",
help="Run conversion with a custom map (-c) or the not (-C)",
),
validate: bool = typer.Option(
False,
"-v/-V",
"--validate/--no-validate",
help="Validate spike times alignment (-v) or skip validation (-V)",
),
) -> None:
"""
Convert session data into a pyaldata dataframe and saves it as a .mat
Expand All @@ -57,8 +63,9 @@ def to_pyal(

\b
Basic usage:
`bnd to-pyal M037_2024_01_01_10_00 # Kilosorts data and converts to pyaldata
`bnd to-pyal M037_2024_01_01_10_00 -c # Uses custom mapping
`bnd to-pyal M037_2024_01_01_10_00` # Kilosorts data and converts to pyaldata
`bnd to-pyal M037_2024_01_01_10_00 -c` # Uses custom mapping
`bnd to-pyal M037_2024_01_01_10_00 -v` # Validates spike alignment
"""
_check_processing_dependencies()
from .pipeline.pyaldata import run_pyaldata_conversion
Expand All @@ -71,7 +78,7 @@ def to_pyal(
_check_session_directory(session_path)

# Run pipeline
run_pyaldata_conversion(session_path, kilosort_flag, custom_map)
run_pyaldata_conversion(session_path, kilosort_flag, custom_map, validate)

return

Expand All @@ -91,6 +98,12 @@ def to_nwb(
"--custom-map/--default-map",
help="Run conversion with a custom map (-c) or the not (-C)",
),
validate: bool = typer.Option(
False,
"-v/-V",
"--validate/--no-validate",
help="Validate spike times alignment after conversion (-v) or skip validation (-V)",
),
) -> None:
"""
Convert session data into a nwb file and saves it as a .nwb
Expand All @@ -102,6 +115,7 @@ def to_nwb(
Basic usage:
`bnd to-nwb M037_2024_01_01_10_00`
`bnd to-nwb M037_2024_01_01_10_00 -c` # Use custom channel mapping
`bnd to-nwb M037_2024_01_01_10_00 -v` # Validate spike alignment after conversion
"""
# TODO: Add channel map argument: no-map, default-map, custom-map
# _check_processing_dependencies()
Expand All @@ -115,6 +129,27 @@ def to_nwb(

# Run pipeline
run_nwb_conversion(session_path, kilosort_flag, custom_map)

# Validate spike times after conversion if requested
if validate:
from .pipeline.spike_validation import check_negative_spike_times
nwb_path = session_path / f"{session_name}.nwb"

if nwb_path.exists():
has_negative, min_spike_time = check_negative_spike_times(nwb_path)

if has_negative:
print(f"\n[green]✓ Validation passed: Found negative spike times (min = {min_spike_time:.3f}s)[/green]")
print(f"Alignment between pycontrol and ephys data looks correct.")
elif min_spike_time is None:
print(f"\n[yellow]⚠ No spike data found for validation[/yellow]")
else:
print(f"\n[red]✗ Validation failed: No negative spike times found (min = {min_spike_time:.3f}s)[/red]")
print(f"This suggests alignment issues between pycontrol and ephys data.")
print(f"Consider checking your raw data and re-running conversion.")
else:
print(f"\n[red]✗ NWB file not found for validation[/red]")

return


Expand Down Expand Up @@ -142,9 +177,152 @@ def ksort(session_name: str = typer.Argument(help="Session name to kilosort")) -
return


@app.command()
def validate_spikes(
session_name: str = typer.Argument(..., help="Session name to validate"),
delete_if_invalid: bool = typer.Option(
True,
"-d/-D",
"--delete/--no-delete",
help="Delete files if no negative spikes found (-d) or keep them (-D)",
),
batch: bool = typer.Option(
False,
"-b/-B",
"--batch/--single",
help="Validate all sessions for this animal (-b) or single session (-B)",
),
) -> None:
"""
Validate spike times alignment in NWB files.

Checks for negative spike times which indicate correct pycontrol-ephys alignment.
If no negative spikes are found, optionally deletes NWB and pyaldata files.

\b
Basic usage:
`bnd validate-spikes M037_2024_01_01_10_00` # Validate single session
`bnd validate-spikes M037_2024_01_01_10_00 -D` # Check only, don't delete
`bnd validate-spikes M037 -b` # Validate all M037 sessions
"""
_check_processing_dependencies()
from .pipeline.spike_validation import validate_and_clean_session, batch_validate_sessions

config = _load_config()

if batch:
# Extract animal name from session
animal_name = session_name.split("_")[0]
# Get animal directory and list sessions
animal_path = config.LOCAL_PATH / "raw" / animal_name

if not animal_path.exists():
print(f"[red]Animal directory not found: {animal_path}[/red]")
return

try:
# Get all sessions for this animal
_, session_names = list_session_datetime(animal_path)
session_paths = [config.get_local_session_path(name) for name in session_names]

if not session_paths:
print(f"[red]No sessions found for animal {animal_name}[/red]")
return

print(f"[bold]Validating {len(session_paths)} sessions for {animal_name}[/bold]")
batch_validate_sessions(session_paths, delete_if_no_negative=delete_if_invalid)

except Exception as e:
print(f"[red]Error getting sessions for {animal_name}: {str(e)}[/red]")
return
else:
# Single session validation
session_path = config.get_local_session_path(session_name)
_check_session_directory(session_path)

needs_reconversion = validate_and_clean_session(
session_path,
delete_if_no_negative=delete_if_invalid
)

if needs_reconversion and delete_if_invalid:
print(f"\n[yellow]Session {session_name} needs re-conversion.[/yellow]")
print("Run: `bnd to-pyal {session_name}` to re-convert")

return


# ================================== Data Transfer ========================================


@app.command()
def replace_processed(
session_name: str = typer.Argument(..., help="Session name to replace processed files for"),
auto_confirm: bool = typer.Option(
False,
"-y/-Y",
"--yes/--no-yes",
help="Auto-confirm replacement (-y) or prompt for confirmation (-Y)",
),
) -> None:
"""
Replace processed files (.nwb and .mat) on the server.

This command specifically replaces NWB and PyAlData files that may have been
corrected due to alignment issues. It will overwrite existing files on RDS.

\b
Basic usage:
`bnd replace-processed M037_2024_01_01_10_00` # Replace with confirmation
`bnd replace-processed M037_2024_01_01_10_00 -y` # Auto-confirm
"""
_check_processing_dependencies()
from .data_transfer import replace_processed_files

config = _load_config()
session_path = config.get_local_session_path(session_name)

# Check session directory
_check_session_directory(session_path)

# Find processed files to replace
processed_files = []

# NWB files
nwb_files = list(session_path.glob("*.nwb"))
processed_files.extend(nwb_files)

# PyAlData files (could be partitioned)
mat_files = list(session_path.glob("*_pyaldata*.mat"))

processed_files.extend(mat_files)

if not processed_files:
print(f"[yellow]No processed files (.nwb or .mat) found in {session_name}[/yellow]")
return

# Show what will be replaced
print(f"[bold]Files to replace on RDS:[/bold]")
for file in processed_files:
print(f" - {file.name}")

# Confirm replacement
if not auto_confirm:
response = input(f"\nReplace these {len(processed_files)} files on RDS? (y/n): ").strip().lower()
if "y" not in response:
print("[yellow]Operation cancelled.[/yellow]")
return

# Replace files
try:
replace_processed_files(session_name, processed_files)
print(f"[green]✓ Successfully replaced processed files for {session_name}[/green]")
except Exception as e:
print(f"[red]✗ Error replacing files: {str(e)}[/red]")

return


@app.command()
def up(
session_or_animal_name: str = typer.Argument(
Expand Down
56 changes: 56 additions & 0 deletions bnd/data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,59 @@ def download_animal(animal_name: str, file_extension: str, max_size_MB: float =
_,session_list = list_session_datetime(remote_animal_path)
for session_name in session_list:
download_session(session_name, file_extension, max_size_MB, do_video)


def _replace_file(local_file: Path, remote_file: Path):
"""
Replace a file on the remote server, overwriting if it exists.

Parameters
----------
local_file: Path
local path of the file to upload
remote_file: Path
remote path of the file to upload (will be overwritten if exists)
"""

# Ensure the destination directory exists
remote_file.parent.mkdir(parents=True, exist_ok=True)

try:
shutil.copy2(local_file, remote_file)
except PermissionError:
shutil.copyfile(local_file, remote_file)

logger.info(f'Replaced "{local_file.name}" on server')


def replace_processed_files(session_name: str, processed_files: list[Path]) -> None:
"""
Replace specific processed files (.nwb and .mat) on the server.

This function will overwrite existing files, unlike the regular upload_session
which refuses to overwrite existing files.

Parameters
----------
session_name: str
Name of the session
processed_files: list[Path]
List of local processed files to replace on server
"""
config = _load_config()

uploaded_count = 0

for local_file in processed_files:
if not local_file.is_file():
logger.warning(f"Skipping {local_file.name} - not a file")
continue

# Convert to remote path
remote_file = config.convert_to_remote(local_file)

# Replace the file (will overwrite if exists)
_replace_file(local_file=local_file, remote_file=remote_file)
uploaded_count += 1

logger.info(f"Replacement complete. Replaced {uploaded_count} processed files for {session_name}.")
21 changes: 20 additions & 1 deletion bnd/pipeline/pyaldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,10 @@ def save(self):


def run_pyaldata_conversion(
session_path: Path, kilosort_flag: bool, custom_map: bool
session_path: Path,
kilosort_flag: bool,
custom_map: bool,
validate_spikes: bool = False
) -> None:
"""
Main pyaldata conversion routine. Creates pyaldata file for a specific session. It will
Expand All @@ -900,6 +903,10 @@ def run_pyaldata_conversion(
session_path : Path
kilosort_flag : bool
Whether to run kilosort or not. Defaults to True
custom_map : bool
Whether to use custom channel mapping
validate_spikes : bool
Whether to validate spike times for alignment issues

Returns
-------
Expand All @@ -912,6 +919,18 @@ def run_pyaldata_conversion(
if isinstance(session_path, str):
session_path = Path(session_path)

# Validate spike times if requested
if validate_spikes:
from .spike_validation import validate_and_clean_session
needs_reconversion = validate_and_clean_session(
session_path,
delete_if_no_negative=True
)
if needs_reconversion:
logger.info("Re-running conversion due to alignment issues...")
# Files have been deleted, now re-run the conversion
run_nwb_conversion(session_path, kilosort_flag, custom_map)

# Get nwb file
nwbfile_path = config.get_subdirectories_from_pattern(session_path, "*.nwb")
if not nwbfile_path:
Expand Down
Loading
Loading