diff --git a/README.md b/README.md index dc79cdb..8052f25 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ cd models # SuperPoint. git clone https://github.com/rpautrat/SuperPoint.git +mv SuperPoint/weights/superpoint_v6_from_tf.pth . mv SuperPoint/pretrained_models/sp_v6.tgz . && rm -rf SuperPoint tar zxvf sp_v6.tgz && rm sp_v6.tgz diff --git a/src/omniglue/dino_extract.py b/src/omniglue/dino_extract.py index 57f659a..8c46fb3 100644 --- a/src/omniglue/dino_extract.py +++ b/src/omniglue/dino_extract.py @@ -21,12 +21,11 @@ import tensorflow as tf import torch - class DINOExtract: """Class to initialize DINO model and extract features from an image.""" - def __init__(self, cpt_path: str, feature_layer: int = 1): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def __init__(self, cpt_path: str, feature_layer: int = 1, device = torch.device("cuda")): + self.device = device self.feature_layer = feature_layer self.model = dino.vit_base() state_dict_raw = torch.load(cpt_path, map_location='cpu') diff --git a/src/omniglue/omniglue_extract.py b/src/omniglue/omniglue_extract.py index e4b901e..0dc0d2f 100644 --- a/src/omniglue/omniglue_extract.py +++ b/src/omniglue/omniglue_extract.py @@ -18,142 +18,139 @@ from omniglue import dino_extract from omniglue import superpoint_extract from omniglue import utils +from omniglue.utils import get_device_framework import tensorflow as tf DINO_FEATURE_DIM = 768 MATCH_THRESHOLD = 1e-3 - class OmniGlue: - # TODO(omniglue): class docstring + # TODO: Add class docstring + + def __init__(self, og_export: str, sp_export: str | None = None, dino_export: str | None = None) -> None: + """Initialize OmniGlue with specified exports.""" + + tf_device, torch_device = get_device_framework() + if tf_device != "GPU": + with tf.device(tf_device): + self.matcher = tf.saved_model.load(og_export) + + self.matcher = tf.saved_model.load(og_export) - def __init__( - self, - og_export: str, - sp_export: str | None = None, - dino_export: str | None = None, - ) -> None: - self.matcher = tf.saved_model.load(og_export) - if sp_export is not None: - self.sp_extract = superpoint_extract.SuperPointExtract(sp_export) - if dino_export is not None: - self.dino_extract = dino_extract.DINOExtract(dino_export, feature_layer=1) + if sp_export is not None: + if sp_export.endswith((".pth", "pt")): + self.sp_extract = superpoint_extract.SuperPointExtract_Pytorch(sp_export, torch_device) + else: + self.sp_extract = superpoint_extract.SuperPointExtract(sp_export, tf_device) + + if dino_export is not None: + self.dino_extract = dino_extract.DINOExtract(dino_export, feature_layer=1, device=torch_device) - def FindMatches(self, image0: np.ndarray, image1: np.ndarray): - """TODO(omniglue): docstring.""" - height0, width0 = image0.shape[:2] - height1, width1 = image1.shape[:2] - sp_features0 = self.sp_extract(image0) - sp_features1 = self.sp_extract(image1) - dino_features0 = self.dino_extract(image0) - dino_features1 = self.dino_extract(image1) - dino_descriptors0 = dino_extract.get_dino_descriptors( - dino_features0, - tf.convert_to_tensor(sp_features0[0], dtype=tf.float32), - tf.convert_to_tensor(height0, dtype=tf.int32), - tf.convert_to_tensor(width0, dtype=tf.int32), - DINO_FEATURE_DIM, - ) - dino_descriptors1 = dino_extract.get_dino_descriptors( - dino_features1, - tf.convert_to_tensor(sp_features1[0], dtype=tf.float32), - tf.convert_to_tensor(height1, dtype=tf.int32), - tf.convert_to_tensor(width1, dtype=tf.int32), - DINO_FEATURE_DIM, - ) + def FindMatches(self, image0: np.ndarray, image1: np.ndarray): + """Find matches between two images using SP and DINO features.""" + height0, width0 = image0.shape[:2] + height1, width1 = image1.shape[:2] + + sp_features0 = self.sp_extract(image0) + sp_features1 = self.sp_extract(image1) + + dino_features0 = self.dino_extract(image0) + dino_features1 = self.dino_extract(image1) + + dino_descriptors0 = dino_extract.get_dino_descriptors( + dino_features0, + tf.convert_to_tensor(sp_features0[0], dtype=tf.float32), + tf.convert_to_tensor(height0, dtype=tf.int32), + tf.convert_to_tensor(width0, dtype=tf.int32), + DINO_FEATURE_DIM, + ) + + dino_descriptors1 = dino_extract.get_dino_descriptors( + dino_features1, + tf.convert_to_tensor(sp_features1[0], dtype=tf.float32), + tf.convert_to_tensor(height1, dtype=tf.int32), + tf.convert_to_tensor(width1, dtype=tf.int32), + DINO_FEATURE_DIM, + ) - inputs = self._construct_inputs( - width0, - height0, - width1, - height1, - sp_features0, - sp_features1, - dino_descriptors0, - dino_descriptors1, - ) + inputs = self._construct_inputs( + width0, height0, width1, height1, + sp_features0, sp_features1, + dino_descriptors0, dino_descriptors1 + ) - og_outputs = self.matcher.signatures['serving_default'](**inputs) - soft_assignment = og_outputs['soft_assignment'][:, :-1, :-1] + og_outputs = self.matcher.signatures['serving_default'](**inputs) + soft_assignment = og_outputs['soft_assignment'][:, :-1, :-1] - match_matrix = ( - utils.soft_assignment_to_match_matrix(soft_assignment, MATCH_THRESHOLD) - .numpy() - .squeeze() - ) + match_matrix = ( + utils.soft_assignment_to_match_matrix(soft_assignment, MATCH_THRESHOLD) + .numpy() + .squeeze() + ) - # Filter out any matches with 0.0 confidence keypoints. - match_indices = np.argwhere(match_matrix) - keep = [] - for i in range(match_indices.shape[0]): - match = match_indices[i, :] - if (sp_features0[2][match[0]] > 0.0) and ( - sp_features1[2][match[1]] > 0.0 - ): - keep.append(i) - match_indices = match_indices[keep] + # Filter out any matches with 0.0 confidence keypoints. + match_indices = np.argwhere(match_matrix) + keep = [] + for i in range(match_indices.shape[0]): + match = match_indices[i, :] + if (sp_features0[2][match[0]] > 0.0) and (sp_features1[2][match[1]] > 0.0): + keep.append(i) + match_indices = match_indices[keep] - # Format matches in terms of keypoint locations. - match_kp0s = [] - match_kp1s = [] - match_confidences = [] - for match in match_indices: - match_kp0s.append(sp_features0[0][match[0], :]) - match_kp1s.append(sp_features1[0][match[1], :]) - match_confidences.append(soft_assignment[0, match[0], match[1]]) - match_kp0s = np.array(match_kp0s) - match_kp1s = np.array(match_kp1s) - match_confidences = np.array(match_confidences) - return match_kp0s, match_kp1s, match_confidences + # Format matches in terms of keypoint locations. + match_kp0s = [] + match_kp1s = [] + match_confidences = [] + for match in match_indices: + match_kp0s.append(sp_features0[0][match[0], :]) + match_kp1s.append(sp_features1[0][match[1], :]) + match_confidences.append(soft_assignment[0, match[0], match[1]]) + + match_kp0s = np.array(match_kp0s) + match_kp1s = np.array(match_kp1s) + match_confidences = np.array(match_confidences) + + return match_kp0s, match_kp1s, match_confidences - ### Private methods ### + ### Private methods ### - def _construct_inputs( - self, - width0, - height0, - width1, - height1, - sp_features0, - sp_features1, - dino_descriptors0, - dino_descriptors1, - ): - inputs = { - 'keypoints0': tf.convert_to_tensor( - np.expand_dims(sp_features0[0], axis=0), - dtype=tf.float32, - ), - 'keypoints1': tf.convert_to_tensor( - np.expand_dims(sp_features1[0], axis=0), dtype=tf.float32 - ), - 'descriptors0': tf.convert_to_tensor( - np.expand_dims(sp_features0[1], axis=0), dtype=tf.float32 - ), - 'descriptors1': tf.convert_to_tensor( - np.expand_dims(sp_features1[1], axis=0), dtype=tf.float32 - ), - 'scores0': tf.convert_to_tensor( - np.expand_dims(np.expand_dims(sp_features0[2], axis=0), axis=-1), - dtype=tf.float32, - ), - 'scores1': tf.convert_to_tensor( - np.expand_dims(np.expand_dims(sp_features1[2], axis=0), axis=-1), - dtype=tf.float32, - ), - 'descriptors0_dino': tf.expand_dims(dino_descriptors0, axis=0), - 'descriptors1_dino': tf.expand_dims(dino_descriptors1, axis=0), - 'width0': tf.convert_to_tensor( - np.expand_dims(width0, axis=0), dtype=tf.int32 - ), - 'width1': tf.convert_to_tensor( - np.expand_dims(width1, axis=0), dtype=tf.int32 - ), - 'height0': tf.convert_to_tensor( - np.expand_dims(height0, axis=0), dtype=tf.int32 - ), - 'height1': tf.convert_to_tensor( - np.expand_dims(height1, axis=0), dtype=tf.int32 - ), - } - return inputs + def _construct_inputs(self, width0, height0, width1, height1, sp_features0, sp_features1, dino_descriptors0, dino_descriptors1): + """Construct input dictionary for the model.""" + inputs = { + 'keypoints0': tf.convert_to_tensor( + np.expand_dims(sp_features0[0], axis=0), + dtype=tf.float32, + ), + 'keypoints1': tf.convert_to_tensor( + np.expand_dims(sp_features1[0], axis=0), dtype=tf.float32 + ), + 'descriptors0': tf.convert_to_tensor( + np.expand_dims(sp_features0[1], axis=0), dtype=tf.float32 + ), + 'descriptors1': tf.convert_to_tensor( + np.expand_dims(sp_features1[1], axis=0), dtype=tf.float32 + ), + 'scores0': tf.convert_to_tensor( + np.expand_dims(np.expand_dims(sp_features0[2], axis=0), axis=-1), + dtype=tf.float32, + ), + 'scores1': tf.convert_to_tensor( + np.expand_dims(np.expand_dims(sp_features1[2], axis=0), axis=-1), + dtype=tf.float32, + ), + 'descriptors0_dino': tf.expand_dims(dino_descriptors0, axis=0), + 'descriptors1_dino': tf.expand_dims(dino_descriptors1, axis=0), + 'width0': tf.convert_to_tensor( + np.expand_dims(width0, axis=0), dtype=tf.int32 + ), + 'width1': tf.convert_to_tensor( + np.expand_dims(width1, axis=0), dtype=tf.int32 + ), + 'height0': tf.convert_to_tensor( + np.expand_dims(height0, axis=0), dtype=tf.int32 + ), + 'height1': tf.convert_to_tensor( + np.expand_dims(height1, axis=0), dtype=tf.int32 + ), + } + return inputs diff --git a/src/omniglue/superpoint_extract.py b/src/omniglue/superpoint_extract.py index 547a9fb..da49ff8 100644 --- a/src/omniglue/superpoint_extract.py +++ b/src/omniglue/superpoint_extract.py @@ -19,9 +19,11 @@ import cv2 import numpy as np -from omniglue import utils +import tensorflow as tf import tensorflow.compat.v1 as tf1 - +import torch +from omniglue import utils +from omniglue import superpoint_pytorch class SuperPointExtract: """Class to initialize SuperPoint model and extract features from an image. @@ -33,13 +35,21 @@ class SuperPointExtract: model_path: string, filepath to saved SuperPoint TF1 model weights. """ - def __init__(self, model_path: str): + def __init__(self, model_path: str, device): self.model_path = model_path - self._graph = tf1.Graph() - self._sess = tf1.Session(graph=self._graph) - tf1.saved_model.loader.load( - self._sess, [tf1.saved_model.tag_constants.SERVING], model_path - ) + if device != "GPU": + with tf.device(device): + self._graph = tf1.Graph() + self._sess = tf1.Session(graph=self._graph) + tf1.saved_model.loader.load( + self._sess, [tf1.saved_model.tag_constants.SERVING], model_path + ) + else: + self._graph = tf1.Graph() + self._sess = tf1.Session(graph=self._graph) + tf1.saved_model.loader.load( + self._sess, [tf1.saved_model.tag_constants.SERVING], model_path + ) def __call__( self, @@ -212,3 +222,63 @@ def _select_k_best(points, k): descriptors.append(utils.lookup_descriptor_bilinear(kp, descriptor_map)) descriptors = np.array(descriptors) return keypoints, descriptors, scores + + +class SuperPointExtract_Pytorch: + """Class to initialize SuperPoint model and extract features from an image. + + To stay consistent with SuperPoint training and eval configurations, resize + images to (320x240) or (640x480). + """ + + def __init__(self, model_path: str, device): + """Initialize the SuperPoint model.""" + self.device = device + self.model_path = model_path + self.model = superpoint_pytorch.SuperPoint() + self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) + self.model = self.model.to(device=self.device) + self.model.eval() + + def __call__(self, image): + """Extract features from the image.""" + return self._extract_features(image) + + def _preprocess_image(self, image): + """Convert image to grayscale and normalize values for model input.""" + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + image = np.expand_dims(image, 2) + image = image.astype(np.float32) + image = image / 255.0 + return image + + def _resize_input_image(self, image, interpolation=cv2.INTER_LINEAR): + """Resize image such that both dimensions are divisible by 8.""" + new_dim = [-1, -1] + keypoint_scale_factors = [1.0, 1.0] + for i in range(2): + dim_size = image.shape[i] + mod_eight = dim_size % 8 + if mod_eight < 4: + new_dim[i] = dim_size - mod_eight # Round down to nearest multiple of 8. + elif mod_eight >= 4: + new_dim[i] = dim_size + (8 - mod_eight) # Round up to nearest multiple of 8. + keypoint_scale_factors[i] = (new_dim[i] - 1) / (dim_size - 1) + + new_dim = new_dim[::-1] # Convert from (row, col) to (x,y). + keypoint_scale_factors = keypoint_scale_factors[::-1] + image = cv2.resize(image, tuple(new_dim), interpolation=interpolation) + return image, keypoint_scale_factors + + def _extract_features(self, image): + """Extract keypoints, descriptors, and scores from the image.""" + image, keypoint_scale_factors = self._resize_input_image(image) + image_preprocessed = self._preprocess_image(image) + image_tensor = torch.from_numpy(image_preprocessed.transpose(2, 0, 1)[None]).float().to(self.device) # [1, C, H, W] + + with torch.no_grad(): + pred = self.model({'image': image_tensor}) + + keypoints, descriptors, scores = [p[0].cpu().numpy().astype("float64") for p in pred] + keypoints = keypoints / keypoint_scale_factors + return keypoints, descriptors, scores diff --git a/src/omniglue/superpoint_pytorch.py b/src/omniglue/superpoint_pytorch.py new file mode 100644 index 0000000..aa958eb --- /dev/null +++ b/src/omniglue/superpoint_pytorch.py @@ -0,0 +1,175 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Reference: https://github.com/rpautrat/SuperPoint/blob/master/superpoint_pytorch.py + +"""PyTorch implementation of the SuperPoint model.""" + +import numpy as np +import cv2 +import torch +import torch.nn as nn +from collections import OrderedDict +from types import SimpleNamespace + +def sample_descriptors(keypoints, descriptors, s: int = 8): + """Interpolate descriptors at keypoint locations.""" + b, c, h, w = descriptors.shape + keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s) + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + +def batched_nms(scores, nms_radius: int): + """Perform batched non-maximum suppression.""" + assert nms_radius >= 0 + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + +def select_top_k_keypoints(keypoints, scores, k): + """Select top k keypoints based on scores.""" + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0, sorted=True) + return keypoints[indices], scores + +class VGGBlock(nn.Sequential): + """VGG block used in the SuperPoint model.""" + def __init__(self, c_in, c_out, kernel_size, relu=True): + padding = (kernel_size - 1) // 2 + conv = nn.Conv2d( + c_in, c_out, kernel_size=kernel_size, stride=1, padding=padding + ) + activation = nn.ReLU(inplace=True) if relu else nn.Identity() + bn = nn.BatchNorm2d(c_out, eps=0.001) + super().__init__( + OrderedDict( + [ + ("conv", conv), + ("activation", activation), + ("bn", bn), + ] + ) + ) + +class SuperPoint(nn.Module): + """SuperPoint model definition.""" + default_conf = { + "nms_radius": 4, + "max_num_keypoints": 1024, + "detection_threshold": 0.005, + "remove_borders": 4, + "descriptor_dim": 256, + "channels": [64, 64, 128, 128, 256], + } + + def __init__(self, **conf): + super().__init__() + conf = {**self.default_conf, **conf} + self.conf = SimpleNamespace(**conf) + self.stride = 2 ** (len(self.conf.channels) - 2) + channels = [1, *self.conf.channels[:-1]] + + backbone = [] + for i, c in enumerate(channels[1:], 1): + layers = [VGGBlock(channels[i - 1], c, 3), VGGBlock(c, c, 3)] + if i < len(channels) - 1: + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + backbone.append(nn.Sequential(*layers)) + self.backbone = nn.Sequential(*backbone) + + c = self.conf.channels[-1] + self.detector = nn.Sequential( + VGGBlock(channels[-1], c, 3), + VGGBlock(c, self.stride**2 + 1, 1, relu=False), + ) + self.descriptor = nn.Sequential( + VGGBlock(channels[-1], c, 3), + VGGBlock(c, self.conf.descriptor_dim, 1, relu=False), + ) + + def forward(self, data): + """Forward pass for the SuperPoint model.""" + image = data["image"] + + features = self.backbone(image) + descriptors_dense = torch.nn.functional.normalize( + self.descriptor(features), p=2, dim=1 + ) + + # Decode the detection scores + scores = self.detector(features) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, self.stride, self.stride) + scores = scores.permute(0, 1, 3, 2, 4).reshape( + b, h * self.stride, w * self.stride + ) + scores = batched_nms(scores, self.conf.nms_radius) + + # Discard keypoints near the image borders + if self.conf.remove_borders: + pad = self.conf.remove_borders + scores[:, :pad] = -1 + scores[:, :, :pad] = -1 + scores[:, -pad:] = -1 + scores[:, :, -pad:] = -1 + + # Extract keypoints + if b > 1: + idxs = torch.where(scores > self.conf.detection_threshold) + mask = idxs[0] == torch.arange(b, device=scores.device)[:, None] + else: # Faster shortcut + scores = scores.squeeze(0) + idxs = torch.where(scores > self.conf.detection_threshold) + + # Convert (i, j) to (x, y) + keypoints_all = torch.stack(idxs[-2:], dim=-1).flip(1).float() + scores_all = scores[idxs] + + keypoints = [] + scores = [] + descriptors = [] + for i in range(b): + if b > 1: + k = keypoints_all[mask[i]] + s = scores_all[mask[i]] + else: + k = keypoints_all + s = scores_all + if self.conf.max_num_keypoints is not None: + k, s = select_top_k_keypoints(k, s, self.conf.max_num_keypoints) + d = sample_descriptors(k[None], descriptors_dense[i, None], self.stride) + keypoints.append(k) + scores.append(s) + descriptors.append(d.squeeze(0).transpose(0, 1)) + + return keypoints, descriptors, scores diff --git a/src/omniglue/utils.py b/src/omniglue/utils.py index fcec00a..683203c 100644 --- a/src/omniglue/utils.py +++ b/src/omniglue/utils.py @@ -19,7 +19,7 @@ import cv2 import numpy as np import tensorflow as tf - +import torch def lookup_descriptor_bilinear( keypoint: np.ndarray, descriptor_map: np.ndarray @@ -272,3 +272,28 @@ def visualize_matches( cv2.LINE_AA, ) return viz + +def get_device_framework(): + """Check if GPU is available for both TensorFlow and PyTorch, and return the appropriate device for each.""" + # TensorFlow device check + + physical_devices = tf.config.list_physical_devices('GPU') + if physical_devices: + for gpu in physical_devices: + try: + tf.config.experimental.set_memory_growth(gpu, True) + print(f"GPU configured for {gpu}") + except RuntimeError as e: + print(e) + tf_device = "GPU" + else: + tf_device ="/CPU:0" + + # Pytorch device check + if torch.cuda.is_available(): + gpu_idx = [i for i in range(torch.cuda.device_count())] + torch_device = torch.device(f"cuda:{gpu_idx[0]}") # Use the first GPU + else: + torch_device = torch.device("cpu") + + return tf_device, torch_device