Skip to content

ONNX export compatible with OpenCV #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions export_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse
import io
import torch
from torch.autograd import Variable
import onnx

from ssd import build_ssd


def assertONNXExpected(binary_pb):
model_def = onnx.ModelProto.FromString(binary_pb)
onnx.helper.strip_doc_string(model_def)
return model_def


def export_to_string(model, inputs, version=None):
f = io.BytesIO()
with torch.no_grad():
torch.onnx.export(model, inputs, f, export_params=True, opset_version=version)
return f.getvalue()


def save_model(model, input, output):
onnx_model_pb = export_to_string(model, input)
model_def = assertONNXExpected(onnx_model_pb)
with open(output, 'wb') as file:
file.write(model_def.SerializeToString())


if __name__ == '__main__':
parser = argparse.ArgumentParser('Export trained model to ONNX format')
parser.add_argument('--model', required=True, help='Path to saved PyTorch network weights (*.pth)')
parser.add_argument('--output', default='ssd.onnx', help='Name of ouput file')
parser.add_argument('--size', default=300, help='Input resolution')
parser.add_argument('--num_classes', default=21, help='Number of trained classes + 1 for background')
args = parser.parse_args()

net = build_ssd('export', args.size, args.num_classes)
net.load_state_dict(torch.load(args.model, map_location='cpu'))
net.eval()

input = Variable(torch.randn(1, 3, args.size, args.size))
save_model(net, input, args.output)
45 changes: 27 additions & 18 deletions layers/functions/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@ class Detect(Function):
scores and threshold to a top_k number of output predictions for both
confidence score and locations.
"""
def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
self.num_classes = num_classes
self.background_label = bkg_label
self.top_k = top_k
# Parameters used in nms.
self.nms_thresh = nms_thresh
if nms_thresh <= 0:
raise ValueError('nms_threshold must be non negative.')
self.conf_thresh = conf_thresh
self.variance = cfg['variance']
@staticmethod
def symbolic(g, loc_data, conf_data, prior_data, num_classes, top_k, variance, conf_thresh, nms_thresh, phase):
return g.op('DetectionOutput', loc_data, conf_data, prior_data,
num_classes_i=num_classes,
top_k_i=top_k,
keep_top_k_i=top_k,
confidence_threshold_f=conf_thresh,
nms_threshold_f=nms_thresh,
share_location_i=1,
variance_encoded_in_target_i=0,
code_type_s='CENTER_SIZE',
background_label_id_i=0)

def forward(self, loc_data, conf_data, prior_data):
def forward(self, loc_data, conf_data, prior_data, num_classes, top_k, variance, conf_thresh, nms_thresh, phase):
"""
Args:
loc_data: (tensor) Loc preds from loc layers
Expand All @@ -31,32 +33,39 @@ def forward(self, loc_data, conf_data, prior_data):
prior_data: (tensor) Prior boxes and variances from priorbox layers
Shape: [1,num_priors,4]
"""
loc_data = loc_data.view(loc_data.shape[0], -1, 4)

if phase == 'export':
prior_data = prior_data.view(-1, 4)
# Ignore variance from priors data
prior_data = prior_data[:prior_data.shape[0] // 2]

num = loc_data.size(0) # batch size
num_priors = prior_data.size(0)
output = torch.zeros(num, self.num_classes, self.top_k, 5)
output = torch.zeros(num, num_classes, top_k, 5)
conf_preds = conf_data.view(num, num_priors,
self.num_classes).transpose(2, 1)
num_classes).transpose(2, 1)

# Decode predictions into bboxes.
for i in range(num):
decoded_boxes = decode(loc_data[i], prior_data, self.variance)
decoded_boxes = decode(loc_data[i], prior_data, variance)
# For each class, perform nms
conf_scores = conf_preds[i].clone()

for cl in range(1, self.num_classes):
c_mask = conf_scores[cl].gt(self.conf_thresh)
for cl in range(1, num_classes):
c_mask = conf_scores[cl].gt(conf_thresh)
scores = conf_scores[cl][c_mask]
if scores.size(0) == 0:
continue
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
boxes = decoded_boxes[l_mask].view(-1, 4)
# idx of highest scoring and non-overlapping boxes per class
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
ids, count = nms(boxes, scores, nms_thresh, top_k)
output[i, cl, :count] = \
torch.cat((scores[ids[:count]].unsqueeze(1),
boxes[ids[:count]]), 1)
flt = output.contiguous().view(num, -1, 5)
_, idx = flt[:, :, 0].sort(1, descending=True)
_, rank = idx.sort(1)
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
flt[(rank < top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
return output
17 changes: 15 additions & 2 deletions layers/functions/prior_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ class PriorBox(object):
"""Compute priorbox coordinates in center-offset form for each source
feature map.
"""
def __init__(self, cfg):
def __init__(self, cfg, phase):
super(PriorBox, self).__init__()
self.image_size = cfg['min_dim']
# number of priors for feature map location (either 4 or 6)
self.num_priors = len(cfg['aspect_ratios'])
self.variance = cfg['variance'] or [0.1]
self.variance = cfg['variance'] or [0.1, 0.1]
self.feature_maps = cfg['feature_maps']
self.min_sizes = cfg['min_sizes']
self.max_sizes = cfg['max_sizes']
self.steps = cfg['steps']
self.aspect_ratios = cfg['aspect_ratios']
self.clip = cfg['clip']
self.version = cfg['name']
self.phase = phase
for v in self.variance:
if v <= 0:
raise ValueError('Variances must be greater than 0')
Expand Down Expand Up @@ -52,4 +53,16 @@ def forward(self):
output = torch.Tensor(mean).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
if self.phase == 'export':
# CENTER based to CORNER based representaion
w, h = output[:,2], output[:,3]
output[:,0] -= w * 0.5
output[:,1] -= h * 0.5
output[:,2] = output[:,0] + w
output[:,3] = output[:,1] + h

# Append variance after prior boxes like in Caffe
variance = torch.Tensor([self.variance[0], self.variance[0], self.variance[1], self.variance[1]]) \
.repeat(output.shape[0]).view(-1, 4)
return torch.cat([output, variance], 0).view(1, 2, -1)
return output
2 changes: 1 addition & 1 deletion layers/modules/l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps
#x /= norm
x = torch.div(x,norm)
out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
out = self.weight.view(1, -1, 1, 1) * x
return out
28 changes: 18 additions & 10 deletions ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, phase, size, base, extras, head, num_classes):
self.phase = phase
self.num_classes = num_classes
self.cfg = (coco, voc)[num_classes == 21]
self.priorbox = PriorBox(self.cfg)
self.priorbox = PriorBox(self.cfg, phase)
self.priors = Variable(self.priorbox.forward(), volatile=True)
self.size = size

Expand All @@ -43,9 +43,11 @@ def __init__(self, phase, size, base, extras, head, num_classes):
self.loc = nn.ModuleList(head[0])
self.conf = nn.ModuleList(head[1])

if phase == 'test':
if phase == 'test' or phase == 'export':
self.softmax = nn.Softmax(dim=-1)
self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
self.top_k = 200
self.conf_thresh = 0.01
self.nms_thresh = 0.45

def forward(self, x):
"""Applies network layers and ops on input image(s) x.
Expand Down Expand Up @@ -95,12 +97,18 @@ def forward(self, x):

loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
if self.phase == "test":
output = self.detect(
loc.view(loc.size(0), -1, 4), # loc preds
self.softmax(conf.view(conf.size(0), -1,
self.num_classes)), # conf preds
self.priors.type(type(x.data)) # default boxes
if self.phase == "test" or self.phase == "export":
output = Detect.apply(
loc.view(loc.size(0), -1), # loc preds
self.softmax(conf.view(conf.size(0), -1, self.num_classes))
.view(conf.size(0), -1), # conf preds
self.priors.type(type(x.data)), # default boxes
self.num_classes,
self.top_k,
self.cfg['variance'],
self.conf_thresh,
self.nms_thresh,
self.phase
)
else:
output = (
Expand Down Expand Up @@ -196,7 +204,7 @@ def multibox(vgg, extra_layers, cfg, num_classes):


def build_ssd(phase, size=300, num_classes=21):
if phase != "test" and phase != "train":
if phase != "test" and phase != "train" and phase != "export":
print("ERROR: Phase: " + phase + " not recognized")
return
if size != 300:
Expand Down