|
| 1 | +import sys |
| 2 | +from pathlib import Path |
| 3 | +from ..utils.base_model import BaseModel |
| 4 | +from .. import logger |
| 5 | + |
| 6 | +fire_path = Path(__file__).parent / "../../third_party/LiftFeat" |
| 7 | +sys.path.append(str(fire_path)) |
| 8 | + |
| 9 | +from models.liftfeat_wrapper import LiftFeat, MODEL_PATH |
| 10 | + |
| 11 | + |
| 12 | +class Liftfeat(BaseModel): |
| 13 | + default_conf = { |
| 14 | + "keypoint_threshold": 0.05, |
| 15 | + "max_keypoints": 5000, |
| 16 | + } |
| 17 | + |
| 18 | + required_inputs = ["image"] |
| 19 | + |
| 20 | + def _init(self, conf): |
| 21 | + logger.info("Loading LiftFeat model...") |
| 22 | + self.net = LiftFeat( |
| 23 | + weight=MODEL_PATH, |
| 24 | + detect_threshold=self.conf["keypoint_threshold"], |
| 25 | + top_k=self.conf["max_keypoints"], |
| 26 | + ) |
| 27 | + logger.info("Loading LiftFeat model done!") |
| 28 | + |
| 29 | + def _forward(self, data): |
| 30 | + image = data["image"].cpu().numpy().squeeze() * 255 |
| 31 | + image = image.transpose(1, 2, 0) |
| 32 | + pred = self.net.extract(image) |
| 33 | + |
| 34 | + keypoints = pred["keypoints"] |
| 35 | + descriptors = pred["descriptors"] |
| 36 | + scores = pred["scores"] |
| 37 | + if self.conf["max_keypoints"] < len(keypoints): |
| 38 | + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] |
| 39 | + keypoints = keypoints[idxs, :2] |
| 40 | + descriptors = descriptors[idxs] |
| 41 | + scores = scores[idxs] |
| 42 | + |
| 43 | + pred = { |
| 44 | + "keypoints": keypoints[None], |
| 45 | + "descriptors": descriptors[None].permute(0, 2, 1), |
| 46 | + "scores": scores[None], |
| 47 | + } |
| 48 | + return pred |
0 commit comments