Skip to content

Commit 1d3ffc4

Browse files
committed
Refactor NMS and detections to tuple array convertion.
1 parent c8976c2 commit 1d3ffc4

File tree

1 file changed

+35
-26
lines changed

1 file changed

+35
-26
lines changed

src/darknet/py/network.pyx

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@ from .util import fsspec_cache_open
1010

1111
np.import_array()
1212

13+
cdef convert_detections_to_tuples(dn.detection* detections, int num_dets, float nms_threshold, str nms_type):
14+
if nms_threshold > 0 and num_dets > 0:
15+
if nms_type == "obj":
16+
dn.do_nms_obj(detections, num_dets, detections[0].classes, nms_threshold)
17+
elif nms_type == "sort":
18+
dn.do_nms_sort(detections, num_dets, detections[0].classes, nms_threshold)
19+
else:
20+
raise ValueError(f"non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}")
21+
rv = [
22+
(j,
23+
detections[i].prob[j],
24+
(detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h)
25+
)
26+
for i in range(num_dets)
27+
for j in range(detections[i].classes)
28+
if detections[i].prob[j] > 0
29+
]
30+
return sorted(rv, key=lambda x: x[1], reverse=True)
31+
32+
1333
cdef class Metadata:
1434
classes = [] # typing: List[AnyStr]
1535

@@ -31,7 +51,6 @@ cdef class Network:
3151
with fsspec_cache_open(weights_url, mode="rb") as weights:
3252
return Network(config.name, weights.name)
3353

34-
3554
def __cinit__(self, config_file, weights_file, clear = True, batch_size = 1):
3655
self._c_network = dn.load_network_custom(config_file.encode(),
3756
weights_file.encode(),
@@ -83,38 +102,28 @@ cdef class Network:
83102
output_shape[0] = self.output_size()
84103
return np.PyArray_SimpleNewFromData(1, output_shape, np.NPY_FLOAT32, output)
85104

86-
def detect(self, frame_size=None,
87-
float threshold=.5, float hierarchical_threshold=.5,
88-
int relative=0, int letterbox=1,
89-
str nms_type="sort", float nms_threshold=.45,
105+
def detect(self,
106+
frame_size=None,
107+
float threshold=.5,
108+
float hierarchical_threshold=.5,
109+
int relative=0,
110+
int letterbox=1,
111+
str nms_type="sort",
112+
float nms_threshold=.45,
90113
):
91-
frame_size = self.shape if frame_size is None else frame_size
114+
pred_width, pred_height = self.shape if frame_size is None else frame_size
115+
92116
cdef int num_dets = 0
93117
cdef dn.detection* detections
94-
95118
detections = dn.get_network_boxes(self._c_network,
96-
frame_size[0], frame_size[1],
97-
threshold, hierarchical_threshold,
119+
pred_width,
120+
pred_height,
121+
threshold,
122+
hierarchical_threshold,
98123
<int*>0,
99124
relative,
100125
&num_dets,
101126
letterbox)
102-
103-
if nms_threshold > 0 and num_dets:
104-
if nms_type == "obj":
105-
dn.do_nms_obj(detections, num_dets, detections[0].classes, nms_threshold)
106-
elif nms_type == "sort":
107-
dn.do_nms_sort(detections, num_dets, detections[0].classes, nms_threshold)
108-
else:
109-
raise ValueError(f"non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}")
110-
111-
rv = [
112-
(j, detections[i].prob[j],
113-
(detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h))
114-
for i in range(num_dets)
115-
for j in range(detections[i].classes)
116-
if detections[i].prob[j] > 0
117-
]
118-
127+
rv = convert_detections_to_tuples(detections, num_dets, nms_type, nms_threshold)
119128
dn.free_detections(detections, num_dets)
120129
return sorted(rv, key=lambda x: x[1], reverse=True)

0 commit comments

Comments
 (0)