From 0ea85b81b15e9fc70c1536feca6d75df71010af1 Mon Sep 17 00:00:00 2001 From: WEN Hao Date: Tue, 31 Mar 2026 22:46:42 +0800 Subject: [PATCH 1/5] update the usage of type hints for the functions --- .../train_crnn_cinc2020/scoring_metrics.py | 11 +- .../train_crnn_cinc2021/scoring_metrics.py | 11 +- .../train_crnn_cinc2021/special_detectors.py | 33 +++-- benchmarks/train_crnn_cinc2023/data_reader.py | 11 +- benchmarks/train_hybrid_cpsc2020/metrics.py | 12 +- .../signal_processing/ecg_denoise.py | 13 +- .../signal_processing/ecg_preproc.py | 9 +- .../signal_processing/ecg_rpeaks.py | 16 +-- .../signal_processing/ecg_rpeaks_dl.py | 5 +- .../train_hybrid_cpsc2021/aux_metrics.py | 11 +- benchmarks/train_hybrid_cpsc2021/model.py | 13 +- benchmarks/train_mtl_cinc2022/inputs.py | 2 +- benchmarks/train_multi_cpsc2019/metrics.py | 3 +- benchmarks/train_unet_ludb/metrics.py | 17 ++- test/test_databases/test_shhs.py | 11 +- test/test_preprocessors.py | 3 +- test/test_preprocessors_t.py | 26 +++- torch_ecg/_preprocessors/normalize.py | 7 +- torch_ecg/augmenters/mixup.py | 9 +- torch_ecg/augmenters/random_flip.py | 8 +- torch_ecg/augmenters/random_masking.py | 20 +-- torch_ecg/augmenters/random_renormalize.py | 9 +- torch_ecg/augmenters/stretch_compress.py | 13 +- torch_ecg/components/inputs.py | 2 +- torch_ecg/components/loggers.py | 11 +- torch_ecg/components/metrics.py | 2 +- torch_ecg/components/outputs.py | 2 +- .../databases/aux_data/cinc2020_aux_data.py | 11 +- .../databases/aux_data/cinc2021_aux_data.py | 11 +- torch_ecg/databases/base.py | 9 +- .../databases/cpsc_databases/cpsc2018.py | 5 +- .../databases/cpsc_databases/cpsc2019.py | 13 +- .../databases/cpsc_databases/cpsc2020.py | 13 +- .../databases/cpsc_databases/cpsc2021.py | 13 +- torch_ecg/databases/nsrr_databases/shhs.py | 6 +- .../databases/other_databases/cachet_cadb.py | 15 +-- torch_ecg/databases/other_databases/sph.py | 2 +- .../physionet_databases/apnea_ecg.py | 17 ++- .../databases/physionet_databases/cinc2018.py | 25 ++-- .../databases/physionet_databases/cinc2020.py | 15 +-- .../databases/physionet_databases/cinc2021.py | 15 +-- .../databases/physionet_databases/ltafdb.py | 5 +- torch_ecg/models/_nets.py | 21 ++- torch_ecg/models/cnn/mobilenet.py | 3 +- torch_ecg/models/cnn/regnet.py | 5 +- torch_ecg/models/cnn/resnet.py | 7 +- torch_ecg/models/cnn/xception.py | 15 +-- torch_ecg/models/ecg_fcn.py | 1 - torch_ecg/models/loss.py | 11 +- torch_ecg/preprocessors/README.md | 3 +- torch_ecg/preprocessors/bandpass.py | 13 +- torch_ecg/preprocessors/baseline_remove.py | 5 +- torch_ecg/preprocessors/normalize.py | 25 ++-- torch_ecg/utils/_edr.py | 15 +-- torch_ecg/utils/_preproc.py | 122 ++++++++++-------- torch_ecg/utils/utils_interval.py | 41 +++--- torch_ecg/utils/utils_metrics.py | 33 +++-- torch_ecg/utils/utils_nn.py | 5 +- torch_ecg/utils/utils_signal_t.py | 85 ++++++++---- 59 files changed, 428 insertions(+), 447 deletions(-) diff --git a/benchmarks/train_crnn_cinc2020/scoring_metrics.py b/benchmarks/train_crnn_cinc2020/scoring_metrics.py index b21fc851..6a42d863 100644 --- a/benchmarks/train_crnn_cinc2020/scoring_metrics.py +++ b/benchmarks/train_crnn_cinc2020/scoring_metrics.py @@ -2,19 +2,10 @@ metrics from the official scoring repository """ -from numbers import Real from typing import List, Sequence, Tuple import numpy as np -try: - import torch_ecg # noqa: F401 -except ModuleNotFoundError: - import sys - from pathlib import Path - - sys.path.insert(0, str(Path(__file__).absolute().parents[2])) - from torch_ecg.databases.aux_data.cinc2020_aux_data import load_weights __all__ = [ @@ -168,7 +159,7 @@ def compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> float: # Compute F-beta and G-beta measures from the unofficial phase of the Challenge. -def compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) -> Tuple[float, float]: +def compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: float) -> Tuple[float, float]: """checked,""" num_recordings, num_classes = np.shape(labels) diff --git a/benchmarks/train_crnn_cinc2021/scoring_metrics.py b/benchmarks/train_crnn_cinc2021/scoring_metrics.py index 983c3b96..ccfb86fe 100644 --- a/benchmarks/train_crnn_cinc2021/scoring_metrics.py +++ b/benchmarks/train_crnn_cinc2021/scoring_metrics.py @@ -2,19 +2,10 @@ metrics from the official scoring repository """ -from numbers import Real from typing import List, Sequence, Tuple, Union import numpy as np -try: - import torch_ecg # noqa: F401 -except ModuleNotFoundError: - import sys - from pathlib import Path - - sys.path.insert(0, str(Path(__file__).absolute().parents[2])) - from torch_ecg.databases.aux_data.cinc2021_aux_data import load_weights __all__ = [ @@ -233,7 +224,7 @@ def compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, n # Compute F-beta and G-beta measures from the unofficial phase of the Challenge. -def compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) -> Tuple[float, float]: +def compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: float) -> Tuple[float, float]: """checked,""" num_recordings, num_classes = np.shape(labels) diff --git a/benchmarks/train_crnn_cinc2021/special_detectors.py b/benchmarks/train_crnn_cinc2021/special_detectors.py index 2b0aee58..77670bb5 100644 --- a/benchmarks/train_crnn_cinc2021/special_detectors.py +++ b/benchmarks/train_crnn_cinc2021/special_detectors.py @@ -17,7 +17,6 @@ """ from itertools import repeat -from numbers import Real from typing import Any, Optional, Sequence import numpy as np @@ -54,7 +53,7 @@ def special_detectors( raw_sig: np.ndarray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", leads: Sequence[str] = Standard12Leads, verbose: int = 0, @@ -66,7 +65,7 @@ def special_detectors( ---------- raw_sig: ndarray, the raw multi-lead ecg signal, with units in mV - fs: real number, + fs: int, sampling frequency of `sig` sig_fmt: str, default "channel_first", format of the multi-lead ecg signal, @@ -133,19 +132,19 @@ def special_detectors( def pacing_rhythm_detector( raw_sig: np.ndarray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", leads: Sequence[str] = Standard12Leads, ret_prob: bool = True, verbose: int = 0, -) -> Real: +) -> float: """to be improved (fine-tuning hyper-parameters in cfg.py), Parameters ---------- raw_sig: ndarray, the raw multi-lead ecg signal, with units in mV - fs: real number, + fs: int, sampling frequency of `sig` sig_fmt: str, default "channel_first", format of the multi-lead ecg signal, @@ -238,7 +237,7 @@ def pacing_rhythm_detector( def electrical_axis_detector( filtered_sig: np.ndarray, rpeaks: np.ndarray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", leads: Sequence[str] = Standard12Leads, method: Optional[str] = None, @@ -255,7 +254,7 @@ def electrical_axis_detector( the filtered multi-lead ecg signal, with units in mV rpeaks: ndarray, array of indices of the R peaks - fs: real number, + fs: int, sampling frequency of `sig` sig_fmt: str, default "channel_first", format of the multi-lead ecg signal, @@ -375,8 +374,8 @@ def electrical_axis_detector( def brady_tachy_detector( rpeaks: np.ndarray, - fs: Real, - normal_rr_range: Optional[Sequence[Real]] = None, + fs: int, + normal_rr_range: Optional[Sequence[int]] = None, verbose: int = 0, ) -> str: """to be improved (fine-tuning hyper-parameters in cfg.py), @@ -391,7 +390,7 @@ def brady_tachy_detector( ---------- rpeaks: ndarray, array of indices of the R peaks - fs: real number, + fs: int, sampling frequency of the ecg signal normal_rr_range: sequence of int, optional, the range of normal rr interval, with units in ms; @@ -438,7 +437,7 @@ def brady_tachy_detector( def LQRSV_detector( filtered_sig: np.ndarray, rpeaks: np.ndarray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", leads: Sequence[str] = Standard12Leads, verbose: int = 0, @@ -451,7 +450,7 @@ def LQRSV_detector( the filtered multi-lead ecg signal, with units in mV rpeaks: ndarray, array of indices of the R peaks - fs: real number, + fs: int, sampling frequency of the ecg signal sig_fmt: str, default "channel_first", format of the 12 lead ecg signal, @@ -511,7 +510,7 @@ def LQRSV_detector( def LQRSV_detector_backup( filtered_sig: np.ndarray, rpeaks: np.ndarray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", leads: Sequence[str] = Standard12Leads, verbose: int = 0, @@ -524,7 +523,7 @@ def LQRSV_detector_backup( the filtered 12-lead ecg signal, with units in mV rpeaks: ndarray, array of indices of the R peaks - fs: real number, + fs: int, sampling frequency of the ecg signal sig_fmt: str, default "channel_first", format of the 12 lead ecg signal, @@ -608,7 +607,7 @@ def LQRSV_detector_backup( def PRWP_detector( filtered_sig: np.ndarray, rpeaks: np.ndarray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", leads: Sequence[str] = Standard12Leads, verbose: int = 0, @@ -621,7 +620,7 @@ def PRWP_detector( the filtered multi-lead ecg signal, with units in mV rpeaks: ndarray, array of indices of the R peaks - fs: real number, + fs: int, sampling frequency of the ecg signal sig_fmt: str, default "channel_first", format of the 12 lead ecg signal, diff --git a/benchmarks/train_crnn_cinc2023/data_reader.py b/benchmarks/train_crnn_cinc2023/data_reader.py index 6f63aa47..898df31c 100644 --- a/benchmarks/train_crnn_cinc2023/data_reader.py +++ b/benchmarks/train_crnn_cinc2023/data_reader.py @@ -4,7 +4,6 @@ import re import warnings from ast import literal_eval -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -599,9 +598,9 @@ def load_data( return_channels: bool = False, ) -> Union[ np.ndarray, - Tuple[np.ndarray, Real], + Tuple[np.ndarray, int], Tuple[np.ndarray, List[str]], - Tuple[np.ndarray, Real, List[str]], + Tuple[np.ndarray, int, List[str]], ]: """Load EEG data from the record. @@ -640,7 +639,7 @@ def load_data( ------- data : numpy.ndarray The loaded EEG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. data_channels : list of str, optional @@ -721,7 +720,7 @@ def load_bipolar_data( units: Literal["mV", "uV", "muV", "μV", None] = "uV", fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """Load bipolar EEG data from the record. Bipolar EEG is the difference between two channels. @@ -754,7 +753,7 @@ def load_bipolar_data( ------- data : numpy.ndarray The loaded EEG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. """ diff --git a/benchmarks/train_hybrid_cpsc2020/metrics.py b/benchmarks/train_hybrid_cpsc2020/metrics.py index 21aa5bee..83b3407a 100644 --- a/benchmarks/train_hybrid_cpsc2020/metrics.py +++ b/benchmarks/train_hybrid_cpsc2020/metrics.py @@ -1,18 +1,8 @@ """ """ -from numbers import Real from typing import List, Sequence, Tuple, Union import numpy as np - -try: - import torch_ecg # noqa: F401 -except ModuleNotFoundError: - import sys - from pathlib import Path - - sys.path.insert(0, str(Path(__file__).absolute().parents[2])) - from cfg import BaseCfg from torch_ecg.cfg import CFG @@ -323,7 +313,7 @@ def compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> float: # Compute F-beta and G-beta measures from the unofficial phase of the Challenge. -def compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) -> Tuple[float, float]: +def compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: float) -> Tuple[float, float]: """checked,""" num_recordings, num_classes = np.shape(labels) diff --git a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_denoise.py b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_denoise.py index 36d93883..f1094d65 100644 --- a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_denoise.py +++ b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_denoise.py @@ -15,19 +15,10 @@ to add """ -from numbers import Real from typing import List import numpy as np -try: - import torch_ecg # noqa: F401 -except ModuleNotFoundError: - import sys - from pathlib import Path - - sys.path.insert(0, str(Path(__file__).absolute().parents[3])) - from torch_ecg.cfg import CFG from torch_ecg.utils.utils_data import mask_to_intervals @@ -36,7 +27,7 @@ ] -def ecg_denoise(filtered_sig: np.ndarray, fs: Real, config: CFG) -> List[List[int]]: +def ecg_denoise(filtered_sig: np.ndarray, fs: int, config: CFG) -> List[List[int]]: """ a naive function removing non-ECG segments (flat and motion artefact) @@ -45,7 +36,7 @@ def ecg_denoise(filtered_sig: np.ndarray, fs: Real, config: CFG) -> List[List[in ---------- filtered_sig: ndarray, 1d filtered (typically bandpassed) ECG signal, - fs: real number, + fs: int, sampling frequency of `filtered_sig` config: dict, configs of relavant parameters, like window, step, etc. diff --git a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_preproc.py b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_preproc.py index 2cee4b7c..66223b18 100644 --- a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_preproc.py +++ b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_preproc.py @@ -19,7 +19,6 @@ import os import time from copy import deepcopy -from numbers import Real from typing import Dict, Optional import numpy as np @@ -64,14 +63,14 @@ ] -def preprocess_signal(raw_sig: np.ndarray, fs: Real, config: Optional[CFG] = None) -> Dict[str, np.ndarray]: +def preprocess_signal(raw_sig: np.ndarray, fs: int, config: Optional[CFG] = None) -> Dict[str, np.ndarray]: """ Parameters ---------- raw_sig: ndarray, the raw ecg signal - fs: real number, + fs: int, sampling frequency of `raw_sig` config: dict, optional, extra process configuration, @@ -135,7 +134,7 @@ def preprocess_signal(raw_sig: np.ndarray, fs: Real, config: Optional[CFG] = Non def parallel_preprocess_signal( raw_sig: np.ndarray, - fs: Real, + fs: int, config: Optional[CFG] = None, save_dir: Optional[str] = None, save_fmt: str = "npy", @@ -147,7 +146,7 @@ def parallel_preprocess_signal( ---------- raw_sig: ndarray, the raw ecg signal - fs: real number, + fs: int, sampling frequency of `raw_sig` config: dict, optional, extra process configuration, diff --git a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks.py b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks.py index 25fa51f5..a48d5c07 100644 --- a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks.py +++ b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks.py @@ -13,8 +13,6 @@ [1] Liu, Feifei, et al. "Performance analysis of ten common QRS detectors on different ECG application cases." Journal of healthcare engineering 2018 (2018). """ -from numbers import Real - import biosppy.signals.ecg as BSE import numpy as np from wfdb.processing.qrs import GQRS, XQRS # noqa: F401 @@ -34,7 +32,7 @@ # --------------------------------------------------------------------- # algorithms from wfdb -def xqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def xqrs_detect(sig: np.ndarray, fs: int, **kwargs) -> np.ndarray: """ default kwargs: sampfrom=0, sampto='end', conf=None, learn=True, verbose=True @@ -52,7 +50,7 @@ def xqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def gqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def gqrs_detect(sig: np.ndarray, fs: int, **kwargs) -> np.ndarray: """ default kwargs: d_sig=None, adc_gain=None, adc_zero=None, @@ -90,7 +88,7 @@ def gqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: # --------------------------------------------------------------------- # algorithms from biosppy -def hamilton_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def hamilton_detect(sig: np.ndarray, fs: int, **kwargs) -> np.ndarray: """ References @@ -110,7 +108,7 @@ def hamilton_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def ssf_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def ssf_detect(sig: np.ndarray, fs: int, **kwargs) -> np.ndarray: """ Slope Sum Function (SSF) @@ -138,7 +136,7 @@ def ssf_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def christov_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def christov_detect(sig: np.ndarray, fs: int, **kwargs) -> np.ndarray: """ References @@ -156,7 +154,7 @@ def christov_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def engzee_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def engzee_detect(sig: np.ndarray, fs: int, **kwargs) -> np.ndarray: """ References @@ -179,7 +177,7 @@ def engzee_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def gamboa_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def gamboa_detect(sig: np.ndarray, fs: int, **kwargs) -> np.ndarray: """ References diff --git a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks_dl.py b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks_dl.py index 508b97df..a57f2e3f 100644 --- a/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks_dl.py +++ b/benchmarks/train_hybrid_cpsc2020/signal_processing/ecg_rpeaks_dl.py @@ -8,7 +8,6 @@ import math from itertools import repeat -from numbers import Real from typing import Sequence, Union import biosppy.signals.ecg as BSE @@ -34,7 +33,7 @@ CNN_MODEL, CRNN_MODEL = load_model("keras_ecg_seq_lab_net") -def seq_lab_net_detect(sig: np.ndarray, fs: Real, correction: bool = False, **kwargs) -> np.ndarray: +def seq_lab_net_detect(sig: np.ndarray, fs: int, correction: bool = False, **kwargs) -> np.ndarray: """ use model of entry 0416 of CPSC2019, @@ -46,7 +45,7 @@ def seq_lab_net_detect(sig: np.ndarray, fs: Real, correction: bool = False, **kw ---------- sig: ndarray, the (raw) ECG signal of arbitrary length, with units in mV - fs: real number, + fs: int, sampling frequency of `sig` correction: bool, default False, if True, correct rpeaks to local maximum in a small nbh diff --git a/benchmarks/train_hybrid_cpsc2021/aux_metrics.py b/benchmarks/train_hybrid_cpsc2021/aux_metrics.py index 5d959379..a148b618 100644 --- a/benchmarks/train_hybrid_cpsc2021/aux_metrics.py +++ b/benchmarks/train_hybrid_cpsc2021/aux_metrics.py @@ -7,7 +7,6 @@ """ import multiprocessing as mp -from numbers import Real from typing import Dict, Optional, Sequence, Union import numpy as np @@ -37,7 +36,7 @@ def compute_rpeak_metric( rpeaks_truths: Sequence[Union[np.ndarray, Sequence[int]]], rpeaks_preds: Sequence[Union[np.ndarray, Sequence[int]]], - fs: Real, + fs: int, thr: float = 0.075, verbose: int = 0, ) -> Dict[str, float]: @@ -49,7 +48,7 @@ def compute_rpeak_metric( sequence of ground truths of rpeaks locations (indices) from multiple records rpeaks_preds: sequence, predictions of ground truths of rpeaks locations (indices) for multiple records - fs: real number, + fs: int, sampling frequency of ECG signal thr: float, default 0.075, threshold for a prediction to be truth positive, @@ -179,7 +178,7 @@ def compute_rr_metric( def compute_main_task_metric( mask_truths: Sequence[Union[np.ndarray, Sequence[int]]], mask_preds: Sequence[Union[np.ndarray, Sequence[int]]], - fs: Real, + fs: int, reduction: int, weight_masks: Optional[Sequence[Union[np.ndarray, Sequence[int]]]] = None, rpeaks: Optional[Sequence[Sequence[int]]] = None, @@ -196,7 +195,7 @@ def compute_main_task_metric( sequences of AF labels on rr intervals, of shape (n_samples, seq_len) mask_preds: array_like, sequences of AF predictions on rr intervals, of shape (n_samples, seq_len) - fs: Real, + fs: int, sampling frequency of the model input ECGs, used when (indices of) `rpeaks` not privided reduction: int, @@ -270,7 +269,7 @@ def compute_main_task_metric( # """ # __name__ = "WeightedBoundaryLoss" -# def __init__(self, weight_map:Dict[int,Real], sigma:Real, w:Real) -> None: +# def __init__(self, weight_map:Dict[int, Union[int, float]], sigma:float, w:float) -> None: # """ # """ # self.weight_map = weight_map diff --git a/benchmarks/train_hybrid_cpsc2021/model.py b/benchmarks/train_hybrid_cpsc2021/model.py index 9d02c102..59af1152 100644 --- a/benchmarks/train_hybrid_cpsc2021/model.py +++ b/benchmarks/train_hybrid_cpsc2021/model.py @@ -10,7 +10,6 @@ """ from itertools import repeat -from numbers import Real from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np @@ -733,7 +732,7 @@ def inference( episode_len_thr=episode_len_thr, ) if rpeaks is not None: - if isinstance((rpeaks[0]), Real): + if isinstance((rpeaks[0]), (int, np.integer)): _rpeaks = [rpeaks] else: _rpeaks = rpeaks @@ -791,7 +790,7 @@ def from_checkpoint(path: str, device: Optional[torch.device] = None) -> Tuple[t def _qrs_detection_post_process( prob: np.ndarray, - fs: Real, + fs: int, reduction: int, bin_pred_thr: float = 0.5, skip_dist: int = 500, @@ -805,7 +804,7 @@ def _qrs_detection_post_process( ---------- prob: ndarray, array of predicted probability - fs: real number, + fs: int, sampling frequency of the ECG reduction: int, reduction (granularity) of `prob` w.r.t. the ECG @@ -901,10 +900,10 @@ def _qrs_detection_post_process( def _main_task_post_process( prob: np.ndarray, - fs: Real, + fs: int, reduction: int, bin_pred_thr: float = 0.5, - rpeaks: Sequence[Sequence[int]] = None, + rpeaks: Optional[Sequence[Sequence[int]]] = None, siglens: Optional[Sequence[int]] = None, episode_len_thr: int = 5, ) -> Tuple[List[List[List[int]]], np.ndarray]: @@ -917,7 +916,7 @@ def _main_task_post_process( ---------- prob: ndarray, predicted af mask, of shape (batch_size, seq_len) - fs: real number, + fs: int, sampling frequency of the signal reduction: int, reduction ratio of the predicted af mask w.r.t. the signal diff --git a/benchmarks/train_mtl_cinc2022/inputs.py b/benchmarks/train_mtl_cinc2022/inputs.py index 1c482bb9..8b7e7e73 100644 --- a/benchmarks/train_mtl_cinc2022/inputs.py +++ b/benchmarks/train_mtl_cinc2022/inputs.py @@ -138,7 +138,7 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self._device - def compute_input_shape(self, waveform_shape: Union[Sequence[int], torch.Size]) -> Tuple[Union[type(None), int], ...]: + def compute_input_shape(self, waveform_shape: Union[Sequence[int], torch.Size]) -> Tuple[Union[None, int], ...]: """ computes the input shape of the model based on the input type and the waveform shape diff --git a/benchmarks/train_multi_cpsc2019/metrics.py b/benchmarks/train_multi_cpsc2019/metrics.py index 3b02b88d..6e246c06 100644 --- a/benchmarks/train_multi_cpsc2019/metrics.py +++ b/benchmarks/train_multi_cpsc2019/metrics.py @@ -5,7 +5,6 @@ """ import math -from numbers import Real from typing import Sequence, Union import numpy as np @@ -18,7 +17,7 @@ def compute_metrics( rpeaks_truths: Sequence[Union[np.ndarray, Sequence[int]]], rpeaks_preds: Sequence[Union[np.ndarray, Sequence[int]]], - fs: Real, + fs: int, thr: float = 0.075, verbose: int = 0, ) -> float: diff --git a/benchmarks/train_unet_ludb/metrics.py b/benchmarks/train_unet_ludb/metrics.py index 22d92a8d..f7163ff0 100644 --- a/benchmarks/train_unet_ludb/metrics.py +++ b/benchmarks/train_unet_ludb/metrics.py @@ -14,7 +14,6 @@ """ -from numbers import Real from typing import Dict, Sequence import numpy as np @@ -44,7 +43,7 @@ def compute_metrics( truth_masks: Sequence[np.ndarray], pred_masks: Sequence[np.ndarray], class_map: Dict[str, int], - fs: Real, + fs: int, mask_format: str = "channel_first", ) -> Dict[str, Dict[str, float]]: """ @@ -62,7 +61,7 @@ def compute_metrics( class_map: dict, class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain 'pwave', 'qrs', 'twave' - fs: real number, + fs: int, sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, hence the error and standard deviations of errors @@ -101,7 +100,7 @@ class map, mapping names to waves to numbers from 0 to n_classes-1, def compute_metrics_waveform( truth_waveforms: Sequence[Sequence[ECGWaveForm]], pred_waveforms: Sequence[Sequence[ECGWaveForm]], - fs: Real, + fs: int, ) -> Dict[str, Dict[str, float]]: """ @@ -116,7 +115,7 @@ def compute_metrics_waveform( pred_waveforms: sequence of sequence of `ECGWaveForm`s, the predictions corresponding to `truth_waveforms`, each element is a sequence of `ECGWaveForm`s from the same sample - fs: real number, + fs: int, sampling frequency of the signal corresponding to the waveforms, used to compute the duration of each waveform, hence the error and standard deviations of errors @@ -213,7 +212,7 @@ def compute_metrics_waveform( def _compute_metrics_waveform( - truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: Real + truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: int ) -> Dict[str, Dict[str, float]]: """ @@ -226,7 +225,7 @@ def _compute_metrics_waveform( the ground truth preds: sequence of `ECGWaveForm`s, the predictions corresponding to `truths`, - fs: real number, + fs: int, sampling frequency of the signal corresponding to the waveforms, used to compute the duration of each waveform, hence the error and standard deviations of errors @@ -295,7 +294,7 @@ def _compute_metrics_waveform( return scorings -def _compute_metrics_base(truths: Sequence[Real], preds: Sequence[Real], fs: Real) -> Dict[str, float]: +def _compute_metrics_base(truths: Sequence[int], preds: Sequence[int], fs: int) -> Dict[str, float]: """ Parameters @@ -304,7 +303,7 @@ def _compute_metrics_base(truths: Sequence[Real], preds: Sequence[Real], fs: Rea ground truth of indices of corresponding critical points preds: sequence of real numbers, predicted indices of corresponding critical points - fs: real number, + fs: int, sampling frequency of the signal corresponding to the critical points, used to compute the duration of each waveform, hence the error and standard deviations of errors diff --git a/test/test_databases/test_shhs.py b/test/test_databases/test_shhs.py index 9f66e376..1f851b93 100644 --- a/test/test_databases/test_shhs.py +++ b/test/test_databases/test_shhs.py @@ -6,7 +6,6 @@ import os import time -from numbers import Real from pathlib import Path import numpy as np @@ -86,20 +85,20 @@ def test_load_psg_data(self): assert isinstance(value, tuple) assert len(value) == 2 assert isinstance(value[0], np.ndarray) - assert isinstance(value[1], Real) and value[1] > 0 # type: ignore + assert isinstance(value[1], (int, float)) and value[1] > 0 # type: ignore available_signals = reader.get_available_signals(0) for signal in available_signals: # type: ignore psg_data = reader.load_psg_data(0, channel=signal, physical=True) assert isinstance(psg_data, tuple) assert len(psg_data) == 2 assert isinstance(psg_data[0], np.ndarray) - assert isinstance(psg_data[1], Real) and psg_data[1] > 0 # type: ignore + assert isinstance(psg_data[1], (int, float)) and psg_data[1] > 0 # type: ignore def test_load_data(self): data, fs = reader.load_data(0) assert isinstance(data, np.ndarray) assert data.ndim == 2 - assert isinstance(fs, Real) and fs > 0 # type: ignore + assert isinstance(fs, (int, float)) and fs > 0 # type: ignore data_1, fs_1 = reader.load_data(0, fs=500, data_format="flat") assert isinstance(data_1, np.ndarray) assert data_1.ndim == 1 @@ -509,11 +508,11 @@ def test_get_fs(self): available_signals = reader.get_available_signals(0) for sig in available_signals: # type: ignore fs = reader.get_fs(0, sig) - assert isinstance(fs, Real) and fs > 0 # type: ignore + assert isinstance(fs, (int, float)) and fs > 0 # type: ignore rec = reader.rec_with_rpeaks_ann[0] fs = reader.get_fs(rec, "rpeak") - assert isinstance(fs, Real) and fs > 0 # type: ignore + assert isinstance(fs, (int, float)) and fs > 0 # type: ignore rec = "shhs2-200001" # a record (both signal and ann. files) that does not exist fs = reader.get_fs(rec) diff --git a/test/test_preprocessors.py b/test/test_preprocessors.py index abd8f20c..a90d70f4 100644 --- a/test/test_preprocessors.py +++ b/test/test_preprocessors.py @@ -1,7 +1,6 @@ """ """ import itertools -from numbers import Real from typing import Tuple import numpy as np @@ -30,7 +29,7 @@ class DummyPreProcessor(PreProcessor): def __init__(self) -> None: super().__init__() - def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: + def apply(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]: return sig, fs diff --git a/test/test_preprocessors_t.py b/test/test_preprocessors_t.py index 51db59db..a3aa3fbd 100644 --- a/test/test_preprocessors_t.py +++ b/test/test_preprocessors_t.py @@ -116,6 +116,8 @@ def test_preproc_manager() -> None: def test_bandpass() -> None: + from torch_ecg.utils.utils_signal_t import bandpass_filter + bp = BandPass(fs=500) # Tensor sig = test_sig.clone() @@ -126,14 +128,36 @@ def test_bandpass() -> None: sig_np = bp(sig_np) assert isinstance(sig_np, np.ndarray) + # lowcut=0 now warns and disables the high-pass side (treated as lowpass) bp = BandPass(fs=500, lowcut=0, highcut=40) sig = test_sig.clone() - sig = bp(sig) + with pytest.warns(RuntimeWarning, match="lowcut <= 0"): + sig = bp(sig) bp = BandPass(fs=500, lowcut=1.5, highcut=None, inplace=False) sig = test_sig.clone() sig = bp(sig) + # highcut >= nyquist warns and disables the low-pass side (treated as highpass) + with pytest.warns(RuntimeWarning, match="highcut >= Nyquist"): + bandpass_filter(test_sig.clone(), fs=500, lowcut=1.0, highcut=250.0) + + # invalid fs + with pytest.raises(ValueError, match="fs must be a positive real number"): + bandpass_filter(test_sig.clone(), fs=-1) + + # lowcut >= nyquist raises + with pytest.raises(ValueError, match="lowcut must be less than Nyquist"): + bandpass_filter(test_sig.clone(), fs=500, lowcut=300.0) + + # highcut <= 0 raises + with pytest.raises(ValueError, match="highcut must be positive"): + bandpass_filter(test_sig.clone(), fs=500, highcut=-5.0) + + # lowcut >= highcut raises + with pytest.raises(ValueError, match="lowcut must be less than highcut"): + bandpass_filter(test_sig.clone(), fs=500, lowcut=80.0, highcut=40.0) + del bp, sig diff --git a/torch_ecg/_preprocessors/normalize.py b/torch_ecg/_preprocessors/normalize.py index bcbca694..58e825dd 100644 --- a/torch_ecg/_preprocessors/normalize.py +++ b/torch_ecg/_preprocessors/normalize.py @@ -1,6 +1,5 @@ """Normalization of the signals.""" -from numbers import Real from typing import Any, List, Literal, Tuple, Union from numpy.typing import NDArray @@ -80,13 +79,13 @@ def __init__( self.mean = mean self.std = std self.per_channel = per_channel - if isinstance(std, Real): + if isinstance(std, (float, int)): assert std > 0, "standard deviation should be positive" else: assert (std > 0).all(), "standard deviations should all be positive" if not per_channel: - assert isinstance(mean, Real) and isinstance( - std, Real + assert isinstance(mean, (float, int)) and isinstance( + std, (float, int) ), "mean and std should be real numbers in the non per-channel setting" def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]: diff --git a/torch_ecg/augmenters/mixup.py b/torch_ecg/augmenters/mixup.py index f73083f4..f1bbaef4 100644 --- a/torch_ecg/augmenters/mixup.py +++ b/torch_ecg/augmenters/mixup.py @@ -1,7 +1,6 @@ """ """ from copy import deepcopy -from numbers import Real from random import shuffle from typing import Any, List, Optional, Sequence, Tuple @@ -32,9 +31,9 @@ class Mixup(Augmenter): ---------- fs : int, optional Sampling frequency of the ECGs to be augmented. - alpha : numbers.Real, default 0.5 + alpha : float, default 0.5 alpha parameter of the Beta distribution used in Mixup. - beta : numbers.Real, optional + beta : float, optional beta parameter of the Beta distribution used in Mixup, defaults to `alpha`. prob : float, default 0.5 @@ -67,8 +66,8 @@ class Mixup(Augmenter): def __init__( self, fs: Optional[int] = None, - alpha: Real = 0.5, - beta: Optional[Real] = None, + alpha: float = 0.5, + beta: Optional[float] = None, prob: float = 0.5, inplace: bool = True, **kwargs: Any, diff --git a/torch_ecg/augmenters/random_flip.py b/torch_ecg/augmenters/random_flip.py index 0401ee47..c755d054 100644 --- a/torch_ecg/augmenters/random_flip.py +++ b/torch_ecg/augmenters/random_flip.py @@ -1,12 +1,12 @@ """ """ -from numbers import Real from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch from torch import Tensor +from ..cfg import DEFAULTS from .base import Augmenter from .registry import AUGMENTERS @@ -60,10 +60,10 @@ def __init__( self.per_channel = per_channel self.inplace = inplace self.prob = prob - if isinstance(self.prob, Real): - self.prob = np.array([self.prob, self.prob]) + if isinstance(self.prob, (float, int)): + self.prob = np.array([self.prob, self.prob], dtype=DEFAULTS.np_dtype) else: - self.prob = np.array(self.prob) + self.prob = np.array(self.prob, dtype=DEFAULTS.np_dtype) assert (self.prob >= 0).all() and (self.prob <= 1).all(), "Probability must be between 0 and 1" def forward( diff --git a/torch_ecg/augmenters/random_masking.py b/torch_ecg/augmenters/random_masking.py index 768553e9..40b8605c 100644 --- a/torch_ecg/augmenters/random_masking.py +++ b/torch_ecg/augmenters/random_masking.py @@ -1,6 +1,5 @@ """ """ -from numbers import Real from random import randint from typing import Any, List, Optional, Sequence, Tuple, Union @@ -8,6 +7,7 @@ import torch from torch import Tensor +from ..cfg import DEFAULTS from .base import Augmenter from .registry import AUGMENTERS @@ -25,11 +25,11 @@ class RandomMasking(Augmenter): ---------- fs : int Sampling frequency of the ECGs to be augmented. - mask_value : numbers.Real, default 0.0 + mask_value : int or float, default 0.0 Value to mask with. - mask_width : Sequence[numbers.Real], default ``[0.08, 0.18]`` + mask_width : Sequence[int or float], default ``[0.08, 0.18]`` Width range of the masking window, with units in seconds - prob : numbers.Real or Sequence[numbers.Real], default ``[0.3, 0.15]`` + prob : float or Sequence[float], default ``[0.3, 0.15]`` Probabilities of masking ECG signals, the first probality is for the batch dimension, the second probability is for the lead dimension. @@ -55,19 +55,19 @@ class RandomMasking(Augmenter): def __init__( self, fs: int, - mask_value: Real = 0.0, - mask_width: Sequence[Real] = [0.08, 0.18], - prob: Union[Sequence[Real], Real] = [0.3, 0.15], + mask_value: Union[int, float] = 0.0, + mask_width: Sequence[Union[int, float]] = [0.08, 0.18], + prob: Union[Sequence[float], float] = [0.3, 0.15], inplace: bool = True, **kwargs: Any, ) -> None: super().__init__() self.fs = fs self.prob = prob - if isinstance(self.prob, Real): - self.prob = np.array([self.prob, self.prob]) + if isinstance(self.prob, (float, int)): + self.prob = np.array([self.prob, self.prob], dtype=DEFAULTS.np_dtype) else: - self.prob = np.array(self.prob) + self.prob = np.array(self.prob, dtype=DEFAULTS.np_dtype) assert (self.prob >= 0).all() and (self.prob <= 1).all(), "Probability must be between 0 and 1" self.mask_value = mask_value self.mask_width = (np.array(mask_width) * self.fs).round().astype(int) diff --git a/torch_ecg/augmenters/random_renormalize.py b/torch_ecg/augmenters/random_renormalize.py index 69fad112..687721f0 100644 --- a/torch_ecg/augmenters/random_renormalize.py +++ b/torch_ecg/augmenters/random_renormalize.py @@ -1,6 +1,5 @@ """ """ -from numbers import Real from typing import Any, Iterable, List, Optional, Sequence, Tuple import numpy as np @@ -55,18 +54,18 @@ class RandomRenormalize(Augmenter): def __init__( self, - mean: Iterable[Real] = [-0.05, 0.1], - std: Iterable[Real] = [0.08, 0.32], + mean: Iterable[float] = [-0.05, 0.1], + std: Iterable[float] = [0.08, 0.32], per_channel: bool = False, prob: float = 0.5, inplace: bool = True, **kwargs: Any, ) -> None: super().__init__() - self.mean = np.array(mean) + self.mean = np.array(mean, dtype=DEFAULTS.np_dtype) self.mean_mean = self.mean.mean(axis=-1, keepdims=True) self.mean_scale = (self.mean[..., -1] - self.mean_mean) * 0.3 - self.std = np.array(std) + self.std = np.array(std, dtype=DEFAULTS.np_dtype) self.std_mean = self.std.mean(axis=-1, keepdims=True) self.std_scale = (self.std[..., -1] - self.std_mean) * 0.3 self.per_channel = per_channel diff --git a/torch_ecg/augmenters/stretch_compress.py b/torch_ecg/augmenters/stretch_compress.py index 6024b655..c582524d 100644 --- a/torch_ecg/augmenters/stretch_compress.py +++ b/torch_ecg/augmenters/stretch_compress.py @@ -1,6 +1,5 @@ """ """ -from numbers import Real from random import choice, randint from typing import Any, List, Optional, Sequence, Tuple, Union @@ -32,7 +31,7 @@ class StretchCompress(Augmenter): Parameters ---------- - ratio : numbers.Real, default 6 + ratio : int or float, default 6 Mean ratio of the stretch or compress. If it is in the interval[1, 100], then it will be transformed to [0, 1]. @@ -59,7 +58,7 @@ class StretchCompress(Augmenter): __name__ = "StretchCompress" - def __init__(self, ratio: Real = 6, prob: float = 0.5, inplace: bool = True, **kwargs: Any) -> None: + def __init__(self, ratio: Union[int, float] = 6, prob: float = 0.5, inplace: bool = True, **kwargs: Any) -> None: super().__init__() self.prob = prob assert 0 <= self.prob <= 1, "Probability must be between 0 and 1" @@ -228,13 +227,13 @@ def extra_repr_keys(self) -> List[str]: def _stretch_compress_one_batch_element( - ratio: Real, sig: Tensor, *labels: Sequence[Tensor] + ratio: Union[int, float], sig: Tensor, *labels: Sequence[Tensor] ) -> Union[Tensor, Tuple[Tensor, ...]]: """Stretch or compress one batch element of the ECGs. Parameters ---------- - ratio : numbers.Real + ratio : int or float Ratio of the stretch/compress. sig : torch.Tensor The ECGs to be stretched or compressed, @@ -339,7 +338,7 @@ class StretchCompressOffline(ReprMixin): Parameters ---------- - ratio : numbers.Real, default 6 + ratio : int or float, default 6 Mean ratio of the stretch or compress. If it is in the interval [1, 100], then it will be transformed to [0, 1]. @@ -368,7 +367,7 @@ class StretchCompressOffline(ReprMixin): def __init__( self, - ratio: Real = 6, + ratio: Union[int, float] = 6, prob: float = 0.5, overlap: float = 0.5, critical_overlap: float = 0.85, diff --git a/torch_ecg/components/inputs.py b/torch_ecg/components/inputs.py index 3153b0a4..4a40eedd 100644 --- a/torch_ecg/components/inputs.py +++ b/torch_ecg/components/inputs.py @@ -202,7 +202,7 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self._device - def compute_input_shape(self, waveform_shape: Union[Sequence[int], torch.Size]) -> Tuple[Union[type(None), int], ...]: + def compute_input_shape(self, waveform_shape: Union[Sequence[int], torch.Size]) -> Tuple[Union[None, int], ...]: """Computes the input shape of the model based on the input type and the waveform shape. diff --git a/torch_ecg/components/loggers.py b/torch_ecg/components/loggers.py index e56fc8ff..4da136ea 100644 --- a/torch_ecg/components/loggers.py +++ b/torch_ecg/components/loggers.py @@ -11,7 +11,6 @@ import warnings from abc import ABC, abstractmethod from datetime import datetime -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -87,7 +86,7 @@ class BaseLogger(ReprMixin, ABC): @abstractmethod def log_metrics( self, - metrics: Dict[str, Union[Real, torch.Tensor]], + metrics: Dict[str, Union[int, float, torch.Tensor]], step: Optional[int] = None, epoch: Optional[int] = None, part: str = "train", @@ -178,7 +177,7 @@ def __init__( @add_docstring(_log_metrics_doc) def log_metrics( self, - metrics: Dict[str, Union[Real, torch.Tensor]], + metrics: Dict[str, Union[int, float, torch.Tensor]], step: Optional[int] = None, epoch: Optional[int] = None, part: str = "train", @@ -309,7 +308,7 @@ def __init__( @add_docstring(_log_metrics_doc) def log_metrics( self, - metrics: Dict[str, Union[Real, torch.Tensor]], + metrics: Dict[str, Union[int, float, torch.Tensor]], step: Optional[int] = None, epoch: Optional[int] = None, part: str = "train", @@ -408,7 +407,7 @@ def __init__( @add_docstring(_log_metrics_doc) def log_metrics( self, - metrics: Dict[str, Union[Real, torch.Tensor]], + metrics: Dict[str, Union[int, float, torch.Tensor]], step: Optional[int] = None, epoch: Optional[int] = None, part: str = "train", @@ -573,7 +572,7 @@ def _add_tensorboardx_logger(self) -> None: @add_docstring(_log_metrics_doc) def log_metrics( self, - metrics: Dict[str, Union[Real, torch.Tensor]], + metrics: Dict[str, Union[int, float, torch.Tensor]], step: Optional[int] = None, epoch: Optional[int] = None, part: str = "train", diff --git a/torch_ecg/components/metrics.py b/torch_ecg/components/metrics.py index 83ef28da..7f5e3257 100644 --- a/torch_ecg/components/metrics.py +++ b/torch_ecg/components/metrics.py @@ -509,7 +509,7 @@ def set_macro(self, macro: bool) -> None: class_map : dict Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain {", ".join([f'"{item}"' for item in ECGWaveFormNames])}. - fs : numbers.Real + fs : int Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, and thus the error and standard deviations of errors. diff --git a/torch_ecg/components/outputs.py b/torch_ecg/components/outputs.py index f86ce4d8..834ef27d 100644 --- a/torch_ecg/components/outputs.py +++ b/torch_ecg/components/outputs.py @@ -356,7 +356,7 @@ def required_fields(self) -> Set[str]: Parameters ---------- - fs : numbers.Real + fs : int Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, and thus the error and standard deviations of errors. diff --git a/torch_ecg/databases/aux_data/cinc2020_aux_data.py b/torch_ecg/databases/aux_data/cinc2020_aux_data.py index b260d313..c863158b 100644 --- a/torch_ecg/databases/aux_data/cinc2020_aux_data.py +++ b/torch_ecg/databases/aux_data/cinc2020_aux_data.py @@ -5,7 +5,6 @@ """ from io import StringIO -from numbers import Real from typing import Dict, Literal, Optional, Sequence, Union import pandas as pd @@ -428,10 +427,10 @@ def get_class_weight( exclude_classes: Optional[Sequence[str]] = None, scored_only: bool = False, normalize: bool = True, - threshold: Optional[Real] = 0, + threshold: Optional[Union[int, float]] = 0, fmt: str = "a", - min_weight: Real = 0.5, -) -> Dict[str, int]: + min_weight: Union[int, float] = 0.5, +) -> Dict[str, Union[int, float]]: """Get the weight of each class in each tranche. Parameters @@ -446,7 +445,7 @@ def get_class_weight( normalize : bool, default True Whether collapse equivalent classes into one or not, used only when `scored_only` is True. - threshold : numbers.Real, default 0 + threshold : int or float, default 0 Minimum ratio (0-1) or absolute number (>1) of a class to be counted. fmt : str, default "a" Format of the names of the classes in the returned dict, @@ -454,7 +453,7 @@ def get_class_weight( - "a", abbreviations - "f", full names - "s", SNOMED CT Code - min_weight : numbers.Real, default 0.5 + min_weight : int or float, default 0.5 Minimum value of the weight of all classes, or equivalently the weight of the largest class. diff --git a/torch_ecg/databases/aux_data/cinc2021_aux_data.py b/torch_ecg/databases/aux_data/cinc2021_aux_data.py index 53e7e93c..f08cb9b8 100644 --- a/torch_ecg/databases/aux_data/cinc2021_aux_data.py +++ b/torch_ecg/databases/aux_data/cinc2021_aux_data.py @@ -5,7 +5,6 @@ """ from io import StringIO -from numbers import Real from typing import Dict, List, Literal, Optional, Sequence, Union import pandas as pd @@ -517,10 +516,10 @@ def get_class_weight( exclude_classes: Optional[Sequence[str]] = None, scored_only: bool = False, normalize: bool = True, - threshold: Optional[Real] = 0, + threshold: Optional[Union[int, float]] = 0, fmt: str = "a", - min_weight: Real = 0.5, -) -> Dict[str, int]: + min_weight: Union[int, float] = 0.5, +) -> Dict[str, Union[int, float]]: """Get the weight of each class in each tranche. Parameters @@ -535,7 +534,7 @@ def get_class_weight( normalize : bool, default True Whether collapse equivalent classes into one or not, used only when `scored_only` is True. - threshold : numbers.Real, default 0 + threshold : int or float, default 0 Minimum ratio (0-1) or absolute number (>1) of a class to be counted. fmt : str, default "a" Format of the names of the classes in the returned dict, @@ -543,7 +542,7 @@ def get_class_weight( - "a", abbreviations - "f", full names - "s", SNOMED CT Code - min_weight : numbers.Real, default 0.5 + min_weight : int or float, default 0.5 Minimum value of the weight of all classes, or equivalently the weight of the largest class. diff --git a/torch_ecg/databases/base.py b/torch_ecg/databases/base.py index 34a51afa..2ef3b48c 100644 --- a/torch_ecg/databases/base.py +++ b/torch_ecg/databases/base.py @@ -26,7 +26,6 @@ from abc import ABC, abstractmethod from copy import deepcopy from dataclasses import dataclass -from numbers import Real from pathlib import Path from string import punctuation from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -625,9 +624,9 @@ def load_data( # type: ignore sampto: Optional[int] = None, data_format: str = "channel_first", units: Union[str, None] = "mV", - fs: Optional[Real] = None, + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load physical (converted from digital) ECG data, which is more understandable for humans; or load digital signal directly. @@ -651,7 +650,7 @@ def load_data( # type: ignore units : str or None, default "mV" Units of the output signal, can also be "μV" (aliases "uV", "muV"). None for digital data, without digital-to-physical conversion. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency; if None, `self.fs` will be used if available and not None; @@ -666,7 +665,7 @@ def load_data( # type: ignore data : numpy.ndarray The ECG data loaded from the record, with given `units` and `data_format`. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. diff --git a/torch_ecg/databases/cpsc_databases/cpsc2018.py b/torch_ecg/databases/cpsc_databases/cpsc2018.py index 4ce15724..d5440017 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2018.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2018.py @@ -3,7 +3,6 @@ import os import re import warnings -from numbers import Real from pathlib import Path from typing import Any, List, Optional, Sequence, Tuple, Union @@ -265,7 +264,7 @@ def load_data( data_format="channel_first", units: str = "mV", return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load the ECG data of a record. Parameters @@ -287,7 +286,7 @@ def load_data( ------- data : numpy.ndarray The loaded ECG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. diff --git a/torch_ecg/databases/cpsc_databases/cpsc2019.py b/torch_ecg/databases/cpsc_databases/cpsc2019.py index fc229b76..1470d349 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2019.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2019.py @@ -2,7 +2,6 @@ import json import os -from numbers import Real from pathlib import Path from typing import Any, List, Optional, Sequence, Tuple, Union @@ -221,9 +220,9 @@ def load_data( rec: Union[int, str], data_format: str = "channel_first", units: str = "mV", - fs: Optional[Real] = None, + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load the ECG data of the record `rec`. Parameters @@ -237,7 +236,7 @@ def load_data( "flat" (alias "plain"). units : str or None, default "mV" Units of the output signal, can also be "μV" (with aliases "uV", "muV"). - fs : numbers.Real, optional + fs : int, optional If provided, the loaded data will be resampled to this frequency, otherwise the original sampling frequency will be used. return_fs : bool, default False @@ -247,7 +246,7 @@ def load_data( ------- data : numpy.ndarray, The loaded ECG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. @@ -400,7 +399,7 @@ def webpage(self) -> str: def compute_metrics( rpeaks_truths: Sequence[Union[NDArray, Sequence[int]]], rpeaks_preds: Sequence[Union[NDArray, Sequence[int]]], - fs: Real, + fs: int, thr: float = 0.075, verbose: int = 0, ) -> float: @@ -413,7 +412,7 @@ def compute_metrics( Sequence of ground truths of rpeaks locations from multiple records. rpeaks_preds : sequence Predictions of ground truths of rpeaks locations for multiple records. - fs : numbers.Real + fs : int Sampling frequency of ECG signal. thr : float, default 0.075 Threshold for a prediction to be truth positive, diff --git a/torch_ecg/databases/cpsc_databases/cpsc2020.py b/torch_ecg/databases/cpsc_databases/cpsc2020.py index f36f029f..6c8f654b 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2020.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2020.py @@ -2,7 +2,6 @@ import math import os -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -379,9 +378,9 @@ def load_data( sampto: Optional[int] = None, data_format: str = "channel_first", units: str = "mV", - fs: Optional[Real] = None, + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load the ECG data of the record `rec`. Parameters @@ -400,7 +399,7 @@ def load_data( units : str or None, default "mV" Units of the output signal, can also be "μV" (with aliases "uV", "muV"). - fs : numbers.Real, optional + fs : int, optional Frequency of the output signal. if not None, the loaded data will be resampled to this frequency; if None, the loaded data will be returned as is. @@ -411,7 +410,7 @@ def load_data( ------- data : numpy.ndarray The loaded ECG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. @@ -531,7 +530,7 @@ def locate_premature_beats( self, rec: Union[int, str], premature_type: Optional[str] = None, - window: Real = 10, + window: Union[int, float] = 10, sampfrom: Optional[int] = None, sampto: Optional[int] = None, ) -> List[List[int]]: @@ -547,7 +546,7 @@ def locate_premature_beats( premature_type : str, optional Premature beat type, can be one of "SPB", "PVC". If not specified, both SPBs and PVCs will be located. - window : numbers.Real, default 10 + window : int or float, default 10 Window length of each premature beat, with units in seconds. sampfrom : int, optional diff --git a/torch_ecg/databases/cpsc_databases/cpsc2021.py b/torch_ecg/databases/cpsc_databases/cpsc2021.py index 990c4f59..bec78411 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2021.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2021.py @@ -5,7 +5,6 @@ import os import time import warnings -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -476,7 +475,7 @@ def load_rpeaks( sampto: Optional[int] = None, keep_original: bool = False, valid_only: bool = True, - fs: Optional[Real] = None, + fs: Optional[int] = None, ) -> NDArray: """Load position (in terms of samples) of rpeaks. @@ -499,7 +498,7 @@ def load_rpeaks( otherwise, all indices in the `sample` field of the annotation will be returned. Valid rpeaks are those with symbol in `WFDB_Beat_Annotations`. Symbols in `WFDB_Non_Beat_Annotations` are considered as invalid rpeaks - fs : numbers.Real, optional + fs : int, optional If not None, positions of the loaded rpeaks will be ajusted according to this sampling frequency. @@ -539,7 +538,7 @@ def load_rpeak_indices( sampto: Optional[int] = None, keep_original: bool = False, valid_only: bool = True, - fs: Optional[Real] = None, + fs: Optional[int] = None, ) -> NDArray: """alias of `self.load_rpeaks`""" return self.load_rpeaks( @@ -559,7 +558,7 @@ def load_af_episodes( sampfrom: Optional[int] = None, sampto: Optional[int] = None, keep_original: bool = False, - fs: Optional[Real] = None, + fs: Optional[int] = None, fmt: Literal["intervals", "mask", "c_intervals"] = "intervals", ) -> Union[List[List[int]], NDArray]: """Load the episodes of atrial fibrillation, @@ -582,7 +581,7 @@ def load_af_episodes( If True, indices will keep the same with the annotation file, otherwise subtract `sampfrom` if specified. Valid only when `fmt` is not "c_intervals". - fs : numbers.Real, optional + fs : int, optional If not None, positions of the loaded intervals or mask will be ajusted according to this sampling frequency. Otherwise, the sampling frequency of the record will be used. @@ -991,7 +990,7 @@ def plot( plt.subplots_adjust(hspace=0.2) plt.show() - def _round(self, n: Real) -> int: + def _round(self, n: Union[int, float]) -> int: """ dealing with round(0.5) = 0, hence keeping accordance with output length of `resample_poly` diff --git a/torch_ecg/databases/nsrr_databases/shhs.py b/torch_ecg/databases/nsrr_databases/shhs.py index c43d3762..115e697b 100644 --- a/torch_ecg/databases/nsrr_databases/shhs.py +++ b/torch_ecg/databases/nsrr_databases/shhs.py @@ -2279,16 +2279,16 @@ def str_to_real_number(self, s: Union[str, float, int]) -> Union[float, int]: """Convert a string to a real number. Some columns in the annotations might incorrectly - been converted from numbers.Real to string, using ``xmltodict``. + been converted from int or float to string, using ``xmltodict``. Parameters ---------- - s : str or numbers.Real + s : str or int or float The string to be converted. Returns ------- - numbers.Real + int or float The converted number. """ diff --git a/torch_ecg/databases/other_databases/cachet_cadb.py b/torch_ecg/databases/other_databases/cachet_cadb.py index 6855c475..3c15aa61 100644 --- a/torch_ecg/databases/other_databases/cachet_cadb.py +++ b/torch_ecg/databases/other_databases/cachet_cadb.py @@ -5,7 +5,6 @@ import warnings import xml.etree.ElementTree as ET from copy import deepcopy -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -347,10 +346,10 @@ def load_data( sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", - fs: Optional[Real] = None, + units: Union[str, None] = "mV", + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load physical (converted from digital) ECG data, or load digital signal directly. @@ -371,7 +370,7 @@ def load_data( units : str or None, default "mV" Units of the output signal, can also be "μV" (aliases "uV", "muV"); None for digital data, without digital-to-physical conversion. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. @@ -382,7 +381,7 @@ def load_data( ------- data : numpy.ndarray The loaded ECG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. @@ -434,7 +433,7 @@ def load_context_data( sampto: Optional[int] = None, channels: Optional[Union[str, int, List[str], List[int]]] = None, units: Optional[str] = None, - fs: Optional[Real] = None, + fs: Optional[int] = None, ) -> Union[NDArray, pd.DataFrame]: """Load context data (e.g. accelerometer, heart rate, etc.). @@ -457,7 +456,7 @@ def load_context_data( Units of the output signal, currently can only be "default"; None for digital data, without digital-to-physical conversion. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. diff --git a/torch_ecg/databases/other_databases/sph.py b/torch_ecg/databases/other_databases/sph.py index 5aabc0d3..e2786017 100644 --- a/torch_ecg/databases/other_databases/sph.py +++ b/torch_ecg/databases/other_databases/sph.py @@ -196,7 +196,7 @@ def load_data( ------- data : numpy.ndarray The loaded ECG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. diff --git a/torch_ecg/databases/physionet_databases/apnea_ecg.py b/torch_ecg/databases/physionet_databases/apnea_ecg.py index 28677d3e..fbdc7f91 100644 --- a/torch_ecg/databases/physionet_databases/apnea_ecg.py +++ b/torch_ecg/databases/physionet_databases/apnea_ecg.py @@ -2,7 +2,6 @@ import os from datetime import datetime -from numbers import Real from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np @@ -198,10 +197,10 @@ def load_data( sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", - fs: Optional[Real] = None, + units: Union[str, None] = "mV", + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: return super().load_data(rec, leads, sampfrom, sampto, data_format, units, fs, return_fs) @add_docstring(PhysioNetDataBase.load_data.__doc__) @@ -211,10 +210,10 @@ def load_ecg_data( sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", - fs: Optional[Real] = None, + units: Union[str, None] = "mV", + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: if isinstance(rec, int): rec = self[rec] if rec not in self.ecg_records: @@ -242,8 +241,8 @@ def load_rsp_data( sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", - fs: Optional[Real] = None, + units: Union[str, None] = "mV", + fs: Optional[int] = None, return_fs: bool = False, ) -> NDArray: if rec not in self.rsp_records: diff --git a/torch_ecg/databases/physionet_databases/cinc2018.py b/torch_ecg/databases/physionet_databases/cinc2018.py index a7460b89..1e3d6938 100644 --- a/torch_ecg/databases/physionet_databases/cinc2018.py +++ b/torch_ecg/databases/physionet_databases/cinc2018.py @@ -2,7 +2,6 @@ import os from collections import defaultdict -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -305,9 +304,9 @@ def load_psg_data( sampto: Optional[int] = None, data_format: str = "channel_first", physical: bool = True, - fs: Optional[Real] = None, + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load PSG data of the record. Parameters @@ -330,7 +329,7 @@ def load_psg_data( physical : bool, default True If True, the data will be converted to physical units, otherwise, the data will be in digital units. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. @@ -341,7 +340,7 @@ def load_psg_data( ------- data : numpy.ndarray PSG data corr. to the given `channel` of the record. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. """ @@ -403,10 +402,10 @@ def load_data( sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", - fs: Optional[Real] = None, + units: Union[str, None] = "mV", + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load ECG data of the record. Parameters @@ -428,7 +427,7 @@ def load_data( units : str or None, default "mV" Units of the output signal, can also be "μV" (aliases "uV", "muV"). None for digital data, without digital-to-physical conversion. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. @@ -440,7 +439,7 @@ def load_data( data : numpy.ndarray The ECG data loaded from the record, with given `units` and `data_format`. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. @@ -475,10 +474,10 @@ def load_ecg_data( sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", - fs: Optional[Real] = None, + units: Union[str, None] = "mV", + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """alias of `load_data`""" return self.load_data( rec=rec, diff --git a/torch_ecg/databases/physionet_databases/cinc2020.py b/torch_ecg/databases/physionet_databases/cinc2020.py index f296bdbd..4d420dc7 100644 --- a/torch_ecg/databases/physionet_databases/cinc2020.py +++ b/torch_ecg/databases/physionet_databases/cinc2020.py @@ -8,7 +8,6 @@ import time from copy import deepcopy from datetime import datetime -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -500,9 +499,9 @@ def load_data( data_format: Literal["channel_first", "lead_first", "channel_last", "lead_last"] = "channel_first", backend: Literal["wfdb", "scipy"] = "wfdb", units: Literal["mV", "μV", "uV", None] = "mV", - fs: Optional[Real] = None, + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load physical (converted from digital) ECG data, which is more understandable for humans; or load digital signal directly. @@ -522,7 +521,7 @@ def load_data( units : str or None, default "mV" Units of the output signal, can also be "μV" (aliases "uV", "muV"). None for digital data, without digital-to-physical conversion. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. @@ -533,7 +532,7 @@ def load_data( ------- data : numpy.ndarray The ECG data of the record. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. @@ -890,7 +889,7 @@ def get_labels( labels = [self.label_trans_dict.get(item, item) for item in labels] return labels - def get_fs(self, rec: Union[str, int]) -> Real: + def get_fs(self, rec: Union[str, int]) -> int: """Get the sampling frequency of the record. Parameters @@ -900,7 +899,7 @@ def get_fs(self, rec: Union[str, int]) -> Real: Returns ------- - fs : numbers.Real + fs : int Sampling frequency of the record. """ @@ -1533,7 +1532,7 @@ def _compute_f_measure(labels: NDArray, outputs: NDArray) -> float: return macro_f_measure -def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: Real) -> Tuple[float, float]: +def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: float) -> Tuple[float, float]: """Compute F-beta and G-beta measures. Parameters diff --git a/torch_ecg/databases/physionet_databases/cinc2021.py b/torch_ecg/databases/physionet_databases/cinc2021.py index 51a57344..d23b7629 100644 --- a/torch_ecg/databases/physionet_databases/cinc2021.py +++ b/torch_ecg/databases/physionet_databases/cinc2021.py @@ -9,7 +9,6 @@ import warnings from copy import deepcopy from datetime import datetime -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -670,9 +669,9 @@ def load_data( data_format: str = "channel_first", backend: Literal["wfdb", "scipy"] = "wfdb", units: Literal["mV", "μV", "uV", "muV", None] = "mV", - fs: Optional[Real] = None, + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: """Load physical (converted from digital) ECG data. Parameters @@ -690,7 +689,7 @@ def load_data( units : str or None, default "mV" Units of the output signal, can also be "μV" (aliases "uV", "muV"). None for digital data, without digital-to-physical conversion. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. @@ -701,7 +700,7 @@ def load_data( ------- data : numpy.ndarray The loaded ECG data. - data_fs : numbers.Real, optional + data_fs : int, optional Sampling frequency of the output signal. Returned if `return_fs` is True. @@ -1092,7 +1091,7 @@ def get_labels( labels = _labels return labels - def get_fs(self, rec: Union[str, int], from_hea: bool = True) -> Real: + def get_fs(self, rec: Union[str, int], from_hea: bool = True) -> int: """Get the sampling frequency of the record. Parameters @@ -1106,7 +1105,7 @@ def get_fs(self, rec: Union[str, int], from_hea: bool = True) -> Real: Returns ------- - fs : numbers.Real + fs : int Sampling frequency of the record. """ @@ -1974,7 +1973,7 @@ def _compute_f_measure(labels: NDArray, outputs: NDArray) -> Tuple[float, NDArra return macro_f_measure, f_measure -def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: Real) -> Tuple[float, float]: +def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: float) -> Tuple[float, float]: """Compute F-beta and G-beta measures. Parameters diff --git a/torch_ecg/databases/physionet_databases/ltafdb.py b/torch_ecg/databases/physionet_databases/ltafdb.py index 4d458ef7..3b71fb24 100644 --- a/torch_ecg/databases/physionet_databases/ltafdb.py +++ b/torch_ecg/databases/physionet_databases/ltafdb.py @@ -4,7 +4,6 @@ import math import os from copy import deepcopy -from numbers import Real from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union import numpy as np @@ -159,9 +158,9 @@ def load_data( sampto: Optional[int] = None, data_format: str = "channel_first", units: str = "mV", - fs: Optional[Real] = None, + fs: Optional[int] = None, return_fs: bool = False, - ) -> Union[NDArray, Tuple[NDArray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, int]]: return super().load_data(rec, leads, sampfrom, sampto, data_format, units, fs, return_fs) def load_ann( diff --git a/torch_ecg/models/_nets.py b/torch_ecg/models/_nets.py index 66486fe2..240893e6 100644 --- a/torch_ecg/models/_nets.py +++ b/torch_ecg/models/_nets.py @@ -7,7 +7,6 @@ from inspect import isclass from itertools import repeat from math import sqrt -from numbers import Real from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple, Union import numpy as np @@ -270,7 +269,7 @@ def get_normalization(norm: Union[str, nn.Module, None], kw_norm: Optional[dict] ---------- input_len : int, optional The length of the input. - fs : numbers.Real, optional + fs : int, optional The sampling frequency of the input signal. If is not ``None``, then the receptive field is returned in seconds. @@ -369,7 +368,7 @@ def compute_output_shape( return output_shape @add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block")) - def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[Real] = None) -> Union[int, float]: + def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[int] = None) -> Union[int, float]: return 1 @@ -751,7 +750,7 @@ def compute_output_shape( return output_shape @add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block")) - def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[Real] = None) -> Union[int, float]: + def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[int] = None) -> Union[int, float]: return compute_receptive_field( kernel_sizes=self.__kernel_size, strides=self.__stride, @@ -883,7 +882,7 @@ def __init__( len(strides) == self.__num_convs ), f"`subsample_lengths` must be of type int or sequence of int of length {self.__num_convs}" - if isinstance(dropouts, (Real, dict)): + if isinstance(dropouts, (float, dict)): _dropouts = list(repeat(dropouts, self.__num_convs)) else: _dropouts = list(dropouts) # type: ignore @@ -962,7 +961,7 @@ def compute_output_shape( return output_shape # type: ignore @add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block")) - def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[Real] = None) -> Union[int, float]: + def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[int] = None) -> Union[int, float]: kernel_sizes, strides, dilations = [], [], [] for module in self: if hasattr(module, "__name__") and module.__name__ == Conv_Bn_Activation.__name__: @@ -1052,7 +1051,7 @@ def __init__( len(strides) == self.__num_branches ), f"`subsample_lengths` must be of type int or sequence of int of length {self.__num_branches}" - if isinstance(dropouts, (Real, dict)): + if isinstance(dropouts, (float, dict)): _dropouts = list(repeat(dropouts, self.__num_branches)) else: _dropouts = list(dropouts) # type: ignore @@ -1124,14 +1123,14 @@ def compute_output_shape( output_shapes.append(branch_output_shape) return output_shapes - def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[Real] = None) -> Tuple[Union[int, float]]: + def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[int] = None) -> Tuple[Union[int, float]]: """Compute the receptive field of each branch. Parameters ---------- input_len : int, optional Length of the input. - fs : numbers.Real, optional + fs : int, optional The sampling frequency of the input signal. If is not ``None``, then the receptive field is returned in seconds. @@ -1317,7 +1316,7 @@ def compute_output_shape( return output_shape @add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block")) - def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[Real] = None) -> Union[int, float]: + def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[int] = None) -> Union[int, float]: return compute_receptive_field( kernel_sizes=[self.__kernel_size, 1], strides=[self.__stride, 1], @@ -2781,7 +2780,7 @@ def __init__( else: self.__kernel_initializer = None self.__bias = bias - if isinstance(dropouts, Real): + if isinstance(dropouts, float): if self.__num_layers > 1: self.__dropouts = list(repeat(dropouts, self.__num_layers - 1)) + [0.0] else: diff --git a/torch_ecg/models/cnn/mobilenet.py b/torch_ecg/models/cnn/mobilenet.py index f50a6ba8..ae0f3d91 100644 --- a/torch_ecg/models/cnn/mobilenet.py +++ b/torch_ecg/models/cnn/mobilenet.py @@ -12,7 +12,6 @@ import textwrap from copy import deepcopy from itertools import repeat -from numbers import Real from typing import Any, List, Optional, Sequence, Union from deprecate_kwargs import deprecate_kwargs @@ -972,7 +971,7 @@ def __init__( self.__in_channels = in_channels self.__n_blocks = n_blocks self.__expansion = expansion - if isinstance(expansion, Real): + if isinstance(expansion, (float, int)): self.__expansion = list(repeat(expansion, self.n_blocks)) else: self.__expansion = expansion diff --git a/torch_ecg/models/cnn/regnet.py b/torch_ecg/models/cnn/regnet.py index 650989f9..6961a8d5 100644 --- a/torch_ecg/models/cnn/regnet.py +++ b/torch_ecg/models/cnn/regnet.py @@ -7,7 +7,6 @@ import warnings from collections import Counter from itertools import repeat -from numbers import Real from typing import List, Optional, Sequence, Union import torch @@ -376,7 +375,7 @@ def _get_stage_configs(self) -> List[CFG]: f"`config.num_filters` indicates {len(self.__num_filters)} stages, " f"while `config.num_blocks` indicates {len(self.config.num_blocks)}" ) - if isinstance(self.config.dropouts, Real): + if isinstance(self.config.dropouts, (float, int)): self.__dropouts = list(repeat(self.config.dropouts, len(self.config.num_blocks))) else: self.__dropouts = self.config.dropouts @@ -464,7 +463,7 @@ def _get_stage_configs(self) -> List[CFG]: f"while there are {num_stages} computed from " "`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`" ) - if isinstance(self.config.dropouts, Real): + if isinstance(self.config.dropouts, (float, int)): self.__dropouts = list(repeat(self.config.dropouts, num_stages)) else: self.__dropouts = self.config.dropouts diff --git a/torch_ecg/models/cnn/resnet.py b/torch_ecg/models/cnn/resnet.py index b7b72dd7..8c90777e 100644 --- a/torch_ecg/models/cnn/resnet.py +++ b/torch_ecg/models/cnn/resnet.py @@ -6,7 +6,6 @@ import textwrap from copy import deepcopy from itertools import repeat -from numbers import Real from typing import List, Optional, Sequence, Union import torch.nn.functional as F @@ -284,7 +283,7 @@ class ResNetBottleNeck(nn.Module, SizeMixin): For more details, ref. :class:`torch.nn.Conv1d`. dilation : int, default 1 Dilation of the conv layers. - base_width : numbers.Real, default 12*4 + base_width : int or float, default 12*4 Number of filters per group for the neck conv layer. Usually number of filters of the initial conv layer of the whole ResNet model. @@ -325,7 +324,7 @@ def __init__( subsample_length: int, groups: int = 1, dilation: int = 1, - base_width: Real = 12 * 4, + base_width: Union[int, float] = 12 * 4, base_groups: int = 1, base_filter_length: int = 1, attn: Optional[dict] = None, @@ -785,7 +784,7 @@ def __init__(self, in_channels: int, **config) -> None: f"`config.num_filters` indicates {len(self.__num_filters)} macro blocks, " f"while `config.num_blocks` indicates {len(self.config.num_blocks)}" ) - if isinstance(self.config.dropouts, (Real, dict)): + if isinstance(self.config.dropouts, (int, float, dict)): self.__dropouts = list(repeat(self.config.dropouts, len(self.config.num_blocks))) else: self.__dropouts = self.config.dropouts diff --git a/torch_ecg/models/cnn/xception.py b/torch_ecg/models/cnn/xception.py index 85308e6b..b6e0f599 100644 --- a/torch_ecg/models/cnn/xception.py +++ b/torch_ecg/models/cnn/xception.py @@ -7,7 +7,6 @@ import textwrap from copy import deepcopy from itertools import repeat -from numbers import Real from typing import List, Optional, Sequence, Union from torch import Tensor, nn @@ -116,7 +115,7 @@ def __init__( assert self.__num_convs == len(self.__dilations), ( f"the main stream has {self.__num_convs} convolutions, " f"while `dilations` indicates {len(self.__dilations)}" ) - if isinstance(dropouts, (Real, dict)): + if isinstance(dropouts, (int, float, dict)): self.__dropouts = list(repeat(dropouts, self.__num_convs)) else: self.__dropouts = list(dropouts) @@ -307,14 +306,14 @@ def __init__( assert self.__num_blocks == len(self.__dilations), ( f"the entry flow has {self.__num_blocks} blocks, " f"while `dilations` indicates {len(self.__dilations)}" ) - if isinstance(dropouts, (Real, dict)): + if isinstance(dropouts, (int, float, dict)): self.__dropouts = list(repeat(dropouts, self.__num_blocks)) else: self.__dropouts = list(dropouts) assert self.__num_blocks == len(self.__dropouts), ( f"the entry flow has {self.__num_blocks} blocks, " f"while `dropouts` indicates {len(self.__dropouts)}" ) - if isinstance(block_dropouts, (Real, dict)): + if isinstance(block_dropouts, (int, float, dict)): self.__block_dropouts = list(repeat(block_dropouts, self.__num_blocks)) else: self.__block_dropouts = list(block_dropouts) @@ -495,14 +494,14 @@ def __init__( assert self.__num_blocks == len(self.__dilations), ( f"the middle flow has {self.__num_blocks} blocks, " f"while `dilations` indicates {len(self.__dilations)}" ) - if isinstance(dropouts, (Real, dict)): + if isinstance(dropouts, (int, float, dict)): self.__dropouts = list(repeat(dropouts, self.__num_blocks)) else: self.__dropouts = list(dropouts) assert self.__num_blocks == len(self.__dropouts), ( f"the middle flow has {self.__num_blocks} blocks, " f"while `dropouts` indicates {len(self.__dropouts)}" ) - if isinstance(block_dropouts, (Real, dict)): + if isinstance(block_dropouts, (int, float, dict)): self.__block_dropouts = list(repeat(block_dropouts, self.__num_blocks)) else: self.__block_dropouts = list(block_dropouts) @@ -693,14 +692,14 @@ def __init__( assert self.__num_blocks == len(self.__dilations), ( f"the exit flow has {self.__num_blocks} blocks, " f"while `dilations` indicates {len(self.__dilations)}" ) - if isinstance(dropouts, (Real, dict)): + if isinstance(dropouts, (int, float, dict)): self.__dropouts = list(repeat(dropouts, self.__num_blocks)) else: self.__dropouts = list(dropouts) assert self.__num_blocks == len(self.__dropouts), ( f"the exit flow has {self.__num_blocks} blocks, " f"while `dropouts` indicates {len(self.__dropouts)}" ) - if isinstance(block_dropouts, (Real, dict)): + if isinstance(block_dropouts, (int, float, dict)): self.__block_dropouts = list(repeat(block_dropouts, self.__num_blocks + len(final_num_filters))) else: self.__block_dropouts = list(block_dropouts) diff --git a/torch_ecg/models/ecg_fcn.py b/torch_ecg/models/ecg_fcn.py index f68a56f8..41004774 100644 --- a/torch_ecg/models/ecg_fcn.py +++ b/torch_ecg/models/ecg_fcn.py @@ -9,7 +9,6 @@ # from collections import OrderedDict # from copy import deepcopy -# from numbers import Number, Real # from typing import Any, List, Optional, Sequence, Tuple, Union # import numpy as np diff --git a/torch_ecg/models/loss.py b/torch_ecg/models/loss.py index 0c3e5eac..6d0c979f 100644 --- a/torch_ecg/models/loss.py +++ b/torch_ecg/models/loss.py @@ -24,8 +24,7 @@ """ -from numbers import Real -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union import torch import torch.nn.functional as F @@ -439,9 +438,9 @@ class AsymmetricLoss(nn.Module): Parameters ---------- - gamma_neg : numbers.Real, default 4 + gamma_neg : int or float, default 4 Exponent of the multiplier to the negative loss. - gamma_pos : numbers.Real, default 1 + gamma_pos : int or float, default 1 Exponent of the multiplier to the positive loss. prob_margin : float, default 0.05 The probability margin @@ -475,8 +474,8 @@ class AsymmetricLoss(nn.Module): def __init__( self, - gamma_neg: Real = 4, - gamma_pos: Real = 1, + gamma_neg: Union[int, float] = 4, + gamma_pos: Union[int, float] = 1, prob_margin: float = 0.05, disable_torch_grad_focal_loss: bool = False, reduction: Literal["none", "mean", "sum"] = "mean", diff --git a/torch_ecg/preprocessors/README.md b/torch_ecg/preprocessors/README.md index e0a38d73..3a0abd0c 100644 --- a/torch_ecg/preprocessors/README.md +++ b/torch_ecg/preprocessors/README.md @@ -77,7 +77,6 @@ sig = ppm(sig) Here is another example for `numpy` version custom preprocessors ```python -from numbers import Real from typing import Tuple import numpy as np @@ -92,7 +91,7 @@ class DummyPreProcessor(PreProcessor): a dummy preprocessor that does nothing """ __name__ = "DummyPreProcessor" - def apply(self, sig:np.ndarray, fs:Real) -> Tuple[np.ndarray, int]: + def apply(self, sig : np.ndarray, fs : int) -> Tuple[np.ndarray, int]: """ """ return sig, fs diff --git a/torch_ecg/preprocessors/bandpass.py b/torch_ecg/preprocessors/bandpass.py index 2f71a4a9..d7181eeb 100644 --- a/torch_ecg/preprocessors/bandpass.py +++ b/torch_ecg/preprocessors/bandpass.py @@ -1,6 +1,5 @@ """ """ -from numbers import Real from typing import Any, Optional, Union import numpy as np @@ -22,11 +21,11 @@ class BandPass(torch.nn.Module): Parameters ---------- - fs : numbers.Real + fs : int Sampling frequency of the ECG signal to be filtered. - lowcut : numbers.Real, optional + lowcut : float, optional Low cutoff frequency. - highcut : numbers.Real, optional + highcut : float, optional High cutoff frequency. inplace : bool, default True Whether to perform the filtering in-place. @@ -39,9 +38,9 @@ class BandPass(torch.nn.Module): def __init__( self, - fs: Real, - lowcut: Optional[Real] = 0.5, - highcut: Optional[Real] = 45, + fs: int, + lowcut: Optional[float] = 0.5, + highcut: Optional[float] = 45, inplace: bool = True, **kwargs: Any, ) -> None: diff --git a/torch_ecg/preprocessors/baseline_remove.py b/torch_ecg/preprocessors/baseline_remove.py index c5531d81..d33bac04 100644 --- a/torch_ecg/preprocessors/baseline_remove.py +++ b/torch_ecg/preprocessors/baseline_remove.py @@ -1,7 +1,6 @@ """ """ import warnings -from numbers import Real from typing import Any, Union import numpy as np @@ -23,7 +22,7 @@ class BaselineRemove(torch.nn.Module): Parameters ---------- - fs : numbers.Real + fs : int Sampling frequency of the ECG signal to be filtered. window1 : float, default 0.2 The smaller window size, with units in seconds. @@ -38,7 +37,7 @@ class BaselineRemove(torch.nn.Module): __name__ = "BaselineRemove" - def __init__(self, fs: Real, window1: float = 0.2, window2: float = 0.6, inplace: bool = True, **kwargs: Any) -> None: + def __init__(self, fs: int, window1: float = 0.2, window2: float = 0.6, inplace: bool = True, **kwargs: Any) -> None: super().__init__() self.fs = fs self.window1 = window1 diff --git a/torch_ecg/preprocessors/normalize.py b/torch_ecg/preprocessors/normalize.py index dadc4d50..10100eef 100644 --- a/torch_ecg/preprocessors/normalize.py +++ b/torch_ecg/preprocessors/normalize.py @@ -1,6 +1,5 @@ """ """ -from numbers import Real from typing import Any, Iterable, Literal, Union import numpy as np @@ -41,14 +40,14 @@ class Normalize(torch.nn.Module): ---------- method : {"naive", "min-max", "z-score"}, default "z-score", Normalization method, by default "z-score", case-insensitive. - mean : numbers.Real or array_like, default 0.0 + mean : int or float or array_like, default 0.0 If `method` is "z-score", then `mean is the mean value of the normalized signal, or mean values for each lead of the normalized signal. If `method` is "naive", then `mean` is the mean value to be subtracted from the original signal. Useless if `method` is ``"min-max"``. - std : numbers.Real or array_like, default 1.0 + std : int or float or array_like, default 1.0 If `method` is "z-score", then `std` is the standard deviation of the normalized signal, or standard deviations for each lead of the normalized signal. @@ -67,8 +66,8 @@ class Normalize(torch.nn.Module): def __init__( self, method: Literal["naive", "min-max", "z-score"] = "z-score", - mean: Union[Real, Iterable[Real]] = 0.0, - std: Union[Real, Iterable[Real]] = 1.0, + mean: Union[int, float, Iterable[Union[int, float]]] = 0.0, + std: Union[int, float, Iterable[Union[int, float]]] = 1.0, per_channel: bool = False, inplace: bool = True, **kwargs: Any, @@ -163,9 +162,9 @@ class NaiveNormalize(Normalize): Parameters ---------- - mean : numbers.Real or array_like, default 0.0 + mean : int or float or array_like, default 0.0 Value(s) to be subtracted. - std : numbers.Real or array_like, default 1.0 + std : int or float or array_like, default 1.0 Value(s) to be divided. per_channel : bool, default False Whether to perform the normalization per channel. @@ -178,8 +177,8 @@ class NaiveNormalize(Normalize): def __init__( self, - mean: Union[Real, Iterable[Real]] = 0.0, - std: Union[Real, Iterable[Real]] = 1.0, + mean: Union[int, float, Iterable[Union[int, float]]] = 0.0, + std: Union[int, float, Iterable[Union[int, float]]] = 1.0, per_channel: bool = False, inplace: bool = True, **kwargs: Any, @@ -206,10 +205,10 @@ class ZScoreNormalize(Normalize): Parameters ---------- - mean : numbers.Real or array_like, default 0.0 + mean : int or float or array_like, default 0.0 Mean value of the normalized signal, or mean values for each lead of the normalized signal. - std : numbers.Real or array_like, default 1.0 + std : int or float or array_like, default 1.0 Standard deviation of the normalized signal, or standard deviations for each lead of the normalized signal. per_channel : bool, default False @@ -223,8 +222,8 @@ class ZScoreNormalize(Normalize): def __init__( self, - mean: Union[Real, Iterable[Real]] = 0.0, - std: Union[Real, Iterable[Real]] = 1.0, + mean: Union[int, float, Iterable[Union[int, float]]] = 0.0, + std: Union[int, float, Iterable[Union[int, float]]] = 1.0, per_channel: bool = False, inplace: bool = True, **kwargs: Any, diff --git a/torch_ecg/utils/_edr.py b/torch_ecg/utils/_edr.py index a084ccec..b741654a 100644 --- a/torch_ecg/utils/_edr.py +++ b/torch_ecg/utils/_edr.py @@ -4,8 +4,7 @@ A python re-implementation of the `edr` function of physionet edr.c """ -from numbers import Real -from typing import Sequence +from typing import Sequence, Union import numpy as np from numpy.typing import NDArray @@ -19,8 +18,8 @@ def phs_edr( sig: Sequence, fs: int, rpeaks: Sequence, - winL_t: Real = 40, - winR_t: Real = 40, + winL_t: Union[int, float] = 40, + winR_t: Union[int, float] = 40, return_with_time: bool = True, mode: str = "complex", verbose: int = 0, @@ -38,10 +37,10 @@ def phs_edr( sampling frequency of the signal rpeaks: array-like, indices of R peaks in the signal - winL_t: numbers.Real, default 40, + winL_t: int or float, default 40, left length of the window at R peaks for the computation of the area of a QRS complex, with units in milliseconds - winR_t: numbers.Real, default 40, + winR_t: int or float, default 40, right length of the window at R peaks for the computation of the area of a QRS complex, with units in milliseconds return_with_time: bool, default True, @@ -124,6 +123,6 @@ def phs_edr( return ecg_der_rsp -def _getxy(sig: Sequence, von: int, bis: int) -> Real: +def _getxy(sig: Sequence, von: int, bis: int) -> float: """compute the integrand from `von` to `bis` of the signals with baseline removed""" - return (np.array(sig)[von : bis + 1]).sum() + return (np.array(sig)[von : bis + 1]).sum().item() diff --git a/torch_ecg/utils/_preproc.py b/torch_ecg/utils/_preproc.py index 6899fcad..1eeb5459 100644 --- a/torch_ecg/utils/_preproc.py +++ b/torch_ecg/utils/_preproc.py @@ -20,8 +20,7 @@ import multiprocessing as mp from collections import Counter -from numbers import Real -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np @@ -64,10 +63,10 @@ def preprocess_multi_lead_signal( raw_sig: NDArray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", - bl_win: Optional[List[Real]] = None, - band_fs: Optional[List[Real]] = None, + bl_win: Optional[List[Union[int, float]]] = None, + band_fs: Optional[List[Union[int, float]]] = None, rpeak_fn: Optional[str] = None, verbose: int = 0, ) -> Dict[str, NDArray]: @@ -77,33 +76,36 @@ def preprocess_multi_lead_signal( Parameters ---------- - raw_sig: ndarray, - the raw ECG signal, with units in mV - fs: numbers.Real, - sampling frequency of `raw_sig` - sig_fmt: str, default "channel_first", - format of the multi-lead ECG signal, + raw_sig : NDArray + The raw ECG signal, with units in mV. + fs : int + Sampling frequency of `raw_sig`. + sig_fmt : str, default "channel_first" + Format of the multi-lead ECG signal, "channel_last" (alias "lead_last"), or "channel_first" (alias "lead_first", original) - bl_win: list (of 2 numbers.Real), optional, - window (units in second) of baseline removal using `median_filter`, + bl_win : list (of 2 int or float), optional + Window (units in second) of baseline removal using `median_filter`, the first is the shorter one, the second the longer one, a typical pair is [0.2, 0.6], if is None or empty, baseline removal will not be performed - band_fs: list (of 2 numbers.Real), optional, - frequency band of the bandpass filter, + band_fs : list (of 2 int or float), optional + Frequency band of the bandpass filter, a typical pair is [0.5, 45], be careful when detecting paced rhythm, if is None or empty, bandpass filtering will not be performed - rpeak_fn: str, optional, - name of the function detecting rpeaks, + rpeak_fn : str, optional + Name of the function detecting rpeaks, can be one of keys of `QRS_DETECTORS`, case insensitive - verbose: int, default 0, - print verbosity + verbose : int, default 0 + Verbosity level. + If 1, print the rpeak candidates for each lead. + If 2, also print the split indices and corresponding intervals + when merging rpeaks. Returns ------- - retval: dict, + retval : dict with items - "filtered_ecg": the array of the processed ECG signal - "rpeaks": the array of indices of rpeaks; empty if `rpeak_fn` is not given @@ -148,9 +150,9 @@ def preprocess_multi_lead_signal( def preprocess_single_lead_signal( raw_sig: NDArray, - fs: Real, - bl_win: Optional[List[Real]] = None, - band_fs: Optional[List[Real]] = None, + fs: int, + bl_win: Optional[List[Union[int, float]]] = None, + band_fs: Optional[List[Union[int, float]]] = None, rpeak_fn: Optional[str] = None, verbose: int = 0, ) -> Dict[str, NDArray]: @@ -160,29 +162,32 @@ def preprocess_single_lead_signal( Parameters ---------- - raw_sig: ndarray, - the raw ECG signal, with units in mV - fs: numbers.Real, - sampling frequency of `raw_sig` - bl_win: list (of 2 numbers.Real), optional, - window (units in second) of baseline removal using `median_filter`, + raw_sig : NDArray + The raw ECG signal, with units in mV. + fs : int + Sampling frequency of `raw_sig`. + bl_win : list (of 2 int or float), optional + Window (units in second) of baseline removal using `median_filter`, the first is the shorter one, the second the longer one, a typical pair is [0.2, 0.6], if is None or empty, baseline removal will not be performed - band_fs: list (of 2 numbers.Real), optional, - frequency band of the bandpass filter, + band_fs : list (of 2 int or float), optional + Frequency band of the bandpass filter, a typical pair is [0.5, 45], be careful when detecting paced rhythm, if is None or empty, bandpass filtering will not be performed - rpeak_fn: str, optional, - name of the function detecting rpeaks, + rpeak_fn : str, optional + Name of the function detecting rpeaks, can be one of keys of `QRS_DETECTORS`, case insensitive - verbose: int, default 0, - print verbosity + verbose : int, default 0, + Verbosity level. + If 1, print the rpeak candidates. + If 2, also print the split indices and corresponding intervals + when merging rpeaks. Returns ------- - retval: dict, + retval : dict with items - "filtered_ecg": the array of the processed ECG signal - "rpeaks": the array of indices of rpeaks; empty if `rpeak_fn` is not given @@ -227,7 +232,7 @@ def preprocess_single_lead_signal( def rpeaks_detect_multi_leads( sig: NDArray, - fs: Real, + fs: int, sig_fmt: str = "channel_first", rpeak_fn: str = "xqrs", verbose: int = 0, @@ -237,24 +242,27 @@ def rpeaks_detect_multi_leads( Parameters ---------- - sig: ndarray, - the (better be filtered) ECG signal, with units in mV - fs: numbers.Real, - sampling frequency of `sig` - sig_fmt: str, default "channel_first", - format of the multi-lead ECG signal, + sig : NDArray + The (better be filtered) ECG signal, with units in mV. + fs : int + Sampling frequency of `sig`. + sig_fmt : str, default "channel_first" + Format of the multi-lead ECG signal, "channel_last" (alias "lead_last"), or "channel_first" (alias "lead_first", original) - rpeak_fn: str, - name of the function detecting rpeaks, - can be one of keys of `QRS_DETECTORS`, case insensitive - verbose: int, default 0, - print verbosity + rpeak_fn : str + Name of the function detecting rpeaks, + can be one of keys of `QRS_DETECTORS`, case insensitive. + verbose : int, default 0 + Verbosity level. + If 1, print the rpeak candidates. + If 2, also print the split indices and corresponding intervals + when merging rpeaks. Returns ------- - rpeaks: NDArray, - array of indices of the detected rpeaks of the multi-lead ECG signal + rpeaks : NDArray + Array of indices of the detected rpeaks of the multi-lead ECG signal. """ assert sig_fmt.lower() in [ @@ -274,19 +282,19 @@ def rpeaks_detect_multi_leads( return rpeaks -def merge_rpeaks(rpeaks_candidates: List[NDArray], sig: NDArray, fs: Real, verbose: int = 0) -> NDArray: +def merge_rpeaks(rpeaks_candidates: List[NDArray], sig: NDArray, fs: int, verbose: int = 0) -> NDArray: """ merge rpeaks that are detected from each lead of multi-lead signals (with units in mV), using certain criterion merging qrs masks from each lead Parameters ---------- - rpeaks_candidates: list of ndarray, - each element (ndarray) is the array of indices of rpeaks of corr. lead - sig: ndarray, - the multi-lead ECG signal, with units in mV - fs: numbers.Real, - sampling frequency of `sig` + rpeaks_candidates : list of ndarray + Each element (ndarray) is the array of indices of rpeaks of the corresponding lead. + sig : NDArray + The multi-lead ECG signal, with units in mV. + fs : int + Sampling frequency of `sig`. verbose: int, default 0, print verbosity diff --git a/torch_ecg/utils/utils_interval.py b/torch_ecg/utils/utils_interval.py index 5eb48aa8..9a5d8bc5 100644 --- a/torch_ecg/utils/utils_interval.py +++ b/torch_ecg/utils/utils_interval.py @@ -16,7 +16,6 @@ import time import warnings from copy import deepcopy -from numbers import Real from typing import Any, List, Literal, Sequence, Tuple, Union import numpy as np @@ -43,7 +42,7 @@ EMPTY_SET = [] -Interval = Union[Sequence[Real], type(EMPTY_SET)] +Interval = Union[Sequence[Union[int, float]], type(EMPTY_SET)] GeneralizedInterval = Union[Sequence[Interval], type(EMPTY_SET)] @@ -137,12 +136,12 @@ def validate_interval( return False, [] -def in_interval(val: Real, interval: Interval, left_closed: bool = True, right_closed: bool = False) -> bool: +def in_interval(val: Union[int, float], interval: Interval, left_closed: bool = True, right_closed: bool = False) -> bool: """Check whether val is inside interval or not. Parameters ---------- - val : numbers.Real + val : int or float The value to be checked. interval : Interval The interval to be checked. @@ -185,7 +184,7 @@ def in_interval(val: Real, interval: Interval, left_closed: bool = True, right_c def in_generalized_interval( - val: Real, + val: Union[int, float], generalized_interval: GeneralizedInterval, left_closed: bool = True, right_closed: bool = False, @@ -194,7 +193,7 @@ def in_generalized_interval( Parameters ---------- - val : numbers.Real + val : int or float The value to be checked whether it is inside `generalized_interval` or not. generalized_interval : GeneralizedInterval @@ -485,10 +484,10 @@ def generalized_interval_complement(total_interval: Interval, generalized_interv def get_optimal_covering( total_interval: Interval, - to_cover: List[Union[Real, Interval]], - min_len: Real, - split_threshold: Real, - isolated_point_dist_threshold: Real = 0, + to_cover: List[Union[int, float, Interval]], + min_len: Union[int, float], + split_threshold: Union[int, float], + isolated_point_dist_threshold: Union[int, float] = 0, traceback: bool = False, **kwargs: Any, ) -> Union[GeneralizedInterval, Tuple[GeneralizedInterval, list]]: @@ -506,11 +505,11 @@ def get_optimal_covering( The total interval that the covering is picked from. to_cover : list A list of intervals or points to cover. - min_len : numbers.Real + min_len : int or float Minimun length (positive) of the intervals of the covering. - split_threshold : numbers.Real + split_threshold : int or float Minumun distance (positive) of intervals of the covering. - isolated_point_dist_threshold : numbers.Real, default 0.0 + isolated_point_dist_threshold : int or float, default 0.0 The minimum distance (non-negative) of isolated points in `to_cover` to the interval boundaries of the interval containing the point in the covering. @@ -582,8 +581,8 @@ def get_optimal_covering( tmp = sorted(total_interval) tot_start, tot_end = tmp[0], tmp[-1] - if (tot_start > min([item if isinstance(item, Real) else item[0] for item in to_cover])) or ( - tot_end < max([item if isinstance(item, Real) else item[-1] for item in to_cover]) + if (tot_start > min([item if isinstance(item, (int, float)) else item[0] for item in to_cover])) or ( + tot_end < max([item if isinstance(item, (int, float)) else item[-1] for item in to_cover]) ): raise ValueError("some of the elements in `to_cover` exceeds the range of `total_interval`") @@ -751,7 +750,7 @@ def get_optimal_covering( return covering -def find_max_cont_len(sublist: Interval, tot_rng: Real) -> dict: +def find_max_cont_len(sublist: Interval, tot_rng: Union[int, float]) -> dict: """Compute the maximum length of continuous (consecutive) sublists. This function computes the maximum length of continuous (consecutive) @@ -763,7 +762,7 @@ def find_max_cont_len(sublist: Interval, tot_rng: Real) -> dict: ---------- sublist : Interval The sublist. - tot_rng : numbers.Real + tot_rng : int or float The total range. Returns @@ -796,7 +795,7 @@ def find_max_cont_len(sublist: Interval, tot_rng: Real) -> dict: return ret -def interval_len(interval: Interval) -> Real: +def interval_len(interval: Interval) -> Union[int, float]: """Compute the length of an interval. Parameters @@ -806,7 +805,7 @@ def interval_len(interval: Interval) -> Real: Returns ------- - numbers.Real + int or float The "length" of `interval`, 0 for the empty interval []. Examples @@ -826,7 +825,7 @@ def interval_len(interval: Interval) -> Real: return itv_len -def generalized_interval_len(generalized_interval: GeneralizedInterval) -> Real: +def generalized_interval_len(generalized_interval: GeneralizedInterval) -> Union[int, float]: """Compute the length of an interval. Parameters @@ -836,7 +835,7 @@ def generalized_interval_len(generalized_interval: GeneralizedInterval) -> Real: Returns ------- - numbers.Real + int or float The "length" of `generalized_interval`, 0 for the empty interval []. diff --git a/torch_ecg/utils/utils_metrics.py b/torch_ecg/utils/utils_metrics.py index e7bed79c..ac86a0e7 100644 --- a/torch_ecg/utils/utils_metrics.py +++ b/torch_ecg/utils/utils_metrics.py @@ -6,7 +6,6 @@ """ import warnings -from numbers import Real from typing import Dict, Optional, Sequence, Tuple, Union import einops @@ -610,7 +609,7 @@ def accuracy( def QRS_score( rpeaks_truths: Sequence[Union[NDArray, Sequence[int]]], rpeaks_preds: Sequence[Union[NDArray, Sequence[int]]], - fs: Real, + fs: int, thr: float = 0.075, ) -> float: """ @@ -624,7 +623,7 @@ def QRS_score( rpeaks_preds : array_like predictions of ground truths of rpeaks locations (indices) for multiple records. - fs : numbers.Real + fs : int Sampling frequency of ECG signal thr : float, default 0.075 Threshold for a prediction to be truth positive, @@ -771,9 +770,9 @@ def compute_wave_delineation_metrics( truth_masks: Sequence[NDArray], pred_masks: Sequence[NDArray], class_map: Dict[str, int], - fs: Real, + fs: int, mask_format: str = "channel_first", - tol: Real = 0.15, + tol: float = 0.15, ) -> Dict[str, Dict[str, float]]: f""" Compute metrics for the task of ECG wave delineation @@ -794,7 +793,7 @@ def compute_wave_delineation_metrics( class_map : dict Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain {", ".join([f'"{item}"' for item in ECGWaveFormNames])}. - fs : numbers.Real + fs : int Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, and thus the error and standard deviations of errors. @@ -836,8 +835,8 @@ def compute_wave_delineation_metrics( def compute_metrics_waveform( truth_waveforms: Sequence[Sequence[ECGWaveForm]], pred_waveforms: Sequence[Sequence[ECGWaveForm]], - fs: Real, - tol: Real = 0.15, + fs: int, + tol: float = 0.15, ) -> Dict[str, Dict[str, float]]: """ compute the sensitivity, precision, f1_score, mean error @@ -852,7 +851,7 @@ def compute_metrics_waveform( pred_waveforms : Sequence[Sequence[ECGWaveForm]] The predictions corresponding to `truth_waveforms`, each element is a sequence of :class:`ECGWaveForm` from the same sample. - fs : numbers.Real + fs : int Sampling frequency of the signal corresponding to the waveforms, used to compute the duration of each waveform, and thus the error and standard deviations of errors. @@ -911,8 +910,8 @@ def compute_metrics_waveform( def _compute_metrics_waveform( truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], - fs: Real, - tol: Real = 0.15, + fs: int, + tol: float = 0.15, ) -> Dict[str, Dict[str, float]]: """ Compute the sensitivity, precision, f1_score, mean error @@ -925,7 +924,7 @@ def _compute_metrics_waveform( The ground truth preds : Sequence[ECGWaveForm] The predictions corresponding to `truths`. - fs : numbers.Real + fs : int Sampling frequency of the signal corresponding to the waveforms, used to compute the duration of each waveform, and thus the error and standard deviations of errors. @@ -994,16 +993,16 @@ def _compute_metrics_waveform( return scorings -def _compute_metrics_base(truths: Sequence[Real], preds: Sequence[Real], fs: Real, tol: Real = 0.15) -> Dict[str, float]: +def _compute_metrics_base(truths: Sequence[int], preds: Sequence[int], fs: int, tol: float = 0.15) -> Dict[str, float]: r"""Base function for computing the metrics of the onset and offset of a waveform. Parameters ---------- - truths : Sequence[numbers.Real] + truths : Sequence[int] Ground truth of indices of corresponding critical points. - preds : Sequence[numbers.Real] + preds : Sequence[int] Predicted indices of corresponding critical points. - fs : numbers.Real + fs : int Sampling frequency of the signal corresponding to the critical points, used to compute the duration of each waveform, and thus the error and standard deviations of errors. @@ -1060,4 +1059,4 @@ def _compute_metrics_base(truths: Sequence[Real], preds: Sequence[Real], fs: Rea f1_score, mean_error, standard_deviation, - ) + ) # type: ignore diff --git a/torch_ecg/utils/utils_nn.py b/torch_ecg/utils/utils_nn.py index d01ec960..4c50e507 100644 --- a/torch_ecg/utils/utils_nn.py +++ b/torch_ecg/utils/utils_nn.py @@ -11,7 +11,6 @@ from copy import deepcopy from itertools import chain, repeat from math import floor -from numbers import Real from pathlib import Path, PosixPath, WindowsPath from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -823,7 +822,7 @@ def compute_receptive_field( The sequence of strides for all the layers in the flow input_len : int, optional Length of the first feature map in the flow. - fs : numbers.Real, optional + fs : int, optional Sampling frequency of the input signal. If is not ``None``, then the receptive field is returned in seconds. @@ -980,7 +979,7 @@ def _adjust_cnn_filter_lengths( _adjust_cnn_filter_lengths({"filter_length": fl, "fs": config["fs"]}, fs, ensure_odd)["filter_length"] for fl in v ] - elif isinstance(v, Real): + elif isinstance(v, (int, float)): # DO NOT use `int`, which might not work for numpy array elements if v > 1: # type: ignore config[k] = int(round(v * fs / config["fs"])) diff --git a/torch_ecg/utils/utils_signal_t.py b/torch_ecg/utils/utils_signal_t.py index a8b4ff85..545ab8c4 100644 --- a/torch_ecg/utils/utils_signal_t.py +++ b/torch_ecg/utils/utils_signal_t.py @@ -4,7 +4,6 @@ """ import warnings -from numbers import Real from typing import Callable, Iterable, Literal, Optional, Union import torch @@ -20,8 +19,8 @@ def normalize( sig: torch.Tensor, method: Literal["z-score", "naive", "min-max"] = "z-score", - mean: Union[Real, Iterable[Real]] = 0.0, - std: Union[Real, Iterable[Real]] = 1.0, + mean: Union[int, float, Iterable[Union[int, float]]] = 0.0, + std: Union[int, float, Iterable[Union[int, float]]] = 1.0, per_channel: bool = False, inplace: bool = True, ) -> torch.Tensor: @@ -47,12 +46,12 @@ def normalize( Signal to be normalized, assumed to have shape ``(..., n_leads, siglen)``. method : {"z-score", "min-max", "naive"}, default "z-score" Normalization method, case insensitive. - mean : numbers.Real or array_like, default 0.0 + mean : int or float or array_like, default 0.0 Mean value of the normalized signal, or mean values for each lead of the normalized signal, if `method` is "z-score"; mean values to be subtracted from the original signal, if `method` is "naive". Useless if `method` is "min-max". - std : numbers.Real or array_like, default 1.0 + std : int or float or array_like, default 1.0 Standard deviation of the normalized signal, or standard deviations for each lead of the normalized signal, if `method` is "z-score"; or std to be divided from the original signal, if `method` is "naive". @@ -114,7 +113,7 @@ def normalize( bs = sig.shape[0] device = sig.device dtype = sig.dtype - if isinstance(std, Real): + if isinstance(std, (int, float)): assert std > 0, "standard deviation should be positive" _std = torch.full((sig.shape[0], 1, 1), std, dtype=dtype, device=device) else: @@ -128,7 +127,7 @@ def normalize( _std = _std.view((-1, sig.shape[1], 1)) else: raise ValueError(f"shape of `sig` = {sig.shape} and `std` = {_std.shape} mismatch") - if isinstance(mean, Real): + if isinstance(mean, (int, float)): _mean = torch.full((sig.shape[0], 1, 1), mean, dtype=dtype, device=device) else: _mean = torch.as_tensor(mean, dtype=dtype, device=device) @@ -454,9 +453,9 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: def bandpass_filter( sig: torch.Tensor, - fs: Real, - lowcut: Optional[Real] = None, - highcut: Optional[Real] = None, + fs: Union[int, float], + lowcut: Optional[Union[int, float]] = None, + highcut: Optional[Union[int, float]] = None, ) -> torch.Tensor: """Zero-phase bandpass filter using FFT. @@ -464,11 +463,11 @@ def bandpass_filter( ---------- sig : torch.Tensor Signal to be filtered, assumed to have shape ``(..., n_leads, siglen)``. - fs : numbers.Real + fs : int or float Sampling frequency of the signal. - lowcut : numbers.Real, optional + lowcut : int or float, optional Low cutoff frequency. - highcut : numbers.Real, optional + highcut : int or float, optional High cutoff frequency. Returns @@ -477,7 +476,39 @@ def bandpass_filter( The filtered signal. """ - if lowcut is None and highcut is None: + if not isinstance(fs, (int, float)) or fs <= 0: + raise ValueError(f"fs must be a positive real number, got {fs!r}") + nyquist = fs / 2.0 + # normalize cutoff frequencies, clipping/disabling values at invalid boundaries + effective_lowcut = lowcut + effective_highcut = highcut + if effective_lowcut is not None: + if effective_lowcut <= 0: + warnings.warn( + "lowcut <= 0 in bandpass_filter; disabling high-pass side.", + RuntimeWarning, + ) + effective_lowcut = None + elif effective_lowcut >= nyquist: + raise ValueError( + f"lowcut must be less than Nyquist frequency (fs/2={nyquist}), " f"got lowcut={effective_lowcut!r}" + ) + if effective_highcut is not None: + if effective_highcut >= nyquist: + warnings.warn( + "highcut >= Nyquist frequency in bandpass_filter; disabling low-pass side.", + RuntimeWarning, + ) + effective_highcut = None + elif effective_highcut <= 0: + raise ValueError(f"highcut must be positive, got highcut={effective_highcut!r}") + if effective_lowcut is not None and effective_highcut is not None: + if effective_lowcut >= effective_highcut: + raise ValueError( + f"lowcut must be less than highcut for bandpass_filter, " + f"got lowcut={effective_lowcut!r}, highcut={effective_highcut!r}" + ) + if effective_lowcut is None and effective_highcut is None: return sig n = sig.shape[-1] @@ -485,10 +516,10 @@ def bandpass_filter( freqs = torch.fft.rfftfreq(n, d=1 / fs, device=sig.device) # mask for bandpass mask = torch.ones_like(freqs, dtype=sig.dtype) - if lowcut is not None: - mask[freqs < lowcut] = 0 - if highcut is not None: - mask[freqs > highcut] = 0 + if effective_lowcut is not None: + mask[freqs < effective_lowcut] = 0 + if effective_highcut is not None: + mask[freqs > effective_highcut] = 0 # perform FFT sig_fft = torch.fft.rfft(sig, dim=-1) @@ -502,9 +533,9 @@ def bandpass_filter( def baseline_removal( sig: torch.Tensor, - fs: Real, - window1: Optional[Real] = 0.2, - window2: Optional[Real] = 0.6, + fs: Union[int, float], + window1: Optional[Union[int, float]] = 0.2, + window2: Optional[Union[int, float]] = 0.6, ) -> torch.Tensor: """Remove baseline wander using sliding average (median filter alternative). @@ -515,11 +546,11 @@ def baseline_removal( ---------- sig : torch.Tensor Signal to be processed, assumed to have shape ``(..., n_leads, siglen)``. - fs : numbers.Real + fs : int or float Sampling frequency of the signal. - window1 : numbers.Real, default 0.2 + window1 : int or float, default 0.2 The first window size in seconds. - window2 : numbers.Real, default 0.6 + window2 : int or float, default 0.6 The second window size in seconds. Returns @@ -533,8 +564,8 @@ def baseline_removal( ori_shape = sig.shape # Ensure 3D (batch, leads, siglen) - sig = sig.view(-1, ori_shape[-2], ori_shape[-1]) - siglen = ori_shape[-1] + sig = sig.reshape(-1, ori_shape[-2], ori_shape[-1]) + siglen = sig.shape[-1] baseline = sig for window in [window1, window2]: @@ -557,4 +588,4 @@ def baseline_removal( stride=1, ) - return (sig - baseline).view(ori_shape) + return (sig - baseline).reshape(ori_shape) From c2748912674790acb93ec6331c0f4538f1941a59 Mon Sep 17 00:00:00 2001 From: WEN Hao Date: Sun, 31 May 2026 23:14:06 +0800 Subject: [PATCH 2/5] fix(ckpt): correct safetensors path for decimal-suffix stems; return actual save path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug A (utils_nn.py CkptMixin.save): Path("…metric_0.91").with_suffix(".safetensors") treated ".91" as the existing extension and replaced it, silently producing "…metric_0.safetensors". Fix: after determining use_safetensors=True, if path.suffix is not already ".safetensors" we append rather than replace, giving the correct "…metric_0.91.safetensors". The redundant path.with_suffix(".safetensors") inside the single-file branch is removed (path is already normalised). save() now returns the final Path instead of None so callers know exactly where the file was written. Bug B (components/trainer.py BaseTrainer): saved_models stored the raw stem path while the actual file on disk carried a ".safetensors" suffix, so every os.remove() in keep_checkpoint_max cleanup raised FileNotFoundError. The trainer now stores the Path returned by save_checkpoint(), which in turn forwards the value returned by model.save(). Directory-style (non-single-file) checkpoints are now removed with shutil.rmtree instead of os.remove. shutil is promoted to a top-level import. Also adds test_ckpt_decimal_suffix_path covering all three code paths: single-file safetensors, directory safetensors, and pth/torch.save fallback. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.rst | 11 +++++++ test/test_utils/test_utils_nn.py | 54 ++++++++++++++++++++++++++++++++ torch_ecg/components/trainer.py | 22 ++++++++++--- torch_ecg/utils/utils_nn.py | 19 ++++++++--- 4 files changed, 96 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3afe47b7..d0287490 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -72,6 +72,17 @@ Fixed - Fix bugs in utility function `torch_ecg.utils.make_serializable`: the previous implementation does not drop some types of unserializable items correctly. Two additional parameters `drop_unserializable` and `drop_paths` are added. +- Fix ``CkptMixin.save()`` silently truncating checkpoint filenames that contain + decimal values (e.g. ``…metric_0.91``): ``pathlib.Path.with_suffix(".safetensors")`` + treated ``.91`` as the existing suffix and replaced it, producing + ``…metric_0.safetensors`` instead of the correct ``…metric_0.91.safetensors``. + The method now appends ``.safetensors`` for paths without a recognised + extension, and returns the final ``Path`` used so callers can track it. +- Fix ``BaseTrainer`` checkpoint cleanup permanently failing: ``saved_models`` + stored the raw stem path while the actual file on disk had a ``.safetensors`` + suffix, causing every ``os.remove()`` call to raise ``FileNotFoundError``. + The trainer now stores the path returned by ``save_checkpoint()``, and handles + both single-file (``os.remove``) and directory (``shutil.rmtree``) checkpoints. Security ~~~~~~~~ diff --git a/test/test_utils/test_utils_nn.py b/test/test_utils/test_utils_nn.py index 7ed005dc..52a62626 100644 --- a/test/test_utils/test_utils_nn.py +++ b/test/test_utils/test_utils_nn.py @@ -584,6 +584,60 @@ def test_mixin_classes(): assert inp is out +def test_ckpt_decimal_suffix_path(): + """Regression test for Bug A: CkptMixin.save() with a decimal-valued suffix. + + Paths like ``…metric_0.91`` have ``.91`` as their pathlib suffix. + The old code did ``path.with_suffix(".safetensors")`` which silently + truncated the name to ``…metric_0.safetensors``. The fix appends the + extension instead of replacing it, so the actual file is + ``…metric_0.91.safetensors``. + """ + import shutil + + tmp = Path(__file__).resolve().parents[1] / "tmp" + tmp.mkdir(exist_ok=True) + + model = Model1D(12, CFG(out_channels=128)) + + # --- single-file safetensors (trainer default) --- + # Simulate the path that BaseTrainer generates for a checkpoint with a + # decimal metric value, e.g. "…_metric_0.91" + stem = tmp / "checkpoint_epoch1_metric_0.91" + expected = Path(str(stem) + ".safetensors") # correct: append + wrong = stem.with_suffix(".safetensors") # old bug: "…metric_0.safetensors" + + actual = model.save(stem, CFG(n_leads=12)) + + assert actual == expected, f"Expected {expected}, got {actual}" + assert expected.is_file(), "Correct .safetensors file was not created" + assert not wrong.is_file(), "Buggy truncated path should NOT exist" + + # Verify the saved file can be loaded back + loaded, _ = Model1D.from_checkpoint(expected) + assert repr(model) == repr(loaded) + expected.unlink() + + # --- non-single-file safetensors (directory) --- + # The returned path should be the directory (stem without .safetensors) + stem_dir = tmp / "checkpoint_epoch1_metric_0.91" + actual_dir = model.save(stem_dir, CFG(n_leads=12), safetensors_single_file=False) + # After normalization: "…0.91.safetensors", then with_suffix("") → "…0.91" (dir) + expected_dir = Path(str(stem_dir)) + assert actual_dir == expected_dir, f"Expected dir {expected_dir}, got {actual_dir}" + assert expected_dir.is_dir() + loaded, _ = Model1D.from_checkpoint(expected_dir) + assert repr(model) == repr(loaded) + shutil.rmtree(expected_dir) + + # --- pth fallback (torch.save) --- + pth_path = tmp / "checkpoint_epoch1.pth" + actual_pth = model.save(pth_path, CFG(n_leads=12), use_safetensors=False) + assert actual_pth == pth_path + assert pth_path.is_file() + pth_path.unlink() + + def test_make_safe_globals(): # CFG and dict cfg = CFG(a=1, b=None, c={"d": 2}) diff --git a/torch_ecg/components/trainer.py b/torch_ecg/components/trainer.py index f98d8ae0..a1ace131 100644 --- a/torch_ecg/components/trainer.py +++ b/torch_ecg/components/trainer.py @@ -5,6 +5,7 @@ import logging import os +import shutil import textwrap from abc import ABC, abstractmethod from collections import OrderedDict, deque @@ -250,13 +251,16 @@ def train(self) -> OrderedDict: save_folder = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}" save_path = self.train_config.checkpoints / save_folder # type: ignore if self.train_config.keep_checkpoint_max != 0: # type: ignore - self.save_checkpoint(str(save_path)) - self.saved_models.append(save_path) + actual_save_path = self.save_checkpoint(str(save_path)) + self.saved_models.append(actual_save_path if actual_save_path is not None else save_path) # remove outdated models if len(self.saved_models) > self.train_config.keep_checkpoint_max > 0: # type: ignore model_to_remove = self.saved_models.popleft() try: - os.remove(model_to_remove) + if model_to_remove.is_dir(): + shutil.rmtree(model_to_remove) + else: + os.remove(model_to_remove) except Exception: self.log_manager.log_message(f"failed to remove {str(model_to_remove)}") # type: ignore @@ -764,7 +768,7 @@ def resume_from_checkpoint(self, checkpoint: Union[str, dict]) -> None: self._setup_from_config(ckpt["train_config"]) # TODO: resume optimizer, etc. - def save_checkpoint(self, path: str) -> None: + def save_checkpoint(self, path: str) -> Optional[Path]: """Save the current state of the trainer to a checkpoint. Parameters @@ -772,10 +776,17 @@ def save_checkpoint(self, path: str) -> None: path : str Path to save the checkpoint + Returns + ------- + Path, optional + The actual path the checkpoint was saved to (suffix may differ + from ``path`` after normalisation, e.g. ``.safetensors``). + Returns ``None`` when the model does not implement ``save()``. + """ # if self._model has method `save`, then use it if hasattr(self._model, "save"): - self._model.save( + return self._model.save( path=path, train_config=self.train_config, extra_items={ @@ -797,6 +808,7 @@ def save_checkpoint(self, path: str) -> None: }, path, ) + return Path(path) def extra_repr_keys(self) -> List[str]: return [ diff --git a/torch_ecg/utils/utils_nn.py b/torch_ecg/utils/utils_nn.py index 4c50e507..9272dca5 100644 --- a/torch_ecg/utils/utils_nn.py +++ b/torch_ecg/utils/utils_nn.py @@ -1338,7 +1338,7 @@ def save( extra_items: Optional[dict] = None, use_safetensors: Optional[bool] = None, safetensors_single_file: bool = True, - ) -> None: + ) -> Path: """Save the model to disk. .. note:: @@ -1375,7 +1375,9 @@ def save( Returns ------- - None + Path + The actual path the model was saved to (may differ from the input + ``path`` when the suffix is normalised, e.g. ``.pth`` → ``.safetensors``). """ if isinstance(path, bytes): @@ -1401,6 +1403,12 @@ def save( elif path.suffix == ".safetensors": use_safetensors = True + if use_safetensors and path.suffix != ".safetensors": + # path has no recognised extension (or a non-standard one such as ".91" + # from a decimal metric value like "…metric_0.91"). path.with_suffix() + # would silently truncate the numeric part, so we append instead. + path = Path(str(path) + ".safetensors") + if use_safetensors and safetensors_single_file: tensors = dict(self.state_dict()) # type: ignore @@ -1422,8 +1430,8 @@ def save( continue meta[f"{_SFT_META_EXTRA_JSON_PREFIX}{key}"] = json.dumps(make_serializable(val), ensure_ascii=False) - save_file(tensors, path.with_suffix(".safetensors"), metadata=meta) - return + save_file(tensors, path, metadata=meta) + return path if use_safetensors: # not single file # save the model with safetensors with the same name as `path` @@ -1442,7 +1450,7 @@ def save( save_file(val, path / f"{key}.safetensors") else: (path / f"{key}.json").write_text(json.dumps(make_serializable(val), ensure_ascii=False)) - return + return path torch.save( { @@ -1453,3 +1461,4 @@ def save( }, path, ) + return path From e1c52c93db72a10c3b6309f606a9e911e6f8b780 Mon Sep 17 00:00:00 2001 From: Jingsu Kang <40791785+kjs11@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:05:09 +0800 Subject: [PATCH 3/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- torch_ecg/models/_nets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_ecg/models/_nets.py b/torch_ecg/models/_nets.py index 240893e6..5d042286 100644 --- a/torch_ecg/models/_nets.py +++ b/torch_ecg/models/_nets.py @@ -882,7 +882,7 @@ def __init__( len(strides) == self.__num_convs ), f"`subsample_lengths` must be of type int or sequence of int of length {self.__num_convs}" - if isinstance(dropouts, (float, dict)): + if isinstance(dropouts, (int, float, dict)): _dropouts = list(repeat(dropouts, self.__num_convs)) else: _dropouts = list(dropouts) # type: ignore From e80b228c3dfb48b3dc8807fe3bbdd88be7521b7a Mon Sep 17 00:00:00 2001 From: Jingsu Kang <40791785+kjs11@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:06:35 +0800 Subject: [PATCH 4/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- torch_ecg/models/_nets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_ecg/models/_nets.py b/torch_ecg/models/_nets.py index 5d042286..afd24223 100644 --- a/torch_ecg/models/_nets.py +++ b/torch_ecg/models/_nets.py @@ -2780,7 +2780,7 @@ def __init__( else: self.__kernel_initializer = None self.__bias = bias - if isinstance(dropouts, float): + if isinstance(dropouts, (int, float)): if self.__num_layers > 1: self.__dropouts = list(repeat(dropouts, self.__num_layers - 1)) + [0.0] else: From 917245f543b0a6a6d77eca655d846d48c471e338 Mon Sep 17 00:00:00 2001 From: Jingsu Kang <40791785+kjs11@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:06:57 +0800 Subject: [PATCH 5/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- torch_ecg/models/_nets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_ecg/models/_nets.py b/torch_ecg/models/_nets.py index afd24223..c916ba19 100644 --- a/torch_ecg/models/_nets.py +++ b/torch_ecg/models/_nets.py @@ -1051,7 +1051,7 @@ def __init__( len(strides) == self.__num_branches ), f"`subsample_lengths` must be of type int or sequence of int of length {self.__num_branches}" - if isinstance(dropouts, (float, dict)): + if isinstance(dropouts, (int, float, dict)): _dropouts = list(repeat(dropouts, self.__num_branches)) else: _dropouts = list(dropouts) # type: ignore