@@ -10,6 +10,26 @@ from .util import fsspec_cache_open
1010
1111np.import_array()
1212
13+ cdef convert_detections_to_tuples(dn.detection* detections, int num_dets, float nms_threshold, str nms_type):
14+ if nms_threshold > 0 and num_dets > 0 :
15+ if nms_type == " obj" :
16+ dn.do_nms_obj(detections, num_dets, detections[0 ].classes, nms_threshold)
17+ elif nms_type == " sort" :
18+ dn.do_nms_sort(detections, num_dets, detections[0 ].classes, nms_threshold)
19+ else :
20+ raise ValueError (f" non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}" )
21+ rv = [
22+ (j,
23+ detections[i].prob[j],
24+ (detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h)
25+ )
26+ for i in range (num_dets)
27+ for j in range (detections[i].classes)
28+ if detections[i].prob[j] > 0
29+ ]
30+ return sorted (rv, key = lambda x : x[1 ], reverse = True )
31+
32+
1333cdef class Metadata:
1434 classes = [] # typing: List[AnyStr]
1535
@@ -31,7 +51,6 @@ cdef class Network:
3151 with fsspec_cache_open(weights_url, mode = " rb" ) as weights:
3252 return Network(config.name, weights.name)
3353
34-
3554 def __cinit__ (self , config_file , weights_file , clear = True , batch_size = 1 ):
3655 self ._c_network = dn.load_network_custom(config_file.encode(),
3756 weights_file.encode(),
@@ -83,38 +102,28 @@ cdef class Network:
83102 output_shape[0] = self.output_size()
84103 return np.PyArray_SimpleNewFromData(1, output_shape , np.NPY_FLOAT32 , output )
85104
86- def detect(self , frame_size = None ,
87- float threshold = .5 , float hierarchical_threshold = .5 ,
88- int relative = 0 , int letterbox = 1 ,
89- str nms_type = " sort" , float nms_threshold = .45 ,
105+ def detect(self ,
106+ frame_size = None ,
107+ float threshold = .5 ,
108+ float hierarchical_threshold = .5 ,
109+ int relative = 0 ,
110+ int letterbox = 1 ,
111+ str nms_type = " sort" ,
112+ float nms_threshold = .45 ,
90113 ):
91- frame_size = self .shape if frame_size is None else frame_size
114+ pred_width, pred_height = self .shape if frame_size is None else frame_size
115+
92116 cdef int num_dets = 0
93117 cdef dn.detection* detections
94-
95118 detections = dn.get_network_boxes(self ._c_network,
96- frame_size[0 ], frame_size[1 ],
97- threshold, hierarchical_threshold,
119+ pred_width,
120+ pred_height,
121+ threshold,
122+ hierarchical_threshold,
98123 < int * > 0 ,
99124 relative,
100125 & num_dets,
101126 letterbox)
102-
103- if nms_threshold > 0 and num_dets:
104- if nms_type == " obj" :
105- dn.do_nms_obj(detections, num_dets, detections[0 ].classes, nms_threshold)
106- elif nms_type == " sort" :
107- dn.do_nms_sort(detections, num_dets, detections[0 ].classes, nms_threshold)
108- else :
109- raise ValueError (f" non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}" )
110-
111- rv = [
112- (j, detections[i].prob[j],
113- (detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h))
114- for i in range (num_dets)
115- for j in range (detections[i].classes)
116- if detections[i].prob[j] > 0
117- ]
118-
127+ rv = convert_detections_to_tuples(detections, num_dets, nms_type, nms_threshold)
119128 dn.free_detections(detections, num_dets)
120129 return sorted (rv, key = lambda x : x[1 ], reverse = True )
0 commit comments