-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for multiprocessing
- Loading branch information
David Wallace
committed
Mar 10, 2024
1 parent
656c8f7
commit 0e003e4
Showing
4 changed files
with
122 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Dict, List | ||
|
||
from loguru import logger | ||
from mpire import WorkerPool | ||
|
||
from raman_fitting.models.fit_models import SpectrumFitModel | ||
|
||
|
||
def run_fit_multi(**kwargs) -> SpectrumFitModel: | ||
# include optional https://lmfit.github.io/lmfit-py/model.html#saving-and-loading-modelresults | ||
spectrum = kwargs.pop("spectrum") | ||
model = kwargs.pop("model") | ||
lmfit_model = model["lmfit_model"] | ||
region = kwargs.pop("region") | ||
import time | ||
|
||
lmfit_kwargs = {} | ||
if "method" not in kwargs: | ||
lmfit_kwargs["method"] = "leastsq" | ||
|
||
init_params = lmfit_model.make_params() | ||
start_time = time.time() | ||
x, y = spectrum["ramanshift"], spectrum["intensity"] | ||
out = lmfit_model.fit(y, init_params, x=x, **lmfit_kwargs) # 'leastsq' | ||
end_time = time.time() | ||
elapsed_seconds = abs(start_time - end_time) | ||
elapsed_time = elapsed_seconds | ||
logger.debug( | ||
f"Fit with model {model['name']} on {region} success: {out.success} in {elapsed_time:.2f}s." | ||
) | ||
return out | ||
|
||
|
||
def run_fit_multiprocessing( | ||
spec_fits: List[SpectrumFitModel], | ||
) -> Dict[str, SpectrumFitModel]: | ||
spec_fits_dumps = [i.model_dump() for i in spec_fits] | ||
|
||
with WorkerPool(n_jobs=4, use_dill=True) as pool: | ||
results = pool.map( | ||
run_fit_multi, spec_fits_dumps, progress_bar=True, progress_bar_style="rich" | ||
) | ||
# patch spec_fits, setattr fit_result | ||
fit_model_results = {} | ||
for result in results: | ||
_spec_fit_search = [ | ||
i for i in spec_fits if i.model.lmfit_model.name == result.model.name | ||
] | ||
if len(_spec_fit_search) != 1: | ||
continue | ||
_spec_fit = _spec_fit_search[0] | ||
_spec_fit.fit_result = result | ||
fit_model_results[_spec_fit.model.name] = _spec_fit | ||
return fit_model_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import List, Dict | ||
|
||
from raman_fitting.delegating.run_fit_multi import run_fit_multiprocessing | ||
from raman_fitting.models.spectrum import SpectrumData | ||
from raman_fitting.types import LMFitModelCollection | ||
from raman_fitting.delegating.models import AggregatedSampleSpectrumFitResult | ||
from raman_fitting.delegating.pre_processing import ( | ||
prepare_aggregated_spectrum_from_files, | ||
) | ||
from raman_fitting.imports.models import RamanFileInfo | ||
from raman_fitting.models.deconvolution.spectrum_regions import RegionNames | ||
from raman_fitting.models.fit_models import SpectrumFitModel | ||
|
||
from loguru import logger | ||
|
||
|
||
def run_fit_over_selected_models( | ||
raman_files: List[RamanFileInfo], | ||
models: LMFitModelCollection, | ||
use_multiprocessing: bool = False, | ||
) -> Dict[RegionNames, AggregatedSampleSpectrumFitResult]: | ||
results = {} | ||
for region_name, model_region_grp in models.items(): | ||
aggregated_spectrum = prepare_aggregated_spectrum_from_files( | ||
region_name, raman_files | ||
) | ||
if aggregated_spectrum is None: | ||
continue | ||
spec_fits = prepare_spec_fit_regions( | ||
aggregated_spectrum.spectrum, model_region_grp | ||
) | ||
if use_multiprocessing: | ||
fit_model_results = run_fit_multiprocessing(spec_fits) | ||
else: | ||
fit_model_results = run_fit_loop(spec_fits) | ||
fit_region_results = AggregatedSampleSpectrumFitResult( | ||
region_name=region_name, | ||
aggregated_spectrum=aggregated_spectrum, | ||
fit_model_results=fit_model_results, | ||
) | ||
results[region_name] = fit_region_results | ||
return results | ||
|
||
|
||
def prepare_spec_fit_regions( | ||
spectrum: SpectrumData, model_region_grp | ||
) -> List[SpectrumFitModel]: | ||
spec_fits = [] | ||
for model_name, model in model_region_grp.items(): | ||
region = model.region_name.name | ||
spec_fit = SpectrumFitModel(spectrum=spectrum, model=model, region=region) | ||
spec_fits.append(spec_fit) | ||
return spec_fits | ||
|
||
|
||
def run_fit_loop(spec_fits: List[SpectrumFitModel]) -> Dict[str, SpectrumFitModel]: | ||
fit_model_results = {} | ||
for spec_fit in spec_fits: | ||
# include optional https://lmfit.github.io/lmfit-py/model.html#saving-and-loading-modelresults | ||
spec_fit.run_fit() | ||
logger.debug( | ||
f"Fit with model {spec_fit.model.name} on {spec_fit.region} success: {spec_fit.fit_result.success} in {spec_fit.elapsed_time:.2f}s." | ||
) | ||
fit_model_results[spec_fit.model.name] = spec_fit | ||
return fit_model_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters