Skip to content

Commit a38453a

Browse files
committed
onnx check
1 parent eee6fb1 commit a38453a

1 file changed

Lines changed: 26 additions & 15 deletions

File tree

src/model_manager.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -214,26 +214,37 @@ def load_face_detection_model(self):
214214
if not os.path.exists(local_path):
215215
use_onnx = False
216216

217-
if not use_onnx:
218-
local_path = os.path.join(self.models_dir, model_name_pt)
219-
if not os.path.exists(local_path):
220-
self._download_model('face_detect')
217+
# Try loading ONNX with a smoke test
218+
if use_onnx:
219+
local_path_onnx = os.path.join(self.models_dir, model_name_onnx)
220+
try:
221+
model = YOLO(local_path_onnx, task='detect')
222+
# Smoke test: run a tiny dummy inference to catch runtime errors early
223+
dummy = np.zeros((64, 64, 3), dtype=np.uint8)
224+
device_str = "0" if is_cuda else "cpu"
225+
model.predict(dummy, conf=0.9, verbose=False, device=device_str)
226+
self._models[model_name_onnx] = model
227+
print("ForbiddenVision: Loaded face detection model [ONNX]")
228+
return model
229+
except Exception as e:
230+
print(f"ForbiddenVision: ONNX runtime failed smoke test ({e}), falling back to PyTorch")
231+
self._models.pop(model_name_onnx, None)
232+
use_onnx = False
221233

222-
if not os.path.exists(local_path):
234+
# PyTorch fallback
235+
local_path_pt = os.path.join(self.models_dir, model_name_pt)
236+
if not os.path.exists(local_path_pt):
237+
self._download_model('face_detect')
238+
239+
if not os.path.exists(local_path_pt):
223240
return None
224241

225242
try:
226-
model = YOLO(local_path, task='detect')
227-
228-
if not use_onnx:
229-
model.to(device)
230-
231-
cache_key = model_name_onnx if use_onnx else model_name_pt
232-
self._models[cache_key] = model
233-
fmt = "ONNX" if use_onnx else "PyTorch"
234-
print(f"ForbiddenVision: Loaded face detection model [{fmt}]")
243+
model = YOLO(local_path_pt, task='detect')
244+
model.to(device)
245+
self._models[model_name_pt] = model
246+
print("ForbiddenVision: Loaded face detection model [PyTorch]")
235247
return model
236-
237248
except Exception as e:
238249
print(f"ForbiddenVision: Error loading face detection model: {e}")
239250
return None

0 commit comments

Comments
 (0)