diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 5d54e34eaa06..fe1a90a6af55 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -687,6 +687,106 @@ def parse_args(input_args=None): "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." ) + # Additional comprehensive parameter validation + if args.learning_rate <= 0: + raise ValueError("`--learning_rate` must be positive") + + if args.train_batch_size <= 0: + raise ValueError("`--train_batch_size` must be positive") + + if args.num_train_epochs <= 0: + raise ValueError("`--num_train_epochs` must be positive") + + if args.gradient_accumulation_steps <= 0: + raise ValueError("`--gradient_accumulation_steps` must be positive") + + if args.max_train_steps is not None and args.max_train_steps <= 0: + raise ValueError("`--max_train_steps` must be positive when specified") + + if args.checkpointing_steps <= 0: + raise ValueError("`--checkpointing_steps` must be positive") + + if args.validation_steps <= 0: + raise ValueError("`--validation_steps` must be positive") + + if args.num_validation_images <= 0: + raise ValueError("`--num_validation_images` must be positive") + + if args.lr_warmup_steps < 0: + raise ValueError("`--lr_warmup_steps` must be non-negative") + + if args.lr_num_cycles <= 0: + raise ValueError("`--lr_num_cycles` must be positive") + + if args.lr_power <= 0: + raise ValueError("`--lr_power` must be positive") + + if args.dataloader_num_workers < 0: + raise ValueError("`--dataloader_num_workers` must be non-negative") + + if not (0.0 <= args.adam_beta1 < 1.0): + raise ValueError("`--adam_beta1` must be in the range [0.0, 1.0)") + + if not (0.0 <= args.adam_beta2 < 1.0): + raise ValueError("`--adam_beta2` must be in the range [0.0, 1.0)") + + if args.adam_weight_decay < 0: + raise ValueError("`--adam_weight_decay` must be non-negative") + + if args.adam_epsilon <= 0: + raise ValueError("`--adam_epsilon` must be positive") + + if args.max_grad_norm <= 0: + raise ValueError("`--max_grad_norm` must be positive") + + if args.max_train_samples is not None and args.max_train_samples <= 0: + raise ValueError("`--max_train_samples` must be positive when specified") + + if args.num_double_layers <= 0: + raise ValueError("`--num_double_layers` must be positive") + + if args.num_single_layers <= 0: + raise ValueError("`--num_single_layers` must be positive") + + if args.guidance_scale < 0: + raise ValueError("`--guidance_scale` must be non-negative") + + if args.logit_std <= 0: + raise ValueError("`--logit_std` must be positive") + + if args.mode_scale <= 0: + raise ValueError("`--mode_scale` must be positive") + + if args.checkpoints_total_limit is not None and args.checkpoints_total_limit <= 0: + raise ValueError("`--checkpoints_total_limit` must be positive when specified") + + # Validate resolution is reasonable (not too small or absurdly large) + if args.resolution < 64: + raise ValueError("`--resolution` must be at least 64 pixels") + + if args.resolution > 4096: + raise ValueError("`--resolution` should not exceed 4096 pixels for memory efficiency") + + # Validate crop coordinates are non-negative + if args.crops_coords_top_left_h < 0: + raise ValueError("`--crops_coords_top_left_h` must be non-negative") + + if args.crops_coords_top_left_w < 0: + raise ValueError("`--crops_coords_top_left_w` must be non-negative") + + # Warn about potentially problematic combinations + if args.gradient_accumulation_steps > 1 and args.train_batch_size > 32: + logger.warning( + f"Large batch size ({args.train_batch_size}) with gradient accumulation ({args.gradient_accumulation_steps}) " + "may cause memory issues. Consider reducing batch size or gradient accumulation steps." + ) + + if args.learning_rate > 1e-2: + logger.warning( + f"Learning rate ({args.learning_rate}) is quite high. This may cause training instability. " + "Consider using a lower learning rate (e.g., 1e-4 to 1e-5)." + ) + return args