Skip to content

Commit 765373e

Browse files
authored
update: download liftfeat models (#131)
1 parent d9d04b5 commit 765373e

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

imcui/hloc/extractors/liftfeat.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
import sys
22
from pathlib import Path
33
from ..utils.base_model import BaseModel
4-
from .. import logger
4+
from .. import logger, MODEL_REPO_ID
55

66
fire_path = Path(__file__).parent / "../../third_party/LiftFeat"
77
sys.path.append(str(fire_path))
88

9-
from models.liftfeat_wrapper import LiftFeat, MODEL_PATH
9+
from models.liftfeat_wrapper import LiftFeat
1010

1111

1212
class Liftfeat(BaseModel):
1313
default_conf = {
1414
"keypoint_threshold": 0.05,
1515
"max_keypoints": 5000,
16+
"model_name": "LiftFeat.pth",
1617
}
1718

1819
required_inputs = ["image"]
1920

2021
def _init(self, conf):
2122
logger.info("Loading LiftFeat model...")
23+
model_path = self._download_model(
24+
repo_id=MODEL_REPO_ID,
25+
filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
26+
)
2227
self.net = LiftFeat(
23-
weight=MODEL_PATH,
28+
weight=model_path,
2429
detect_threshold=self.conf["keypoint_threshold"],
2530
top_k=self.conf["max_keypoints"],
2631
)

0 commit comments

Comments
 (0)