Skip to content

Commit

Permalink
feat: add support for multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
David Wallace committed Mar 10, 2024
1 parent 656c8f7 commit 0e003e4
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"pydantic_numpy>=4.1",
"loguru>=0.7",
"typer[all]",
"mpire[dill]~=2.10.0",
]

[project.optional-dependencies]
Expand Down
54 changes: 54 additions & 0 deletions src/raman_fitting/delegating/run_fit_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict, List

from loguru import logger
from mpire import WorkerPool

from raman_fitting.models.fit_models import SpectrumFitModel


def run_fit_multi(**kwargs) -> SpectrumFitModel:
# include optional https://lmfit.github.io/lmfit-py/model.html#saving-and-loading-modelresults
spectrum = kwargs.pop("spectrum")
model = kwargs.pop("model")
lmfit_model = model["lmfit_model"]
region = kwargs.pop("region")
import time

lmfit_kwargs = {}
if "method" not in kwargs:
lmfit_kwargs["method"] = "leastsq"

init_params = lmfit_model.make_params()
start_time = time.time()
x, y = spectrum["ramanshift"], spectrum["intensity"]
out = lmfit_model.fit(y, init_params, x=x, **lmfit_kwargs) # 'leastsq'
end_time = time.time()
elapsed_seconds = abs(start_time - end_time)
elapsed_time = elapsed_seconds
logger.debug(
f"Fit with model {model['name']} on {region} success: {out.success} in {elapsed_time:.2f}s."
)
return out


def run_fit_multiprocessing(
spec_fits: List[SpectrumFitModel],
) -> Dict[str, SpectrumFitModel]:
spec_fits_dumps = [i.model_dump() for i in spec_fits]

with WorkerPool(n_jobs=4, use_dill=True) as pool:
results = pool.map(
run_fit_multi, spec_fits_dumps, progress_bar=True, progress_bar_style="rich"
)
# patch spec_fits, setattr fit_result
fit_model_results = {}
for result in results:
_spec_fit_search = [
i for i in spec_fits if i.model.lmfit_model.name == result.model.name
]
if len(_spec_fit_search) != 1:
continue
_spec_fit = _spec_fit_search[0]
_spec_fit.fit_result = result
fit_model_results[_spec_fit.model.name] = _spec_fit
return fit_model_results
65 changes: 65 additions & 0 deletions src/raman_fitting/delegating/run_fit_spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List, Dict

from raman_fitting.delegating.run_fit_multi import run_fit_multiprocessing
from raman_fitting.models.spectrum import SpectrumData
from raman_fitting.types import LMFitModelCollection
from raman_fitting.delegating.models import AggregatedSampleSpectrumFitResult
from raman_fitting.delegating.pre_processing import (
prepare_aggregated_spectrum_from_files,
)
from raman_fitting.imports.models import RamanFileInfo
from raman_fitting.models.deconvolution.spectrum_regions import RegionNames
from raman_fitting.models.fit_models import SpectrumFitModel

from loguru import logger


def run_fit_over_selected_models(
raman_files: List[RamanFileInfo],
models: LMFitModelCollection,
use_multiprocessing: bool = False,
) -> Dict[RegionNames, AggregatedSampleSpectrumFitResult]:
results = {}
for region_name, model_region_grp in models.items():
aggregated_spectrum = prepare_aggregated_spectrum_from_files(
region_name, raman_files
)
if aggregated_spectrum is None:
continue
spec_fits = prepare_spec_fit_regions(
aggregated_spectrum.spectrum, model_region_grp
)
if use_multiprocessing:
fit_model_results = run_fit_multiprocessing(spec_fits)
else:
fit_model_results = run_fit_loop(spec_fits)
fit_region_results = AggregatedSampleSpectrumFitResult(
region_name=region_name,
aggregated_spectrum=aggregated_spectrum,
fit_model_results=fit_model_results,
)
results[region_name] = fit_region_results
return results


def prepare_spec_fit_regions(
spectrum: SpectrumData, model_region_grp
) -> List[SpectrumFitModel]:
spec_fits = []
for model_name, model in model_region_grp.items():
region = model.region_name.name
spec_fit = SpectrumFitModel(spectrum=spectrum, model=model, region=region)
spec_fits.append(spec_fit)
return spec_fits


def run_fit_loop(spec_fits: List[SpectrumFitModel]) -> Dict[str, SpectrumFitModel]:
fit_model_results = {}
for spec_fit in spec_fits:
# include optional https://lmfit.github.io/lmfit-py/model.html#saving-and-loading-modelresults
spec_fit.run_fit()
logger.debug(
f"Fit with model {spec_fit.model.name} on {spec_fit.region} success: {spec_fit.fit_result.success} in {spec_fit.elapsed_time:.2f}s."
)
fit_model_results[spec_fit.model.name] = spec_fit
return fit_model_results
3 changes: 2 additions & 1 deletion src/raman_fitting/interfaces/typer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def run(
),
],
run_mode: Annotated[RunModes, typer.Argument()] = RunModes.NORMAL,
multiprocessing: Annotated[bool, typer.Option("--multiprocessing")] = False,
):
if run_mode is None:
print("No make run mode passed")
raise typer.Exit()
kwargs = {"run_mode": run_mode}
kwargs = {"run_mode": run_mode, "use_multiprocessing": multiprocessing}
if run_mode == RunModes.EXAMPLES:
kwargs.update({"fit_model_specific_names": ["2peaks", "3peaks", "4peaks"]})
logger.info(f"Starting raman_fitting with CLI args:\n{run_mode}")
Expand Down

0 comments on commit 0e003e4

Please sign in to comment.