5151from detectron .utils .model_convert_utils import convert_op_in_proto
5252from detectron .utils .model_convert_utils import op_filter
5353import detectron .core .test_engine as test_engine
54+ import detectron .core .test as test
5455import detectron .utils .c2 as c2_utils
5556import detectron .utils .model_convert_utils as mutils
5657import detectron .utils .vis as vis_utils
58+ import detectron .utils .blob as blob_utils
59+ import detectron .utils .keypoints as keypoint_utils
60+ import pycocotools .mask as mask_utils
5761
5862c2_utils .import_contrib_ops ()
5963c2_utils .import_detectron_ops ()
@@ -315,14 +319,14 @@ def gen_init_net(net, blobs, empty_blobs):
315319def _save_image_graphs (args , all_net , all_init_net ):
316320 print ('Saving model graph...' )
317321 mutils .save_graph (
318- all_net .Proto (), os .path .join (args .out_dir , "model_def. png" ),
322+ all_net .Proto (), os .path .join (args .out_dir , all_net . Proto (). name + '. png' ),
319323 op_only = False )
320324 print ('Model def image saved to {}.' .format (args .out_dir ))
321325
322326
323327def _save_models (all_net , all_init_net , args ):
324328 print ('Writing converted model to {}...' .format (args .out_dir ))
325- fname = "model"
329+ fname = all_net . Proto (). name
326330
327331 if not os .path .exists (args .out_dir ):
328332 os .makedirs (args .out_dir )
@@ -380,13 +384,14 @@ def run_model_cfg(args, im, check_blobs):
380384 cls_boxes , cls_segms , cls_keyps = test_engine .im_detect_all (
381385 model , im , None , None ,
382386 )
383-
384- boxes , segms , keypoints , classes = vis_utils .convert_from_cls_format (
387+ boxes , segms , keypoints , classids = vis_utils .convert_from_cls_format (
385388 cls_boxes , cls_segms , cls_keyps )
386389
390+ segms = mask_utils .decode (segms ) if segms else None
391+
387392 # sort the results based on score for comparision
388- boxes , segms , keypoints , classes = _sort_results (
389- boxes , segms , keypoints , classes )
393+ boxes , segms , keypoints , classids = _sort_results (
394+ boxes , segms , keypoints , classids )
390395
391396 # write final results back to workspace
392397 def _ornone (res ):
@@ -395,12 +400,16 @@ def _ornone(res):
395400 workspace .FeedBlob (core .ScopedName ('result_boxes' ), _ornone (boxes ))
396401 workspace .FeedBlob (core .ScopedName ('result_segms' ), _ornone (segms ))
397402 workspace .FeedBlob (core .ScopedName ('result_keypoints' ), _ornone (keypoints ))
398- workspace .FeedBlob (core .ScopedName ('result_classids' ), _ornone (classes ))
403+ workspace .FeedBlob (core .ScopedName ('result_classids' ), _ornone (classids ))
399404
400405 # get result blobs
401406 with c2_utils .NamedCudaScope (0 ):
402407 ret = _get_result_blobs (check_blobs )
403408
409+ print ('result_boxes' , _ornone (boxes ))
410+ print ('result_segms' , _ornone (segms ))
411+ print ('result_keypoints' , _ornone (keypoints ))
412+ print ('result_classids' , _ornone (classids ))
404413 return ret
405414
406415
@@ -438,13 +447,13 @@ def _prepare_blobs(
438447 return blobs
439448
440449
441- def run_model_pb (args , net , init_net , im , check_blobs ):
450+ def run_model_pb (args , models_pb , im , check_blobs ):
442451 workspace .ResetWorkspace ()
452+ net , init_net = models_pb ['net' ]
443453 workspace .RunNetOnce (init_net )
444454 mutils .create_input_blobs_for_net (net .Proto ())
445455 workspace .CreateNet (net )
446456
447- # input_blobs, _ = core_test._get_blobs(im, None)
448457 input_blobs = _prepare_blobs (
449458 im ,
450459 cfg .PIXEL_MEANS ,
@@ -462,37 +471,137 @@ def run_model_pb(args, net, init_net, im, check_blobs):
462471 )
463472
464473 try :
465- workspace .RunNet (net . Proto (). name )
466- scores = workspace .FetchBlob ('score_nms' )
467- classids = workspace .FetchBlob ('class_nms' )
468- boxes = workspace .FetchBlob ('bbox_nms' )
474+ workspace .RunNet (net )
475+ scores = workspace .FetchBlob (core . ScopedName ( 'score_nms' ) )
476+ classids = workspace .FetchBlob (core . ScopedName ( 'class_nms' ) )
477+ boxes = workspace .FetchBlob (core . ScopedName ( 'bbox_nms' ) )
469478 except Exception as e :
470479 print ('Running pb model failed.\n {}' .format (e ))
471- # may not detect anything at all
472480 R = 0
473481 scores = np .zeros ((R ,), dtype = np .float32 )
474482 boxes = np .zeros ((R , 4 ), dtype = np .float32 )
475483 classids = np .zeros ((R ,), dtype = np .float32 )
476484
485+ cls_keyps , cls_segms = None , None
486+
487+ if 'keypoint_net' in models_pb :
488+ keypoint_net , init_keypoint_net = models_pb ['keypoint_net' ]
489+ workspace .RunNetOnce (init_keypoint_net )
490+ mutils .create_input_blobs_for_net (keypoint_net .Proto ())
491+ keypoint_net .Proto ().external_input .extend (['rpn_rois' , 'bbox_pred' , 'im_info' , 'cls_prob' ])
492+ workspace .CreateNet (keypoint_net )
493+
494+ im_scale = input_blobs ['im_info' ][0 ][2 ]
495+ input_blobs = {'keypoint_rois' : test ._get_rois_blob (boxes , im_scale )}
496+
497+ # Add multi-level rois for FPN
498+ if cfg .FPN .MULTILEVEL_ROIS :
499+ test ._add_multilevel_rois_for_test (input_blobs , 'keypoint_rois' )
500+
501+ gpu_blobs = []
502+ if args .device == 'gpu' :
503+ gpu_blobs = ['data' ]
504+ for k , v in list (input_blobs .items ()):
505+ workspace .FeedBlob (
506+ core .ScopedName (k ),
507+ v ,
508+ mutils .get_device_option_cuda () if k in gpu_blobs else
509+ mutils .get_device_option_cpu ()
510+ )
511+
512+ try :
513+ workspace .RunNet (keypoint_net )
514+ pred_heatmaps = workspace .FetchBlob (core .ScopedName ('kps_score' )).squeeze ()
515+ # In case of 1
516+ if pred_heatmaps .ndim == 3 :
517+ pred_heatmaps = np .expand_dims (pred_heatmaps , axis = 0 )
518+ except Exception as e :
519+ print ('Running pb model failed.\n {}' .format (e ))
520+ R , M = 0 , cfg .KRCNN .HEATMAP_SIZE
521+ pred_heatmaps = np .zeros ((R , cfg .KRCNN .NUM_KEYPOINTS , M , M ), np .float32 )
522+
523+ xy_preds = keypoint_utils .heatmaps_to_keypoints (pred_heatmaps , boxes )
524+ cls_keyps = [[] for _ in range (cfg .MODEL .NUM_CLASSES )]
525+ cls_keyps [1 ] = [xy_preds [i ] for i in range (xy_preds .shape [0 ])]
526+
527+ if 'mask_net' in models_pb :
528+ mask_net , init_mask_net = models_pb ['mask_net' ]
529+ workspace .RunNetOnce (init_mask_net )
530+ mutils .create_input_blobs_for_net (mask_net .Proto ())
531+ mask_net .Proto ().external_input .extend (['rpn_rois' , 'bbox_pred' , 'im_info' , 'cls_prob' ])
532+ workspace .CreateNet (mask_net )
533+
534+ im_scale = input_blobs ['im_info' ][0 ][2 ]
535+ input_blobs = {'mask_rois' : test ._get_rois_blob (boxes , im_scale )}
536+
537+ # Add multi-level rois for FPN
538+ if cfg .FPN .MULTILEVEL_ROIS :
539+ test ._add_multilevel_rois_for_test (input_blobs , 'mask_rois' )
540+
541+ gpu_blobs = []
542+ if args .device == 'gpu' :
543+ gpu_blobs = ['data' ]
544+ for k , v in list (input_blobs .items ()):
545+ workspace .FeedBlob (
546+ core .ScopedName (k ),
547+ v ,
548+ mutils .get_device_option_cuda () if k in gpu_blobs else
549+ mutils .get_device_option_cpu ()
550+ )
551+ M = cfg .MRCNN .RESOLUTION
552+ try :
553+ workspace .RunNet (mask_net )
554+ # Fetch masks
555+ pred_masks = workspace .FetchBlob (core .ScopedName ('mask_fcn_probs' )).squeeze ()
556+ if cfg .MRCNN .CLS_SPECIFIC_MASK :
557+ pred_masks = pred_masks .reshape ([- 1 , cfg .MODEL .NUM_CLASSES , M , M ])
558+ else :
559+ pred_masks = pred_masks .reshape ([- 1 , 1 , M , M ])
560+ except Exception as e :
561+ print ('Running pb model failed.\n {}' .format (e ))
562+ R = 0
563+ if cfg .MRCNN .CLS_SPECIFIC_MASK :
564+ pred_masks = np .zeros ((R , cfg .MODEL .NUM_CLASSES , M , M ), dtype = np .float32 )
565+ else :
566+ pred_masks = np .zeros ((R , 1 , M , M ), dtype = np .float32 )
567+
568+ cls_boxes = [np .empty (list (classids ).count (i )) for i in range (cfg .MODEL .NUM_CLASSES )]
569+ cls_segms = test .segm_results (cls_boxes , pred_masks , boxes , im .shape [0 ], im .shape [1 ])
570+
477571 boxes = np .column_stack ((boxes , scores ))
478572
573+ _ , segms , keypoints , _ = vis_utils .convert_from_cls_format ([], cls_segms , cls_keyps )
574+ segms = mask_utils .decode (segms ) if segms else None
575+
479576 # sort the results based on score for comparision
480- boxes , _ , _ , classids = _sort_results (
481- boxes , None , None , classids )
577+ boxes , segms , keypoints , classids = _sort_results (
578+ boxes , segms , keypoints , classids )
482579
483580 # write final result back to workspace
484- workspace .FeedBlob ('result_boxes' , boxes )
485- workspace .FeedBlob ('result_classids' , classids )
581+ def _ornone (res ):
582+ return np .array (res ) if res is not None else np .array ([], dtype = np .float32 )
583+ workspace .FeedBlob (core .ScopedName ('result_boxes' ), _ornone (boxes ))
584+ workspace .FeedBlob (core .ScopedName ('result_classids' ), _ornone (classids ))
585+ workspace .FeedBlob (core .ScopedName ('result_segms' ), _ornone (segms ))
586+ workspace .FeedBlob (core .ScopedName ('result_keypoints' ), _ornone (keypoints ))
486587
487588 ret = _get_result_blobs (check_blobs )
488589
590+ print ('result_boxes' , _ornone (boxes ))
591+ print ('result_segms' , _ornone (segms ))
592+ print ('result_keypoints' , _ornone (keypoints ))
593+ print ('result_classids' , _ornone (classids ))
489594 return ret
490595
491596
492- def verify_model (args , model_pb , test_img_file ):
493- check_blobs = [
494- "result_boxes" , "result_classids" , # result
495- ]
597+ def verify_model (args , models_pb , test_img_file ):
598+ check_blobs = ['result_boxes' , 'result_classids' ]
599+
600+ if cfg .MODEL .MASK_ON :
601+ check_blobs .append ('result_segms' )
602+
603+ if cfg .MODEL .KEYPOINTS_ON :
604+ check_blobs .append ('result_keypoints' )
496605
497606 print ('Loading test file {}...' .format (test_img_file ))
498607 test_img = cv2 .imread (test_img_file )
@@ -502,13 +611,49 @@ def _run_cfg_func(im, blobs):
502611 return run_model_cfg (args , im , check_blobs )
503612
504613 def _run_pb_func (im , blobs ):
505- return run_model_pb (args , model_pb [ 0 ], model_pb [ 1 ] , im , check_blobs )
614+ return run_model_pb (args , models_pb , im , check_blobs )
506615
507616 print ('Checking models...' )
508617 assert mutils .compare_model (
509618 _run_cfg_func , _run_pb_func , test_img , check_blobs )
510619
511620
621+ def convert_to_pb (args , net , blobs , part_name = 'net' , input_blobs = []):
622+ pb_net = core .Net ('' )
623+ pb_net .Proto ().op .extend (copy .deepcopy (net .op ))
624+
625+ pb_net .Proto ().external_input .extend (
626+ copy .deepcopy (net .external_input ))
627+ pb_net .Proto ().external_output .extend (
628+ copy .deepcopy (net .external_output ))
629+ pb_net .Proto ().type = args .net_execution_type
630+ pb_net .Proto ().num_workers = 1 if args .net_execution_type == 'simple' else 4
631+
632+ # Reset the device_option, change to unscope name and replace python operators
633+ convert_net (args , pb_net .Proto (), blobs )
634+
635+ # add operators for bbox
636+ add_bbox_ops (args , pb_net , blobs )
637+
638+ if args .fuse_af :
639+ print ('Fusing affine channel...' )
640+ pb_net , blobs = mutils .fuse_net_affine (pb_net , blobs )
641+
642+ if args .use_nnpack :
643+ mutils .update_mobile_engines (pb_net .Proto ())
644+
645+ # generate init net
646+ pb_init_net = gen_init_net (pb_net , blobs , input_blobs )
647+
648+ if args .device == 'gpu' :
649+ [pb_net , pb_init_net ] = convert_model_gpu (args , pb_net , pb_init_net )
650+
651+ pb_net .Proto ().name = args .net_name + '_' + part_name
652+ pb_init_net .Proto ().name = args .net_name + '_' + part_name + '_init'
653+
654+ return pb_net , pb_init_net
655+
656+
512657def main ():
513658 workspace .GlobalInit (['caffe2' , '--caffe2_log_level=0' ])
514659 args = parse_args ()
@@ -523,52 +668,24 @@ def main():
523668 logger .info ('Conerting model with config:' )
524669 logger .info (pprint .pformat (cfg ))
525670
526- assert not cfg .MODEL .KEYPOINTS_ON , "Keypoint model not supported."
527- assert not cfg .MODEL .MASK_ON , "Mask model not supported."
528- assert not cfg .FPN .FPN_ON , "FPN not supported."
529- assert not cfg .RETINANET .RETINANET_ON , "RetinaNet model not supported."
530-
671+ models_pb = {}
531672 # load model from cfg
532673 model , blobs = load_model (args )
533674
534- net = core .Net ('' )
535- net .Proto ().op .extend (copy .deepcopy (model .net .Proto ().op ))
536- net .Proto ().external_input .extend (
537- copy .deepcopy (model .net .Proto ().external_input ))
538- net .Proto ().external_output .extend (
539- copy .deepcopy (model .net .Proto ().external_output ))
540- net .Proto ().type = args .net_execution_type
541- net .Proto ().num_workers = 1 if args .net_execution_type == 'simple' else 4
542-
543- # Reset the device_option, change to unscope name and replace python operators
544- convert_net (args , net .Proto (), blobs )
545-
546- # add operators for bbox
547- add_bbox_ops (args , net , blobs )
548-
549- if args .fuse_af :
550- print ('Fusing affine channel...' )
551- net , blobs = mutils .fuse_net_affine (
552- net , blobs )
553-
554- if args .use_nnpack :
555- mutils .update_mobile_engines (net .Proto ())
675+ input_net = ['data' , 'im_info' ]
676+ models_pb ['net' ] = convert_to_pb (args , model .net .Proto (), blobs , input_blobs = input_net )
556677
557- # generate init net
558- empty_blobs = ['data' , 'im_info' ]
559- init_net = gen_init_net (net , blobs , empty_blobs )
678+ if cfg .MODEL .MASK_ON :
679+ models_pb ['mask_net' ] = convert_to_pb (args , model .mask_net .Proto (), blobs , part_name = 'mask_net' )
560680
561- if args . device == 'gpu' :
562- [ net , init_net ] = convert_model_gpu (args , net , init_net )
681+ if cfg . MODEL . KEYPOINTS_ON :
682+ models_pb [ 'keypoint_net' ] = convert_to_pb (args , model . keypoint_net . Proto (), blobs , part_name = 'keypoint_net' )
563683
564- net . Proto (). name = args . net_name
565- init_net . Proto (). name = args . net_name + "_init"
684+ for ( pb_net , pb_init_net ) in models_pb . values ():
685+ _save_models ( pb_net , pb_init_net , args )
566686
567687 if args .test_img is not None :
568- verify_model (args , [net , init_net ], args .test_img )
569-
570- _save_models (net , init_net , args )
571-
688+ verify_model (args , models_pb , args .test_img )
572689
573690if __name__ == '__main__' :
574691 main ()
0 commit comments