Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions cookbooks/maestro_early_stopping.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0341f3f7",
"metadata": {},
"source": [
"# Multimodal Maestro: Using Early Stopping for Efficient Training\n",
"\n",
"This notebook demonstrates how to use the early stopping feature with Multimodal Maestro models to reduce training time and prevent overfitting."
]
},
{
"cell_type": "markdown",
"id": "2008f170",
"metadata": {},
"source": [
"## Introduction\n",
"\n",
"Early stopping is a regularization technique to prevent overfitting in machine learning models. It works by monitoring a validation metric (typically validation loss) and stopping training when the model performance on the validation set stops improving for a specified number of epochs.\n",
"\n",
"Benefits of early stopping:\n",
"1. Reduces training time\n",
"2. Prevents overfitting\n",
"3. Automatically determines optimal training duration\n",
"\n",
"In this notebook, we'll demonstrate how to enable early stopping with Florence-2 model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b0cd0d1",
"metadata": {},
"outputs": [],
"source": [
"# Install required packages\n",
"%pip install multimodal-maestro supervision --quiet"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e81adcdc",
"metadata": {},
"outputs": [],
"source": [
"# Import necessary libraries\n",
"import os\n",
"\n",
"from maestro.trainer.common.metrics import MeanAveragePrecisionMetric\n",
"from maestro.trainer.models.florence_2.core import Florence2Configuration, train"
]
},
{
"cell_type": "markdown",
"id": "fdc038f9",
"metadata": {},
"source": [
"## Downloading a sample dataset\n",
"\n",
"For this example, we'll use a small object detection dataset. You can replace this with your own dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e9e51e1",
"metadata": {},
"outputs": [],
"source": [
"# Download a sample dataset (chess pieces dataset)\n",
"%pip install roboflow\n",
"\n",
"from roboflow import Roboflow\n",
"\n",
"rf = Roboflow(api_key=\"YOUR_API_KEY\") # Replace with your API key or remove if using public datasets\n",
"project = rf.workspace(\"roboflow-100\").project(\"chess-pieces-detection\")\n",
"dataset = project.version(2).download(\"coco\")"
]
},
{
"cell_type": "markdown",
"id": "1888b2dc",
"metadata": {},
"source": [
"## Configuring Training with Early Stopping\n",
"\n",
"Now we'll set up the training configuration with early stopping enabled."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da75158d",
"metadata": {},
"outputs": [],
"source": [
"# Get the dataset path\n",
"dataset_path = os.path.join(os.getcwd(), dataset.location)\n",
"\n",
"# Configure the training\n",
"config = Florence2Configuration(\n",
" dataset=dataset_path,\n",
" epochs=20, # Set a large enough number of epochs\n",
" batch_size=2, # Use a small batch size for this example\n",
" lr=1e-5,\n",
" optimization_strategy=\"lora\",\n",
" metrics=[MeanAveragePrecisionMetric()],\n",
" # Early stopping configuration\n",
" early_stopping=True, # Enable early stopping\n",
" early_stopping_patience=3, # Stop after 3 epochs with no improvement\n",
" early_stopping_threshold=0.01, # Minimum change to be considered as improvement\n",
" early_stopping_monitor=\"val_loss\", # Metric to monitor\n",
")"
]
},
{
"cell_type": "markdown",
"id": "705dc6a9",
"metadata": {},
"source": [
"## Training the Model with Early Stopping\n",
"\n",
"Now we'll start training the model. With early stopping enabled, training will automatically stop once the validation loss stops improving for 3 consecutive epochs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84e11341",
"metadata": {},
"outputs": [],
"source": [
"# Train the model\n",
"train(config)"
]
},
{
"cell_type": "markdown",
"id": "438768a8",
"metadata": {},
"source": [
"## Visualizing Training Metrics\n",
"\n",
"After training completes, you can examine the training and validation metrics to see how early stopping worked."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b259009",
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Find the most recent training run\n",
"runs = sorted(glob.glob(\"./training/florence_2/*\"))\n",
"latest_run = runs[-1] if runs else None\n",
"\n",
"if latest_run:\n",
" # Try to load the metrics\n",
" try:\n",
" metrics_dir = os.path.join(latest_run, \"metrics\")\n",
" train_loss = pd.read_csv(os.path.join(metrics_dir, \"train_loss.csv\"))\n",
" val_loss = pd.read_csv(os.path.join(metrics_dir, \"val_loss.csv\"))\n",
"\n",
" # Plot training and validation loss\n",
" plt.figure(figsize=(10, 5))\n",
" plt.plot(train_loss[\"epoch\"], train_loss[\"value\"], label=\"Training Loss\")\n",
" plt.plot(val_loss[\"epoch\"], val_loss[\"value\"], label=\"Validation Loss\")\n",
" plt.xlabel(\"Epoch\")\n",
" plt.ylabel(\"Loss\")\n",
" plt.legend()\n",
" plt.title(\"Training and Validation Loss (with Early Stopping)\")\n",
" plt.grid(True, linestyle=\"--\", alpha=0.7)\n",
" plt.show()\n",
"\n",
" # Show where early stopping occurred\n",
" best_epoch = val_loss[\"value\"].idxmin()\n",
" print(f\"Best epoch: {best_epoch}\")\n",
" print(f\"Best validation loss: {val_loss['value'].min()}\")\n",
" print(f\"Training stopped at epoch: {val_loss['epoch'].max()}\")\n",
" except Exception as e:\n",
" print(f\"Could not load metrics: {e}\")\n",
"else:\n",
" print(\"No training runs found\")"
]
},
{
"cell_type": "markdown",
"id": "2b52ad9a",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"In this notebook, we've demonstrated how to use early stopping with the Florence-2 model in Multimodal Maestro. The same approach can be applied to other models like PaliGemma-2, and Qwen2.5-VL.\n",
"\n",
"Early stopping is a valuable technique for efficient model training, as it:\n",
"\n",
"1. Saves training time and computational resources\n",
"2. Automatically determines the optimal number of training epochs\n",
"3. Helps prevent overfitting\n",
"\n",
"By adjusting the `early_stopping_patience`, `early_stopping_threshold`, and `early_stopping_monitor` parameters, you can fine-tune the early stopping behavior to suit your specific training needs."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
64 changes: 64 additions & 0 deletions cookbooks/maestro_early_stopping_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Example script demonstrating how to enable early stopping in Maestro models.
This is useful to prevent overfitting and reduce training time when model
performance on the validation set has stopped improving.
"""

from maestro.trainer.models.florence_2.core import Florence2Configuration
from maestro.trainer.models.florence_2.core import train as train_florence
from maestro.trainer.models.paligemma_2.core import PaliGemma2Configuration
from maestro.trainer.models.paligemma_2.core import train as train_paligemma
from maestro.trainer.models.qwen_2_5_vl.core import Qwen25VLConfiguration
from maestro.trainer.models.qwen_2_5_vl.core import train as train_qwen


# Example with Florence-2 model
def train_florence_with_early_stopping():
"""Train a Florence-2 model with early stopping enabled"""
config = Florence2Configuration(
dataset="path/to/your/dataset", # Replace with your dataset path
epochs=20, # Set a larger number of epochs
early_stopping=True, # Enable early stopping
early_stopping_patience=3, # Stop after 3 epochs without improvement
early_stopping_threshold=0.01, # Minimum change to be considered as improvement
early_stopping_monitor="val_loss", # Metric to monitor (default: val_loss)
)

train_florence(config)


# Example with PaliGemma-2 model
def train_paligemma_with_early_stopping():
"""Train a PaliGemma-2 model with early stopping enabled"""
config = PaliGemma2Configuration(
dataset="path/to/your/dataset", # Replace with your dataset path
epochs=20, # Set a larger number of epochs
early_stopping=True, # Enable early stopping
early_stopping_patience=5, # Stop after 5 epochs without improvement
early_stopping_threshold=0.001, # More sensitive to small improvements
early_stopping_monitor="val_loss", # Metric to monitor
)

train_paligemma(config)


# Example with Qwen2.5-VL model
def train_qwen_with_early_stopping():
"""Train a Qwen2.5-VL model with early stopping enabled"""
config = Qwen25VLConfiguration(
dataset="path/to/your/dataset", # Replace with your dataset path
epochs=20, # Set a larger number of epochs
early_stopping=True, # Enable early stopping
early_stopping_patience=3, # Stop after 3 epochs without improvement
early_stopping_threshold=0.01, # Minimum change to be considered as improvement
early_stopping_monitor="val_loss", # Metric to monitor
)

train_qwen(config)


if __name__ == "__main__":
# Choose one of the training functions to run
train_florence_with_early_stopping()
# train_paligemma_with_early_stopping()
# train_qwen_with_early_stopping()
35 changes: 34 additions & 1 deletion maestro/trainer/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable

import lightning
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import Callback, EarlyStopping

from maestro.trainer.common.training import MaestroTrainer, TModel, TProcessor

Expand All @@ -26,3 +26,36 @@ def on_train_epoch_end(self, trainer: lightning.Trainer, pl_module: MaestroTrain

def on_train_end(self, trainer: lightning.Trainer, pl_module: MaestroTrainer):
pass


class EarlyStoppingCallback(EarlyStopping):
"""
Early stopping callback for PyTorch Lightning trainers.

This callback stops training when a monitored metric has stopped improving.

Attributes:
monitor (str): Quantity to be monitored. Default is 'val_loss'.
min_delta (float): Minimum change in monitored quantity to qualify as improvement.
patience (int): Number of validation epochs with no improvement after which training will be stopped.
mode (str): One of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored
has stopped decreasing; in 'max' mode it will stop when the quantity monitored
has stopped increasing. Default is 'min'.
verbose (bool): Whether to print progress messages.
"""

def __init__(
self,
monitor: str = "val_loss",
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = True,
mode: str = "min",
):
super().__init__(
monitor=monitor,
min_delta=min_delta,
patience=patience,
verbose=verbose,
mode=mode,
)
32 changes: 31 additions & 1 deletion maestro/trainer/models/florence_2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ class Florence2Configuration:
Random seed for ensuring reproducibility. If None, no seeding is applied.
peft_advanced_params (Optional[dict]):
Custom LoRA configuration . If None, default configuration is applied.
early_stopping_patience (int):
Number of epochs with no improvement after which training will be stopped.
Only applies if early_stopping is True. Default is 3.
early_stopping (bool):
Whether to use early stopping. Default is False.
early_stopping_threshold (float):
Minimum change in monitored quantity to qualify as improvement. Default is 0.0.
early_stopping_monitor (str):
Quantity to be monitored for early stopping. Default is "val_loss".
"""

dataset: str
Expand All @@ -106,6 +115,10 @@ class Florence2Configuration:
max_new_tokens: int = 1024
random_seed: Optional[int] = None
peft_advanced_params: Optional[dict] = None
early_stopping: bool = False
early_stopping_patience: int = 3
early_stopping_threshold: float = 0.0
early_stopping_monitor: str = "val_loss"

def __post_init__(self):
if self.val_batch_size is None:
Expand Down Expand Up @@ -273,12 +286,29 @@ def train(config: Florence2Configuration | dict) -> None:
)
save_checkpoints_path = os.path.join(config.output_dir, "checkpoints")
save_checkpoint_callback = SaveCheckpoint(result_path=save_checkpoints_path, save_model_callback=save_model)

callbacks = [save_checkpoint_callback]

# Add early stopping if enabled
if config.early_stopping:
from maestro.trainer.common.callbacks import EarlyStoppingCallback

early_stopping_callback = EarlyStoppingCallback(
monitor=config.early_stopping_monitor,
min_delta=config.early_stopping_threshold,
patience=config.early_stopping_patience,
verbose=True,
mode="min" if config.early_stopping_monitor == "val_loss" else "max",
)
callbacks.append(early_stopping_callback)
logger.info(f"Early stopping enabled with patience {config.early_stopping_patience}")

trainer = lightning.Trainer(
max_epochs=config.epochs,
accumulate_grad_batches=config.accumulate_grad_batches,
check_val_every_n_epoch=1,
limit_val_batches=1,
log_every_n_steps=10,
callbacks=[save_checkpoint_callback],
callbacks=callbacks,
)
trainer.fit(pl_module)
Loading