From 6b42cf8957a2a11c8ce7bd2c48829b8521a9870b Mon Sep 17 00:00:00 2001 From: David Wallace Date: Sun, 10 Mar 2024 11:39:37 +0100 Subject: [PATCH] chore: clean up preprocessing --- .../delegating/pre_processing.py | 38 ++---- .../imports/spectrum/spectra_collection.py | 120 +++++++----------- .../imports/spectrum/spectrum_constructor.py | 85 ------------- 3 files changed, 54 insertions(+), 189 deletions(-) delete mode 100644 src/raman_fitting/imports/spectrum/spectrum_constructor.py diff --git a/src/raman_fitting/delegating/pre_processing.py b/src/raman_fitting/delegating/pre_processing.py index b28afee..f58b63c 100644 --- a/src/raman_fitting/delegating/pre_processing.py +++ b/src/raman_fitting/delegating/pre_processing.py @@ -1,6 +1,5 @@ from typing import List -from raman_fitting.models.spectrum import SpectrumData from raman_fitting.models.splitter import RegionNames from raman_fitting.imports.spectrumdata_parser import SpectrumReader from raman_fitting.processing.post_processing import SpectrumProcessor @@ -10,17 +9,17 @@ PreparedSampleSpectrum, ) -import numpy as np from loguru import logger from raman_fitting.config.path_settings import CLEAN_SPEC_REGION_NAME_PREFIX +from ..imports.spectrum.spectra_collection import SpectraDataCollection def prepare_aggregated_spectrum_from_files( region_name: RegionNames, raman_files: List[RamanFileInfo] ) -> AggregatedSampleSpectrum | None: - selected_processed_data = f"{CLEAN_SPEC_REGION_NAME_PREFIX}{region_name}" - clean_data_for_region = {} + select_region_key = f"{CLEAN_SPEC_REGION_NAME_PREFIX}{region_name}" + clean_data_for_region = [] data_sources = [] for i in raman_files: read = SpectrumReader(i.file) @@ -29,32 +28,17 @@ def prepare_aggregated_spectrum_from_files( file_info=i, read=read, processed=processed ) data_sources.append(prepared_spec) - selected_clean_data = processed.clean_spectrum.spec_regions[ - selected_processed_data - ] - clean_data_for_region[i.file] = selected_clean_data - + selected_clean_data = processed.clean_spectrum.spec_regions[select_region_key] + clean_data_for_region.append(selected_clean_data) if not clean_data_for_region: - logger.info("prepare_mean_data_for_fitting received no files.") + logger.warning( + f"prepare_mean_data_for_fitting received no files. {region_name}" + ) return - # wrap this in a ProcessedSpectraCollection model - mean_int = np.mean( - np.vstack([i.intensity for i in clean_data_for_region.values()]), axis=0 - ) - mean_ramanshift = np.mean( - np.vstack([i.ramanshift for i in clean_data_for_region.values()]), axis=0 - ) - source_files = list(map(str, clean_data_for_region.keys())) - mean_spec = SpectrumData( - **{ - "ramanshift": mean_ramanshift, - "intensity": mean_int, - "label": f"clean_{region_name}_mean", - "region_name": region_name, - "source": source_files, - } + spectra_collection = SpectraDataCollection( + spectra=clean_data_for_region, region_name=region_name ) aggregated_spectrum = AggregatedSampleSpectrum( - sources=data_sources, spectrum=mean_spec + sources=data_sources, spectrum=spectra_collection.mean_spectrum ) return aggregated_spectrum diff --git a/src/raman_fitting/imports/spectrum/spectra_collection.py b/src/raman_fitting/imports/spectrum/spectra_collection.py index 3324a5b..da840ca 100644 --- a/src/raman_fitting/imports/spectrum/spectra_collection.py +++ b/src/raman_fitting/imports/spectrum/spectra_collection.py @@ -1,97 +1,63 @@ -import logging -from operator import itemgetter -from typing import Dict, List +from typing import List import numpy as np -from pydantic import BaseModel, ValidationError, model_validator, ConfigDict +from pydantic import BaseModel, ValidationError, model_validator -from .spectrum_constructor import SpectrumDataLoader +from raman_fitting.models.deconvolution.spectrum_regions import RegionNames from raman_fitting.models.spectrum import SpectrumData -logger = logging.getLogger(__name__) -SPECTRUM_KEYS = ("ramanshift", "intensity") - - -class PostProcessedSpectrum(BaseModel): - pass - class SpectraDataCollection(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) + spectra: List[SpectrumData] + region_name: RegionNames + mean_spectrum: SpectrumData | None = None - spectra: List[SpectrumDataLoader] + @model_validator(mode="after") + def check_spectra_have_same_label(self) -> "SpectraDataCollection": + """checks member of lists""" + labels = set(i.label for i in self.spectra) + if len(labels) > 1: + raise ValidationError(f"Spectra have different labels {labels}") + return self @model_validator(mode="after") - def check_spectra_have_clean_spectrum(self) -> "SpectraDataCollection": + def check_spectra_have_same_region(self) -> "SpectraDataCollection": """checks member of lists""" - if not all(hasattr(spec, "clean_spectrum") for spec in self.spectra): - raise ValidationError("missing clean_data attribute") + region_names = set(i.region_name for i in self.spectra) + if len(region_names) > 1: + raise ValidationError(f"Spectra have different region_names {region_names}") return self @model_validator(mode="after") def check_spectra_lengths(self) -> "SpectraDataCollection": - unique_lengths = set([i.spectrum_length for i in self.spectra]) - if len(unique_lengths) > 1: + unique_lengths_rs = set(len(i.ramanshift) for i in self.spectra) + unique_lengths_int = set(len(i.intensity) for i in self.spectra) + if len(unique_lengths_rs) > 1: raise ValidationError( - f"The spectra have different lenghts where they should be the same.\n\t{unique_lengths}" + f"The spectra have different ramanshift lengths where they should be the same.\n\t{unique_lengths_rs}" + ) + if len(unique_lengths_int) > 1: + raise ValidationError( + f"The spectra have different intensity lengths where they should be the same.\n\t{unique_lengths_int}" ) - return self - - -def get_mean_spectra_info(spectra: List[SpectrumDataLoader]) -> Dict: - """retrieves the info dict from spec instances and merges dict in keys that have 1 common value""" - - all_spec_info = [spec.info for spec in spectra] - _all_spec_info_merged = {k: val for i in all_spec_info for k, val in i.items()} - _all_spec_info_sets = [ - (k, set([i.get(k, None) for i in all_spec_info])) for k in _all_spec_info_merged - ] - mean_spec_info = { - k: list(val)[0] for k, val in _all_spec_info_sets if len(val) == 1 - } - mean_spec_info.update({"mean_spectrum": True}) - return mean_spec_info - - -def calculate_mean_spectrum_from_spectra( - spectra: List[SpectrumDataLoader], -) -> Dict[str, SpectrumData]: - """retrieves the clean data from spec instances and makes lists of tuples""" - - try: - spectra_regions = [i.clean_spectrum.spec_regions for i in spectra] - mean_spec_regions = {} - for region_name in spectra_regions[0].keys(): - regions_specs = [i[region_name] for i in spectra_regions] - ramanshift_mean = np.mean([i.ramanshift for i in regions_specs], axis=0) - intensity_mean = np.mean([i.intensity for i in regions_specs], axis=0) - - _data = { - "ramanshift": ramanshift_mean, - "intensity": intensity_mean, - "label": regions_specs[0].label + "_mean", - "region_name": region_name + "_mean", - } - mean_spec = SpectrumData(**_data) - mean_spec_regions[region_name] = mean_spec - - except Exception: - logger.warning(f"get_mean_spectra_prep_data failed for spectra {spectra}") - mean_spec_regions = {} - - return mean_spec_regions - -def get_best_guess_spectra_length(spectra: List[SpectrumDataLoader]) -> List: - lengths = [i.spectrum_length for i in spectra] - set_lengths = set(lengths) - if len(set_lengths) == 1: - # print(f'Spectra all same length {set_lengths}') - return spectra + return self - length_counts = [(i, lengths.count(i)) for i in set_lengths] - best_guess_length = max(length_counts, key=itemgetter(1))[0] - print(f"Spectra not same length {length_counts} took {best_guess_length}") - spectra = [spec for spec in spectra if spec.spectrum_length == best_guess_length] - return spectra + @model_validator(mode="after") + def set_mean_spectrum(self) -> "SpectraDataCollection": + # wrap this in a ProcessedSpectraCollection model + mean_int = np.mean(np.vstack([i.intensity for i in self.spectra]), axis=0) + mean_ramanshift = np.mean( + np.vstack([i.ramanshift for i in self.spectra]), axis=0 + ) + source_files = list(set(i.source for i in self.spectra)) + _label = "".join(map(str, set(i.label for i in self.spectra))) + mean_spec = SpectrumData( + ramanshift=mean_ramanshift, + intensity=mean_int, + label=f"clean_{self.region_name}_mean", + region_name=self.region_name, + source=source_files, + ) + self.mean_spectrum = mean_spec diff --git a/src/raman_fitting/imports/spectrum/spectrum_constructor.py b/src/raman_fitting/imports/spectrum/spectrum_constructor.py deleted file mode 100644 index 5a58aae..0000000 --- a/src/raman_fitting/imports/spectrum/spectrum_constructor.py +++ /dev/null @@ -1,85 +0,0 @@ -import logging -from dataclasses import dataclass, field -from pathlib import Path -from typing import Dict - -import pandas as pd - -from raman_fitting.imports.spectrumdata_parser import SpectrumReader -from pydantic import BaseModel - -from raman_fitting.processing.post_processing import SpectrumProcessor - -from raman_fitting.models.splitter import SplitSpectrum - -logger = logging.getLogger(__name__) -SPECTRUM_KEYS = ("ramanshift", "intensity") - - -class PostProcessedSpectrum(BaseModel): - pass - - -@dataclass(order=True, frozen=False) -class SpectrumDataLoader: - """ - Raman Spectrum Loader Dataclass, reads in the file and constructs - a clean spectrum from the data. - A sequence of steps is performed on the raw data from SpectrumReader. - The steps/methods are: smoothening filter, despiking and baseline correction. - """ - - file: Path = field(default_factory=Path) - info: Dict = field(default_factory=dict, repr=False) - # ovv: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) - run_kwargs: Dict = field(default_factory=dict, repr=False) - spectrum_length: int = field(default=0, init=False) - clean_spectrum: SplitSpectrum = field(default=None, init=False) - - def __post_init__(self): - self._qcnm = self.__class__.__qualname__ - self.register = {} # this stores the data of each method as they are performed - self.load_data_delegator() - - def validate_info_with_filepath(self): - if not self.info: - self.info = {"FilePath": self.file} - return - filepath_ = self.info.get("FilePath", None) - if filepath_ and Path(filepath_) != self.file: - raise ValueError( - f"Mismatch in value for FilePath: {self.file} != {filepath_}" - ) - - def load_data_delegator(self): - """calls the SpectrumReader class""" - - self.validate_info_with_filepath() - raw_spectrum = SpectrumReader(self.file) - self._raw_spectrum = raw_spectrum - self.info = {**self.info, **self.run_kwargs} - self.spectrum_length = 0 - if raw_spectrum.spectrum is None or raw_spectrum.spectrum_length == 0: - logger.error(f"{self._qcnm} load data fail for:\n\t {self.file}") - return - - spectrum_processor = SpectrumProcessor(raw_spectrum.spectrum) - self.clean_spectrum = spectrum_processor.clean_spectrum - self.spectrum_length = raw_spectrum.spectrum_length - - def set_clean_spectrum_df(self): - if self.clean_spectrum is None: - return - - self.clean_df = { - k: pd.DataFrame({"ramanshift": val.ramanshift, "int": val.intensity}) - for k, val in self.clean_spectrum.items() - } - - def plot_raw(self): - _raw_lbls = [ - i - for i in self.register_df.columns - if not any(a in i for a in ["ramanshift", "clean_spectrum"]) - ] - self.register_df.plot(x="ramanshift", y=_raw_lbls)