Skip to content

Commit 9fc00bf

Browse files
committed
Add error checking for when frames are not of the expected size, or number of frames is bigger than batch size.
1 parent 9d37862 commit 9fc00bf

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/darknet/py/network.pyx

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ cdef class Network:
4646
cdef dn.network* _c_network
4747

4848
@staticmethod
49-
def open(config_url, weights_url):
49+
def open(config_url, weights_url, batch_size=1):
5050
with fsspec_cache_open(config_url, mode="rt") as config:
5151
with fsspec_cache_open(weights_url, mode="rb") as weights:
52-
return Network(config.name, weights.name)
52+
return Network(config.name, weights.name, batch_size)
5353

54-
def __cinit__(self, config_file, weights_file, clear = True, batch_size = 1):
54+
def __cinit__(self, str config_file, str weights_file, int batch_size, bint clear=True):
5555
self._c_network = dn.load_network_custom(config_file.encode(),
5656
weights_file.encode(),
57-
1 if clear else 0,
57+
clear,
5858
batch_size)
5959
if self._c_network is NULL:
6060
raise RuntimeError("Failed to create the DarkNet Network...")
@@ -146,7 +146,6 @@ cdef class Network:
146146
return rv
147147

148148
def detect_batch(self,
149-
int batch_size,
150149
np.ndarray[dtype=np.float32_t, ndim=1, mode="c"] frames,
151150
frame_size=None,
152151
float threshold=.5,
@@ -165,11 +164,20 @@ cdef class Network:
165164
imr.h = 0
166165
imr.data = <float *> frames.data
167166

167+
if frames.size % self.input_size() != 0:
168+
raise TypeError("The frames array is not divisible by network input size. "
169+
f"({frames.size} % {self.input_size()} != 0)")
170+
171+
num_frames = frames.size // self.input_size()
172+
if num_frames > self.batch_size:
173+
raise TypeError("There are more frames than the configured batch size. "
174+
f"({num_frames} > {self.batch_size})")
175+
168176
cdef dn.det_num_pair* batch_detections
169177
batch_detections = dn.network_predict_batch(
170178
self._c_network,
171179
imr,
172-
batch_size,
180+
num_frames,
173181
pred_width,
174182
pred_height,
175183
threshold,
@@ -180,9 +188,9 @@ cdef class Network:
180188
)
181189
rv = [
182190
convert_detections_to_tuples(batch_detections[b].dets, batch_detections[b].num, nms_type, nms_threshold)
183-
for b in range(batch_size)
191+
for b in range(num_frames)
184192
]
185-
dn.free_batch_detections(batch_detections, batch_size)
193+
dn.free_batch_detections(batch_detections, num_frames)
186194
return rv
187195

188196

0 commit comments

Comments
 (0)