Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions MLFLOW_INTEGRATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# MLflow Integration for Flux Fine-Tuner

This document describes the MLflow integration for tracking Flux LoRA training experiments.

## Overview

MLflow has been integrated alongside Weights & Biases to provide experiment tracking capabilities. You can use MLflow independently or alongside W&B to track your training runs.

## Features

The MLflow integration tracks:
- **Parameters**: All training hyperparameters (learning rate, batch size, steps, LoRA rank, etc.)
- **Metrics**: Training loss logged at each step
- **Artifacts**:
- Sample images generated during training
- Final LoRA weights (.safetensors file)

## Setup

### 1. Start MLflow Tracking Server

To use MLflow, you need a tracking server. You can run one locally:

```bash
mlflow server --host 0.0.0.0 --port 5000
```

Or use a remote MLflow server if you have one deployed.

### 2. Configure Training Parameters

When running training, provide the MLflow parameters:

```python
mlflow_tracking_uri="http://localhost:5000" # Your MLflow server URI
mlflow_experiment_name="flux-lora-training" # Experiment name
mlflow_run_name="my-custom-run" # Optional: specific run name
```

## Training Parameters

### MLflow-specific Parameters

- **mlflow_tracking_uri** (string, optional): MLflow tracking server URI
- Example: `"http://localhost:5000"` or `"https://your-mlflow-server.com"`
- Default: `None` (MLflow disabled)

- **mlflow_experiment_name** (string): Name of the MLflow experiment
- Default: `"flux-lora-training"`
- Only applicable if `mlflow_tracking_uri` is set

- **mlflow_run_name** (string, optional): Name for this specific run
- Default: `None` (auto-generated)
- Only applicable if `mlflow_tracking_uri` is set

## Usage Example

```python
from pathlib import Path

result = train(
input_images=Path("my_images.zip"),
trigger_word="TOK",
steps=1000,
learning_rate=4e-4,
lora_rank=16,
# MLflow configuration
mlflow_tracking_uri="http://localhost:5000",
mlflow_experiment_name="my-flux-experiment",
mlflow_run_name="run-001",
)
```

## Viewing Results

Once training starts, MLflow will print the URL to view the run:

```
MLflow tracking initialized. View at: http://localhost:5000/#/experiments/1/runs/abc123
```

Open this URL in your browser to:
- View real-time training metrics
- Compare different runs
- Download artifacts (samples and weights)
- Analyze parameter impact

## MLflow UI Features

### Metrics Tab
- View training loss curves over time
- Compare metrics across multiple runs

### Parameters Tab
- See all hyperparameters for the run
- Filter and sort runs by parameters

### Artifacts Tab
- Download sample images generated during training
- Access final LoRA weights
- View organized by training step

## Using with Weights & Biases

You can use both MLflow and W&B simultaneously:

```python
result = train(
input_images=Path("my_images.zip"),
# W&B configuration
wandb_api_key=my_wandb_key,
wandb_project="my-project",
# MLflow configuration
mlflow_tracking_uri="http://localhost:5000",
mlflow_experiment_name="my-flux-experiment",
)
```

Both systems will receive the same metrics and artifacts independently.

## Architecture

### MLflow Client (`mlflow_client.py`)

The `MLflowClient` class provides:
- Experiment and run management
- Parameter logging
- Metric logging with step tracking
- Artifact (images and weights) logging

### Integration Points

The integration hooks into:
1. **CustomSDTrainer.hook_train_loop()**: Logs training loss at each step
2. **CustomSDTrainer.sample()**: Logs generated sample images
3. **CustomSDTrainer.post_save_hook()**: Logs LoRA weights

## Troubleshooting

### Connection Issues

If you see connection errors:
- Verify the MLflow server is running
- Check the tracking URI is correct
- Ensure network connectivity to the server

### Missing Artifacts

If artifacts aren't appearing:
- Check disk space on the MLflow server
- Verify artifact storage is configured correctly
- Review server logs for errors

### Performance

MLflow logging is non-blocking and won't slow down training. However:
- Large sample image sets may take time to upload
- Consider reducing sample frequency for very frequent sampling

## Advanced Configuration

### Remote Artifact Storage

MLflow can store artifacts in S3, Azure Blob Storage, or other backends:

```bash
mlflow server \
--backend-store-uri postgresql://user:pass@localhost/mlflow \
--default-artifact-root s3://my-mlflow-bucket/ \
--host 0.0.0.0
```

### Database Backend

For production use, configure a database backend:

```bash
mlflow server \
--backend-store-uri postgresql://user:pass@localhost/mlflow \
--host 0.0.0.0
```

## References

- [MLflow Documentation](https://mlflow.org/docs/latest/)
- [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html)
- [MLflow Python API](https://mlflow.org/docs/latest/python_api/mlflow.html)
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ It also includes code for running inference with a fine-tuned model.
- Image generation using the LoRA (inference)
- Optionally uploads fine-tuned weights to Hugging Face after training
- Automated test suite with [cog-safe-push](https://github.com/replicate/cog-safe-push) for continuous deployment
- Weights and biases integration
- Weights and Biases integration
- MLflow experiment tracking integration (see [MLFLOW_INTEGRATION.md](MLFLOW_INTEGRATION.md))

## Getting Started

Expand Down
1 change: 1 addition & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ build:
- "wandb==0.17.8"
- "wavedrom==2.0.3.post3"
- "Pygments==2.16.1"
- "mlflow==2.18.0"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint"
Expand Down
111 changes: 111 additions & 0 deletions mlflow_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from pathlib import Path
from typing import Any, Sequence
from contextlib import suppress
import mlflow
import mlflow.pytorch


class MLflowClient:
"""Client for tracking Flux LoRA training experiments with MLflow."""

def __init__(
self,
tracking_uri: str,
experiment_name: str,
run_name: str | None,
config: dict,
sample_prompts: list[str],
):
"""Initialize MLflow tracking.

Args:
tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000")
experiment_name: Name of the MLflow experiment
run_name: Optional name for this specific run
config: Training configuration dictionary to log as parameters
sample_prompts: List of prompts used for sample generation
"""
self.sample_prompts = sample_prompts
self.tracking_uri = tracking_uri
self.experiment_name = experiment_name

# Set tracking URI
mlflow.set_tracking_uri(self.tracking_uri)

# Set or create experiment
try:
mlflow.set_experiment(self.experiment_name)
except Exception as e:
raise ValueError(f"Failed to set MLflow experiment: {e}")

# Start the run
try:
self.run = mlflow.start_run(run_name=run_name)
# Log all config parameters
mlflow.log_params(config)
print(
f"MLflow tracking initialized. View at: {self.tracking_uri}/#/experiments/{mlflow.active_run().info.experiment_id}/runs/{mlflow.active_run().info.run_id}"
)
except Exception as e:
raise ValueError(f"Failed to start MLflow run: {e}")

def log_loss(self, loss_dict: dict[str, Any], step: int | None):
"""Log training loss metrics.

Args:
loss_dict: Dictionary of loss values to log
step: Training step number
"""
try:
mlflow.log_metrics(loss_dict, step=step)
except Exception as e:
print(f"Failed to log metrics to MLflow: {e}")

def log_samples(self, image_paths: Sequence[Path], step: int | None):
"""Log generated sample images.

Args:
image_paths: List of paths to generated images
step: Training step number
"""
try:
# Log each image with its corresponding prompt
for path in image_paths:
mlflow.log_artifact(str(path), artifact_path=f"samples/step_{step}")
except Exception as e:
print(f"Failed to log samples to MLflow: {e}")

def save_weights(self, lora_path: Path):
"""Save LoRA weights as an artifact.

Args:
lora_path: Path to the LoRA safetensors file
"""
try:
# Log the weights file as an artifact
mlflow.log_artifact(str(lora_path), artifact_path="weights")
print(f"Logged weights to MLflow: {lora_path.name}")
except Exception as e:
print(f"Failed to save weights to MLflow: {e}")

def finish(self):
"""End the MLflow run."""
with suppress(Exception):
mlflow.end_run()
print("MLflow run completed successfully")


def truncate(text, max_chars=50):
"""Truncate text to max_chars, adding ellipsis in the middle if needed.

Args:
text: Text to truncate
max_chars: Maximum character length

Returns:
Truncated text
"""
if len(text) <= max_chars:
return text
half = (max_chars - 3) // 2
return f"{text[:half]}...{text[-half:]}"
Loading