|
| 1 | +import sys |
| 2 | +import yaml |
| 3 | +from pathlib import Path |
| 4 | +from ..utils.base_model import BaseModel |
| 5 | +from .. import logger, MODEL_REPO_ID, DEVICE |
| 6 | + |
| 7 | +rdd_path = Path(__file__).parent / "../../third_party/rdd" |
| 8 | +sys.path.append(str(rdd_path)) |
| 9 | + |
| 10 | +from RDD.RDD import build as build_rdd |
| 11 | + |
| 12 | + |
| 13 | +class Rdd(BaseModel): |
| 14 | + default_conf = { |
| 15 | + "keypoint_threshold": 0.1, |
| 16 | + "max_keypoints": 4096, |
| 17 | + "model_name": "RDD-v2.pth", |
| 18 | + } |
| 19 | + |
| 20 | + required_inputs = ["image"] |
| 21 | + |
| 22 | + def _init(self, conf): |
| 23 | + logger.info("Loading RDD model...") |
| 24 | + model_path = self._download_model( |
| 25 | + repo_id=MODEL_REPO_ID, |
| 26 | + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), |
| 27 | + ) |
| 28 | + config_path = rdd_path / "configs/default.yaml" |
| 29 | + with open(config_path, "r") as file: |
| 30 | + config = yaml.safe_load(file) |
| 31 | + config["top_k"] = conf["max_keypoints"] |
| 32 | + config["detection_threshold"] = conf["keypoint_threshold"] |
| 33 | + config["device"] = DEVICE |
| 34 | + self.net = build_rdd(config=config, weights=model_path) |
| 35 | + self.net.eval() |
| 36 | + logger.info("Loading RDD model done!") |
| 37 | + |
| 38 | + def _forward(self, data): |
| 39 | + image = data["image"] |
| 40 | + pred = self.net.extract(image)[0] |
| 41 | + keypoints = pred["keypoints"] |
| 42 | + descriptors = pred["descriptors"] |
| 43 | + scores = pred["scores"] |
| 44 | + if self.conf["max_keypoints"] < len(keypoints): |
| 45 | + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] |
| 46 | + keypoints = keypoints[idxs, :2] |
| 47 | + descriptors = descriptors[idxs] |
| 48 | + scores = scores[idxs] |
| 49 | + |
| 50 | + pred = { |
| 51 | + "keypoints": keypoints[None], |
| 52 | + "descriptors": descriptors[None].permute(0, 2, 1), |
| 53 | + "scores": scores[None], |
| 54 | + } |
| 55 | + return pred |
0 commit comments