Skip to content

Commit 763ff29

Browse files
authored
Support exporting for CPU Mask & Keypoint nets
Pre-requisite : facebookresearch#372 Purpose : enable exporting all the models for CPU by exporting 2 separate nets : one for the bboxes and one for the rest of the inference. Two main modifications - Refactor the main() : it will call a function convert_to_pb for each sub_net - run_model_pb : always do the inference for bbox and then call mask or keypoint part if needed. The exact same approach is adopted. Then helper functions are only lightly modified to fit with the new objective to export 2 pb files
1 parent 0375b05 commit 763ff29

1 file changed

Lines changed: 178 additions & 61 deletions

File tree

tools/convert_pkl_to_pb.py

Lines changed: 178 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,13 @@
5151
from detectron.utils.model_convert_utils import convert_op_in_proto
5252
from detectron.utils.model_convert_utils import op_filter
5353
import detectron.core.test_engine as test_engine
54+
import detectron.core.test as test
5455
import detectron.utils.c2 as c2_utils
5556
import detectron.utils.model_convert_utils as mutils
5657
import detectron.utils.vis as vis_utils
58+
import detectron.utils.blob as blob_utils
59+
import detectron.utils.keypoints as keypoint_utils
60+
import pycocotools.mask as mask_utils
5761

5862
c2_utils.import_contrib_ops()
5963
c2_utils.import_detectron_ops()
@@ -315,14 +319,14 @@ def gen_init_net(net, blobs, empty_blobs):
315319
def _save_image_graphs(args, all_net, all_init_net):
316320
print('Saving model graph...')
317321
mutils.save_graph(
318-
all_net.Proto(), os.path.join(args.out_dir, "model_def.png"),
322+
all_net.Proto(), os.path.join(args.out_dir, all_net.Proto().name + '.png'),
319323
op_only=False)
320324
print('Model def image saved to {}.'.format(args.out_dir))
321325

322326

323327
def _save_models(all_net, all_init_net, args):
324328
print('Writing converted model to {}...'.format(args.out_dir))
325-
fname = "model"
329+
fname = all_net.Proto().name
326330

327331
if not os.path.exists(args.out_dir):
328332
os.makedirs(args.out_dir)
@@ -380,13 +384,14 @@ def run_model_cfg(args, im, check_blobs):
380384
cls_boxes, cls_segms, cls_keyps = test_engine.im_detect_all(
381385
model, im, None, None,
382386
)
383-
384-
boxes, segms, keypoints, classes = vis_utils.convert_from_cls_format(
387+
boxes, segms, keypoints, classids = vis_utils.convert_from_cls_format(
385388
cls_boxes, cls_segms, cls_keyps)
386389

390+
segms = mask_utils.decode(segms) if segms else None
391+
387392
# sort the results based on score for comparision
388-
boxes, segms, keypoints, classes = _sort_results(
389-
boxes, segms, keypoints, classes)
393+
boxes, segms, keypoints, classids = _sort_results(
394+
boxes, segms, keypoints, classids)
390395

391396
# write final results back to workspace
392397
def _ornone(res):
@@ -395,12 +400,16 @@ def _ornone(res):
395400
workspace.FeedBlob(core.ScopedName('result_boxes'), _ornone(boxes))
396401
workspace.FeedBlob(core.ScopedName('result_segms'), _ornone(segms))
397402
workspace.FeedBlob(core.ScopedName('result_keypoints'), _ornone(keypoints))
398-
workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classes))
403+
workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classids))
399404

400405
# get result blobs
401406
with c2_utils.NamedCudaScope(0):
402407
ret = _get_result_blobs(check_blobs)
403408

409+
print('result_boxes', _ornone(boxes))
410+
print('result_segms', _ornone(segms))
411+
print('result_keypoints', _ornone(keypoints))
412+
print('result_classids', _ornone(classids))
404413
return ret
405414

406415

@@ -438,13 +447,13 @@ def _prepare_blobs(
438447
return blobs
439448

440449

441-
def run_model_pb(args, net, init_net, im, check_blobs):
450+
def run_model_pb(args, models_pb, im, check_blobs):
442451
workspace.ResetWorkspace()
452+
net, init_net = models_pb['net']
443453
workspace.RunNetOnce(init_net)
444454
mutils.create_input_blobs_for_net(net.Proto())
445455
workspace.CreateNet(net)
446456

447-
# input_blobs, _ = core_test._get_blobs(im, None)
448457
input_blobs = _prepare_blobs(
449458
im,
450459
cfg.PIXEL_MEANS,
@@ -462,37 +471,137 @@ def run_model_pb(args, net, init_net, im, check_blobs):
462471
)
463472

464473
try:
465-
workspace.RunNet(net.Proto().name)
466-
scores = workspace.FetchBlob('score_nms')
467-
classids = workspace.FetchBlob('class_nms')
468-
boxes = workspace.FetchBlob('bbox_nms')
474+
workspace.RunNet(net)
475+
scores = workspace.FetchBlob(core.ScopedName('score_nms'))
476+
classids = workspace.FetchBlob(core.ScopedName('class_nms'))
477+
boxes = workspace.FetchBlob(core.ScopedName('bbox_nms'))
469478
except Exception as e:
470479
print('Running pb model failed.\n{}'.format(e))
471-
# may not detect anything at all
472480
R = 0
473481
scores = np.zeros((R,), dtype=np.float32)
474482
boxes = np.zeros((R, 4), dtype=np.float32)
475483
classids = np.zeros((R,), dtype=np.float32)
476484

485+
cls_keyps, cls_segms = None, None
486+
487+
if 'keypoint_net' in models_pb:
488+
keypoint_net, init_keypoint_net = models_pb['keypoint_net']
489+
workspace.RunNetOnce(init_keypoint_net)
490+
mutils.create_input_blobs_for_net(keypoint_net.Proto())
491+
keypoint_net.Proto().external_input.extend(['rpn_rois', 'bbox_pred', 'im_info', 'cls_prob'])
492+
workspace.CreateNet(keypoint_net)
493+
494+
im_scale = input_blobs['im_info'][0][2]
495+
input_blobs = {'keypoint_rois': test._get_rois_blob(boxes, im_scale)}
496+
497+
# Add multi-level rois for FPN
498+
if cfg.FPN.MULTILEVEL_ROIS:
499+
test._add_multilevel_rois_for_test(input_blobs, 'keypoint_rois')
500+
501+
gpu_blobs = []
502+
if args.device == 'gpu':
503+
gpu_blobs = ['data']
504+
for k, v in list(input_blobs.items()):
505+
workspace.FeedBlob(
506+
core.ScopedName(k),
507+
v,
508+
mutils.get_device_option_cuda() if k in gpu_blobs else
509+
mutils.get_device_option_cpu()
510+
)
511+
512+
try:
513+
workspace.RunNet(keypoint_net)
514+
pred_heatmaps = workspace.FetchBlob(core.ScopedName('kps_score')).squeeze()
515+
# In case of 1
516+
if pred_heatmaps.ndim == 3:
517+
pred_heatmaps = np.expand_dims(pred_heatmaps, axis=0)
518+
except Exception as e:
519+
print('Running pb model failed.\n{}'.format(e))
520+
R, M = 0, cfg.KRCNN.HEATMAP_SIZE
521+
pred_heatmaps = np.zeros((R, cfg.KRCNN.NUM_KEYPOINTS, M, M), np.float32)
522+
523+
xy_preds = keypoint_utils.heatmaps_to_keypoints(pred_heatmaps, boxes)
524+
cls_keyps = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
525+
cls_keyps[1] = [xy_preds[i] for i in range(xy_preds.shape[0])]
526+
527+
if 'mask_net' in models_pb:
528+
mask_net, init_mask_net = models_pb['mask_net']
529+
workspace.RunNetOnce(init_mask_net)
530+
mutils.create_input_blobs_for_net(mask_net.Proto())
531+
mask_net.Proto().external_input.extend(['rpn_rois', 'bbox_pred', 'im_info', 'cls_prob'])
532+
workspace.CreateNet(mask_net)
533+
534+
im_scale = input_blobs['im_info'][0][2]
535+
input_blobs = {'mask_rois': test._get_rois_blob(boxes, im_scale)}
536+
537+
# Add multi-level rois for FPN
538+
if cfg.FPN.MULTILEVEL_ROIS:
539+
test._add_multilevel_rois_for_test(input_blobs, 'mask_rois')
540+
541+
gpu_blobs = []
542+
if args.device == 'gpu':
543+
gpu_blobs = ['data']
544+
for k, v in list(input_blobs.items()):
545+
workspace.FeedBlob(
546+
core.ScopedName(k),
547+
v,
548+
mutils.get_device_option_cuda() if k in gpu_blobs else
549+
mutils.get_device_option_cpu()
550+
)
551+
M = cfg.MRCNN.RESOLUTION
552+
try:
553+
workspace.RunNet(mask_net)
554+
# Fetch masks
555+
pred_masks = workspace.FetchBlob(core.ScopedName('mask_fcn_probs')).squeeze()
556+
if cfg.MRCNN.CLS_SPECIFIC_MASK:
557+
pred_masks = pred_masks.reshape([-1, cfg.MODEL.NUM_CLASSES, M, M])
558+
else:
559+
pred_masks = pred_masks.reshape([-1, 1, M, M])
560+
except Exception as e:
561+
print('Running pb model failed.\n{}'.format(e))
562+
R = 0
563+
if cfg.MRCNN.CLS_SPECIFIC_MASK:
564+
pred_masks = np.zeros((R, cfg.MODEL.NUM_CLASSES, M, M), dtype=np.float32)
565+
else:
566+
pred_masks = np.zeros((R, 1, M, M), dtype=np.float32)
567+
568+
cls_boxes = [np.empty(list(classids).count(i)) for i in range(cfg.MODEL.NUM_CLASSES)]
569+
cls_segms = test.segm_results(cls_boxes, pred_masks, boxes, im.shape[0], im.shape[1])
570+
477571
boxes = np.column_stack((boxes, scores))
478572

573+
_, segms, keypoints, _ = vis_utils.convert_from_cls_format([], cls_segms, cls_keyps)
574+
segms = mask_utils.decode(segms) if segms else None
575+
479576
# sort the results based on score for comparision
480-
boxes, _, _, classids = _sort_results(
481-
boxes, None, None, classids)
577+
boxes, segms, keypoints, classids = _sort_results(
578+
boxes, segms, keypoints, classids)
482579

483580
# write final result back to workspace
484-
workspace.FeedBlob('result_boxes', boxes)
485-
workspace.FeedBlob('result_classids', classids)
581+
def _ornone(res):
582+
return np.array(res) if res is not None else np.array([], dtype=np.float32)
583+
workspace.FeedBlob(core.ScopedName('result_boxes'), _ornone(boxes))
584+
workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classids))
585+
workspace.FeedBlob(core.ScopedName('result_segms'), _ornone(segms))
586+
workspace.FeedBlob(core.ScopedName('result_keypoints'), _ornone(keypoints))
486587

487588
ret = _get_result_blobs(check_blobs)
488589

590+
print('result_boxes', _ornone(boxes))
591+
print('result_segms', _ornone(segms))
592+
print('result_keypoints', _ornone(keypoints))
593+
print('result_classids', _ornone(classids))
489594
return ret
490595

491596

492-
def verify_model(args, model_pb, test_img_file):
493-
check_blobs = [
494-
"result_boxes", "result_classids", # result
495-
]
597+
def verify_model(args, models_pb, test_img_file):
598+
check_blobs = ['result_boxes', 'result_classids']
599+
600+
if cfg.MODEL.MASK_ON:
601+
check_blobs.append('result_segms')
602+
603+
if cfg.MODEL.KEYPOINTS_ON:
604+
check_blobs.append('result_keypoints')
496605

497606
print('Loading test file {}...'.format(test_img_file))
498607
test_img = cv2.imread(test_img_file)
@@ -502,13 +611,49 @@ def _run_cfg_func(im, blobs):
502611
return run_model_cfg(args, im, check_blobs)
503612

504613
def _run_pb_func(im, blobs):
505-
return run_model_pb(args, model_pb[0], model_pb[1], im, check_blobs)
614+
return run_model_pb(args, models_pb, im, check_blobs)
506615

507616
print('Checking models...')
508617
assert mutils.compare_model(
509618
_run_cfg_func, _run_pb_func, test_img, check_blobs)
510619

511620

621+
def convert_to_pb(args, net, blobs, part_name='net', input_blobs=[]):
622+
pb_net = core.Net('')
623+
pb_net.Proto().op.extend(copy.deepcopy(net.op))
624+
625+
pb_net.Proto().external_input.extend(
626+
copy.deepcopy(net.external_input))
627+
pb_net.Proto().external_output.extend(
628+
copy.deepcopy(net.external_output))
629+
pb_net.Proto().type = args.net_execution_type
630+
pb_net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4
631+
632+
# Reset the device_option, change to unscope name and replace python operators
633+
convert_net(args, pb_net.Proto(), blobs)
634+
635+
# add operators for bbox
636+
add_bbox_ops(args, pb_net, blobs)
637+
638+
if args.fuse_af:
639+
print('Fusing affine channel...')
640+
pb_net, blobs = mutils.fuse_net_affine(pb_net, blobs)
641+
642+
if args.use_nnpack:
643+
mutils.update_mobile_engines(pb_net.Proto())
644+
645+
# generate init net
646+
pb_init_net = gen_init_net(pb_net, blobs, input_blobs)
647+
648+
if args.device == 'gpu':
649+
[pb_net, pb_init_net] = convert_model_gpu(args, pb_net, pb_init_net)
650+
651+
pb_net.Proto().name = args.net_name + '_' + part_name
652+
pb_init_net.Proto().name = args.net_name + '_' + part_name + '_init'
653+
654+
return pb_net, pb_init_net
655+
656+
512657
def main():
513658
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
514659
args = parse_args()
@@ -523,52 +668,24 @@ def main():
523668
logger.info('Conerting model with config:')
524669
logger.info(pprint.pformat(cfg))
525670

526-
assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported."
527-
assert not cfg.MODEL.MASK_ON, "Mask model not supported."
528-
assert not cfg.FPN.FPN_ON, "FPN not supported."
529-
assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported."
530-
671+
models_pb = {}
531672
# load model from cfg
532673
model, blobs = load_model(args)
533674

534-
net = core.Net('')
535-
net.Proto().op.extend(copy.deepcopy(model.net.Proto().op))
536-
net.Proto().external_input.extend(
537-
copy.deepcopy(model.net.Proto().external_input))
538-
net.Proto().external_output.extend(
539-
copy.deepcopy(model.net.Proto().external_output))
540-
net.Proto().type = args.net_execution_type
541-
net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4
542-
543-
# Reset the device_option, change to unscope name and replace python operators
544-
convert_net(args, net.Proto(), blobs)
545-
546-
# add operators for bbox
547-
add_bbox_ops(args, net, blobs)
548-
549-
if args.fuse_af:
550-
print('Fusing affine channel...')
551-
net, blobs = mutils.fuse_net_affine(
552-
net, blobs)
553-
554-
if args.use_nnpack:
555-
mutils.update_mobile_engines(net.Proto())
675+
input_net = ['data', 'im_info']
676+
models_pb['net'] = convert_to_pb(args, model.net.Proto(), blobs, input_blobs=input_net)
556677

557-
# generate init net
558-
empty_blobs = ['data', 'im_info']
559-
init_net = gen_init_net(net, blobs, empty_blobs)
678+
if cfg.MODEL.MASK_ON:
679+
models_pb['mask_net'] = convert_to_pb(args, model.mask_net.Proto(), blobs, part_name='mask_net')
560680

561-
if args.device == 'gpu':
562-
[net, init_net] = convert_model_gpu(args, net, init_net)
681+
if cfg.MODEL.KEYPOINTS_ON:
682+
models_pb['keypoint_net'] = convert_to_pb(args, model.keypoint_net.Proto(), blobs, part_name='keypoint_net')
563683

564-
net.Proto().name = args.net_name
565-
init_net.Proto().name = args.net_name + "_init"
684+
for (pb_net, pb_init_net) in models_pb.values():
685+
_save_models(pb_net, pb_init_net, args)
566686

567687
if args.test_img is not None:
568-
verify_model(args, [net, init_net], args.test_img)
569-
570-
_save_models(net, init_net, args)
571-
688+
verify_model(args, models_pb, args.test_img)
572689

573690
if __name__ == '__main__':
574691
main()

0 commit comments

Comments
 (0)