Skip to content

Commit 2031c23

Browse files
authored
Merge pull request #7 from zeroae/f/override-batch-size
Support batch_size when loading the neural network
2 parents 5e0bde7 + 91dbf70 commit 2031c23

File tree

5 files changed

+368
-34
lines changed

5 files changed

+368
-34
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies:
2323
- click-plugins
2424
- darknet
2525
- entrypoints
26-
- fsspec
26+
- fsspec <=0.7.5
2727
- numpy
2828
- pillow
2929

examples/batch.ipynb

Lines changed: 238 additions & 0 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"click>=7.0",
1515
"click-plugins",
1616
"entrypoints",
17-
"fsspec",
17+
"fsspec <=0.7.5",
1818
"numpy",
1919
"pillow",
2020
# fmt: on

src/darknet/py/network.pyx

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@ from .util import fsspec_cache_open
1010

1111
np.import_array()
1212

13+
cdef convert_detections_to_tuples(dn.detection* detections, int num_dets, str nms_type, float nms_threshold):
14+
if nms_threshold > 0 and num_dets > 0:
15+
if nms_type == "obj":
16+
dn.do_nms_obj(detections, num_dets, detections[0].classes, nms_threshold)
17+
elif nms_type == "sort":
18+
dn.do_nms_sort(detections, num_dets, detections[0].classes, nms_threshold)
19+
else:
20+
raise ValueError(f"non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}")
21+
rv = [
22+
(j,
23+
detections[i].prob[j],
24+
(detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h)
25+
)
26+
for i in range(num_dets)
27+
for j in range(detections[i].classes)
28+
if detections[i].prob[j] > 0
29+
]
30+
return sorted(rv, key=lambda x: x[1], reverse=True)
31+
32+
1333
cdef class Metadata:
1434
classes = [] # typing: List[AnyStr]
1535

@@ -26,15 +46,16 @@ cdef class Network:
2646
cdef dn.network* _c_network
2747

2848
@staticmethod
29-
def open(config_url, weights_url):
49+
def open(config_url, weights_url, batch_size=1):
3050
with fsspec_cache_open(config_url, mode="rt") as config:
3151
with fsspec_cache_open(weights_url, mode="rb") as weights:
32-
return Network(config.name, weights.name)
33-
52+
return Network(config.name, weights.name, batch_size)
3453

35-
def __cinit__(self, config_file, weights_file):
36-
clear = 1
37-
self._c_network = dn.load_network(config_file.encode(), weights_file.encode(), clear)
54+
def __cinit__(self, str config_file, str weights_file, int batch_size, bint clear=True):
55+
self._c_network = dn.load_network_custom(config_file.encode(),
56+
weights_file.encode(),
57+
clear,
58+
batch_size)
3859
if self._c_network is NULL:
3960
raise RuntimeError("Failed to create the DarkNet Network...")
4061

@@ -43,10 +64,26 @@ cdef class Network:
4364
dn.free_network(self._c_network[0])
4465
free(self._c_network)
4566

67+
@property
68+
def batch_size(self):
69+
return dn.network_batch_size(self._c_network)
70+
4671
@property
4772
def shape(self):
4873
return dn.network_width(self._c_network), dn.network_height(self._c_network)
4974

75+
@property
76+
def width(self):
77+
return dn.network_width(self._c_network)
78+
79+
@property
80+
def height(self):
81+
return dn.network_height(self._c_network)
82+
83+
@property
84+
def depth(self):
85+
return dn.network_depth(self._c_network)
86+
5087
def input_size(self):
5188
return dn.network_input_size(self._c_network)
5289

@@ -81,38 +118,79 @@ cdef class Network:
81118
output_shape[0] = self.output_size()
82119
return np.PyArray_SimpleNewFromData(1, output_shape, np.NPY_FLOAT32, output)
83120

84-
def detect(self, frame_size=None,
85-
float threshold=.5, float hierarchical_threshold=.5,
86-
int relative=0, int letterbox=1,
87-
str nms_type="sort", float nms_threshold=.45,
121+
def detect(self,
122+
frame_size=None,
123+
float threshold=.5,
124+
float hierarchical_threshold=.5,
125+
int relative=0,
126+
int letterbox=1,
127+
str nms_type="sort",
128+
float nms_threshold=.45,
88129
):
89-
frame_size = self.shape if frame_size is None else frame_size
130+
pred_width, pred_height = self.shape if frame_size is None else frame_size
131+
90132
cdef int num_dets = 0
91133
cdef dn.detection* detections
92-
93134
detections = dn.get_network_boxes(self._c_network,
94-
frame_size[0], frame_size[1],
95-
threshold, hierarchical_threshold,
135+
pred_width,
136+
pred_height,
137+
threshold,
138+
hierarchical_threshold,
96139
<int*>0,
97140
relative,
98141
&num_dets,
99142
letterbox)
143+
rv = convert_detections_to_tuples(detections, num_dets, nms_type, nms_threshold)
144+
dn.free_detections(detections, num_dets)
100145

101-
if nms_threshold > 0 and num_dets:
102-
if nms_type == "obj":
103-
dn.do_nms_obj(detections, num_dets, detections[0].classes, nms_threshold)
104-
elif nms_type == "sort":
105-
dn.do_nms_sort(detections, num_dets, detections[0].classes, nms_threshold)
106-
else:
107-
raise ValueError(f"non-maximum-suppression type {nms_type} is not one of {['obj', 'sort']}")
146+
return rv
108147

148+
def detect_batch(self,
149+
np.ndarray[dtype=np.float32_t, ndim=1, mode="c"] frames,
150+
frame_size=None,
151+
float threshold=.5,
152+
float hierarchical_threshold=.5,
153+
int relative=0,
154+
int letterbox=1,
155+
str nms_type="sort",
156+
float nms_threshold=.45
157+
):
158+
pred_width, pred_height = self.shape if frame_size is None else frame_size
159+
160+
cdef dn.image imr
161+
# This looks awkward, but the batch predict *does not* use c, w, h.
162+
imr.c = 0
163+
imr.w = 0
164+
imr.h = 0
165+
imr.data = <float *> frames.data
166+
167+
if frames.size % self.input_size() != 0:
168+
raise TypeError("The frames array is not divisible by network input size. "
169+
f"({frames.size} % {self.input_size()} != 0)")
170+
171+
num_frames = frames.size // self.input_size()
172+
if num_frames > self.batch_size:
173+
raise TypeError("There are more frames than the configured batch size. "
174+
f"({num_frames} > {self.batch_size})")
175+
176+
cdef dn.det_num_pair* batch_detections
177+
batch_detections = dn.network_predict_batch(
178+
self._c_network,
179+
imr,
180+
num_frames,
181+
pred_width,
182+
pred_height,
183+
threshold,
184+
hierarchical_threshold,
185+
<int*>0,
186+
relative,
187+
letterbox
188+
)
109189
rv = [
110-
(j, detections[i].prob[j],
111-
(detections[i].bbox.x, detections[i].bbox.y, detections[i].bbox.w, detections[i].bbox.h))
112-
for i in range(num_dets)
113-
for j in range(detections[i].classes)
114-
if detections[i].prob[j] > 0
190+
convert_detections_to_tuples(batch_detections[b].dets, batch_detections[b].num, nms_type, nms_threshold)
191+
for b in range(num_frames)
115192
]
193+
dn.free_batch_detections(batch_detections, num_frames)
194+
return rv
195+
116196

117-
dn.free_detections(detections, num_dets)
118-
return sorted(rv, key=lambda x: x[1], reverse=True)

src/libdarknet/__init__.pxd

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@ cdef extern from "darknet.h":
55
/*
66
* darknet.h forgot to extern some useful network functions
77
*/
8+
static int network_depth(network* net) {
9+
return net->c;
10+
}
11+
static int network_batch_size(network* net) {
12+
return net->batch;
13+
}
814
static int network_input_size(network* net) {
9-
return net->layers[0].inputs;
15+
return net->layers[0].inputs;
1016
}
1117
static int network_output_size(network* net) {
12-
int i;
13-
for(i = net->n-1; i > 0; --i) if(net->layers[i].type != COST) break;
14-
return net->layers[i].outputs;
18+
int i;
19+
for(i = net->n-1; i > 0; --i) if(net->layers[i].type != COST) break;
20+
return net->layers[i].outputs;
1521
}
1622
"""
1723

@@ -51,6 +57,13 @@ cdef extern from "darknet.h":
5157

5258
void free_detections(detection* detections, int len)
5359

60+
ctypedef struct det_num_pair:
61+
int num;
62+
detection* dets;
63+
64+
void free_batch_detections(det_num_pair* det_num_pairs, int len)
65+
66+
5467
void do_nms_sort(detection* detections, int len, int num_classes, float thresh)
5568
void do_nms_obj(detection* detections, int len, int num_classes, float thresh)
5669

@@ -59,14 +72,19 @@ cdef extern from "darknet.h":
5972
pass
6073

6174
network* load_network(char* cfg_filename, char* weights_filename, int clear)
75+
network* load_network_custom(char* cfg_filename, char* weights_filename, int clear, int batch_size)
6276
void free_network(network self)
6377

78+
int network_batch_size(network *self);
6479
int network_width(network *self);
6580
int network_height(network *self);
81+
int network_depth(network *self);
6682
int network_input_size(network* self);
6783
int network_output_size(network* self);
6884
float* network_predict(network, float* input)
6985
float* network_predict_image(network*, image)
86+
7087
detection* get_network_boxes(network* self, int width, int height, float thresh, float hier_thresh, int* map, int relative, int* out_len, int letter)
88+
det_num_pair* network_predict_batch(network* self, image, int batch_size, int width, int height, float thresh, float hier_thresh, int* map, int relative, int letter)
7189

7290

0 commit comments

Comments
 (0)