Skip to content

Commit d0bd749

Browse files
authored
try Auth before training starts (#41)
* try Auth before training starts * ConnectionError * linter
1 parent 153289e commit d0bd749

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/litmodels/integrations/lightning_checkpoint.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any
22

3+
from lightning_sdk.lightning_cloud.login import Auth
4+
35
from litmodels import upload_model
46
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
57

@@ -28,6 +30,13 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
2830
super().__init__(*args, **kwargs)
2931
self.model_name = model_name
3032

33+
try:
34+
# authenticate before anything else starts
35+
auth = Auth()
36+
auth.authenticate()
37+
except Exception:
38+
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
39+
3140
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
3241
super()._save_checkpoint(trainer, filepath)
3342
# todo: uploading on background so training does nt stops

tests/integrations/test_lightning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414

1515
@mock.patch("litmodels.io.cloud.sdk_upload_model")
16-
def test_lightning_checkpoint_callback(mock_upload_model, tmp_path):
16+
@mock.patch("litmodels.integrations.lightning_checkpoint.Auth")
17+
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, tmp_path):
1718
mock_upload_model.return_value.name = "org-name/teamspace/model-name"
1819

1920
trainer = Trainer(

0 commit comments

Comments
 (0)