Skip to content

Commit 73822d7

Browse files
committed
Support batch_size when loading the neural network
1 parent 5e0bde7 commit 73822d7

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/darknet/py/network.pyx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ cdef class Network:
3232
return Network(config.name, weights.name)
3333

3434

35-
def __cinit__(self, config_file, weights_file):
36-
clear = 1
37-
self._c_network = dn.load_network(config_file.encode(), weights_file.encode(), clear)
35+
def __cinit__(self, config_file, weights_file, clear = True, batch_size = 1):
36+
self._c_network = dn.load_network_custom(config_file.encode(),
37+
weights_file.encode(),
38+
1 if clear else 0,
39+
batch_size)
3840
if self._c_network is NULL:
3941
raise RuntimeError("Failed to create the DarkNet Network...")
4042

src/libdarknet/__init__.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ cdef extern from "darknet.h":
5959
pass
6060

6161
network* load_network(char* cfg_filename, char* weights_filename, int clear)
62+
network* load_network_custom(char* cfg_filename, char* weights_filename, int clear, int batch_size)
6263
void free_network(network self)
6364

6465
int network_width(network *self);

0 commit comments

Comments
 (0)