@@ -46,15 +46,15 @@ cdef class Network:
4646 cdef dn.network* _c_network
4747
4848 @staticmethod
49- def open (config_url , weights_url ):
49+ def open (config_url , weights_url , batch_size = 1 ):
5050 with fsspec_cache_open(config_url, mode = " rt" ) as config:
5151 with fsspec_cache_open(weights_url, mode = " rb" ) as weights:
52- return Network(config.name, weights.name)
52+ return Network(config.name, weights.name, batch_size )
5353
54- def __cinit__ (self , config_file , weights_file , clear = True , batch_size = 1 ):
54+ def __cinit__ (self , str config_file , str weights_file , int batch_size , bint clear = True ):
5555 self ._c_network = dn.load_network_custom(config_file.encode(),
5656 weights_file.encode(),
57- 1 if clear else 0 ,
57+ clear,
5858 batch_size)
5959 if self ._c_network is NULL :
6060 raise RuntimeError (" Failed to create the DarkNet Network..." )
@@ -146,7 +146,6 @@ cdef class Network:
146146 return rv
147147
148148 def detect_batch (self ,
149- int batch_size ,
150149 np.ndarray[dtype = np.float32_t, ndim = 1 , mode = " c" ] frames,
151150 frame_size = None ,
152151 float threshold = .5 ,
@@ -165,11 +164,20 @@ cdef class Network:
165164 imr.h = 0
166165 imr.data = < float * > frames.data
167166
167+ if frames.size % self .input_size() != 0 :
168+ raise TypeError (" The frames array is not divisible by network input size. "
169+ f" ({frames.size} % {self.input_size()} != 0)" )
170+
171+ num_frames = frames.size // self .input_size()
172+ if num_frames > self .batch_size:
173+ raise TypeError (" There are more frames than the configured batch size. "
174+ f" ({num_frames} > {self.batch_size})" )
175+
168176 cdef dn.det_num_pair* batch_detections
169177 batch_detections = dn.network_predict_batch(
170178 self ._c_network,
171179 imr,
172- batch_size ,
180+ num_frames ,
173181 pred_width,
174182 pred_height,
175183 threshold,
@@ -180,9 +188,9 @@ cdef class Network:
180188 )
181189 rv = [
182190 convert_detections_to_tuples(batch_detections[b].dets, batch_detections[b].num, nms_type, nms_threshold)
183- for b in range (batch_size )
191+ for b in range (num_frames )
184192 ]
185- dn.free_batch_detections(batch_detections, batch_size )
193+ dn.free_batch_detections(batch_detections, num_frames )
186194 return rv
187195
188196
0 commit comments