diff --git a/MLFLOW_INTEGRATION.md b/MLFLOW_INTEGRATION.md new file mode 100644 index 0000000..6ea7b74 --- /dev/null +++ b/MLFLOW_INTEGRATION.md @@ -0,0 +1,187 @@ +# MLflow Integration for Flux Fine-Tuner + +This document describes the MLflow integration for tracking Flux LoRA training experiments. + +## Overview + +MLflow has been integrated alongside Weights & Biases to provide experiment tracking capabilities. You can use MLflow independently or alongside W&B to track your training runs. + +## Features + +The MLflow integration tracks: +- **Parameters**: All training hyperparameters (learning rate, batch size, steps, LoRA rank, etc.) +- **Metrics**: Training loss logged at each step +- **Artifacts**: + - Sample images generated during training + - Final LoRA weights (.safetensors file) + +## Setup + +### 1. Start MLflow Tracking Server + +To use MLflow, you need a tracking server. You can run one locally: + +```bash +mlflow server --host 0.0.0.0 --port 5000 +``` + +Or use a remote MLflow server if you have one deployed. + +### 2. Configure Training Parameters + +When running training, provide the MLflow parameters: + +```python +mlflow_tracking_uri="http://localhost:5000" # Your MLflow server URI +mlflow_experiment_name="flux-lora-training" # Experiment name +mlflow_run_name="my-custom-run" # Optional: specific run name +``` + +## Training Parameters + +### MLflow-specific Parameters + +- **mlflow_tracking_uri** (string, optional): MLflow tracking server URI + - Example: `"http://localhost:5000"` or `"https://your-mlflow-server.com"` + - Default: `None` (MLflow disabled) + +- **mlflow_experiment_name** (string): Name of the MLflow experiment + - Default: `"flux-lora-training"` + - Only applicable if `mlflow_tracking_uri` is set + +- **mlflow_run_name** (string, optional): Name for this specific run + - Default: `None` (auto-generated) + - Only applicable if `mlflow_tracking_uri` is set + +## Usage Example + +```python +from pathlib import Path + +result = train( + input_images=Path("my_images.zip"), + trigger_word="TOK", + steps=1000, + learning_rate=4e-4, + lora_rank=16, + # MLflow configuration + mlflow_tracking_uri="http://localhost:5000", + mlflow_experiment_name="my-flux-experiment", + mlflow_run_name="run-001", +) +``` + +## Viewing Results + +Once training starts, MLflow will print the URL to view the run: + +``` +MLflow tracking initialized. View at: http://localhost:5000/#/experiments/1/runs/abc123 +``` + +Open this URL in your browser to: +- View real-time training metrics +- Compare different runs +- Download artifacts (samples and weights) +- Analyze parameter impact + +## MLflow UI Features + +### Metrics Tab +- View training loss curves over time +- Compare metrics across multiple runs + +### Parameters Tab +- See all hyperparameters for the run +- Filter and sort runs by parameters + +### Artifacts Tab +- Download sample images generated during training +- Access final LoRA weights +- View organized by training step + +## Using with Weights & Biases + +You can use both MLflow and W&B simultaneously: + +```python +result = train( + input_images=Path("my_images.zip"), + # W&B configuration + wandb_api_key=my_wandb_key, + wandb_project="my-project", + # MLflow configuration + mlflow_tracking_uri="http://localhost:5000", + mlflow_experiment_name="my-flux-experiment", +) +``` + +Both systems will receive the same metrics and artifacts independently. + +## Architecture + +### MLflow Client (`mlflow_client.py`) + +The `MLflowClient` class provides: +- Experiment and run management +- Parameter logging +- Metric logging with step tracking +- Artifact (images and weights) logging + +### Integration Points + +The integration hooks into: +1. **CustomSDTrainer.hook_train_loop()**: Logs training loss at each step +2. **CustomSDTrainer.sample()**: Logs generated sample images +3. **CustomSDTrainer.post_save_hook()**: Logs LoRA weights + +## Troubleshooting + +### Connection Issues + +If you see connection errors: +- Verify the MLflow server is running +- Check the tracking URI is correct +- Ensure network connectivity to the server + +### Missing Artifacts + +If artifacts aren't appearing: +- Check disk space on the MLflow server +- Verify artifact storage is configured correctly +- Review server logs for errors + +### Performance + +MLflow logging is non-blocking and won't slow down training. However: +- Large sample image sets may take time to upload +- Consider reducing sample frequency for very frequent sampling + +## Advanced Configuration + +### Remote Artifact Storage + +MLflow can store artifacts in S3, Azure Blob Storage, or other backends: + +```bash +mlflow server \ + --backend-store-uri postgresql://user:pass@localhost/mlflow \ + --default-artifact-root s3://my-mlflow-bucket/ \ + --host 0.0.0.0 +``` + +### Database Backend + +For production use, configure a database backend: + +```bash +mlflow server \ + --backend-store-uri postgresql://user:pass@localhost/mlflow \ + --host 0.0.0.0 +``` + +## References + +- [MLflow Documentation](https://mlflow.org/docs/latest/) +- [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html) +- [MLflow Python API](https://mlflow.org/docs/latest/python_api/mlflow.html) diff --git a/README.md b/README.md index 75fa8ef..5007ab8 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ It also includes code for running inference with a fine-tuned model. - Image generation using the LoRA (inference) - Optionally uploads fine-tuned weights to Hugging Face after training - Automated test suite with [cog-safe-push](https://github.com/replicate/cog-safe-push) for continuous deployment -- Weights and biases integration +- Weights and Biases integration +- MLflow experiment tracking integration (see [MLFLOW_INTEGRATION.md](MLFLOW_INTEGRATION.md)) ## Getting Started diff --git a/cog.yaml b/cog.yaml index 8cae855..815c279 100644 --- a/cog.yaml +++ b/cog.yaml @@ -54,6 +54,7 @@ build: - "wandb==0.17.8" - "wavedrom==2.0.3.post3" - "Pygments==2.16.1" + - "mlflow==2.18.0" run: - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget - pip install "git+https://github.com/Gothos/diffusers.git@flux-inpaint" diff --git a/mlflow_client.py b/mlflow_client.py new file mode 100644 index 0000000..4d448e5 --- /dev/null +++ b/mlflow_client.py @@ -0,0 +1,111 @@ +from pathlib import Path +from typing import Any, Sequence +from contextlib import suppress +import mlflow +import mlflow.pytorch + + +class MLflowClient: + """Client for tracking Flux LoRA training experiments with MLflow.""" + + def __init__( + self, + tracking_uri: str, + experiment_name: str, + run_name: str | None, + config: dict, + sample_prompts: list[str], + ): + """Initialize MLflow tracking. + + Args: + tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000") + experiment_name: Name of the MLflow experiment + run_name: Optional name for this specific run + config: Training configuration dictionary to log as parameters + sample_prompts: List of prompts used for sample generation + """ + self.sample_prompts = sample_prompts + self.tracking_uri = tracking_uri + self.experiment_name = experiment_name + + # Set tracking URI + mlflow.set_tracking_uri(self.tracking_uri) + + # Set or create experiment + try: + mlflow.set_experiment(self.experiment_name) + except Exception as e: + raise ValueError(f"Failed to set MLflow experiment: {e}") + + # Start the run + try: + self.run = mlflow.start_run(run_name=run_name) + # Log all config parameters + mlflow.log_params(config) + print( + f"MLflow tracking initialized. View at: {self.tracking_uri}/#/experiments/{mlflow.active_run().info.experiment_id}/runs/{mlflow.active_run().info.run_id}" + ) + except Exception as e: + raise ValueError(f"Failed to start MLflow run: {e}") + + def log_loss(self, loss_dict: dict[str, Any], step: int | None): + """Log training loss metrics. + + Args: + loss_dict: Dictionary of loss values to log + step: Training step number + """ + try: + mlflow.log_metrics(loss_dict, step=step) + except Exception as e: + print(f"Failed to log metrics to MLflow: {e}") + + def log_samples(self, image_paths: Sequence[Path], step: int | None): + """Log generated sample images. + + Args: + image_paths: List of paths to generated images + step: Training step number + """ + try: + # Log each image with its corresponding prompt + for path in image_paths: + mlflow.log_artifact(str(path), artifact_path=f"samples/step_{step}") + except Exception as e: + print(f"Failed to log samples to MLflow: {e}") + + def save_weights(self, lora_path: Path): + """Save LoRA weights as an artifact. + + Args: + lora_path: Path to the LoRA safetensors file + """ + try: + # Log the weights file as an artifact + mlflow.log_artifact(str(lora_path), artifact_path="weights") + print(f"Logged weights to MLflow: {lora_path.name}") + except Exception as e: + print(f"Failed to save weights to MLflow: {e}") + + def finish(self): + """End the MLflow run.""" + with suppress(Exception): + mlflow.end_run() + print("MLflow run completed successfully") + + +def truncate(text, max_chars=50): + """Truncate text to max_chars, adding ellipsis in the middle if needed. + + Args: + text: Text to truncate + max_chars: Maximum character length + + Returns: + Truncated text + """ + if len(text) <= max_chars: + return text + half = (max_chars - 3) // 2 + return f"{text[:half]}...{text[-half:]}" diff --git a/train.py b/train.py index 556b9db..d88d088 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,7 @@ from caption import Captioner from wandb_client import WeightsAndBiasesClient, logout_wandb +from mlflow_client import MLflowClient from layer_match import match_layers_to_optimize, available_layers_to_optimize @@ -48,11 +49,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.seen_samples = set() self.wandb: WeightsAndBiasesClient | None = None + self.mlflow: MLflowClient | None = None def hook_train_loop(self, batch): loss_dict = super().hook_train_loop(batch) if self.wandb: self.wandb.log_loss(loss_dict, self.step_num) + if self.mlflow: + self.mlflow.log_loss(loss_dict, self.step_num) return loss_dict def sample(self, step=None, is_first=False): @@ -63,6 +67,9 @@ def sample(self, step=None, is_first=False): if self.wandb: image_paths = [output_dir / p for p in sorted(new_samples)] self.wandb.log_samples(image_paths, step) + if self.mlflow: + image_paths = [output_dir / p for p in sorted(new_samples)] + self.mlflow.log_samples(image_paths, step) self.seen_samples = all_samples def post_save_hook(self, save_path): @@ -75,11 +82,17 @@ def post_save_hook(self, save_path): if self.wandb: print(f"Saving weights to W&B: {lora_path.name}") self.wandb.save_weights(lora_path) + if self.mlflow: + print(f"Saving weights to MLflow: {lora_path.name}") + self.mlflow.save_weights(lora_path) class CustomJob(BaseJob): def __init__( - self, config: OrderedDict, wandb_client: WeightsAndBiasesClient | None + self, + config: OrderedDict, + wandb_client: WeightsAndBiasesClient | None, + mlflow_client: MLflowClient | None = None, ): super().__init__(config) self.device = self.get_conf("device", "cpu") @@ -87,6 +100,7 @@ def __init__( self.load_processes(self.process_dict) for process in self.process: process.wandb = wandb_client + process.mlflow = mlflow_client def run(self): super().run() @@ -204,6 +218,18 @@ def train( default=100, ge=1, ), + mlflow_tracking_uri: str = Input( + description="MLflow tracking server URI, if you'd like to log training progress to MLflow. For example, 'http://localhost:5000' or 'https://your-mlflow-server.com'.", + default=None, + ), + mlflow_experiment_name: str = Input( + description="MLflow experiment name. Only applicable if mlflow_tracking_uri is set.", + default="flux-lora-training", + ), + mlflow_run_name: str = Input( + description="MLflow run name. Only applicable if mlflow_tracking_uri is set.", + default=None, + ), skip_training_and_use_pretrained_hf_lora_url: Optional[str] = Input( description="If you'd like to skip LoRA training altogether and instead create a Replicate model from a pre-trained LoRA that's on HuggingFace, use this field with a HuggingFace download URL. For example, https://huggingface.co/fofr/flux-80s-cyberpunk/resolve/main/lora.safetensors.", default=None, @@ -336,30 +362,46 @@ def train( "only_if_contains": layers_to_optimize } + # Initialize tracking clients wandb_client = None + mlflow_client = None + + # Config dict for both tracking systems + tracking_config = { + "trigger_word": trigger_word, + "autocaption": autocaption, + "autocaption_prefix": autocaption_prefix, + "autocaption_suffix": autocaption_suffix, + "steps": steps, + "learning_rate": learning_rate, + "batch_size": batch_size, + "resolution": resolution, + "lora_rank": lora_rank, + "caption_dropout_rate": caption_dropout_rate, + "optimizer": optimizer, + "gradient_checkpointing": gradient_checkpointing, + "cache_latents_to_disk": cache_latents_to_disk, + } + if wandb_api_key: - wandb_config = { - "trigger_word": trigger_word, - "autocaption": autocaption, - "autocaption_prefix": autocaption_prefix, - "autocaption_suffix": autocaption_suffix, - "steps": steps, - "learning_rate": learning_rate, - "batch_size": batch_size, - "resolution": resolution, - "lora_rank": lora_rank, - "caption_dropout_rate": caption_dropout_rate, - "optimizer": optimizer, - } wandb_client = WeightsAndBiasesClient( api_key=wandb_api_key.get_secret_value(), - config=wandb_config, + config=tracking_config, sample_prompts=sample_prompts, project=wandb_project, entity=wandb_entity, name=wandb_run, ) + if mlflow_tracking_uri: + mlflow_client = MLflowClient( + tracking_uri=mlflow_tracking_uri, + experiment_name=mlflow_experiment_name, + run_name=mlflow_run_name, + config=tracking_config, + sample_prompts=sample_prompts, + ) + download_weights() extract_zip(input_images, INPUT_DIR) @@ -375,12 +417,15 @@ def train( torch.cuda.empty_cache() print("Starting train job") - job = CustomJob(get_config(train_config, name=None), wandb_client) + job = CustomJob(get_config(train_config, name=None), wandb_client, mlflow_client) job.run() if wandb_client: wandb_client.finish() + if mlflow_client: + mlflow_client.finish() + job.cleanup() lora_file = JOB_DIR / f"{JOB_NAME}.safetensors"