Skip to content

Commit

Permalink
Add more general option argument to _run_slurm
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanbb committed Jun 15, 2024
1 parent a39ddbe commit d753113
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
15 changes: 6 additions & 9 deletions mesmerize_core/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,10 @@ def get_full_raw_data_path(path: Union[Path, str]) -> Path:
return path


class OverwriteError(IndexError):
"""
Error thrown when trying to write to an existing batch file, but there is a risk
of overwriting existing data.
Note this is a subclass of IndexError to avoid a breaking change, because
IndexError was previously thrown from df.caiman.save_to_disk.
"""
class PreventOverwriteError(IndexError):
"""
Error thrown when trying to write to an existing batch file with a potential risk of removing existing rows.
"""
pass


Expand Down Expand Up @@ -283,7 +280,7 @@ def save_results_safely(batch_path: Union[Path, str], uuid, results: dict, runti
"""
Try to load the given batch and save results to the given item
Uses a file lock to ensure that no other process is writing to the same batch using this function,
which gives up after lock_timeout seconds (set to -1 to never give up)
which gives up after lock_timeout seconds.
"""
try:
with open_batch_for_safe_writing(batch_path) as df:
Expand All @@ -301,7 +298,7 @@ def save_results_safely(batch_path: Union[Path, str], uuid, results: dict, runti
msg = f"Batch file could not be written to"
if isinstance(e, Timeout):
msg += f" (file locked for {BatchLock.TIMEOUT} seconds)"
elif isinstance(e, OverwriteError):
elif isinstance(e, PreventOverwriteError):
msg += f" (items would be overwritten, even though file was locked)"

if results["success"]:
Expand Down
31 changes: 19 additions & 12 deletions mesmerize_core/caiman_extensions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime
import time
from copy import deepcopy
import shlex

import numpy as np
import pandas as pd
Expand All @@ -22,7 +23,7 @@
get_parent_raw_data_path,
load_batch,
open_batch_for_safe_writing,
OverwriteError
PreventOverwriteError
)
from ..utils import validate_path, IS_WINDOWS, make_runfile, warning_experimental
from .cnmf import cnmf_cache
Expand Down Expand Up @@ -141,7 +142,7 @@ def save_to_disk(self, max_index_diff: int = 0):
with open_batch_for_safe_writing(path) as disk_df:
# check that max_index_diff is not exceeded
if abs(disk_df.index.size - self._df.index.size) > max_index_diff:
raise OverwriteError(
raise PreventOverwriteError(
f"The number of rows in the DataFrame on disk differs more "
f"than has been allowed by the `max_index_diff` kwarg which "
f"is set to <{max_index_diff}>. This is to prevent overwriting "
Expand Down Expand Up @@ -466,12 +467,18 @@ def _run_slurm(
self,
runfile_path: str,
wait: bool,
partition: Optional[Union[str, list[str]]] = None,
sbatch_opts: str = '',
**kwargs
):
"""
Run on a cluster using SLURM. Configurable options (to pass to run):
- partition: if given, tells SLRUM to run the job on the given partition(s).
- sbatch_opts: A single string containing additional options for sbatch.
The following options are configured here, but can be overridden:
--job-name
--cpus-per-task (only controls number of CPUs allocated to the job; the number used for
parallel processing is controlled by os.environ['MESMERIZE_N_PROCESSES'])
The following options should NOT be overridden:
--ntasks, --output, --wait
"""

# this needs to match what's in the runfile
Expand All @@ -487,15 +494,15 @@ def _run_slurm(
output_path = output_dir / f'{uuid}.log'

# --wait means that the lifetme of the created process corresponds to the lifetime of the job
submission_opts = (f'--job-name={self._series["algo"]}-{uuid[:8]} --ntasks=1 ' +
f'--cpus-per-task={n_procs} --output={output_path} --wait')
submission_opts = [
f'--job-name={self._series["algo"]}-{uuid[:8]}',
'--ntasks=1',
f'--cpus-per-task={n_procs}',
f'--output={output_path}',
'--wait'
] + shlex.split(sbatch_opts)

if partition is not None:
if isinstance(partition, str):
partition = [partition]
submission_opts += f' --partition={",".join(partition)}'

self.process = Popen(['sbatch', *submission_opts.split(" "), runfile_path])
self.process = Popen(['sbatch', *submission_opts, runfile_path])
if wait:
self.process.wait()

Expand Down

0 comments on commit d753113

Please sign in to comment.