From 554760daa97161f7de38b69a2afd4897d2020976 Mon Sep 17 00:00:00 2001
From: David Wallace <mypydavid@proton.me>
Date: Sun, 10 Mar 2024 21:39:23 +0100
Subject: [PATCH] tests: add test for base peak

---
 src/raman_fitting/config/base_settings.py     |  8 +++++
 .../models/deconvolution/base_peak.py         | 34 -------------------
 tests/conftest.py                             |  5 +++
 tests/models/test_base_peak.py                | 34 +++++++++++++++++++
 4 files changed, 47 insertions(+), 34 deletions(-)
 create mode 100644 tests/models/test_base_peak.py

diff --git a/src/raman_fitting/config/base_settings.py b/src/raman_fitting/config/base_settings.py
index 424fb6e..f8b6d6b 100644
--- a/src/raman_fitting/config/base_settings.py
+++ b/src/raman_fitting/config/base_settings.py
@@ -16,6 +16,7 @@
 )
 from .default_models import load_config_from_toml_files
 from .path_settings import create_default_package_dir_or_ask, InternalPathSettings
+from types import MappingProxyType
 
 
 def get_default_models_and_peaks_from_definitions():
@@ -36,5 +37,12 @@ class Settings(BaseSettings):
         init_var=False,
         validate_default=False,
     )
+    default_definitions: MappingProxyType | None = Field(
+        default_factory=load_config_from_toml_files,
+        alias="my_default_definitions",
+        init_var=False,
+        validate_default=False,
+    )
+
     destination_dir: Path = Field(default_factory=create_default_package_dir_or_ask)
     internal_paths: InternalPathSettings = Field(default_factory=InternalPathSettings)
diff --git a/src/raman_fitting/models/deconvolution/base_peak.py b/src/raman_fitting/models/deconvolution/base_peak.py
index 4d05344..4649b34 100644
--- a/src/raman_fitting/models/deconvolution/base_peak.py
+++ b/src/raman_fitting/models/deconvolution/base_peak.py
@@ -217,37 +217,3 @@ def get_peaks_from_peak_definitions(
         for peak_name, peak_def in peak_type_defs.items():
             peak_models[peak_name] = BasePeak(**peak_def)
     return peak_models
-
-
-def _main():
-    model_definitions = load_config_from_toml_files()
-    print(model_definitions["first_order"]["models"])
-    peaks = {}
-    peak_items = {
-        **model_definitions["first_order"]["peaks"],
-        **model_definitions["second_order"]["peaks"],
-    }.items()
-    for k, v in peak_items:
-        peaks.update({k: BasePeak(**v)})
-
-    peak_d = BasePeak(**model_definitions["first_order"]["peaks"]["D"])
-    print(peak_d)
-    model_items = {
-        **model_definitions["first_order"]["models"],
-        **model_definitions["second_order"]["models"],
-    }.items()
-    models = {}
-    for model_name, model_comp in model_items:
-        print(k, v)
-        comps = model_comp.split("+")
-        peak_comps = [peaks[i] for i in comps]
-        lmfit_comp_model = sum(
-            map(lambda x: x.lmfit_model, peak_comps), peak_comps.pop().lmfit_model
-        )
-        models[model_name] = lmfit_comp_model
-        print(lmfit_comp_model)
-    # breakpoint()
-
-
-if __name__ == "__main__":
-    _main()
diff --git a/tests/conftest.py b/tests/conftest.py
index cef9296..9f95487 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -28,6 +28,11 @@ def example_files(internal_paths):
     return example_files
 
 
+@pytest.fixture(autouse=True)
+def default_definitions(internal_paths):
+    return settings.default_definitions
+
+
 @pytest.fixture(autouse=True)
 def default_models(internal_paths):
     return settings.default_models
diff --git a/tests/models/test_base_peak.py b/tests/models/test_base_peak.py
new file mode 100644
index 0000000..4be455c
--- /dev/null
+++ b/tests/models/test_base_peak.py
@@ -0,0 +1,34 @@
+from raman_fitting.models.deconvolution.base_peak import BasePeak
+
+
+def test_initialize_base_peaks(
+    default_definitions, default_models_first_order, default_models_second_order
+):
+    peaks = {}
+
+    peak_items = {
+        **default_definitions["first_order"]["peaks"],
+        **default_definitions["second_order"]["peaks"],
+    }.items()
+    for k, v in peak_items:
+        peaks.update({k: BasePeak(**v)})
+
+    peak_d = BasePeak(**default_definitions["first_order"]["peaks"]["D"])
+    assert (
+        peak_d.peak_name
+        == default_definitions["first_order"]["peaks"]["D"]["peak_name"]
+    )
+    assert (
+        peak_d.peak_type
+        == default_definitions["first_order"]["peaks"]["D"]["peak_type"]
+    )
+    assert (
+        peak_d.lmfit_model.components[0].prefix
+        == default_definitions["first_order"]["peaks"]["D"]["peak_name"] + "_"
+    )
+    assert (
+        peak_d.param_hints["center"].value
+        == default_definitions["first_order"]["peaks"]["D"]["param_hints"]["center"][
+            "value"
+        ]
+    )