diff --git a/.cspell_dict.txt b/.cspell_dict.txt index 6842f203..491a9105 100644 --- a/.cspell_dict.txt +++ b/.cspell_dict.txt @@ -24,10 +24,12 @@ axlast axpy Axrf Axrs +aypx Bcai Bcajsr bCaMK Bcass +bcast Beta0 Beta1 bnd_rigid @@ -278,6 +280,7 @@ modelparams modifyitems monodomain mpio +mult multigrid multiselect myocytes @@ -447,6 +450,7 @@ usecols v Varr vcell +Vecs vffrt vfrt vjsr diff --git a/.github/workflows/mpi.yml b/.github/workflows/mpi.yml new file mode 100644 index 00000000..44967cc5 --- /dev/null +++ b/.github/workflows/mpi.yml @@ -0,0 +1,24 @@ +name: CI mpi + +on: [push] + +jobs: + test: + name: Run tests with in paralell + runs-on: ubuntu-latest + timeout-minutes: 20 + container: + image: ghcr.io/scientificcomputing/fenics-gmsh:2023-02-20 + + steps: + - uses: actions/checkout@v3 + + - name: Install dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install h5py --no-binary=h5py + python3 -m pip install -e ".[dev]" + + - name: Test with pytest + run: | + mpirun -n 2 python3 -m pytest --no-cov diff --git a/demos/drug_factors/OM_ORd_0p01uM.json b/demos/drug_factors/OM_ORd_0p01uM.json new file mode 100644 index 00000000..8d621373 --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p01uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.0088856374565136, "scale_drug_kwu": 0.9996652773446093, "scale_drug_kws": 0.9996318733194867} diff --git a/demos/drug_factors/OM_ORd_0p05uM.json b/demos/drug_factors/OM_ORd_0p05uM.json new file mode 100644 index 00000000..298afaa8 --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p05uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.0978997656020737, "scale_drug_kwu": 0.9929543496360429, "scale_drug_kws": 0.9934127591569167} diff --git a/demos/drug_factors/OM_ORd_0p15uM.json b/demos/drug_factors/OM_ORd_0p15uM.json new file mode 100644 index 00000000..d43b9be5 --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p15uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.4767525048773726, "scale_drug_kwu": 0.9474720847816791, "scale_drug_kws": 0.9560162970990137} diff --git a/demos/drug_factors/OM_ORd_0p1uM.json b/demos/drug_factors/OM_ORd_0p1uM.json new file mode 100644 index 00000000..3331339e --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p1uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.269045922497537, "scale_drug_kwu": 0.9745091137929672, "scale_drug_kws": 0.9777882574597025} diff --git a/demos/drug_factors/OM_ORd_0p25uM.json b/demos/drug_factors/OM_ORd_0p25uM.json new file mode 100644 index 00000000..0f9f17bf --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p25uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.941480927420373, "scale_drug_kwu": 0.8785071795144263, "scale_drug_kws": 0.902608308710906} diff --git a/demos/drug_factors/OM_ORd_0p2uM.json b/demos/drug_factors/OM_ORd_0p2uM.json new file mode 100644 index 00000000..4763f224 --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p2uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.7044452424646337, "scale_drug_kwu": 0.9146133496969047, "scale_drug_kws": 0.930335941671022} diff --git a/demos/drug_factors/OM_ORd_0p3uM.json b/demos/drug_factors/OM_ORd_0p3uM.json new file mode 100644 index 00000000..888b8380 --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p3uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 2.180577509560597, "scale_drug_kwu": 0.8412966373970264, "scale_drug_kws": 0.8742891733263732} diff --git a/demos/drug_factors/OM_ORd_0p4uM.json b/demos/drug_factors/OM_ORd_0p4uM.json new file mode 100644 index 00000000..01cfcce5 --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p4uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 2.6465097661792636, "scale_drug_kwu": 0.7694839916285091, "scale_drug_kws": 0.8197525418911272} diff --git a/demos/drug_factors/OM_ORd_0p6uM.json b/demos/drug_factors/OM_ORd_0p6uM.json new file mode 100644 index 00000000..2617017b --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p6uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 3.469555524323334, "scale_drug_kwu": 0.6537900886896717, "scale_drug_kws": 0.7304623620053443} diff --git a/demos/drug_factors/OM_ORd_0p8uM.json b/demos/drug_factors/OM_ORd_0p8uM.json new file mode 100644 index 00000000..e3afb373 --- /dev/null +++ b/demos/drug_factors/OM_ORd_0p8uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 4.123192317376672, "scale_drug_kwu": 0.5767117324724851, "scale_drug_kws": 0.6687507645297328} diff --git a/demos/drug_factors/OM_ORd_10p0uM.json b/demos/drug_factors/OM_ORd_10p0uM.json new file mode 100644 index 00000000..599a2307 --- /dev/null +++ b/demos/drug_factors/OM_ORd_10p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 6.999309749749809, "scale_drug_kwu": 0.3922088221303537, "scale_drug_kws": 0.5026870922267004} diff --git a/demos/drug_factors/OM_ORd_1p0uM.json b/demos/drug_factors/OM_ORd_1p0uM.json new file mode 100644 index 00000000..e3609fde --- /dev/null +++ b/demos/drug_factors/OM_ORd_1p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 4.629630066864853, "scale_drug_kwu": 0.5266448966768194, "scale_drug_kws": 0.6271217565262506} diff --git a/demos/drug_factors/OM_ORd_2p0uM.json b/demos/drug_factors/OM_ORd_2p0uM.json new file mode 100644 index 00000000..0e0b93d8 --- /dev/null +++ b/demos/drug_factors/OM_ORd_2p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 5.929416378690433, "scale_drug_kwu": 0.43379486246612387, "scale_drug_kws": 0.5445866278613021} diff --git a/demos/drug_factors/OM_ORd_3p0uM.json b/demos/drug_factors/OM_ORd_3p0uM.json new file mode 100644 index 00000000..9cbcf952 --- /dev/null +++ b/demos/drug_factors/OM_ORd_3p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 6.412447446238673, "scale_drug_kwu": 0.41108288992071473, "scale_drug_kws": 0.5225309650727457} diff --git a/demos/drug_factors/OM_ORd_4p0uM.json b/demos/drug_factors/OM_ORd_4p0uM.json new file mode 100644 index 00000000..3ae9358b --- /dev/null +++ b/demos/drug_factors/OM_ORd_4p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 6.643997266654334, "scale_drug_kwu": 0.40238549865921347, "scale_drug_kws": 0.5136732885993582} diff --git a/demos/drug_factors/OM_ORd_5p0uM.json b/demos/drug_factors/OM_ORd_5p0uM.json new file mode 100644 index 00000000..8baa707d --- /dev/null +++ b/demos/drug_factors/OM_ORd_5p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 6.774221008260938, "scale_drug_kwu": 0.39816285965920994, "scale_drug_kws": 0.5092338211609909} diff --git a/demos/drug_factors/OM_Tor_0p01uM.json b/demos/drug_factors/OM_Tor_0p01uM.json new file mode 100644 index 00000000..84555df1 --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p01uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.0036066332738462, "scale_drug_kwu": 0.9988983548993555, "scale_drug_kws": 0.9916975158949597} diff --git a/demos/drug_factors/OM_Tor_0p05uM.json b/demos/drug_factors/OM_Tor_0p05uM.json new file mode 100644 index 00000000..f1b79ac6 --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p05uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.0761901109483183, "scale_drug_kwu": 0.9878623455371002, "scale_drug_kws": 0.9402233109846877} diff --git a/demos/drug_factors/OM_Tor_0p15uM.json b/demos/drug_factors/OM_Tor_0p15uM.json new file mode 100644 index 00000000..8677e109 --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p15uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.5822307206911217, "scale_drug_kwu": 0.9408920222338004, "scale_drug_kws": 0.8223227643719653} diff --git a/demos/drug_factors/OM_Tor_0p1uM.json b/demos/drug_factors/OM_Tor_0p1uM.json new file mode 100644 index 00000000..8b1bf993 --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p1uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.2784075771383034, "scale_drug_kwu": 0.966643572414661, "scale_drug_kws": 0.875893839056697} diff --git a/demos/drug_factors/OM_Tor_0p25uM.json b/demos/drug_factors/OM_Tor_0p25uM.json new file mode 100644 index 00000000..a003448a --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p25uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 2.3997308651089764, "scale_drug_kwu": 0.8832747952953535, "scale_drug_kws": 0.7455182673276282} diff --git a/demos/drug_factors/OM_Tor_0p2uM.json b/demos/drug_factors/OM_Tor_0p2uM.json new file mode 100644 index 00000000..c642a615 --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p2uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 1.9638584056192134, "scale_drug_kwu": 0.9126625800533243, "scale_drug_kws": 0.7795400331600741} diff --git a/demos/drug_factors/OM_Tor_0p3uM.json b/demos/drug_factors/OM_Tor_0p3uM.json new file mode 100644 index 00000000..e381888d --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p3uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 2.8681668815513666, "scale_drug_kwu": 0.8536314996303338, "scale_drug_kws": 0.7182554032147207} diff --git a/demos/drug_factors/OM_Tor_0p4uM.json b/demos/drug_factors/OM_Tor_0p4uM.json new file mode 100644 index 00000000..3b735481 --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p4uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 3.832274572267599, "scale_drug_kwu": 0.7958650208325873, "scale_drug_kws": 0.6780153285045393} diff --git a/demos/drug_factors/OM_Tor_0p6uM.json b/demos/drug_factors/OM_Tor_0p6uM.json new file mode 100644 index 00000000..cad6b95c --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p6uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 5.576479937026704, "scale_drug_kwu": 0.6938234586483305, "scale_drug_kws": 0.6303184177777927} diff --git a/demos/drug_factors/OM_Tor_0p8uM.json b/demos/drug_factors/OM_Tor_0p8uM.json new file mode 100644 index 00000000..451b6f65 --- /dev/null +++ b/demos/drug_factors/OM_Tor_0p8uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 6.89322663999222, "scale_drug_kwu": 0.6127852918097438, "scale_drug_kws": 0.6039420141669012} diff --git a/demos/drug_factors/OM_Tor_10p0uM.json b/demos/drug_factors/OM_Tor_10p0uM.json new file mode 100644 index 00000000..8c88a7dc --- /dev/null +++ b/demos/drug_factors/OM_Tor_10p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 10.697710897975815, "scale_drug_kwu": 0.25620303265744615, "scale_drug_kws": 0.5332675144347352} diff --git a/demos/drug_factors/OM_Tor_1p0uM.json b/demos/drug_factors/OM_Tor_1p0uM.json new file mode 100644 index 00000000..b6a830c8 --- /dev/null +++ b/demos/drug_factors/OM_Tor_1p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 7.826325860076493, "scale_drug_kwu": 0.5499969248258909, "scale_drug_kws": 0.5876053136816772} diff --git a/demos/drug_factors/OM_Tor_2p0uM.json b/demos/drug_factors/OM_Tor_2p0uM.json new file mode 100644 index 00000000..59f3edf7 --- /dev/null +++ b/demos/drug_factors/OM_Tor_2p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 9.746288795760044, "scale_drug_kwu": 0.3888488665897669, "scale_drug_kws": 0.555231579295989} diff --git a/demos/drug_factors/OM_Tor_3p0uM.json b/demos/drug_factors/OM_Tor_3p0uM.json new file mode 100644 index 00000000..9f48cf57 --- /dev/null +++ b/demos/drug_factors/OM_Tor_3p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 10.258254191231888, "scale_drug_kwu": 0.3289624699606213, "scale_drug_kws": 0.5452294488953127} diff --git a/demos/drug_factors/OM_Tor_4p0uM.json b/demos/drug_factors/OM_Tor_4p0uM.json new file mode 100644 index 00000000..ba280252 --- /dev/null +++ b/demos/drug_factors/OM_Tor_4p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 10.45914284553491, "scale_drug_kwu": 0.30025482501512335, "scale_drug_kws": 0.540584661821546} diff --git a/demos/drug_factors/OM_Tor_5p0uM.json b/demos/drug_factors/OM_Tor_5p0uM.json new file mode 100644 index 00000000..68de94ed --- /dev/null +++ b/demos/drug_factors/OM_Tor_5p0uM.json @@ -0,0 +1 @@ +{"scale_drug_kuw": 10.557668288612081, "scale_drug_kwu": 0.28410963029717573, "scale_drug_kws": 0.5379645956980111} diff --git a/demos/simple_demo.py b/demos/simple_demo.py index 44059758..4e8be591 100644 --- a/demos/simple_demo.py +++ b/demos/simple_demo.py @@ -46,7 +46,7 @@ # # We can now plot the state traces, where we also specify that we want the trace from the center of the slab -simcardems.postprocess.plot_state_traces(outdir.joinpath("results.h5"), "center") +# simcardems.postprocess.plot_state_traces(outdir.joinpath("results.h5"), "center") # This will create a figure in the output directory called `state_traces_center.png` which in this case is shown in {numref}`Figure {number} ` we see the resulting state traces, and can also see the instant drop in the active tension ($T_a$) at the time of the triggered release. # @@ -60,7 +60,7 @@ # We can also save the output to xdmf-files that can be viewed in Paraview # -simcardems.postprocess.make_xdmffiles(outdir.joinpath("results.h5")) +# simcardems.postprocess.make_xdmffiles(outdir.joinpath("results.h5")) # The `xdmf` files are can be opened in [Paraview](https://www.paraview.org/download/) to visualize the different variables such as in {numref}`Figure {number} `. # diff --git a/setup.cfg b/setup.cfg index acbb2391..98746856 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,7 @@ dev = pre-commit pytest pytest-cov + pytest-mpi sphinx twine wheel diff --git a/src/simcardems/datacollector.py b/src/simcardems/datacollector.py index 36996a48..bb8dc645 100644 --- a/src/simcardems/datacollector.py +++ b/src/simcardems/datacollector.py @@ -204,31 +204,47 @@ def __init__( ) -> None: self.outdir = Path(outdir) self._results_file = self.outdir / outfilename - self.comm = geo.mesh.mpi_comm() # FIXME: Is this important? + self.comm = geo.mesh.mpi_comm() + self._value_extractor = ValueExtractor(geo) + file_exist = False if reset_state: - utils.remove_file(self._results_file) + logger.debug("Reset state") + utils.remove_file(self._results_file, comm=self.comm) + + else: + dolfin.MPI.barrier(self.comm) + if self.comm.rank == 0: + file_exist = self._results_file.is_file() + file_exist = self.comm.bcast(file_exist, root=0) + dolfin.MPI.barrier(self.comm) self._times_stamps = set() - if not self._results_file.is_file(): + logger.debug(f"File {self._results_file} exists: {file_exist}") + if not file_exist: + logger.debug("Dump geometry") geo.dump(self.results_file) + logger.debug("Done with dumping geometry") from . import __version__ from packaging.version import parse version = parse(__version__) - with h5pyfile(self._results_file, "a") as f: - f.create_dataset("version_major", data=version.major) - f.create_dataset("version_minor", data=version.minor) - f.create_dataset("version_micro", data=version.micro) + if self.comm.rank == 0: + with h5pyfile(self._results_file, "a", force_serial=True) as f: + f.create_dataset("version_major", data=version.major) + f.create_dataset("version_minor", data=version.minor) + f.create_dataset("version_micro", data=version.micro) else: + logger.debug("Try to read time stamps") try: with h5pyfile(self._results_file, "r") as f: self._times_stamps = set(f["ep"]["V"].keys()) except KeyError: pass + logger.debug("Done in datacollector init") self._functions: Dict[str, Dict[str, dolfin.Function]] = { "ep": {}, @@ -238,6 +254,9 @@ def __init__( "ep": {}, "mechanics": {}, } + # Let us synchronize so that we done have any processors + # that starts storing data before all processors get here + # dolfin.MPI.barrier(self.comm) @property def valid_reductions(self) -> List[str]: @@ -293,28 +312,38 @@ def store(self, t: float) -> None: self._times_stamps.add(t_str) # First do the full values + logger.debug( + f"Store in file : {self.results_file}, exists: {Path(self.results_file).exists()}", + ) + with dolfin.HDF5File(self.comm, self.results_file, "a") as h5file: for group, names in self.names.items(): logger.debug( - f"Save full states in HDF5File {self.results_file} for group {group}", + f"1. Save full states in HDF5File {self.results_file} for group {group}", ) for name in names: - logger.debug(f"Save {name}") + logger.debug( + f"1. Save {name}: reduction {self._reductions[group][name]}", + ) if self._reductions[group][name] == "full": f = self._functions[group][name] h5file.write(f, f"{group}/{name}/{t_str}") # Next do the other reductions - with h5pyfile(self.results_file, "a") as h5file: + with h5pyfile(self.results_file, "a", comm=self.comm) as h5file: for group, names in self.names.items(): logger.debug( - f"Save reduced states in HDF5File {self.results_file} for group {group}", + f"2. Save reduced states in HDF5File {self.results_file} for group {group}", ) for name in names: - logger.debug(f"Save {name}") + logger.debug( + f"2. Save {name}, reduction: {self._reductions[group][name]}", + ) if self._reductions[group][name] == "full": continue + logger.debug("Here") + f = self._functions[group][name] value = self._value_extractor.eval( f, @@ -323,6 +352,7 @@ def store(self, t: float) -> None: if f"{group}/{name}" not in h5file: h5file.create_group(f"{group}/{name}") h5file[f"{group}/{name}"].create_dataset(t_str, data=value) + logger.debug("Done storing") def save_residual(self, residual, index): logger.debug("Save residual") @@ -358,7 +388,7 @@ def __init__(self, h5name: utils.PathLike, empty_ok: bool = False) -> None: raise FileNotFoundError(f"File {h5name} does not exist") self.geo = load_geometry(self._h5name) - self._h5pyfile = h5pyfile(self._h5name, "r").__enter__() + self._h5pyfile = h5pyfile(self._h5name, "r", comm=self.geo.comm()).__enter__() self.version_major = extract_number_from_h5py( self._h5pyfile.get("version_major"), diff --git a/src/simcardems/ep_model.py b/src/simcardems/ep_model.py index 74ba631e..983c8990 100644 --- a/src/simcardems/ep_model.py +++ b/src/simcardems/ep_model.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from .models.em_model import BaseEMCoupling + from .models.cell_model import BaseCellModel logger = utils.getLogger(__name__) @@ -273,18 +274,20 @@ def load_json(filename: str): def handle_cell_params( - CellModel: Type[cbcbeat.CardiacCellModel], + CellModel: Type[BaseCellModel], cell_params: Optional[Dict[str, float]] = None, disease_state: str = "healthy", drug_factors_file: str = "", popu_factors_file: str = "", ): - cell_params_tmp = CellModel.default_parameters(disease_state) + cell_params_tmp = CellModel.default_parameters() # FIXME: In this case we update the parameters first, while in the # initial condition case we do that last. We need to be consistent # about this. if cell_params is not None: cell_params_tmp.update(cell_params) + + CellModel.update_disease_parameters(cell_params_tmp, disease_state=disease_state) # Adding optional drug factors to parameters (if drug_factors_file exists) if file_exist(drug_factors_file, ".json"): logger.info(f"Drug scaling factors loaded from {drug_factors_file}") diff --git a/src/simcardems/geometry.py b/src/simcardems/geometry.py index 6c4d9b71..19036910 100644 --- a/src/simcardems/geometry.py +++ b/src/simcardems/geometry.py @@ -159,8 +159,12 @@ def default_parameters() -> Dict[str, Any]: def default_stimulus_domain(mesh: dolfin.Mesh) -> StimulusDomain: # Default is to stimulate the entire tissue marker = 1 + #domain = dolfin.MeshFunction("size_t", mesh, mesh.topology().dim()) + #domain.set_all(marker) + subdomain = dolfin.CompiledSubDomain("x[0] < 1.0") domain = dolfin.MeshFunction("size_t", mesh, mesh.topology().dim()) - domain.set_all(marker) + domain.set_all(0) + subdomain.mark(domain, marker) return StimulusDomain(domain=domain, marker=marker) @staticmethod @@ -278,11 +282,12 @@ def dump( kwargs = {k: getattr(self, k) for k in schema if k != "info"} kwargs["info"] = self.parameters + logger.debug("Instantiating geometry") geo = Geometry(**kwargs, schema=schema) if schema_path is None: schema_path = path.with_suffix(".json") - + logger.debug("Save geo") geo.save(path, schema_path=schema_path, unlink=unlink) logger.info(f"Saved geometry to {fname}") diff --git a/src/simcardems/models/__init__.py b/src/simcardems/models/__init__.py index 0089071c..60d21e0b 100644 --- a/src/simcardems/models/__init__.py +++ b/src/simcardems/models/__init__.py @@ -1,3 +1,4 @@ +from . import cell_model from . import em_model from . import explicit_ORdmm_Land from . import fully_coupled_ORdmm_Land @@ -27,6 +28,7 @@ def list_coupling_types(): "fully_coupled_ORdmm_Land", "fully_coupled_Tor_Land", "em_model", + "cell_model", "pureEP_ORdmm_Land", "loggers", ] diff --git a/src/simcardems/models/cell_model.py b/src/simcardems/models/cell_model.py new file mode 100644 index 00000000..e1967fe3 --- /dev/null +++ b/src/simcardems/models/cell_model.py @@ -0,0 +1,15 @@ +from abc import ABC +from abc import abstractmethod +from typing import Dict + +from cbcbeat import CardiacCellModel + + +class BaseCellModel(CardiacCellModel, ABC): + @staticmethod + @abstractmethod + def update_disease_parameters( + params: Dict[str, float], + disease_state: str = "healthy", + ) -> None: + ... diff --git a/src/simcardems/models/explicit_ORdmm_Land/cell_model.py b/src/simcardems/models/explicit_ORdmm_Land/cell_model.py index 771a062e..1d1f9482 100644 --- a/src/simcardems/models/explicit_ORdmm_Land/cell_model.py +++ b/src/simcardems/models/explicit_ORdmm_Land/cell_model.py @@ -7,11 +7,11 @@ import dolfin import ufl -from cbcbeat.cellmodels import CardiacCellModel from dolfin import as_vector from dolfin import Constant from ... import utils +from ..cell_model import BaseCellModel from .em_model import EMCoupling logger = utils.getLogger(__name__) @@ -25,7 +25,7 @@ def Min(a, b): return (a + b - abs(a - b)) / Constant(2.0) -class ORdmmLandExplicit(CardiacCellModel): +class ORdmmLandExplicit(BaseCellModel): def __init__( self, coupling: EMCoupling, @@ -48,7 +48,26 @@ def __init__( self.dLambda = coupling.dLambda_ep @staticmethod - def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: + def update_disease_parameters( + params: Dict[str, float], + disease_state: str = "healthy", + ) -> None: + if disease_state.lower() == "hf": + logger.info("Update scaling parameters for heart failure model") + params["HF_scaling_CaMKa"] = 1.50 + params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) + params["HF_scaling_Jleak"] = 1.3 + params["HF_scaling_Jup"] = 0.45 + params["HF_scaling_GNaL"] = 1.3 + params["HF_scaling_GK1"] = 0.68 + params["HF_scaling_thL"] = 1.8 + params["HF_scaling_Gto"] = 0.4 + params["HF_scaling_Gncx"] = 1.6 + params["HF_scaling_Pnak"] = 0.7 + params["HF_scaling_cat50_ref"] = 0.6 + + @staticmethod + def default_parameters() -> Dict[str, float]: """Set-up and return default parameters. Parameters @@ -235,20 +254,6 @@ def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: ], ) - if disease_state.lower() == "hf": - logger.info("Update scaling parameters for heart failure model") - params["HF_scaling_CaMKa"] = 1.50 - params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) - params["HF_scaling_Jleak"] = 1.3 - params["HF_scaling_Jup"] = 0.45 - params["HF_scaling_GNaL"] = 1.3 - params["HF_scaling_GK1"] = 0.68 - params["HF_scaling_thL"] = 1.8 - params["HF_scaling_Gto"] = 0.4 - params["HF_scaling_Gncx"] = 1.6 - params["HF_scaling_Pnak"] = 0.7 - params["HF_scaling_cat50_ref"] = 0.6 - return params @staticmethod diff --git a/src/simcardems/models/explicit_ORdmm_Land/em_model.py b/src/simcardems/models/explicit_ORdmm_Land/em_model.py index 3afc69d9..8d98e71c 100644 --- a/src/simcardems/models/explicit_ORdmm_Land/em_model.py +++ b/src/simcardems/models/explicit_ORdmm_Land/em_model.py @@ -80,6 +80,29 @@ def __init__( self.lmbda_ep.vector()[:] = lmbda.vector() self.lmbda_ep_prev.vector()[:] = lmbda.vector() + self.transfer_matrix = dolfin.PETScDMCollection.create_transfer_matrix( + self.V_mech, + self.V_ep, + ).mat() + + def interpolate( + self, + f_mech: dolfin.Function, + f_ep: dolfin.Function, + ) -> dolfin.Function: + """Interpolates function from mechanics to ep mesh""" + + x = dolfin.as_backend_type(f_mech.vector()).vec() + a, temp = self.transfer_matrix.getVecs() + self.transfer_matrix.mult(x, temp) + f_ep.vector().vec().aypx(0.0, temp) + f_ep.vector().apply("") + + # Remember to free memory allocated by petsc: https://gitlab.com/petsc/petsc/-/issues/1309 + x.destroy() + a.destroy() + temp.destroy() + def __eq__(self, __o: object) -> bool: if not isinstance(__o, type(self)): return NotImplemented @@ -284,9 +307,7 @@ def mechanics_to_coupling(self): self.u_mech, utils.sub_function(self.mech_state, self._u_subspace_index), ) - self.lmbda_ep.interpolate(self.lmbda_mech_func) - # self.u_ep.interpolate(self.u_mech) - # self._project_lmbda() + self.interpolate(self.lmbda_mech_func, self.lmbda_ep) logger.debug("Done transferring variables from mechanics to coupling") @@ -334,6 +355,7 @@ def save_state( self.cell_params(), path, "ep/cell_params", + comm=self.geometry.comm(), ) @classmethod diff --git a/src/simcardems/models/fully_coupled_ORdmm_Land/cell_model.py b/src/simcardems/models/fully_coupled_ORdmm_Land/cell_model.py index 979354f8..577455e4 100644 --- a/src/simcardems/models/fully_coupled_ORdmm_Land/cell_model.py +++ b/src/simcardems/models/fully_coupled_ORdmm_Land/cell_model.py @@ -8,11 +8,11 @@ import dolfin import ufl -from cbcbeat.cellmodels import CardiacCellModel from dolfin import as_vector from dolfin import Constant from ... import utils +from ..cell_model import BaseCellModel from .em_model import EMCoupling logger = utils.getLogger(__name__) @@ -26,7 +26,7 @@ def Min(a, b): return (a + b - abs(a - b)) / Constant(2.0) -class ORdmmLandFull(CardiacCellModel): +class ORdmmLandFull(BaseCellModel): def __init__( self, coupling: EMCoupling, @@ -49,7 +49,26 @@ def __init__( self.Zetaw = coupling.Zetaw_ep @staticmethod - def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: + def update_disease_parameters( + params: Dict[str, float], + disease_state: str = "healthy", + ) -> None: + if disease_state.lower() == "hf": + logger.info("Update scaling parameters for heart failure model") + params["HF_scaling_CaMKa"] = 1.50 + params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) + params["HF_scaling_Jleak"] = 1.3 + params["HF_scaling_Jup"] = 0.45 + params["HF_scaling_GNaL"] = 1.3 + params["HF_scaling_GK1"] = 0.68 + params["HF_scaling_thL"] = 1.8 + params["HF_scaling_Gto"] = 0.4 + params["HF_scaling_Gncx"] = 1.6 + params["HF_scaling_Pnak"] = 0.7 + params["HF_scaling_cat50_ref"] = 0.6 + + @staticmethod + def default_parameters() -> Dict[str, float]: """Set-up and return default parameters. Parameters ---------- @@ -191,6 +210,9 @@ def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: ("scale_drug_IpCa", 1.0), ("scale_drug_Isacns", 1.0), ("scale_drug_Isack", 1.0), + ("scale_drug_kws", 1.0), + ("scale_drug_kuw", 1.0), + ("scale_drug_kwu", 1.0), # Population factors ("scale_popu_GNa", 1.0), ("scale_popu_GCaL", 1.0), @@ -234,20 +256,6 @@ def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: ], ) - if disease_state.lower() == "hf": - logger.info("Update scaling parameters for heart failure model") - params["HF_scaling_CaMKa"] = 1.50 - params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) - params["HF_scaling_Jleak"] = 1.3 - params["HF_scaling_Jup"] = 0.45 - params["HF_scaling_GNaL"] = 1.3 - params["HF_scaling_GK1"] = 0.68 - params["HF_scaling_thL"] = 1.8 - params["HF_scaling_Gto"] = 0.4 - params["HF_scaling_Gncx"] = 1.6 - params["HF_scaling_Pnak"] = 0.7 - params["HF_scaling_cat50_ref"] = 0.6 - return params @staticmethod @@ -985,6 +993,9 @@ def F(self, v, s, time=None): scale_drug_IpCa = self._parameters["scale_drug_IpCa"] scale_drug_Isacns = self._parameters["scale_drug_Isacns"] scale_drug_Isack = self._parameters["scale_drug_Isack"] + scale_drug_kws = self._parameters["scale_drug_kws"] + scale_drug_kuw = self._parameters["scale_drug_kuw"] + scale_drug_kwu = self._parameters["scale_drug_kwu"] # Population factors scale_popu_GNa = self._parameters["scale_popu_GNa"] @@ -1587,12 +1598,14 @@ def F(self, v, s, time=None): ) gammasu = gammas * Max(zetas1, zetas2) - F_expressions[39] = kws * scale_popu_kws * XW - XS * gammasu - XS * ksu + F_expressions[39] = ( + kws * scale_drug_kws * scale_popu_kws * XW - XS * gammasu - XS * ksu + ) F_expressions[40] = ( - kuw * scale_popu_kuw * XU - - kws * scale_popu_kws * XW + kuw * scale_drug_kuw * scale_popu_kuw * XU + - kws * scale_drug_kws * scale_popu_kws * XW - XW * gammawu - - XW * kwu + - XW * kwu * scale_drug_kwu ) cat50 = ( cat50_ref * scale_popu_CaT50ref + Beta1 * (-1.0 + lambda_min12) diff --git a/src/simcardems/models/fully_coupled_ORdmm_Land/em_model.py b/src/simcardems/models/fully_coupled_ORdmm_Land/em_model.py index b803bf39..89f8c0db 100644 --- a/src/simcardems/models/fully_coupled_ORdmm_Land/em_model.py +++ b/src/simcardems/models/fully_coupled_ORdmm_Land/em_model.py @@ -47,6 +47,30 @@ def __init__( self.Zetas_ep = dolfin.Function(self.V_ep, name="Zetas_ep") self.Zetaw_ep = dolfin.Function(self.V_ep, name="Zetaw_ep") + self.transfer_matrix = dolfin.PETScDMCollection.create_transfer_matrix( + self.V_mech, + self.V_ep, + ).mat() + + def interpolate( + self, + f_mech: dolfin.Function, + f_ep: dolfin.Function, + ) -> dolfin.Function: + """Interpolates function from mechanics to ep mesh""" + + x = dolfin.as_backend_type(f_mech.vector()).vec() + a, temp = self.transfer_matrix.getVecs() + + self.transfer_matrix.mult(x, temp) + f_ep.vector().vec().aypx(0.0, temp) + f_ep.vector().apply("") + + # Remember to free memory allocated by petsc: https://gitlab.com/petsc/petsc/-/issues/1309 + x.destroy() + a.destroy() + temp.destroy() + @property def coupling_type(self): return "fully_coupled_ORdmm_Land" @@ -196,9 +220,9 @@ def coupling_to_mechanics(self): def mechanics_to_coupling(self): logger.debug("Interpolate EP") - self.lmbda_ep.interpolate(self.lmbda_mech) - self.Zetas_ep.interpolate(self.Zetas_mech) - self.Zetaw_ep.interpolate(self.Zetaw_mech) + self.interpolate(self.lmbda_mech, self.lmbda_ep) + self.interpolate(self.Zetas_mech, self.Zetas_ep) + self.interpolate(self.Zetaw_mech, self.Zetaw_ep) logger.debug("Done interpolating EP") def coupling_to_ep(self): @@ -267,6 +291,7 @@ def save_state( self.cell_params(), path, "ep/cell_params", + comm=self.geometry.comm(), ) @classmethod diff --git a/src/simcardems/models/fully_coupled_Tor_Land/cell_model.py b/src/simcardems/models/fully_coupled_Tor_Land/cell_model.py index fea279de..0963eccb 100644 --- a/src/simcardems/models/fully_coupled_Tor_Land/cell_model.py +++ b/src/simcardems/models/fully_coupled_Tor_Land/cell_model.py @@ -9,11 +9,11 @@ import dolfin import ufl -from cbcbeat.cellmodels import CardiacCellModel from dolfin import as_vector from dolfin import Constant from ... import utils +from ..cell_model import BaseCellModel from .em_model import EMCoupling logger = utils.getLogger(__name__) @@ -34,7 +34,7 @@ def vs_functions_to_dict(vs): } -class TorLandFull(CardiacCellModel): +class TorLandFull(BaseCellModel): def __init__( self, coupling: EMCoupling, @@ -58,7 +58,26 @@ def __init__( self.Zetaw = coupling.Zetaw_ep @staticmethod - def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: + def update_disease_parameters( + params: Dict[str, float], + disease_state: str = "healthy", + ) -> None: + if disease_state.lower() == "hf": + logger.info("Update scaling parameters for heart failure model") + params["HF_scaling_CaMKa"] = 1.50 + params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) + params["HF_scaling_Jleak"] = 1.3 + params["HF_scaling_Jup"] = 0.45 + params["HF_scaling_GNaL"] = 1.3 + params["HF_scaling_GK1"] = 0.68 + params["HF_scaling_thL"] = 1.8 + params["HF_scaling_Gto"] = 0.4 + params["HF_scaling_Gncx"] = 1.6 + params["HF_scaling_Pnak"] = 0.7 + params["HF_scaling_cat50_ref"] = 0.7 + + @staticmethod + def default_parameters() -> Dict[str, float]: """Set-up and return default parameters. Parameters @@ -250,20 +269,6 @@ def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: ], ) - if disease_state.lower() == "hf": - logger.info("Update scaling parameters for heart failure model") - params["HF_scaling_CaMKa"] = 1.50 - params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) - params["HF_scaling_Jleak"] = 1.3 - params["HF_scaling_Jup"] = 0.45 - params["HF_scaling_GNaL"] = 1.3 - params["HF_scaling_GK1"] = 0.68 - params["HF_scaling_thL"] = 1.8 - params["HF_scaling_Gto"] = 0.4 - params["HF_scaling_Gncx"] = 1.6 - params["HF_scaling_Pnak"] = 0.7 - params["HF_scaling_cat50_ref"] = 0.7 - return params @staticmethod @@ -663,7 +668,7 @@ def _I(self, v, s, time): ) / (1.0 + ufl.exp(-0.1629 * (v - EK + 14.207))) K1ss = aK1 / (aK1 + bK1) GK1 = 0.6992 * scale_IK1 * scale_drug_IK1 * scale_popu_GK1 * HF_scaling_GK1 - IK1 = ufl.sqrt(ko) * (-EK + v) * GK1 * GK1 * K1ss + IK1 = ufl.sqrt(ko / 5.0) * (-EK + v) * GK1 * K1ss # Expressions for the INaCa_i component hca = ufl.exp(F * qca * v / (R * T)) diff --git a/src/simcardems/models/fully_coupled_Tor_Land/em_model.py b/src/simcardems/models/fully_coupled_Tor_Land/em_model.py index 5225878e..74f41c5d 100644 --- a/src/simcardems/models/fully_coupled_Tor_Land/em_model.py +++ b/src/simcardems/models/fully_coupled_Tor_Land/em_model.py @@ -47,6 +47,29 @@ def __init__( self.Zetas_ep = dolfin.Function(self.V_ep, name="Zetas_ep") self.Zetaw_ep = dolfin.Function(self.V_ep, name="Zetaw_ep") + self.transfer_matrix = dolfin.PETScDMCollection.create_transfer_matrix( + self.V_mech, + self.V_ep, + ).mat() + + def interpolate( + self, + f_mech: dolfin.Function, + f_ep: dolfin.Function, + ) -> dolfin.Function: + """Interpolates function from mechanics to ep mesh""" + + x = dolfin.as_backend_type(f_mech.vector()).vec() + a, temp = self.transfer_matrix.getVecs() + self.transfer_matrix.mult(x, temp) + f_ep.vector().vec().aypx(0.0, temp) + f_ep.vector().apply("") + + # Remember to free memory allocated by petsc: https://gitlab.com/petsc/petsc/-/issues/1309 + x.destroy() + a.destroy() + temp.destroy() + @property def coupling_type(self): return "fully_coupled_Tor_Land" @@ -196,9 +219,9 @@ def coupling_to_mechanics(self): def mechanics_to_coupling(self): logger.debug("Interpolate EP") - self.lmbda_ep.interpolate(self.lmbda_mech) - self.Zetas_ep.interpolate(self.Zetas_mech) - self.Zetaw_ep.interpolate(self.Zetaw_mech) + self.interpolate(self.lmbda_mech, self.lmbda_ep) + self.interpolate(self.Zetas_mech, self.Zetas_ep) + self.interpolate(self.Zetaw_mech, self.Zetaw_ep) logger.debug("Done interpolating EP") def coupling_to_ep(self): @@ -267,6 +290,7 @@ def save_state( self.cell_params(), path, "ep/cell_params", + comm=self.geometry.comm(), ) @classmethod diff --git a/src/simcardems/models/pureEP_ORdmm_Land/cell_model.py b/src/simcardems/models/pureEP_ORdmm_Land/cell_model.py index ba5e6615..709f400b 100644 --- a/src/simcardems/models/pureEP_ORdmm_Land/cell_model.py +++ b/src/simcardems/models/pureEP_ORdmm_Land/cell_model.py @@ -7,11 +7,11 @@ import dolfin import ufl -from cbcbeat.cellmodels import CardiacCellModel from dolfin import as_vector from dolfin import Constant from ... import utils +from ..cell_model import BaseCellModel logger = utils.getLogger(__name__) @@ -24,7 +24,7 @@ def Min(a, b): return (a + b - abs(a - b)) / Constant(2.0) -class ORdmmLandPureEp(CardiacCellModel): +class ORdmmLandPureEp(BaseCellModel): def __init__(self, params=None, init_conditions=None, **kwargs): """ Create cardiac cell model @@ -40,7 +40,26 @@ def __init__(self, params=None, init_conditions=None, **kwargs): super().__init__(params, init_conditions) @staticmethod - def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: + def update_disease_parameters( + params: Dict[str, float], + disease_state: str = "healthy", + ) -> None: + if disease_state.lower() == "hf": + logger.info("Update scaling parameters for heart failure model") + params["HF_scaling_CaMKa"] = 1.50 + params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) + params["HF_scaling_Jleak"] = 1.3 + params["HF_scaling_Jup"] = 0.45 + params["HF_scaling_GNaL"] = 1.3 + params["HF_scaling_GK1"] = 0.68 + params["HF_scaling_thL"] = 1.8 + params["HF_scaling_Gto"] = 0.4 + params["HF_scaling_Gncx"] = 1.6 + params["HF_scaling_Pnak"] = 0.7 + params["HF_scaling_cat50_ref"] = 0.6 + + @staticmethod + def default_parameters() -> Dict[str, float]: """Set-up and return default parameters. Parameters @@ -229,20 +248,6 @@ def default_parameters(disease_state: str = "healthy") -> Dict[str, float]: ], ) - if disease_state.lower() == "hf": - logger.info("Update scaling parameters for heart failure model") - params["HF_scaling_CaMKa"] = 1.50 - params["HF_scaling_Jrel_inf"] = pow(0.8, 8.0) - params["HF_scaling_Jleak"] = 1.3 - params["HF_scaling_Jup"] = 0.45 - params["HF_scaling_GNaL"] = 1.3 - params["HF_scaling_GK1"] = 0.68 - params["HF_scaling_thL"] = 1.8 - params["HF_scaling_Gto"] = 0.4 - params["HF_scaling_Gncx"] = 1.6 - params["HF_scaling_Pnak"] = 0.7 - params["HF_scaling_cat50_ref"] = 0.6 - return params @staticmethod diff --git a/src/simcardems/models/pureEP_ORdmm_Land/em_model.py b/src/simcardems/models/pureEP_ORdmm_Land/em_model.py index 0e0fcad8..9f8e0d41 100644 --- a/src/simcardems/models/pureEP_ORdmm_Land/em_model.py +++ b/src/simcardems/models/pureEP_ORdmm_Land/em_model.py @@ -116,6 +116,7 @@ def save_state( self.cell_params(), path, "ep/cell_params", + comm=self.geometry.comm(), ) @classmethod diff --git a/src/simcardems/postprocess.py b/src/simcardems/postprocess.py index 6d72ce56..f5f042db 100644 --- a/src/simcardems/postprocess.py +++ b/src/simcardems/postprocess.py @@ -1,4 +1,5 @@ import json +from collections import defaultdict from pathlib import Path from typing import Dict from typing import Iterable @@ -14,13 +15,14 @@ import tqdm from . import utils +from .datacollector import DataCollector from .datacollector import DataGroups from .datacollector import DataLoader logger = utils.getLogger(__name__) -def plot_peaks(fname, data, threshold): +def plot_peaks(fname, data, threshold, save_trace): # Find peaks for assessment steady state from scipy.signal import find_peaks @@ -37,6 +39,10 @@ def plot_peaks(fname, data, threshold): # Calculate difference between consecutive list elements change_y = [(s - q) / q * 100 for q, s in zip(y, y[1:])] + if save_trace: + change_y_np = np.array(change_y) + np.save(fname, change_y_np) + fig, ax = plt.subplots() ax.plot(change_y) ax.set_title("Compare peak values") @@ -83,7 +89,7 @@ def extract_traces( reduction=reduction, ) - # values["mechanics"]["inv_lmbda"] = 1 - values["mechanics"]["lmbda"] + values["mechanics"]["inv_lmbda"] = 1 - values["mechanics"]["lambda"] return values @@ -105,10 +111,24 @@ def plot_state_traces( if times[-1] > 2000 and flag_peaks: plot_peaks( - outdir.joinpath("compare-peak-values.png"), + outdir.joinpath("compare-peak-values-Ca.png"), values["ep"]["Ca"], 0.0002, + True, + ) + plot_peaks( + outdir.joinpath("compare-peak-values-V.png"), + values["ep"]["V"], + 10.0, + True, + ) + plot_peaks( + outdir.joinpath("compare-peak-values-Ta.png"), + values["mechanics"]["Ta"], + 2.0, + True, ) + fig, axs = plt.subplots(2, 2, figsize=(10, 8), sharex=True) for i, (group, key) in enumerate( @@ -286,7 +306,7 @@ def plot_population(results, outdir, num_models, reset_time=True): for PoMm in range(1, num_models + 1): ax[0, 0].plot( times, - np.array(results[f"m{PoMm}"]["mechanics"]["lmbda"], dtype=float), + np.array(results[f"m{PoMm}"]["mechanics"]["lambda"], dtype=float), ) ax[0, 1].plot( times, @@ -409,7 +429,7 @@ def get_biomarkers(results, outdir, num_models): V = results[f"m{PoMm}"]["ep"]["V"] Ca = results[f"m{PoMm}"]["ep"]["Ca"] Ta = results[f"m{PoMm}"]["mechanics"]["Ta"] - lmbda = results[f"m{PoMm}"]["mechanics"]["lmbda"] + lmbda = results[f"m{PoMm}"]["mechanics"]["lambda"] u = results[f"m{PoMm}"]["mechanics"]["u"] inv_lmbda = results[f"m{PoMm}"]["mechanics"]["inv_lmbda"] @@ -553,3 +573,99 @@ def activation_map( activation_map.vector()[np.where(dofs)[0]] = float(t) return activation_map + + +def extract_sub_results( + results_file: utils.PathLike, + output_file: utils.PathLike, + t_start: float = 0.0, + t_end: Optional[float] = None, + names: Optional[Dict[str, List[str]]] = None, +) -> DataCollector: + """Extract sub results from another results file. + This can be useful if you have stored a lot of data in one file + and you want to create a smaller file containing only a subset + of the data (e.g only the last beat) + + Parameters + ---------- + results_file : utils.PathLike + The input result file + output_file : utils.PathLike + Path to file where you want to store the sub results + t_start : float, optional + Time point indicating when the sub results should start, by default 0.0 + t_end : float | None, optional + Time point indicating when the sub results should end, by default None + in which case it will choose the last time point + names : Optional[Dict[str, List[str]]], optional + A dictionary of names for each group indicating which + functions to extract for the sub results. If not provided (default) + then all functions will be extracted. + + Returns + ------- + datacollector.DataCollector + A data collector containing the sub results + + Raises + ------ + FileNotFoundError + If the input file does not exist + KeyError + If some of the names provided does not exists + in the input file + """ + results_file = Path(results_file) + if not results_file.is_file(): + raise FileNotFoundError(f"File {results_file} does not exist") + + loader = DataLoader(results_file) + assert loader.time_stamps is not None + if names is None: + # Extract everything + names = loader.names + + t_start_idx = next( + (i for i, t in enumerate(map(float, loader.time_stamps)) if t > t_start - 1e-12) + ) + if t_end is None: + t_end_idx = len(loader.time_stamps) - 1 + else: + try: + t_end_idx = next( + i + for i, t in enumerate(map(float, loader.time_stamps)) + if t > t_end + 1e-12 + ) + except StopIteration: + t_end_idx = len(loader.time_stamps) - 1 + + out = Path(output_file) + collector = DataCollector( + outdir=out.parent, + outfilename=out.name, + geo=loader.geo, + ) + + functions: Dict[str, Dict[str, dolfin.Function]] = defaultdict(dict) + for group_name, group in names.items(): + for func_name in group: + try: + functions[group_name][func_name] = loader._functions[group_name][ + func_name + ] + except KeyError as e: + raise KeyError( + f"Invalid group {group_name} and function {func_name}", + ) from e + collector.register(group_name, func_name, functions[group_name][func_name]) + + for ti in loader.time_stamps[t_start_idx:t_end_idx]: + for group_name, group in names.items(): + for func_name in group: + functions[group_name][func_name].assign( + loader.get(DataGroups[group_name], func_name, ti), + ) + collector.store(float(ti)) + return collector diff --git a/src/simcardems/save_load_functions.py b/src/simcardems/save_load_functions.py index d6f2665b..c239bdf3 100644 --- a/src/simcardems/save_load_functions.py +++ b/src/simcardems/save_load_functions.py @@ -29,14 +29,17 @@ def vs_functions_to_dict(vs, state_names): @contextlib.contextmanager -def h5pyfile(h5name, filemode="r"): +def h5pyfile(h5name, filemode="r", force_serial: bool = False, comm=None): import h5py from mpi4py import MPI - if h5py.h5.get_config().mpi and dolfin.MPI.size(dolfin.MPI.comm_world) > 1: + if comm is None: + comm = dolfin.MPI.comm_world + + if h5py.h5.get_config().mpi and dolfin.MPI.size(comm) > 1 and not force_serial: h5file = h5py.File(h5name, filemode, driver="mpio", comm=MPI.COMM_WORLD) else: - if dolfin.MPI.size(dolfin.MPI.comm_world) > 1: + if dolfin.MPI.size(comm) > 1: warnings.warn("h5py is not installed with MPI support") h5file = h5py.File(h5name, filemode) yield h5file @@ -44,19 +47,27 @@ def h5pyfile(h5name, filemode="r"): h5file.close() -def dict_to_h5(data, h5name, h5group): - with h5pyfile(h5name, "a") as h5file: - if h5group == "": - group = h5file - else: - group = h5file.create_group(h5group) - for k, v in data.items(): - try: - group.create_dataset(k, data=v) - except OSError: - logger.warning( - f"Unable to save key {k} with data {v} in {h5name}/{h5group}", - ) +def dict_to_h5(data, h5name, h5group, use_attrs: bool = False, comm=None): + if comm is None: + comm = dolfin.MPI.comm_world + if comm.rank == 0: + with h5pyfile(h5name, "a", force_serial=True) as h5file: + if h5group == "": + group = h5file + else: + group = h5file.create_group(h5group) + for k, v in data.items(): + if use_attrs: + group.attrs[k] = v + else: + try: + group.create_dataset(k, data=v) + except OSError: + logger.warning( + f"Unable to save key {k} with data {v} in {h5name}/{h5group}", + ) + # Synchronize + dolfin.MPI.barrier(comm) def decode(x): @@ -138,17 +149,17 @@ def save_state( state_params: Optional[Dict[str, float]] = None, ): path = Path(path) - utils.remove_file(path) + utils.remove_file(path, comm=geo.comm()) logger.info(f"Save state to {path}") geo.dump(path) logger.debug("Save using dolfin.HDF5File") logger.debug("Save using h5py") - dict_to_h5(serialize_dict(config.as_dict()), path, "config") + dict_to_h5(serialize_dict(config.as_dict()), path, "config", comm=geo.comm()) if state_params is None: state_params = {} - dict_to_h5(serialize_dict(state_params), path, "state_params") + dict_to_h5(serialize_dict(state_params), path, "state_params", comm=geo.comm()) def load_state( @@ -166,6 +177,7 @@ def load_state( logger.debug("Open file with h5py") with h5pyfile(path) as h5file: config = Config(**h5_to_dict(h5file["config"])) + config.coupling_type = "fully_coupled_Tor_Land" if config.coupling_type == "explicit_ORdmm_Land": from .models.explicit_ORdmm_Land import EMCoupling diff --git a/src/simcardems/utils.py b/src/simcardems/utils.py index 9d90dde7..e38a6dbd 100644 --- a/src/simcardems/utils.py +++ b/src/simcardems/utils.py @@ -24,7 +24,7 @@ def getLogger(name): import daiquiri logger = daiquiri.getLogger(name) - logger.logger.addFilter(mpi_filt) + # logger.logger.addFilter(mpi_filt) return logger @@ -141,12 +141,14 @@ def compute_norm(x, x_prev): return norm -def remove_file(path): +def remove_file(path, comm=None): + if comm is None: + comm = dolfin.MPI.comm_world path = Path(path) - if dolfin.MPI.rank(dolfin.MPI.comm_world) == 0: + if comm.rank == 0: if path.is_file(): path.unlink() - dolfin.MPI.barrier(dolfin.MPI.comm_world) + dolfin.MPI.barrier(comm) def setup_assigner(vs, index): diff --git a/src/simcardems/value_extractor.py b/src/simcardems/value_extractor.py index c82534a0..e1f718b2 100644 --- a/src/simcardems/value_extractor.py +++ b/src/simcardems/value_extractor.py @@ -18,7 +18,7 @@ class ValueExtractor: def __init__(self, geo: BaseGeometry): self.geo = geo self.volume = dolfin.assemble(dolfin.Constant(1.0) * dolfin.dx(domain=geo.mesh)) - logger.debug("Creating ValueExtractor with geo: {geo!r}") + logger.debug(f"Creating ValueExtractor with geo: {geo!r}") if isinstance(self.geo, SlabGeometry): self.boundary: Boundary = SlabBoundary(geo.mesh) @@ -26,6 +26,7 @@ def __init__(self, geo: BaseGeometry): self.boundary = LVBoundary(geo.mesh) else: raise NotImplementedError + logger.debug("Done") def average(self, func: dolfin.Function) -> float: if func.value_rank() == 0: diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 8defe48a..886aa838 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -2,42 +2,41 @@ from unittest import mock import dolfin -import h5py import pytest import simcardems -def test_DataCollector_reset_state_when_file_exists(tmp_path, geo): - simcardems.DataCollector(tmp_path, geo=geo) +def test_DataCollector_reset_state_when_file_exists(mpi_tmp_path, geo): + simcardems.DataCollector(mpi_tmp_path, geo=geo) with mock.patch("simcardems.utils.remove_file") as remove_file_mock: - collector = simcardems.DataCollector(tmp_path, geo=geo, reset_state=True) + collector = simcardems.DataCollector(mpi_tmp_path, geo=geo, reset_state=True) remove_file_mock.assert_called() assert Path(collector.results_file).is_file() -def test_DataCollector_not_reset_state_when_file_exists(tmp_path, geo): - simcardems.DataCollector(tmp_path, geo=geo) +def test_DataCollector_not_reset_state_when_file_exists(mpi_tmp_path, geo): + simcardems.DataCollector(mpi_tmp_path, geo=geo) with mock.patch("simcardems.utils.remove_file") as remove_file_mock: - collector = simcardems.DataCollector(tmp_path, geo=geo, reset_state=False) + collector = simcardems.DataCollector(mpi_tmp_path, geo=geo, reset_state=False) remove_file_mock.assert_not_called() assert Path(collector.results_file).is_file() -def test_DataCollector_create_file_with_geo(tmp_path, geo): - collector = simcardems.DataCollector(tmp_path, geo=geo) +def test_DataCollector_create_file_with_geo(mpi_tmp_path, geo): + collector = simcardems.DataCollector(mpi_tmp_path, geo=geo) assert Path(collector.results_file).is_file() - with h5py.File(collector.results_file, "r") as h5file: + with simcardems.save_load_functions.h5pyfile(collector.results_file, "r") as h5file: assert "geometry" in h5file -def test_DataCollector_register(tmp_path, geo): +def test_DataCollector_register(mpi_tmp_path, geo): V = dolfin.FunctionSpace(geo.mesh, "CG", 1) f = dolfin.Function(V) simcardems.set_log_level(10) - collector = simcardems.DataCollector(tmp_path, geo=geo) + collector = simcardems.DataCollector(mpi_tmp_path, geo=geo) collector.register("ep", "func", f) assert "func" in collector.names["ep"] @@ -47,14 +46,14 @@ def test_DataCollector_register(tmp_path, geo): @pytest.mark.parametrize("group", ["ep", "mechanics"]) -def test_DataCollector_store(group, geo, tmp_path): +def test_DataCollector_store(group, geo, mpi_tmp_path): mesh = geo.mesh if group == "mechanics" else geo.ep_mesh V = dolfin.FunctionSpace(mesh, "CG", 1) f = dolfin.Function(V) f.vector()[:] = 42 simcardems.set_log_level(10) - collector = simcardems.DataCollector(tmp_path, geo=geo) + collector = simcardems.DataCollector(mpi_tmp_path, geo=geo) collector.register(group, "func", f) assert "func" in collector.names[group] @@ -68,14 +67,14 @@ def test_DataCollector_store(group, geo, tmp_path): assert all(g2.vector().get_local() == f.vector().get_local()) -def test_DataLoader_load_empty_files_raises_ValueError(tmp_path, geo): - collector = simcardems.DataCollector(tmp_path, geo=geo) +def test_DataLoader_load_empty_files_raises_ValueError(mpi_tmp_path, geo): + collector = simcardems.DataCollector(mpi_tmp_path, geo=geo) with pytest.raises(ValueError): simcardems.DataLoader(collector.results_file) -def test_DataCollector_store_version(tmp_path, geo): - collector = simcardems.DataCollector(tmp_path, geo=geo) +def test_DataCollector_store_version(mpi_tmp_path, geo): + collector = simcardems.DataCollector(mpi_tmp_path, geo=geo) loader = simcardems.DataLoader(collector.results_file, empty_ok=True) from packaging.version import parse diff --git a/tests/test_geometry.py b/tests/test_geometry.py index 2638420e..fe370f63 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -1,12 +1,19 @@ from pathlib import Path import dolfin +import pytest from simcardems import geometry from simcardems import slabgeometry here = Path(__file__).absolute().parent +no_mpi = pytest.mark.skipif( + dolfin.MPI.size(dolfin.MPI.comm_world) > 1, + reason="Only works in serial", +) + +@no_mpi def test_create_slab_geometry_normal(): parameters = {"lx": 1, "ly": 1, "lz": 1, "dx": 1, "num_refinements": 1} geo = slabgeometry.SlabGeometry(parameters=parameters) @@ -18,6 +25,7 @@ def test_create_slab_geometry_normal(): assert geo.num_refinements == 1 +@no_mpi def test_create_slab_geometry_with_mechanics_mesh(): parameters = {"lx": 1, "ly": 1, "lz": 1, "dx": 1, "num_refinements": 1} mesh = dolfin.UnitCubeMesh(1, 1, 1) @@ -33,6 +41,7 @@ def test_create_slab_geometry_with_mechanics_mesh(): assert geo.num_refinements == 1 +@no_mpi def test_load_geometry(): mesh_folder = here / ".." / "demos" / "geometries" mesh_path = mesh_folder / "slab.h5" @@ -44,18 +53,19 @@ def test_load_geometry(): assert geo.stimulus_domain.marker == 1 -def test_dump_geometry(tmp_path): +def test_dump_geometry(mpi_tmp_path): mesh_folder = here / ".." / "demos" / "geometries" mesh_path = mesh_folder / "slab.h5" schema_path = mesh_folder / "slab.json" geo = geometry.load_geometry(mesh_path=mesh_path, schema_path=schema_path) - outpath = tmp_path / "state.h5" + outpath = mpi_tmp_path / "state.h5" geo.dump(outpath) dumped_geo = geometry.load_geometry( mesh_path=outpath, schema_path=outpath.with_suffix(".json"), ) + print(geo.parameters, " != ", dumped_geo.parameters) assert dumped_geo == geo diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 306a9e07..02f13a16 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -1,9 +1,16 @@ import math +from pathlib import Path import dolfin +import numpy as np +import pytest import simcardems +@pytest.mark.skipif( + dolfin.MPI.size(dolfin.MPI.comm_world) > 1, + reason="Only works in serial", +) def test_activation_map(): mesh = dolfin.UnitCubeMesh(5, 5, 5) V = dolfin.FunctionSpace(mesh, "CG", 1) @@ -41,3 +48,47 @@ def voltage_gen(): assert math.isclose(act(0.5, 0.5, 0.5), 3.0) assert math.isclose(act(1.0, 0.0, 0.0), 2.0) assert math.isclose(act(1.0, 0.5, 0.5), 2.0) + + +def test_extract_sub_results(geo, mpi_tmp_path): + results_file = mpi_tmp_path / "results.h5" + collector = simcardems.DataCollector( + outdir=results_file.parent, + outfilename=results_file.name, + geo=geo, + ) + + # Setup a two mech functions and one ep function + V_mech = dolfin.FunctionSpace(geo.mesh, "Lagrange", 1) + f1_mech = dolfin.Function(V_mech) + collector.register("mechanics", "func1", f1_mech) + f2_mech = dolfin.Function(V_mech) + collector.register("mechanics", "func2", f2_mech) + + V_ep = dolfin.FunctionSpace(geo.mesh, "Lagrange", 1) + f3_ep = dolfin.Function(V_ep) + collector.register("ep", "func3", f3_ep) + + times = np.arange(0, 10, 0.5) + for t in times: + f1_mech.assign(dolfin.Constant(t)) + f2_mech.assign(dolfin.Constant(10 + t)) + collector.store(t) + + loader = simcardems.DataLoader(collector.results_file) + + assert loader.time_stamps == [f"{ti:.2f}" for ti in times] + sub_results_file = mpi_tmp_path / "sub_results.h5" + sub_collector = simcardems.postprocess.extract_sub_results( + results_file=results_file, + output_file=sub_results_file, + t_start=5.0, + t_end=7.0, + names={"mechanics": ["func1"]}, + ) + + assert str(sub_collector.results_file) == str(sub_results_file) + assert Path(sub_results_file).is_file() + sub_loader = simcardems.DataLoader(sub_results_file) + assert sub_loader.names == {"ep": [], "mechanics": ["func1"]} + assert sub_loader.time_stamps == ["5.00", "5.50", "6.00", "6.50", "7.00"]