fix(ckpt): correct safetensors path for decimal-suffix stems; return actual save path#44
Conversation
…actual save path
Bug A (utils_nn.py CkptMixin.save):
Path("…metric_0.91").with_suffix(".safetensors") treated ".91" as the
existing extension and replaced it, silently producing "…metric_0.safetensors".
Fix: after determining use_safetensors=True, if path.suffix is not already
".safetensors" we append rather than replace, giving the correct
"…metric_0.91.safetensors". The redundant path.with_suffix(".safetensors")
inside the single-file branch is removed (path is already normalised).
save() now returns the final Path instead of None so callers know exactly
where the file was written.
Bug B (components/trainer.py BaseTrainer):
saved_models stored the raw stem path while the actual file on disk carried a
".safetensors" suffix, so every os.remove() in keep_checkpoint_max cleanup
raised FileNotFoundError. The trainer now stores the Path returned by
save_checkpoint(), which in turn forwards the value returned by model.save().
Directory-style (non-single-file) checkpoints are now removed with
shutil.rmtree instead of os.remove. shutil is promoted to a top-level import.
Also adds test_ckpt_decimal_suffix_path covering all three code paths:
single-file safetensors, directory safetensors, and pth/torch.save fallback.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
| sig = test_sig.clone() | ||
| sig = bp(sig) | ||
| with pytest.warns(RuntimeWarning, match="lowcut <= 0"): | ||
| sig = bp(sig) |
There was a problem hiding this comment.
Pull request overview
The PR is described as a focused fix for two checkpoint-related bugs (decimal-suffix paths being truncated by pathlib.with_suffix, and BaseTrainer silently failing to clean up checkpoints because it stored the un-normalised path). In practice the diff is much broader: it also performs a sweeping migration from numbers.Real to int / float / Union[int, float] across ~40 files, adds input validation (with new warnings/errors) to bandpass_filter, and makes a few small ancillary changes (view → reshape in baseline_removal, Tuple[Union[type(None), int], ...] → Tuple[Union[None, int], ...], etc.).
Changes:
- Fix
CkptMixin.save()decimal-suffix truncation; makesave()andsave_checkpoint()return the actualPath; clean up directory-style checkpoints withshutil.rmtree. - Replace
numbers.Realannotations/isinstance checks withint/float/Union[int, float]throughout the codebase. - Add validation and warnings to
bandpass_filterand update related tests.
Reviewed changes
Copilot reviewed 62 out of 62 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| torch_ecg/utils/utils_nn.py | Append .safetensors instead of with_suffix; return saved Path; drop Real. |
| torch_ecg/components/trainer.py | Use returned Path from save(); shutil.rmtree for dir checkpoints; promote shutil import. |
| torch_ecg/utils/utils_signal_t.py | New input validation in bandpass_filter; view→reshape in baseline_removal; Real→int/float. |
| torch_ecg/utils/utils_metrics.py, utils_interval.py, _preproc.py, _edr.py | Real → int/float; _getxy now returns a Python float via .item(). |
| torch_ecg/models/_nets.py | Real → float/(float, dict) in several isinstance checks (regression: int no longer accepted). |
| torch_ecg/models/cnn/{xception,resnet,regnet,mobilenet}.py, models/loss.py, models/ecg_fcn.py | Real → (int, float, …) in isinstance checks and annotations. |
| torch_ecg/preprocessors/.py, _preprocessors/normalize.py, augmenters/.py | Type-hint cleanup (Real → int/float); minor dtype additions in augmenters. |
| torch_ecg/databases/**/*.py, components/{inputs,outputs,metrics,loggers}.py | Doc/type updates from Real to int/float; Union[str, type(None)] → Union[str, None]. |
| test/test_utils/test_utils_nn.py | New regression test test_ckpt_decimal_suffix_path covering the three save branches. |
| test/test_preprocessors.py, test/test_preprocessors_t.py, test/test_databases/test_shhs.py | Tests updated to drop Real and to cover new bandpass_filter validation. |
| CHANGELOG.rst | Two Fixed entries for the checkpoint bugs (the broader Real migration and bandpass validation are not noted). |
| benchmarks/**/*.py, torch_ecg/preprocessors/README.md | Annotation updates from Real to int/float and minor import cleanups. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| - Fix ``CkptMixin.save()`` silently truncating checkpoint filenames that contain | ||
| decimal values (e.g. ``…metric_0.91``): ``pathlib.Path.with_suffix(".safetensors")`` | ||
| treated ``.91`` as the existing suffix and replaced it, producing | ||
| ``…metric_0.safetensors`` instead of the correct ``…metric_0.91.safetensors``. | ||
| The method now appends ``.safetensors`` for paths without a recognised | ||
| extension, and returns the final ``Path`` used so callers can track it. | ||
| - Fix ``BaseTrainer`` checkpoint cleanup permanently failing: ``saved_models`` | ||
| stored the raw stem path while the actual file on disk had a ``.safetensors`` | ||
| suffix, causing every ``os.remove()`` call to raise ``FileNotFoundError``. | ||
| The trainer now stores the path returned by ``save_checkpoint()``, and handles | ||
| both single-file (``os.remove``) and directory (``shutil.rmtree``) checkpoints. |
| if effective_lowcut is not None: | ||
| if effective_lowcut <= 0: | ||
| warnings.warn( | ||
| "lowcut <= 0 in bandpass_filter; disabling high-pass side.", | ||
| RuntimeWarning, | ||
| ) | ||
| effective_lowcut = None | ||
| elif effective_lowcut >= nyquist: | ||
| raise ValueError( | ||
| f"lowcut must be less than Nyquist frequency (fs/2={nyquist}), " f"got lowcut={effective_lowcut!r}" | ||
| ) |
| raise ValueError( | ||
| f"lowcut must be less than Nyquist frequency (fs/2={nyquist}), " f"got lowcut={effective_lowcut!r}" | ||
| ) |
| Returns | ||
| ------- | ||
| Path, optional | ||
| The actual path the checkpoint was saved to (suffix may differ | ||
| from ``path`` after normalisation, e.g. ``.safetensors``). | ||
| Returns ``None`` when the model does not implement ``save()``. | ||
|
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c274891267
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if isinstance(dropouts, (float, dict)): | ||
| _dropouts = list(repeat(dropouts, self.__num_convs)) | ||
| else: | ||
| _dropouts = list(dropouts) # type: ignore |
There was a problem hiding this comment.
When callers pass dropouts=0 (a common way to disable dropout and previously accepted because numbers.Real includes int), this now falls into the else branch and executes list(0), raising TypeError during model construction. This also affects the same replacement in BranchedConv/SeqLin; keep int in the scalar check so existing integer dropout configs continue to work.
Useful? React with 👍 / 👎.
❌ 1 Tests Failed:
View the top 1 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Summary
Two related bugs introduced when the checkpoint format switched from
.pth.tarto.safetensors.Bug A —
CkptMixin.save()silently truncates filenames with decimal valuesFile:
torch_ecg/utils/utils_nn.pyTrainer generates checkpoint folder names like
…_metric_0.91.pathlibtreats.91as the file extension, so the oldpath.with_suffix(".safetensors")replaced.91instead of appending, producing…metric_0.safetensorsinstead of the correct…metric_0.91.safetensors.Fix: after determining
use_safetensors=True, ifpath.suffix != ".safetensors"we use string concatenation instead ofwith_suffix:The now-redundant
path.with_suffix(".safetensors")inside the single-filesave_filecall is also removed (path is already normalised at that point).save()now returnsPath(the actual file/directory written) instead ofNone.Bug B —
BaseTrainercheckpoint cleanup always failsFile:
torch_ecg/components/trainer.pysaved_modelsstored the raw stem path (…metric_0.91), but the file on disk was…metric_0.91.safetensors. Everyos.remove(model_to_remove)in thekeep_checkpoint_maxcleanup therefore raisedFileNotFoundErrorsilently.Fix:
save_checkpoint()forwards thePathreturned bymodel.save(); the training loop stores that actual path:Directory-style checkpoints (non-single-file mode) are now cleaned up with
shutil.rmtreeinstead ofos.remove.shutilis promoted to a top-level import.Changes
torch_ecg/utils/utils_nn.pysave()returnsPathtorch_ecg/components/trainer.pysaved_models; handle dir cleanup; top-levelshutilimporttest/test_utils/test_utils_nn.pytest_ckpt_decimal_suffix_pathcovering all three save branchesCHANGELOG.rstFixedentries under UnreleasedTests
test_mixin_classesis excluded only because it requires a Dropbox network connection (pre-existing, unrelated).