Skip to content

Commit a8fdd2c

Browse files
authored
update documentation (#123)
* Apply suggestions from code review
1 parent 42aca4d commit a8fdd2c

File tree

5 files changed

+129
-98
lines changed

5 files changed

+129
-98
lines changed

src/litmodels/integrations/checkpoints.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
# Create a singleton upload manager
3636
@cache
3737
def get_model_manager() -> "ModelManager":
38-
"""Get or create the singleton upload manager."""
38+
"""Get or create the singleton background manager for uploads/removals.
39+
40+
Returns:
41+
ModelManager: A process-wide, cached instance managing asynchronous tasks.
42+
"""
3943
return ModelManager()
4044

4145

@@ -55,12 +59,16 @@ class RemoveType(StrEnum):
5559

5660

5761
class ModelManager:
58-
"""Manages uploads and removals with a single queue but separate counters."""
62+
"""Manage asynchronous uploads and removals via a single worker queue.
63+
64+
This manager runs a daemon worker thread that processes queued upload and removal tasks.
65+
It maintains separate counters for pending uploads and removals and supports graceful shutdown.
66+
"""
5967

6068
task_queue: queue.Queue
6169

6270
def __init__(self) -> None:
63-
"""Initialize the ModelManager with a task queue and counters."""
71+
"""Initialize the manager with a task queue, counters, and a daemon worker thread."""
6472
self.task_queue = queue.Queue()
6573
self.upload_count = 0
6674
self.remove_count = 0
@@ -142,7 +150,11 @@ def shutdown(self) -> None:
142150

143151
# Base class to be inherited
144152
class LitModelCheckpointMixin(ABC):
145-
"""Mixin class for LitModel checkpoint functionality."""
153+
"""Mixin adding upload/remove behavior for Lightning checkpoint callbacks.
154+
155+
This mixin queues uploads to Lightning Cloud upon checkpoint save and can optionally
156+
remove local checkpoints or skip cloud-side pruning based on configuration.
157+
"""
146158

147159
_datetime_stamp: str
148160
model_registry: Optional[str] = None
@@ -151,12 +163,12 @@ class LitModelCheckpointMixin(ABC):
151163
def __init__(
152164
self, model_registry: Optional[str], keep_all_uploaded: bool = False, clear_all_local: bool = False
153165
) -> None:
154-
"""Initialize with model name.
166+
"""Configure model registry and pruning behavior.
155167
156168
Args:
157-
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
158-
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
159-
clear_all_local: Whether to clear local models after uploading to the cloud.
169+
model_registry: Target model registry in the form 'organization/teamspace/modelname'.
170+
keep_all_uploaded: If True, never delete uploaded cloud versions even if local pruning occurs.
171+
clear_all_local: If True, remove local checkpoint files after they are uploaded to cloud.
160172
"""
161173
if not model_registry:
162174
rank_zero_warn(
@@ -201,7 +213,7 @@ def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metad
201213

202214
@rank_zero_only
203215
def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
204-
"""Remove the local version of the model if requested."""
216+
"""Queue removal of local and/or cloud artifacts according to configuration."""
205217
get_model_manager().queue_remove(
206218
filepath=filepath,
207219
# skip the local removal we put it in the queue right after the upload
@@ -211,7 +223,7 @@ def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> No
211223
)
212224

213225
def default_model_name(self, pl_model: "pl.LightningModule") -> str:
214-
"""Generate a default model name based on the class name and timestamp."""
226+
"""Generate a default model name using the LightningModule class name and a timestamp."""
215227
return pl_model.__class__.__name__ + f"_{self._datetime_stamp}"
216228

217229
def _update_model_name(self, pl_model: "pl.LightningModule") -> None:
@@ -252,14 +264,14 @@ def _update_model_name(self, pl_model: "pl.LightningModule") -> None:
252264
if _LIGHTNING_AVAILABLE:
253265

254266
class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoint):
255-
"""Lightning ModelCheckpoint with LitModel support.
267+
"""Drop-in ModelCheckpoint that uploads saved checkpoints to Lightning Cloud.
256268
257269
Args:
258-
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
259-
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
260-
clear_all_local: Whether to clear local models after uploading to the cloud.
261-
*args: Additional arguments to pass to the parent class.
262-
**kwargs: Additional keyword arguments to pass to the parent class.
270+
model_registry: Target model registry in the form 'organization/teamspace/modelname'.
271+
keep_all_uploaded: If True, does not remove cloud versions when local pruning occurs.
272+
clear_all_local: If True, removes local checkpoint files after successful upload.
273+
*args: Additional positional arguments forwarded to the base ModelCheckpoint.
274+
**kwargs: Additional keyword arguments forwarded to the base ModelCheckpoint.
263275
"""
264276

265277
def __init__(
@@ -291,7 +303,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
291303
self._update_model_name(pl_module)
292304

293305
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
294-
"""Extend the save checkpoint method to upload the model."""
306+
"""Save the checkpoint and queue an upload from the global-zero process."""
295307
_LightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
296308
if trainer.is_global_zero: # Only upload from the main process
297309
self._upload_model(trainer=trainer, filepath=filepath)
@@ -311,14 +323,14 @@ def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
311323
if _PYTORCHLIGHTNING_AVAILABLE:
312324

313325
class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
314-
"""PyTorch Lightning ModelCheckpoint with LitModel support.
326+
"""Drop-in ModelCheckpoint for PyTorch Lightning that uploads to Lightning Cloud.
315327
316328
Args:
317329
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
318-
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
319-
clear_all_local: Whether to clear local models after uploading to the cloud.
320-
args: Additional arguments to pass to the parent class.
321-
kwargs: Additional keyword arguments to pass to the parent class.
330+
keep_all_uploaded: If True, does not remove cloud versions when local pruning occurs.
331+
clear_all_local: If True, removes local checkpoint files after successful upload.
332+
args: Additional positional arguments forwarded to the base ModelCheckpoint.
333+
kwargs: Additional keyword arguments forwarded to the base ModelCheckpoint.
322334
"""
323335

324336
def __init__(

src/litmodels/integrations/duplicate.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@ def duplicate_hf_model(
2121
verbose: int = 1,
2222
metadata: Optional[dict] = None,
2323
) -> str:
24-
"""Downloads the model from Hugging Face and uploads it to Lightning Cloud.
24+
"""Download a model from Hugging Face and upload it to Lightning Cloud as a new model.
2525
2626
Args:
27-
hf_model: The name of the Hugging Face model to duplicate.
28-
lit_model: The name of the Lightning Cloud model to create.
29-
local_workdir:
30-
The local working directory to use for the duplication process. If not set a temp folder will be created.
31-
verbose: Shot a progress bar for the upload.
32-
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
27+
hf_model: Hugging Face model identifier, for example 'org/name' or 'user/name'.
28+
lit_model: Target Lightning Cloud model name. If omitted, derived from `hf_model` by replacing '/' with '_'.
29+
local_workdir: Working directory used for download and staging. A temporary directory is created if omitted.
30+
verbose: Verbosity for upload progress (0 = silent, 1 = print link once, 2 = print link always).
31+
metadata: Optional metadata to attach to the uploaded model. Integration markers are added automatically.
3332
3433
Returns:
3534
The name of the duplicated model in Lightning Cloud.

src/litmodels/io/cloud.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121

2222

2323
def _print_model_link(name: str, verbose: Union[bool, int]) -> None:
24-
"""Print a link to the uploaded model.
24+
"""Print a stable URL to the uploaded model.
2525
2626
Args:
27-
name: Name of the model.
28-
verbose: Whether to print the link:
29-
30-
- If set to 0, no link will be printed.
31-
- If set to 1, the link will be printed only once.
32-
- If set to 2, the link will be printed every time.
27+
name: Model registry name. Teamspace defaults may be applied before URL construction.
28+
verbose: Controls printing behavior:
29+
- 0: do not print
30+
- 1: print the link only once for a given model
31+
- 2: always print the link
3332
"""
3433
name = _extend_model_name_with_teamspace(name)
3534
org_name, teamspace_name, model_name, _ = _parse_org_teamspace_model_version(name)
@@ -51,18 +50,18 @@ def upload_model_files(
5150
verbose: Union[bool, int] = 1,
5251
metadata: Optional[dict[str, str]] = None,
5352
) -> "UploadedModelInfo":
54-
"""Upload a local checkpoint file to the model store.
53+
"""Upload local artifact(s) to Lightning Cloud using the SDK.
5554
5655
Args:
57-
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
58-
where entity is either your username or the name of an organization you are part of.
59-
path: Path to the model file to upload.
60-
progress_bar: Whether to show a progress bar for the upload.
61-
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
62-
automatically.
63-
verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed.
64-
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
56+
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
57+
path: File path, directory path, or list of paths to upload.
58+
progress_bar: Whether to show a progress bar during upload.
59+
cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved.
60+
verbose: Verbosity for printing the model link (0 = no output, 1 = print once, 2 = print always).
61+
metadata: Optional metadata to attach to the model/version. The package version is added automatically.
6562
63+
Returns:
64+
UploadedModelInfo describing the created or updated model version.
6665
"""
6766
if not metadata:
6867
metadata = {}
@@ -84,17 +83,15 @@ def download_model_files(
8483
download_dir: Union[str, Path] = ".",
8584
progress_bar: bool = True,
8685
) -> Union[str, list[str]]:
87-
"""Download a checkpoint from the model store.
86+
"""Download artifact(s) for a model version using the SDK.
8887
8988
Args:
90-
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
91-
where entity is either your username or the name of an organization you are part of.
92-
download_dir: A path to directory where the model should be downloaded. Defaults
93-
to the current working directory.
94-
progress_bar: Whether to show a progress bar for the download.
89+
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
90+
download_dir: Directory where downloaded artifact(s) will be stored. Defaults to the current directory.
91+
progress_bar: Whether to show a progress bar during download.
9592
9693
Returns:
97-
The absolute path to the downloaded model file or folder.
94+
str | list[str]: Absolute path(s) to the downloaded artifact(s).
9895
"""
9996
return sdk_download_model(
10097
name=name,
@@ -104,10 +101,10 @@ def download_model_files(
104101

105102

106103
def _list_available_teamspaces() -> dict[str, dict]:
107-
"""List available teamspaces for the authenticated user.
104+
"""List teamspaces available to the authenticated user.
108105
109106
Returns:
110-
Dict with teamspace names as keys and their details as values.
107+
dict[str, dict]: Mapping of 'org/teamspace' to a metadata dictionary with details.
111108
"""
112109
from lightning_sdk.api import OrgApi, UserApi
113110
from lightning_sdk.utils import resolve as sdk_resolvers
@@ -128,13 +125,12 @@ def _list_available_teamspaces() -> dict[str, dict]:
128125

129126
def delete_model_version(
130127
name: str,
131-
version: Optional[str] = None,
128+
version: str,
132129
) -> None:
133-
"""Delete a model version from the model store.
130+
"""Delete a specific model version from the model store.
134131
135132
Args:
136-
name: Name of the model to delete. Must be in the format 'organization/teamspace/modelname'
137-
where entity is either your username or the name of an organization you are part of.
138-
version: Version of the model to delete. If not provided, all versions will be deleted.
133+
name: Base model registry name in the form 'organization/teamspace/modelname'.
134+
version: Identifier of the version to delete. This argument is required.
139135
"""
140136
sdk_delete_model(name=f"{name}:{version}")

src/litmodels/io/gateway.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,22 @@ def upload_model(
2424
verbose: Union[bool, int] = 1,
2525
metadata: Optional[dict[str, str]] = None,
2626
) -> "UploadedModelInfo":
27-
"""Upload a checkpoint to the model store.
27+
"""Upload a local artifact (file or directory) to Lightning Cloud Models.
2828
2929
Args:
30-
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
31-
where entity is either your username or the name of an organization you are part of.
32-
model: The model to upload. Can be a path to a checkpoint file or a folder.
33-
progress_bar: Whether to show a progress bar for the upload.
34-
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
35-
automatically.
36-
verbose: Whether to print some additional information about the uploaded model.
37-
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
30+
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
31+
If the version is omitted, one may be assigned automatically by the service.
32+
model: Path to a checkpoint file or a directory containing model artifacts.
33+
progress_bar: Whether to show a progress bar during the upload.
34+
cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved.
35+
verbose: Verbosity of informational output (0 = silent, 1 = print link once, 2 = print link always).
36+
metadata: Optional metadata key/value pairs to attach to the uploaded model/version.
3837
38+
Returns:
39+
UploadedModelInfo describing the created or updated model version.
40+
41+
Raises:
42+
ValueError: If `model` is not a filesystem path. For in-memory objects, use `save_model()` instead.
3943
"""
4044
if not isinstance(model, (str, Path)):
4145
raise ValueError(
@@ -62,20 +66,29 @@ def save_model(
6266
verbose: Union[bool, int] = 1,
6367
metadata: Optional[dict[str, str]] = None,
6468
) -> "UploadedModelInfo":
65-
"""Upload a checkpoint to the model store.
69+
"""Serialize an in-memory model and upload it to Lightning Cloud Models.
70+
71+
Supported models:
72+
- TorchScript (torch.jit.ScriptModule) → saved as .ts via model.save()
73+
- PyTorch nn.Module → saved as .pth (state_dict via torch.save)
74+
- Keras (tf.keras.Model) → saved as .keras via model.save()
75+
- Any other Python object → saved as .pkl via pickle or joblib
6676
6777
Args:
68-
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
69-
where entity is either your username or the name of an organization you are part of.
70-
model: The model to upload. Can be a PyTorch model, or a Lightning model a.
71-
progress_bar: Whether to show a progress bar for the upload.
72-
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
73-
automatically.
74-
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
75-
be created and used.
76-
verbose: Whether to print some additional information about the uploaded model.
77-
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
78+
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
79+
model: The in-memory model instance to serialize and upload.
80+
progress_bar: Whether to show a progress bar during the upload.
81+
cloud_account: Optional cloud account to store the model in, when it cannot be auto-resolved.
82+
staging_dir: Optional temporary directory used for serialization. A new temp directory is created if omitted.
83+
verbose: Verbosity of informational output (0 = silent, 1 = print link once, 2 = print link always).
84+
metadata: Optional metadata key/value pairs to attach to the uploaded model/version. Integration markers are
85+
added automatically.
7886
87+
Returns:
88+
UploadedModelInfo describing the created or updated model version.
89+
90+
Raises:
91+
ValueError: If `model` is a path. For file/folder uploads use `upload_model()` instead.
7992
"""
8093
if isinstance(model, (str, Path)):
8194
raise ValueError(
@@ -120,17 +133,15 @@ def download_model(
120133
download_dir: Union[str, Path] = ".",
121134
progress_bar: bool = True,
122135
) -> Union[str, list[str]]:
123-
"""Download a checkpoint from the model store.
136+
"""Download a model version from Lightning Cloud Models to a local directory.
124137
125138
Args:
126-
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
127-
where entity is either your username or the name of an organization you are part of.
128-
download_dir: A path to directory where the model should be downloaded. Defaults
129-
to the current working directory.
130-
progress_bar: Whether to show a progress bar for the download.
139+
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
140+
download_dir: Directory where the artifact(s) will be stored. Defaults to the current working directory.
141+
progress_bar: Whether to show a progress bar during the download.
131142
132143
Returns:
133-
The absolute path to the downloaded model file or folder.
144+
str | list[str]: Absolute path(s) to the downloaded file(s) or directory content.
134145
"""
135146
return download_model_files(
136147
name=name,
@@ -140,16 +151,22 @@ def download_model(
140151

141152

142153
def load_model(name: str, download_dir: str = ".") -> Any:
143-
"""Download a model from the model store and load it into memory.
154+
"""Download a model and load it into memory based on its file extension.
155+
156+
Supported formats:
157+
- .ts → torch.jit.load
158+
- .keras → keras.models.load_model
159+
- .pkl → pickle/joblib via load_pickle
144160
145161
Args:
146-
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
147-
where entity is either your username or the name of an organization you are part of.
148-
download_dir: A path to directory where the model should be downloaded. Defaults
149-
to the current working directory.
162+
name: Model registry name in the form 'organization/teamspace/modelname[:version]'.
163+
download_dir: Directory to store the downloaded artifact(s) before loading. Defaults to the current directory.
150164
151165
Returns:
152-
The loaded model.
166+
Any: The loaded model object.
167+
168+
Raises:
169+
NotImplementedError: If multiple files are downloaded or the file extension is not supported.
153170
"""
154171
download_paths = download_model(name=name, download_dir=download_dir)
155172
# filter out all Markdown, TXT and RST files

0 commit comments

Comments
 (0)