@@ -10,6 +10,26 @@ from .util import fsspec_cache_open
10
10
11
11
np.import_array()
12
12
13
+ cdef convert_detections_to_tuples(dn.detection* detections, int num_dets, str nms_type, float nms_threshold):
14
+ if nms_threshold > 0 and num_dets > 0 :
15
+ if nms_type == " obj" :
16
+ dn.do_nms_obj(detections, num_dets, detections[0 ].classes, nms_threshold)
17
+ elif nms_type == " sort" :
18
+ dn.do_nms_sort(detections, num_dets, detections[0 ].classes, nms_threshold)
19
+ else :
20
+ raise ValueError (f" non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}" )
21
+ rv = [
22
+ (j,
23
+ detections[i].prob[j],
24
+ (detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h)
25
+ )
26
+ for i in range (num_dets)
27
+ for j in range (detections[i].classes)
28
+ if detections[i].prob[j] > 0
29
+ ]
30
+ return sorted (rv, key = lambda x : x[1 ], reverse = True )
31
+
32
+
13
33
cdef class Metadata:
14
34
classes = [] # typing: List[AnyStr]
15
35
@@ -26,15 +46,16 @@ cdef class Network:
26
46
cdef dn.network* _c_network
27
47
28
48
@staticmethod
29
- def open (config_url , weights_url ):
49
+ def open (config_url , weights_url , batch_size = 1 ):
30
50
with fsspec_cache_open(config_url, mode = " rt" ) as config:
31
51
with fsspec_cache_open(weights_url, mode = " rb" ) as weights:
32
- return Network(config.name, weights.name)
33
-
52
+ return Network(config.name, weights.name, batch_size)
34
53
35
- def __cinit__ (self , config_file , weights_file ):
36
- clear = 1
37
- self ._c_network = dn.load_network(config_file.encode(), weights_file.encode(), clear)
54
+ def __cinit__ (self , str config_file , str weights_file , int batch_size , bint clear = True ):
55
+ self ._c_network = dn.load_network_custom(config_file.encode(),
56
+ weights_file.encode(),
57
+ clear,
58
+ batch_size)
38
59
if self ._c_network is NULL :
39
60
raise RuntimeError (" Failed to create the DarkNet Network..." )
40
61
@@ -43,10 +64,26 @@ cdef class Network:
43
64
dn.free_network(self ._c_network[0 ])
44
65
free(self ._c_network)
45
66
67
+ @property
68
+ def batch_size (self ):
69
+ return dn.network_batch_size(self ._c_network)
70
+
46
71
@property
47
72
def shape (self ):
48
73
return dn.network_width(self ._c_network), dn.network_height(self ._c_network)
49
74
75
+ @property
76
+ def width (self ):
77
+ return dn.network_width(self ._c_network)
78
+
79
+ @property
80
+ def height (self ):
81
+ return dn.network_height(self ._c_network)
82
+
83
+ @property
84
+ def depth (self ):
85
+ return dn.network_depth(self ._c_network)
86
+
50
87
def input_size (self ):
51
88
return dn.network_input_size(self ._c_network)
52
89
@@ -81,38 +118,79 @@ cdef class Network:
81
118
output_shape[0] = self.output_size()
82
119
return np.PyArray_SimpleNewFromData(1, output_shape , np.NPY_FLOAT32 , output )
83
120
84
- def detect(self , frame_size = None ,
85
- float threshold = .5 , float hierarchical_threshold = .5 ,
86
- int relative = 0 , int letterbox = 1 ,
87
- str nms_type = " sort" , float nms_threshold = .45 ,
121
+ def detect(self ,
122
+ frame_size = None ,
123
+ float threshold = .5 ,
124
+ float hierarchical_threshold = .5 ,
125
+ int relative = 0 ,
126
+ int letterbox = 1 ,
127
+ str nms_type = " sort" ,
128
+ float nms_threshold = .45 ,
88
129
):
89
- frame_size = self .shape if frame_size is None else frame_size
130
+ pred_width, pred_height = self .shape if frame_size is None else frame_size
131
+
90
132
cdef int num_dets = 0
91
133
cdef dn.detection* detections
92
-
93
134
detections = dn.get_network_boxes(self ._c_network,
94
- frame_size[0 ], frame_size[1 ],
95
- threshold, hierarchical_threshold,
135
+ pred_width,
136
+ pred_height,
137
+ threshold,
138
+ hierarchical_threshold,
96
139
< int * > 0 ,
97
140
relative,
98
141
& num_dets,
99
142
letterbox)
143
+ rv = convert_detections_to_tuples(detections, num_dets, nms_type, nms_threshold)
144
+ dn.free_detections(detections, num_dets)
100
145
101
- if nms_threshold > 0 and num_dets:
102
- if nms_type == " obj" :
103
- dn.do_nms_obj(detections, num_dets, detections[0 ].classes, nms_threshold)
104
- elif nms_type == " sort" :
105
- dn.do_nms_sort(detections, num_dets, detections[0 ].classes, nms_threshold)
106
- else :
107
- raise ValueError (f" non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}" )
146
+ return rv
108
147
148
+ def detect_batch (self ,
149
+ np.ndarray[dtype = np.float32_t, ndim = 1 , mode = " c" ] frames,
150
+ frame_size = None ,
151
+ float threshold = .5 ,
152
+ float hierarchical_threshold = .5 ,
153
+ int relative = 0 ,
154
+ int letterbox = 1 ,
155
+ str nms_type = " sort" ,
156
+ float nms_threshold = .45
157
+ ):
158
+ pred_width, pred_height = self .shape if frame_size is None else frame_size
159
+
160
+ cdef dn.image imr
161
+ # This looks awkward, but the batch predict *does not* use c, w, h.
162
+ imr.c = 0
163
+ imr.w = 0
164
+ imr.h = 0
165
+ imr.data = < float * > frames.data
166
+
167
+ if frames.size % self .input_size() != 0 :
168
+ raise TypeError (" The frames array is not divisible by network input size. "
169
+ f" ({frames.size} % {self.input_size()} != 0)" )
170
+
171
+ num_frames = frames.size // self .input_size()
172
+ if num_frames > self .batch_size:
173
+ raise TypeError (" There are more frames than the configured batch size. "
174
+ f" ({num_frames} > {self.batch_size})" )
175
+
176
+ cdef dn.det_num_pair* batch_detections
177
+ batch_detections = dn.network_predict_batch(
178
+ self ._c_network,
179
+ imr,
180
+ num_frames,
181
+ pred_width,
182
+ pred_height,
183
+ threshold,
184
+ hierarchical_threshold,
185
+ < int * > 0 ,
186
+ relative,
187
+ letterbox
188
+ )
109
189
rv = [
110
- (j, detections[i].prob[j],
111
- (detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h))
112
- for i in range (num_dets)
113
- for j in range (detections[i].classes)
114
- if detections[i].prob[j] > 0
190
+ convert_detections_to_tuples(batch_detections[b].dets, batch_detections[b].num, nms_type, nms_threshold)
191
+ for b in range (num_frames)
115
192
]
193
+ dn.free_batch_detections(batch_detections, num_frames)
194
+ return rv
195
+
116
196
117
- dn.free_detections(detections, num_dets)
118
- return sorted (rv, key = lambda x : x[1 ], reverse = True )
0 commit comments