Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

from . import extractors, logger
from .utils.base_model import dynamic_load
from .utils.inference_device import (
dataloader_num_workers_extract,
dataloader_pin_memory,
select_inference_device,
tensor_non_blocking,
)
from .utils.io import list_h5_names, read_image
from .utils.parsers import parse_image_lists

Expand Down Expand Up @@ -276,16 +282,21 @@ def main(
logger.info("Skipping the extraction.")
return feature_path

device = "cuda" if torch.cuda.is_available() else "cpu"
device = select_inference_device()
logger.info("Using device %s for feature extraction", device)
Model = dynamic_load(extractors, conf["model"]["name"])
model = Model(conf["model"]).eval().to(device)

nb = tensor_non_blocking(device)
loader = torch.utils.data.DataLoader(
dataset, num_workers=1, shuffle=False, pin_memory=True
dataset,
num_workers=dataloader_num_workers_extract(device),
shuffle=False,
pin_memory=dataloader_pin_memory(device),
)
for idx, data in enumerate(tqdm(loader)):
name = dataset.names[idx]
pred = model({"image": data["image"].to(device, non_blocking=True)})
pred = model({"image": data["image"].to(device, non_blocking=nb)})
pred = {k: v[0].cpu().numpy() for k, v in pred.items()}

pred["image_size"] = original_size = data["original_size"][0].numpy()
Expand Down
14 changes: 12 additions & 2 deletions hloc/match_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from .extract_features import read_image, resize_image
from .match_features import find_unique_new_pairs
from .utils.base_model import dynamic_load
from .utils.inference_device import (
dataloader_num_workers_dense,
dataloader_pin_memory,
select_inference_device,
)
from .utils.io import list_h5_names
from .utils.parsers import names_to_pair, parse_retrieval

Expand Down Expand Up @@ -236,13 +241,18 @@ def match_dense(
match_path: Path, # out
existing_refs: Optional[List] = [],
):
device = "cuda" if torch.cuda.is_available() else "cpu"
device = select_inference_device()
logger.info("Using device %s for dense matching", device)
Model = dynamic_load(matchers, conf["model"]["name"])
model = Model(conf["model"]).eval().to(device)

dataset = ImagePairDataset(image_dir, conf["preprocessing"], pairs)
loader = torch.utils.data.DataLoader(
dataset, num_workers=16, batch_size=1, shuffle=False
dataset,
num_workers=dataloader_num_workers_dense(device),
batch_size=1,
shuffle=False,
pin_memory=dataloader_pin_memory(device),
)

logger.info("Performing dense matching...")
Expand Down
18 changes: 15 additions & 3 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@

from . import logger, matchers
from .utils.base_model import dynamic_load
from .utils.inference_device import (
dataloader_num_workers_match,
dataloader_pin_memory,
select_inference_device,
tensor_non_blocking,
)
from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval

"""
Expand Down Expand Up @@ -241,19 +247,25 @@ def match_from_paths(
logger.info("Skipping the matching.")
return

device = "cuda" if torch.cuda.is_available() else "cpu"
device = select_inference_device()
logger.info("Using device %s for feature matching", device)
Model = dynamic_load(matchers, conf["model"]["name"])
model = Model(conf["model"]).eval().to(device)

nb = tensor_non_blocking(device)
dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref)
loader = torch.utils.data.DataLoader(
dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True
dataset,
num_workers=dataloader_num_workers_match(device),
batch_size=1,
shuffle=False,
pin_memory=dataloader_pin_memory(device),
)
writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5)

for idx, data in enumerate(tqdm(loader, smoothing=0.1)):
data = {
k: v if k.startswith("image") else v.to(device, non_blocking=True)
k: v if k.startswith("image") else v.to(device, non_blocking=nb)
for k, v in data.items()
}
pred = model(data)
Expand Down
9 changes: 7 additions & 2 deletions hloc/matchers/lightglue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lightglue import LightGlue as LightGlue_

from ..utils.base_model import BaseModel

from ..utils.inference_device import select_inference_device

class LightGlue(BaseModel):
compiled = None
Expand All @@ -24,7 +24,12 @@ def _init(self, conf):
if conf.get("compile_network", False):
if not LightGlue.compiled:
LightGlue.compiled = LightGlue_(conf.pop("features"), **conf)
LightGlue.compiled = LightGlue.compiled.eval().cuda()
LightGlue.compiled = LightGlue.compiled.eval()
_device = select_inference_device()
if _device == "cuda":
LightGlue.compiled = LightGlue.compiled.cuda()
else:
LightGlue.compiled = LightGlue.compiled.to(_device)
LightGlue.compiled.compile()
self.net = LightGlue.compiled
else:
Expand Down
4 changes: 3 additions & 1 deletion hloc/pairs_from_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from . import logger
from .utils.inference_device import select_inference_device
from .utils.io import list_h5_names
from .utils.parsers import parse_image_lists
from .utils.read_write_model import read_images_binary
Expand Down Expand Up @@ -108,7 +109,8 @@ def main(
raise ValueError("Could not find any database image.")
query_names = parse_names(query_prefix, query_list, query_names_h5)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = select_inference_device()
logger.info("Using device %s for retrieval pairing", device)
db_desc = get_descriptors(db_names, db_descriptors, name2db)
query_desc = get_descriptors(query_names, descriptors)
sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device))
Expand Down
72 changes: 72 additions & 0 deletions hloc/utils/inference_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Select inference device for hloc (CUDA, MPS on Apple Silicon, or CPU).

Override with env ``HLOC_DEVICE`` = ``cuda`` | ``mps`` | ``cpu`` (case-insensitive).
See discussion in https://github.com/cvg/LightGlue/issues (MPS in user scripts) and
https://github.com/cvg/Hierarchical-Localization/issues/491 (OpenMP / workers on Mac).
"""

from __future__ import annotations

import os
from contextlib import contextmanager

import torch

@contextmanager
def use_hloc_device(forced: str | None):
"""Temporarily set ``HLOC_DEVICE`` (``cpu`` | ``mps`` | ``cuda``) for hloc imports."""
if not forced:
yield
return
prev = os.environ.get("HLOC_DEVICE")
os.environ["HLOC_DEVICE"] = forced
try:
yield
finally:
if prev is None:
os.environ.pop("HLOC_DEVICE", None)
else:
os.environ["HLOC_DEVICE"] = prev

def select_inference_device() -> str:
forced = os.environ.get("HLOC_DEVICE", "").strip().lower()
if forced:
if forced == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("HLOC_DEVICE=cuda but torch.cuda.is_available() is False")
return "cuda"
if forced == "mps":
mps = getattr(torch.backends, "mps", None)
if mps is None or not mps.is_available():
raise RuntimeError("HLOC_DEVICE=mps but MPS is not available")
return "mps"
if forced == "cpu":
return "cpu"
raise ValueError(f"Invalid HLOC_DEVICE={forced!r} (expected cuda, mps, or cpu)")

if torch.cuda.is_available():
return "cuda"
mps = getattr(torch.backends, "mps", None)
if mps is not None and mps.is_available():
return "mps"
return "cpu"


def tensor_non_blocking(device: str) -> bool:
return device == "cuda"


def dataloader_pin_memory(device: str) -> bool:
return device == "cuda"


def dataloader_num_workers_extract(device: str) -> int:
return 0 if device in ("mps", "cpu") else 1


def dataloader_num_workers_match(device: str) -> int:
return 0 if device in ("mps", "cpu") else 5


def dataloader_num_workers_dense(device: str) -> int:
return 0 if device in ("mps", "cpu") else 16