Skip to content

Commit 96fa47f

Browse files
authored
simple lightning checkpoint callback (#39)
* checkpoint callback * linter * test
1 parent b0a2c77 commit 96fa47f

File tree

8 files changed

+86
-19
lines changed

8 files changed

+86
-19
lines changed

.github/workflows/ci-testing.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
os: ["ubuntu-latest", "macOS-latest", "windows-latest"]
19+
os: ["ubuntu-22.04", "macOS-13", "windows-2022"]
2020
python-version: ["3.9", "3.12"]
2121
requires: ["latest"]
2222
include:
@@ -44,7 +44,7 @@ jobs:
4444
- name: Install package & dependencies
4545
run: |
4646
pip --version
47-
pip install -e '.[test]' -U -q --find-links $TORCH_URL
47+
pip install -e '.[test,extra]' -U -q --find-links $TORCH_URL
4848
pip list
4949
5050
- name: Tests
@@ -60,7 +60,7 @@ jobs:
6060
uses: codecov/codecov-action@v5
6161
with:
6262
token: ${{ secrets.CODECOV_TOKEN }}
63-
file: ./coverage.xml
63+
files: ./coverage.xml
6464
flags: unittests
6565
env_vars: OS,PYTHON
6666
name: codecov-umbrella

_requirements/extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
lightning >= 2.0.0

examples/train-model-with-lightning-callback.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,13 @@
44

55
import torch.utils.data as data
66
import torchvision as tv
7-
from lightning import Callback, Trainer
8-
from litmodels import upload_model
7+
from lightning import Trainer
8+
from litmodels.integrations.lightning_checkpoint import LitModelCheckpoint
99
from sample_model import LitAutoEncoder
1010

1111
# Define the model name - this should be unique to your model
1212
# The format is <organization>/<teamspace>/<model-name>
13-
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback"
14-
15-
16-
class UploadModelCallback(Callback):
17-
def on_train_epoch_end(self, trainer, pl_module):
18-
# Get the best model path from the checkpoint callback
19-
best_model_path = trainer.checkpoint_callback.best_model_path
20-
if best_model_path:
21-
print(f"Uploading model: {best_model_path}")
22-
upload_model(model=best_model_path, name=MY_MODEL_NAME)
13+
MY_MODEL_NAME = "lightning-ai/jirka/lit-auto-encoder-callback"
2314

2415

2516
if __name__ == "__main__":
@@ -30,7 +21,7 @@ def on_train_epoch_end(self, trainer, pl_module):
3021

3122
trainer = Trainer(
3223
max_epochs=2,
33-
callbacks=[UploadModelCallback()],
24+
callbacks=LitModelCheckpoint(model_name=MY_MODEL_NAME),
3425
)
3526
trainer.fit(
3627
autoencoder,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others."""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from lightning_utilities import module_available
2+
3+
_LIGHTNING_AVAILABLE = module_available("lightning")
4+
_PYTORCHLIGHTNING_AVAILABLE = module_available("pytorch_lightning")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Any
2+
3+
from litmodels import upload_model
4+
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
5+
6+
if _LIGHTNING_AVAILABLE:
7+
from lightning.pytorch import Trainer
8+
from lightning.pytorch.callbacks import ModelCheckpoint
9+
elif _PYTORCHLIGHTNING_AVAILABLE:
10+
from pytorch_lightning.callbacks import ModelCheckpoint, Trainer
11+
else:
12+
raise ModuleNotFoundError("No module named 'lightning' or 'pytorch_lightning'")
13+
14+
15+
class LitModelCheckpoint(ModelCheckpoint):
16+
"""Lightning ModelCheckpoint with LitModel support.
17+
18+
Args:
19+
model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
20+
where entity is either your username or the name of an organization you are part of.
21+
args: Additional arguments to pass to the parent class.
22+
kwargs: Additional keyword arguments to pass to the parent class.
23+
24+
"""
25+
26+
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
27+
"""Initialize the LitModelCheckpoint."""
28+
super().__init__(*args, **kwargs)
29+
self.model_name = model_name
30+
31+
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
32+
super()._save_checkpoint(trainer, filepath)
33+
# todo: uploading on background so training does nt stops
34+
# todo: use filename as version but need to validate that such version does not exists yet
35+
upload_model(name=self.model_name, model=filepath)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import re
2+
from unittest import mock
3+
4+
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
5+
from litmodels.integrations.lightning_checkpoint import LitModelCheckpoint
6+
7+
if _LIGHTNING_AVAILABLE:
8+
from lightning import Trainer
9+
from lightning.pytorch.demos.boring_classes import BoringModel
10+
elif _PYTORCHLIGHTNING_AVAILABLE:
11+
from pytorch_lightning import Trainer
12+
from pytorch_lightning.demos.boring_classes import BoringModel
13+
14+
15+
@mock.patch("litmodels.io.cloud.sdk_upload_model")
16+
def test_lightning_checkpoint_callback(mock_upload_model, tmp_path):
17+
mock_upload_model.return_value.name = "org-name/teamspace/model-name"
18+
19+
trainer = Trainer(
20+
max_epochs=2,
21+
callbacks=LitModelCheckpoint(model_name="org-name/teamspace/model-name"),
22+
)
23+
trainer.fit(BoringModel())
24+
25+
# expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
26+
assert mock_upload_model.call_count == 2
27+
assert mock_upload_model.call_args_list == [
28+
mock.call(name="org-name/teamspace/model-name", path=mock.ANY, progress_bar=True, cloud_account=None),
29+
mock.call(name="org-name/teamspace/model-name", path=mock.ANY, progress_bar=True, cloud_account=None),
30+
]
31+
32+
# Verify paths match the expected pattern
33+
for call_args in mock_upload_model.call_args_list:
34+
path = call_args[1]["path"]
35+
assert re.match(r".*[/\\]lightning_logs[/\\]version_\d+[/\\]checkpoints[/\\]epoch=\d+-step=\d+\.ckpt$", path)

tests/test_io_cloud.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ def test_wrong_model_name(name):
2727
],
2828
)
2929
@mock.patch("litmodels.io.cloud.sdk_upload_model")
30-
def test_upload_model(mock_upload_model, tmpdir, model, model_path, verbose):
30+
def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose):
3131
mock_upload_model.return_value.name = "org-name/teamspace/model-name"
3232

3333
# The lit-logger function is just a wrapper around the SDK function
3434
upload_model(
3535
model=model,
3636
name="org-name/teamspace/model-name",
3737
cloud_account="cluster_id",
38-
staging_dir=tmpdir,
38+
staging_dir=str(tmp_path),
3939
verbose=verbose,
4040
)
41-
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
41+
expected_path = model_path % str(tmp_path) if "%" in model_path else model_path
4242
mock_upload_model.assert_called_once_with(
4343
path=expected_path,
4444
name="org-name/teamspace/model-name",

0 commit comments

Comments
 (0)