3535# Create a singleton upload manager
3636@cache
3737def get_model_manager () -> "ModelManager" :
38- """Get or create the singleton upload manager."""
38+ """Get or create the singleton background manager for uploads/removals.
39+
40+ Returns:
41+ ModelManager: A process-wide, cached instance managing asynchronous tasks.
42+ """
3943 return ModelManager ()
4044
4145
@@ -55,12 +59,16 @@ class RemoveType(StrEnum):
5559
5660
5761class ModelManager :
58- """Manages uploads and removals with a single queue but separate counters."""
62+ """Manage asynchronous uploads and removals via a single worker queue.
63+
64+ This manager runs a daemon worker thread that processes queued upload and removal tasks.
65+ It maintains separate counters for pending uploads and removals and supports graceful shutdown.
66+ """
5967
6068 task_queue : queue .Queue
6169
6270 def __init__ (self ) -> None :
63- """Initialize the ModelManager with a task queue and counters ."""
71+ """Initialize the manager with a task queue, counters, and a daemon worker thread ."""
6472 self .task_queue = queue .Queue ()
6573 self .upload_count = 0
6674 self .remove_count = 0
@@ -142,7 +150,11 @@ def shutdown(self) -> None:
142150
143151# Base class to be inherited
144152class LitModelCheckpointMixin (ABC ):
145- """Mixin class for LitModel checkpoint functionality."""
153+ """Mixin adding upload/remove behavior for Lightning checkpoint callbacks.
154+
155+ This mixin queues uploads to Lightning Cloud upon checkpoint save and can optionally
156+ remove local checkpoints or skip cloud-side pruning based on configuration.
157+ """
146158
147159 _datetime_stamp : str
148160 model_registry : Optional [str ] = None
@@ -151,12 +163,12 @@ class LitModelCheckpointMixin(ABC):
151163 def __init__ (
152164 self , model_registry : Optional [str ], keep_all_uploaded : bool = False , clear_all_local : bool = False
153165 ) -> None :
154- """Initialize with model name .
166+ """Configure model registry and pruning behavior .
155167
156168 Args:
157- model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
158- keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so .
159- clear_all_local: Whether to clear local models after uploading to the cloud.
169+ model_registry: Target model registry in the form 'organization/teamspace/modelname'.
170+ keep_all_uploaded: If True, never delete uploaded cloud versions even if local pruning occurs .
171+ clear_all_local: If True, remove local checkpoint files after they are uploaded to cloud.
160172 """
161173 if not model_registry :
162174 rank_zero_warn (
@@ -201,7 +213,7 @@ def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metad
201213
202214 @rank_zero_only
203215 def _remove_model (self , trainer : "pl.Trainer" , filepath : Union [str , Path ]) -> None :
204- """Remove the local version of the model if requested ."""
216+ """Queue removal of local and/or cloud artifacts according to configuration ."""
205217 get_model_manager ().queue_remove (
206218 filepath = filepath ,
207219 # skip the local removal we put it in the queue right after the upload
@@ -211,7 +223,7 @@ def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> No
211223 )
212224
213225 def default_model_name (self , pl_model : "pl.LightningModule" ) -> str :
214- """Generate a default model name based on the class name and timestamp."""
226+ """Generate a default model name using the LightningModule class name and a timestamp."""
215227 return pl_model .__class__ .__name__ + f"_{ self ._datetime_stamp } "
216228
217229 def _update_model_name (self , pl_model : "pl.LightningModule" ) -> None :
@@ -252,14 +264,14 @@ def _update_model_name(self, pl_model: "pl.LightningModule") -> None:
252264if _LIGHTNING_AVAILABLE :
253265
254266 class LightningModelCheckpoint (LitModelCheckpointMixin , _LightningModelCheckpoint ):
255- """Lightning ModelCheckpoint with LitModel support .
267+ """Drop-in ModelCheckpoint that uploads saved checkpoints to Lightning Cloud .
256268
257269 Args:
258- model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
259- keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so .
260- clear_all_local: Whether to clear local models after uploading to the cloud .
261- *args: Additional arguments to pass to the parent class .
262- **kwargs: Additional keyword arguments to pass to the parent class .
270+ model_registry: Target model registry in the form 'organization/teamspace/modelname'.
271+ keep_all_uploaded: If True, does not remove cloud versions when local pruning occurs .
272+ clear_all_local: If True, removes local checkpoint files after successful upload .
273+ *args: Additional positional arguments forwarded to the base ModelCheckpoint .
274+ **kwargs: Additional keyword arguments forwarded to the base ModelCheckpoint .
263275 """
264276
265277 def __init__ (
@@ -291,7 +303,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
291303 self ._update_model_name (pl_module )
292304
293305 def _save_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
294- """Extend the save checkpoint method to upload the model ."""
306+ """Save the checkpoint and queue an upload from the global-zero process ."""
295307 _LightningModelCheckpoint ._save_checkpoint (self , trainer , filepath )
296308 if trainer .is_global_zero : # Only upload from the main process
297309 self ._upload_model (trainer = trainer , filepath = filepath )
@@ -311,14 +323,14 @@ def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
311323if _PYTORCHLIGHTNING_AVAILABLE :
312324
313325 class PytorchLightningModelCheckpoint (LitModelCheckpointMixin , _PytorchLightningModelCheckpoint ):
314- """PyTorch Lightning ModelCheckpoint with LitModel support .
326+ """Drop-in ModelCheckpoint for PyTorch Lightning that uploads to Lightning Cloud .
315327
316328 Args:
317329 model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
318- keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so .
319- clear_all_local: Whether to clear local models after uploading to the cloud .
320- args: Additional arguments to pass to the parent class .
321- kwargs: Additional keyword arguments to pass to the parent class .
330+ keep_all_uploaded: If True, does not remove cloud versions when local pruning occurs .
331+ clear_all_local: If True, removes local checkpoint files after successful upload .
332+ args: Additional positional arguments forwarded to the base ModelCheckpoint .
333+ kwargs: Additional keyword arguments forwarded to the base ModelCheckpoint .
322334 """
323335
324336 def __init__ (
0 commit comments