From d753113e5bd0655421c561f4d9582d56035c18d2 Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Sat, 15 Jun 2024 13:34:06 -0400 Subject: [PATCH] Add more general option argument to _run_slurm --- mesmerize_core/batch_utils.py | 15 +++++------ mesmerize_core/caiman_extensions/common.py | 31 +++++++++++++--------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/mesmerize_core/batch_utils.py b/mesmerize_core/batch_utils.py index 440973c..7c01f55 100644 --- a/mesmerize_core/batch_utils.py +++ b/mesmerize_core/batch_utils.py @@ -249,13 +249,10 @@ def get_full_raw_data_path(path: Union[Path, str]) -> Path: return path -class OverwriteError(IndexError): - """ - Error thrown when trying to write to an existing batch file, but there is a risk - of overwriting existing data. - Note this is a subclass of IndexError to avoid a breaking change, because - IndexError was previously thrown from df.caiman.save_to_disk. - """ +class PreventOverwriteError(IndexError): + """ + Error thrown when trying to write to an existing batch file with a potential risk of removing existing rows. + """ pass @@ -283,7 +280,7 @@ def save_results_safely(batch_path: Union[Path, str], uuid, results: dict, runti """ Try to load the given batch and save results to the given item Uses a file lock to ensure that no other process is writing to the same batch using this function, - which gives up after lock_timeout seconds (set to -1 to never give up) + which gives up after lock_timeout seconds. """ try: with open_batch_for_safe_writing(batch_path) as df: @@ -301,7 +298,7 @@ def save_results_safely(batch_path: Union[Path, str], uuid, results: dict, runti msg = f"Batch file could not be written to" if isinstance(e, Timeout): msg += f" (file locked for {BatchLock.TIMEOUT} seconds)" - elif isinstance(e, OverwriteError): + elif isinstance(e, PreventOverwriteError): msg += f" (items would be overwritten, even though file was locked)" if results["success"]: diff --git a/mesmerize_core/caiman_extensions/common.py b/mesmerize_core/caiman_extensions/common.py index ae63500..db65779 100644 --- a/mesmerize_core/caiman_extensions/common.py +++ b/mesmerize_core/caiman_extensions/common.py @@ -9,6 +9,7 @@ from datetime import datetime import time from copy import deepcopy +import shlex import numpy as np import pandas as pd @@ -22,7 +23,7 @@ get_parent_raw_data_path, load_batch, open_batch_for_safe_writing, - OverwriteError + PreventOverwriteError ) from ..utils import validate_path, IS_WINDOWS, make_runfile, warning_experimental from .cnmf import cnmf_cache @@ -141,7 +142,7 @@ def save_to_disk(self, max_index_diff: int = 0): with open_batch_for_safe_writing(path) as disk_df: # check that max_index_diff is not exceeded if abs(disk_df.index.size - self._df.index.size) > max_index_diff: - raise OverwriteError( + raise PreventOverwriteError( f"The number of rows in the DataFrame on disk differs more " f"than has been allowed by the `max_index_diff` kwarg which " f"is set to <{max_index_diff}>. This is to prevent overwriting " @@ -466,12 +467,18 @@ def _run_slurm( self, runfile_path: str, wait: bool, - partition: Optional[Union[str, list[str]]] = None, + sbatch_opts: str = '', **kwargs ): """ Run on a cluster using SLURM. Configurable options (to pass to run): - - partition: if given, tells SLRUM to run the job on the given partition(s). + - sbatch_opts: A single string containing additional options for sbatch. + The following options are configured here, but can be overridden: + --job-name + --cpus-per-task (only controls number of CPUs allocated to the job; the number used for + parallel processing is controlled by os.environ['MESMERIZE_N_PROCESSES']) + The following options should NOT be overridden: + --ntasks, --output, --wait """ # this needs to match what's in the runfile @@ -487,15 +494,15 @@ def _run_slurm( output_path = output_dir / f'{uuid}.log' # --wait means that the lifetme of the created process corresponds to the lifetime of the job - submission_opts = (f'--job-name={self._series["algo"]}-{uuid[:8]} --ntasks=1 ' + - f'--cpus-per-task={n_procs} --output={output_path} --wait') + submission_opts = [ + f'--job-name={self._series["algo"]}-{uuid[:8]}', + '--ntasks=1', + f'--cpus-per-task={n_procs}', + f'--output={output_path}', + '--wait' + ] + shlex.split(sbatch_opts) - if partition is not None: - if isinstance(partition, str): - partition = [partition] - submission_opts += f' --partition={",".join(partition)}' - - self.process = Popen(['sbatch', *submission_opts.split(" "), runfile_path]) + self.process = Popen(['sbatch', *submission_opts, runfile_path]) if wait: self.process.wait()