Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add issue solutions and superpoint pytorch version that eat up excessive inference memory #23

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ cd models

# SuperPoint.
git clone https://github.com/rpautrat/SuperPoint.git
mv SuperPoint/weights/superpoint_v6_from_tf.pth .
mv SuperPoint/pretrained_models/sp_v6.tgz . && rm -rf SuperPoint
tar zxvf sp_v6.tgz && rm sp_v6.tgz

Expand Down
5 changes: 2 additions & 3 deletions src/omniglue/dino_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
import tensorflow as tf
import torch


class DINOExtract:
"""Class to initialize DINO model and extract features from an image."""

def __init__(self, cpt_path: str, feature_layer: int = 1):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __init__(self, cpt_path: str, feature_layer: int = 1, device = torch.device("cuda")):
self.device = device
self.feature_layer = feature_layer
self.model = dino.vit_base()
state_dict_raw = torch.load(cpt_path, map_location='cpu')
Expand Down
245 changes: 121 additions & 124 deletions src/omniglue/omniglue_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,142 +18,139 @@
from omniglue import dino_extract
from omniglue import superpoint_extract
from omniglue import utils
from omniglue.utils import get_device_framework
import tensorflow as tf

DINO_FEATURE_DIM = 768
MATCH_THRESHOLD = 1e-3


class OmniGlue:
# TODO(omniglue): class docstring
# TODO: Add class docstring

def __init__(self, og_export: str, sp_export: str | None = None, dino_export: str | None = None) -> None:
"""Initialize OmniGlue with specified exports."""

tf_device, torch_device = get_device_framework()
if tf_device != "GPU":
with tf.device(tf_device):
self.matcher = tf.saved_model.load(og_export)

self.matcher = tf.saved_model.load(og_export)

def __init__(
self,
og_export: str,
sp_export: str | None = None,
dino_export: str | None = None,
) -> None:
self.matcher = tf.saved_model.load(og_export)
if sp_export is not None:
self.sp_extract = superpoint_extract.SuperPointExtract(sp_export)
if dino_export is not None:
self.dino_extract = dino_extract.DINOExtract(dino_export, feature_layer=1)
if sp_export is not None:
if sp_export.endswith((".pth", "pt")):
self.sp_extract = superpoint_extract.SuperPointExtract_Pytorch(sp_export, torch_device)
else:
self.sp_extract = superpoint_extract.SuperPointExtract(sp_export, tf_device)

if dino_export is not None:
self.dino_extract = dino_extract.DINOExtract(dino_export, feature_layer=1, device=torch_device)

def FindMatches(self, image0: np.ndarray, image1: np.ndarray):
"""TODO(omniglue): docstring."""
height0, width0 = image0.shape[:2]
height1, width1 = image1.shape[:2]
sp_features0 = self.sp_extract(image0)
sp_features1 = self.sp_extract(image1)
dino_features0 = self.dino_extract(image0)
dino_features1 = self.dino_extract(image1)
dino_descriptors0 = dino_extract.get_dino_descriptors(
dino_features0,
tf.convert_to_tensor(sp_features0[0], dtype=tf.float32),
tf.convert_to_tensor(height0, dtype=tf.int32),
tf.convert_to_tensor(width0, dtype=tf.int32),
DINO_FEATURE_DIM,
)
dino_descriptors1 = dino_extract.get_dino_descriptors(
dino_features1,
tf.convert_to_tensor(sp_features1[0], dtype=tf.float32),
tf.convert_to_tensor(height1, dtype=tf.int32),
tf.convert_to_tensor(width1, dtype=tf.int32),
DINO_FEATURE_DIM,
)
def FindMatches(self, image0: np.ndarray, image1: np.ndarray):
"""Find matches between two images using SP and DINO features."""
height0, width0 = image0.shape[:2]
height1, width1 = image1.shape[:2]

sp_features0 = self.sp_extract(image0)
sp_features1 = self.sp_extract(image1)

dino_features0 = self.dino_extract(image0)
dino_features1 = self.dino_extract(image1)

dino_descriptors0 = dino_extract.get_dino_descriptors(
dino_features0,
tf.convert_to_tensor(sp_features0[0], dtype=tf.float32),
tf.convert_to_tensor(height0, dtype=tf.int32),
tf.convert_to_tensor(width0, dtype=tf.int32),
DINO_FEATURE_DIM,
)

dino_descriptors1 = dino_extract.get_dino_descriptors(
dino_features1,
tf.convert_to_tensor(sp_features1[0], dtype=tf.float32),
tf.convert_to_tensor(height1, dtype=tf.int32),
tf.convert_to_tensor(width1, dtype=tf.int32),
DINO_FEATURE_DIM,
)

inputs = self._construct_inputs(
width0,
height0,
width1,
height1,
sp_features0,
sp_features1,
dino_descriptors0,
dino_descriptors1,
)
inputs = self._construct_inputs(
width0, height0, width1, height1,
sp_features0, sp_features1,
dino_descriptors0, dino_descriptors1
)

og_outputs = self.matcher.signatures['serving_default'](**inputs)
soft_assignment = og_outputs['soft_assignment'][:, :-1, :-1]
og_outputs = self.matcher.signatures['serving_default'](**inputs)
soft_assignment = og_outputs['soft_assignment'][:, :-1, :-1]

match_matrix = (
utils.soft_assignment_to_match_matrix(soft_assignment, MATCH_THRESHOLD)
.numpy()
.squeeze()
)
match_matrix = (
utils.soft_assignment_to_match_matrix(soft_assignment, MATCH_THRESHOLD)
.numpy()
.squeeze()
)

# Filter out any matches with 0.0 confidence keypoints.
match_indices = np.argwhere(match_matrix)
keep = []
for i in range(match_indices.shape[0]):
match = match_indices[i, :]
if (sp_features0[2][match[0]] > 0.0) and (
sp_features1[2][match[1]] > 0.0
):
keep.append(i)
match_indices = match_indices[keep]
# Filter out any matches with 0.0 confidence keypoints.
match_indices = np.argwhere(match_matrix)
keep = []
for i in range(match_indices.shape[0]):
match = match_indices[i, :]
if (sp_features0[2][match[0]] > 0.0) and (sp_features1[2][match[1]] > 0.0):
keep.append(i)
match_indices = match_indices[keep]

# Format matches in terms of keypoint locations.
match_kp0s = []
match_kp1s = []
match_confidences = []
for match in match_indices:
match_kp0s.append(sp_features0[0][match[0], :])
match_kp1s.append(sp_features1[0][match[1], :])
match_confidences.append(soft_assignment[0, match[0], match[1]])
match_kp0s = np.array(match_kp0s)
match_kp1s = np.array(match_kp1s)
match_confidences = np.array(match_confidences)
return match_kp0s, match_kp1s, match_confidences
# Format matches in terms of keypoint locations.
match_kp0s = []
match_kp1s = []
match_confidences = []
for match in match_indices:
match_kp0s.append(sp_features0[0][match[0], :])
match_kp1s.append(sp_features1[0][match[1], :])
match_confidences.append(soft_assignment[0, match[0], match[1]])

match_kp0s = np.array(match_kp0s)
match_kp1s = np.array(match_kp1s)
match_confidences = np.array(match_confidences)

return match_kp0s, match_kp1s, match_confidences

### Private methods ###
### Private methods ###

def _construct_inputs(
self,
width0,
height0,
width1,
height1,
sp_features0,
sp_features1,
dino_descriptors0,
dino_descriptors1,
):
inputs = {
'keypoints0': tf.convert_to_tensor(
np.expand_dims(sp_features0[0], axis=0),
dtype=tf.float32,
),
'keypoints1': tf.convert_to_tensor(
np.expand_dims(sp_features1[0], axis=0), dtype=tf.float32
),
'descriptors0': tf.convert_to_tensor(
np.expand_dims(sp_features0[1], axis=0), dtype=tf.float32
),
'descriptors1': tf.convert_to_tensor(
np.expand_dims(sp_features1[1], axis=0), dtype=tf.float32
),
'scores0': tf.convert_to_tensor(
np.expand_dims(np.expand_dims(sp_features0[2], axis=0), axis=-1),
dtype=tf.float32,
),
'scores1': tf.convert_to_tensor(
np.expand_dims(np.expand_dims(sp_features1[2], axis=0), axis=-1),
dtype=tf.float32,
),
'descriptors0_dino': tf.expand_dims(dino_descriptors0, axis=0),
'descriptors1_dino': tf.expand_dims(dino_descriptors1, axis=0),
'width0': tf.convert_to_tensor(
np.expand_dims(width0, axis=0), dtype=tf.int32
),
'width1': tf.convert_to_tensor(
np.expand_dims(width1, axis=0), dtype=tf.int32
),
'height0': tf.convert_to_tensor(
np.expand_dims(height0, axis=0), dtype=tf.int32
),
'height1': tf.convert_to_tensor(
np.expand_dims(height1, axis=0), dtype=tf.int32
),
}
return inputs
def _construct_inputs(self, width0, height0, width1, height1, sp_features0, sp_features1, dino_descriptors0, dino_descriptors1):
"""Construct input dictionary for the model."""
inputs = {
'keypoints0': tf.convert_to_tensor(
np.expand_dims(sp_features0[0], axis=0),
dtype=tf.float32,
),
'keypoints1': tf.convert_to_tensor(
np.expand_dims(sp_features1[0], axis=0), dtype=tf.float32
),
'descriptors0': tf.convert_to_tensor(
np.expand_dims(sp_features0[1], axis=0), dtype=tf.float32
),
'descriptors1': tf.convert_to_tensor(
np.expand_dims(sp_features1[1], axis=0), dtype=tf.float32
),
'scores0': tf.convert_to_tensor(
np.expand_dims(np.expand_dims(sp_features0[2], axis=0), axis=-1),
dtype=tf.float32,
),
'scores1': tf.convert_to_tensor(
np.expand_dims(np.expand_dims(sp_features1[2], axis=0), axis=-1),
dtype=tf.float32,
),
'descriptors0_dino': tf.expand_dims(dino_descriptors0, axis=0),
'descriptors1_dino': tf.expand_dims(dino_descriptors1, axis=0),
'width0': tf.convert_to_tensor(
np.expand_dims(width0, axis=0), dtype=tf.int32
),
'width1': tf.convert_to_tensor(
np.expand_dims(width1, axis=0), dtype=tf.int32
),
'height0': tf.convert_to_tensor(
np.expand_dims(height0, axis=0), dtype=tf.int32
),
'height1': tf.convert_to_tensor(
np.expand_dims(height1, axis=0), dtype=tf.int32
),
}
return inputs
86 changes: 78 additions & 8 deletions src/omniglue/superpoint_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import cv2
import numpy as np
from omniglue import utils
import tensorflow as tf
import tensorflow.compat.v1 as tf1

import torch
from omniglue import utils
from omniglue import superpoint_pytorch

class SuperPointExtract:
"""Class to initialize SuperPoint model and extract features from an image.
Expand All @@ -33,13 +35,21 @@ class SuperPointExtract:
model_path: string, filepath to saved SuperPoint TF1 model weights.
"""

def __init__(self, model_path: str):
def __init__(self, model_path: str, device):
self.model_path = model_path
self._graph = tf1.Graph()
self._sess = tf1.Session(graph=self._graph)
tf1.saved_model.loader.load(
self._sess, [tf1.saved_model.tag_constants.SERVING], model_path
)
if device != "GPU":
with tf.device(device):
self._graph = tf1.Graph()
self._sess = tf1.Session(graph=self._graph)
tf1.saved_model.loader.load(
self._sess, [tf1.saved_model.tag_constants.SERVING], model_path
)
else:
self._graph = tf1.Graph()
self._sess = tf1.Session(graph=self._graph)
tf1.saved_model.loader.load(
self._sess, [tf1.saved_model.tag_constants.SERVING], model_path
)

def __call__(
self,
Expand Down Expand Up @@ -212,3 +222,63 @@ def _select_k_best(points, k):
descriptors.append(utils.lookup_descriptor_bilinear(kp, descriptor_map))
descriptors = np.array(descriptors)
return keypoints, descriptors, scores


class SuperPointExtract_Pytorch:
"""Class to initialize SuperPoint model and extract features from an image.

To stay consistent with SuperPoint training and eval configurations, resize
images to (320x240) or (640x480).
"""

def __init__(self, model_path: str, device):
"""Initialize the SuperPoint model."""
self.device = device
self.model_path = model_path
self.model = superpoint_pytorch.SuperPoint()
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
self.model = self.model.to(device=self.device)
self.model.eval()

def __call__(self, image):
"""Extract features from the image."""
return self._extract_features(image)

def _preprocess_image(self, image):
"""Convert image to grayscale and normalize values for model input."""
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image = np.expand_dims(image, 2)
image = image.astype(np.float32)
image = image / 255.0
return image

def _resize_input_image(self, image, interpolation=cv2.INTER_LINEAR):
"""Resize image such that both dimensions are divisible by 8."""
new_dim = [-1, -1]
keypoint_scale_factors = [1.0, 1.0]
for i in range(2):
dim_size = image.shape[i]
mod_eight = dim_size % 8
if mod_eight < 4:
new_dim[i] = dim_size - mod_eight # Round down to nearest multiple of 8.
elif mod_eight >= 4:
new_dim[i] = dim_size + (8 - mod_eight) # Round up to nearest multiple of 8.
keypoint_scale_factors[i] = (new_dim[i] - 1) / (dim_size - 1)

new_dim = new_dim[::-1] # Convert from (row, col) to (x,y).
keypoint_scale_factors = keypoint_scale_factors[::-1]
image = cv2.resize(image, tuple(new_dim), interpolation=interpolation)
return image, keypoint_scale_factors

def _extract_features(self, image):
"""Extract keypoints, descriptors, and scores from the image."""
image, keypoint_scale_factors = self._resize_input_image(image)
image_preprocessed = self._preprocess_image(image)
image_tensor = torch.from_numpy(image_preprocessed.transpose(2, 0, 1)[None]).float().to(self.device) # [1, C, H, W]

with torch.no_grad():
pred = self.model({'image': image_tensor})

keypoints, descriptors, scores = [p[0].cpu().numpy().astype("float64") for p in pred]
keypoints = keypoints / keypoint_scale_factors
return keypoints, descriptors, scores
Loading