diff --git a/landmarkdiff/inference.py b/landmarkdiff/inference.py index e7f827d..4c8cb82 100644 --- a/landmarkdiff/inference.py +++ b/landmarkdiff/inference.py @@ -361,6 +361,14 @@ def generate( else: composited = mask_composite(raw_output, image_512, mask) + # confidence scoring and breakdown + confidence_data = self._calculate_confidence( + face=face, + identity_check=identity_check, + mask=mask, + mode=self.mode, + ) + return { "output": composited, "output_raw": raw_output, @@ -378,6 +386,56 @@ def generate( "ip_adapter_active": self._ip_adapter_loaded, "identity_check": identity_check, "restore_used": restore_used, + "confidence": confidence_data["confidence"], + "confidence_breakdown": confidence_data["breakdown"], + } + + def _calculate_confidence( + self, + face: Optional[FaceLandmarks], + identity_check: Optional[dict], + mask: np.ndarray, + mode: str, + ) -> dict: + """Aggregate multiple metrics into a unified confidence score (0-1).""" + breakdown = { + "face_detection": 1.0 if face is not None else 0.0, + "identity_preservation": 1.0, + "landmark_accuracy": 0.95, # baseline for MediaPipe/TPS + "mask_coverage": 1.0, + } + + # identity score from ArcFace (if available) + if identity_check and identity_check.get("similarity", -1) >= 0: + breakdown["identity_preservation"] = float(identity_check["similarity"]) + elif mode == "tps": + breakdown["identity_preservation"] = 1.0 # TPS is identity-perfect + else: + # without verification, we assume a lower bound for diffusion + breakdown["identity_preservation"] = 0.85 if mode == "controlnet_ip" else 0.7 + + # mask coverage (ensure mask isn't empty or oversized) + mask_area = np.mean(mask) + if mask_area < 0.01: # too small + breakdown["mask_coverage"] = 0.5 + elif mask_area > 0.6: # too large (covers most of face) + breakdown["mask_coverage"] = 0.8 + else: + breakdown["mask_coverage"] = 1.0 + + # overall confidence (weighted average) + weights = { + "face_detection": 0.2, + "identity_preservation": 0.5, + "landmark_accuracy": 0.2, + "mask_coverage": 0.1, + } + + confidence = sum(breakdown[k] * weights[k] for k in weights) + + return { + "confidence": round(confidence, 2), + "breakdown": {k: round(v, 2) for k, v in breakdown.items()}, } def _generate_controlnet( @@ -471,6 +529,7 @@ def run_inference( seed: int = 42, mode: str = "img2img", ip_adapter_scale: float = 0.6, + explain: bool = False, ) -> None: out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) @@ -500,6 +559,13 @@ def run_inference( if view.get("warning"): print(f"WARNING: {view['warning']}") print(f"Face view: {view.get('view', 'unknown')} (yaw={view.get('yaw', 0)})") + + print(f"Confidence: {result['confidence']:.2f}") + if explain: + print("\nConfidence Breakdown:") + for k, v in result["confidence_breakdown"].items(): + print(f" - {k.replace('_', ' ').capitalize()}: {v:.2f}") + print(f"Results saved to {out}/") @@ -517,9 +583,10 @@ def run_inference( choices=["img2img", "controlnet", "controlnet_ip", "tps"], ) parser.add_argument("--ip-adapter-scale", type=float, default=0.6) + parser.add_argument("--explain", action="store_true", help="Show confidence breakdown") args = parser.parse_args() run_inference( args.image, args.procedure, args.intensity, args.output, - args.seed, args.mode, args.ip_adapter_scale, + args.seed, args.mode, args.ip_adapter_scale, args.explain, )