Skip to content

Commit

Permalink
make SuperPoint compatible with TorchScript
Browse files Browse the repository at this point in the history
  • Loading branch information
valgur committed Jul 29, 2020
1 parent 82b22d0 commit 2e80db8
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 49 deletions.
2 changes: 1 addition & 1 deletion demo_superglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
assert ret, 'Error when reading the first frame (try different --input?)'

frame_tensor = frame2tensor(frame, device)
last_data = matching.superpoint({'image': frame_tensor})
last_data = matching.superpoint(frame_tensor)
last_data = {k+'0': last_data[k] for k in keys}
last_data['image0'] = frame_tensor
last_frame = frame
Expand Down
5 changes: 5 additions & 0 deletions jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from models.superpoint import SuperPoint
import torch

superpoint = SuperPoint({})
torch.jit.save(superpoint, 'SuperPoint.zip')
4 changes: 2 additions & 2 deletions models/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def forward(self, data):

# Extract SuperPoint (keypoints, scores, descriptors) if not provided
if 'keypoints0' not in data:
pred0 = self.superpoint({'image': data['image0']})
pred0 = self.superpoint(data['image0'])
pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
if 'keypoints1' not in data:
pred1 = self.superpoint({'image': data['image1']})
pred1 = self.superpoint(data['image1'])
pred = {**pred, **{k+'1': v for k, v in pred1.items()}}

# Batch all features
Expand Down
98 changes: 52 additions & 46 deletions models/superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,21 @@
import torch
from torch import nn


def max_pool(x, nms_radius: int):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)

def simple_nms(scores, nms_radius: int):
""" Fast Non-maximum suppression to remove nearby points """
assert(nms_radius >= 0)

def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)

zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
max_mask = scores == max_pool(scores, nms_radius)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_mask = max_pool(max_mask.float(), nms_radius) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
new_max_mask = supp_scores == max_pool(supp_scores, nms_radius)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)

Expand All @@ -81,18 +82,16 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
""" Interpolate descriptors at keypoint locations """
b, c, h, w = descriptors.shape
keypoints = keypoints - s / 2 + 0.5
keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
).to(keypoints)[None]
keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)]).to(keypoints).unsqueeze(0)
keypoints = keypoints*2 - 1 # normalize to (-1, 1)
args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', align_corners=True)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1)
descriptors.reshape(b, c, -1), p=2., dim=1)
return descriptors


class SuperPoint(nn.Module):
class SuperPoint(torch.jit.ScriptModule):
"""SuperPoint Convolutional Detector and Descriptor
SuperPoint: Self-Supervised Interest Point Detection and
Expand All @@ -112,6 +111,12 @@ def __init__(self, config):
super().__init__()
self.config = {**self.default_config, **config}

self.descriptor_dim = self.config['descriptor_dim']
self.nms_radius = self.config['nms_radius']
self.keypoint_threshold = self.config['keypoint_threshold']
self.max_keypoints = self.config['max_keypoints']
self.remove_borders = self.config['remove_borders']

self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
Expand All @@ -130,22 +135,23 @@ def __init__(self, config):

self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
self.convDb = nn.Conv2d(
c5, self.config['descriptor_dim'],
c5, self.descriptor_dim,
kernel_size=1, stride=1, padding=0)

path = Path(__file__).parent / 'weights/superpoint_v1.pth'
self.load_state_dict(torch.load(str(path)))

mk = self.config['max_keypoints']
mk = self.max_keypoints
if mk == 0 or mk < -1:
raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
raise ValueError('"max_keypoints" must be positive or "-1"')

print('Loaded SuperPoint model')

def forward(self, data):
@torch.jit.script_method
def forward(self, image):
""" Compute keypoints, scores, descriptors for image """
# Shared Encoder
x = self.relu(self.conv1a(data['image']))
x = self.relu(self.conv1a(image))
x = self.relu(self.conv1b(x))
x = self.pool(x)
x = self.relu(self.conv2a(x))
Expand All @@ -164,39 +170,39 @@ def forward(self, data):
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
scores = simple_nms(scores, self.config['nms_radius'])

# Extract keypoints
keypoints = [
torch.nonzero(s > self.config['keypoint_threshold'])
for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]

# Discard keypoints near the image borders
keypoints, scores = list(zip(*[
remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
for k, s in zip(keypoints, scores)]))

# Keep the k keypoints with highest score
if self.config['max_keypoints'] >= 0:
keypoints, scores = list(zip(*[
top_k_keypoints(k, s, self.config['max_keypoints'])
for k, s in zip(keypoints, scores)]))

# Convert (h, w) to (x, y)
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
scores = simple_nms(scores, self.nms_radius)

# Compute the dense descriptors
cDa = self.relu(self.convDa(x))
descriptors = self.convDb(cDa)
descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)

# Extract descriptors
descriptors = [sample_descriptors(k[None], d[None], 8)[0]
for k, d in zip(keypoints, descriptors)]
descriptors = torch.nn.functional.normalize(descriptors, p=2., dim=1)

keypoints = []
scores_out = []
descriptors_out = []
for i in range(b):
# Extract keypoints
s = scores[i]
k = torch.nonzero(s > self.keypoint_threshold)
s = s[s > self.keypoint_threshold]

# Discard keypoints near the image borders
k, s = remove_borders(k, s, self.remove_borders, h*8, w*8)

# Keep the k keypoints with highest score
if self.max_keypoints >= 0:
k, s = top_k_keypoints(k, s, self.max_keypoints)

# Convert (h, w) to (x, y)
k = torch.flip(k, [1]).float()

# Extract descriptors
descriptors_out.append(sample_descriptors(k.unsqueeze(0), descriptors[i].unsqueeze(0), 8)[0])
keypoints.append(k)
scores_out.append(s)

return {
'keypoints': keypoints,
'scores': scores,
'descriptors': descriptors,
'scores': scores_out,
'descriptors': descriptors_out,
}

0 comments on commit 2e80db8

Please sign in to comment.