-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcompute_flops.py
125 lines (96 loc) · 4.19 KB
/
compute_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def net_flops(model, table=False):
t_flops = 0
t_macc = 0
for l in model.layers:
o_shape, i_shape, strides, ks, filters = ['', '', ''], ['', '', ''], [1, 1], [0, 0], [0, 0]
flops = 0
macc = 0
name = l.name
factor = 1000000
if ('InputLayer' in str(l)):
i_shape = l.input.get_shape()[1:4].as_list()
o_shape = i_shape
if ('Reshape' in str(l)):
i_shape = l.input.get_shape()[1:4].as_list()
o_shape = l.output.get_shape()[1:4].as_list()
if ('Add' in str(l) or 'Maximum' in str(l) or 'Concatenate' in str(l)):
i_shape = l.input[0].get_shape()[1:4].as_list() + [len(l.input)]
o_shape = l.output.get_shape()[1:4].as_list()
flops = (len(l.input) - 1) * i_shape[0] * i_shape[1] * i_shape[2]
if ('Average' in str(l) and 'pool' not in str(l)):
i_shape = l.input[0].get_shape()[1:4].as_list() + [len(l.input)]
o_shape = l.output.get_shape()[1:4].as_list()
flops = len(l.input) * i_shape[0] * i_shape[1] * i_shape[2]
if ('BatchNormalization' in str(l)):
i_shape = l.input.get_shape()[1:4].as_list()
o_shape = l.output.get_shape()[1:4].as_list()
bflops = 1
for i in range(len(i_shape)):
bflops *= i_shape[i]
flops /= factor
if ('Activation' in str(l) or 'activation' in str(l)):
i_shape = l.input.get_shape()[1:4].as_list()
o_shape = l.output.get_shape()[1:4].as_list()
bflops = 1
for i in range(len(i_shape)):
bflops *= i_shape[i]
flops /= factor
if ('pool' in str(l) and ('Global' not in str(l))):
i_shape = l.input.get_shape()[1:4].as_list()
strides = l.strides
ks = l.pool_size
flops = ((i_shape[0] / strides[0]) * (i_shape[1] / strides[1]) * (ks[0] * ks[1] * i_shape[2]))
if ('Flatten' in str(l)):
i_shape = l.input.shape[1:4].as_list()
flops = 1
out_vec = 1
for i in range(len(i_shape)):
flops *= i_shape[i]
out_vec *= i_shape[i]
o_shape = flops
flops = 0
if ('Dense' in str(l)):
print(l.input)
i_shape = l.input.shape[1:4].as_list()[0]
if (i_shape == None):
i_shape = out_vec
o_shape = l.output.shape[1:4].as_list()
flops = 2 * (o_shape[0] * i_shape)
macc = flops / 2
if ('Padding' in str(l)):
flops = 0
if (('Global' in str(l))):
i_shape = l.input.get_shape()[1:4].as_list()
flops = ((i_shape[0]) * (i_shape[1]) * (i_shape[2]))
o_shape = [l.output.get_shape()[1:4].as_list(), 1, 1]
out_vec = o_shape
if ('Conv2D ' in str(l) and 'DepthwiseConv2D' not in str(l) and 'SeparableConv2D' not in str(l)):
strides = l.strides
ks = l.kernel_size
filters = l.filters
i_shape = l.input.get_shape()[1:4].as_list()
o_shape = l.output.get_shape()[1:4].as_list()
if (filters == None):
filters = i_shape[2]
flops = 2 * ((filters * ks[0] * ks[1] * i_shape[2]) * (
(i_shape[0] / strides[0]) * (i_shape[1] / strides[1])))
macc = flops / 2
if ('Conv2D ' in str(l) and 'DepthwiseConv2D' in str(l) and 'SeparableConv2D' not in str(l)):
strides = l.strides
ks = l.kernel_size
filters = l.filters
i_shape = l.input.get_shape()[1:4].as_list()
o_shape = l.output.get_shape()[1:4].as_list()
if (filters == None):
filters = i_shape[2]
flops = 2 * (
(ks[0] * ks[1] * i_shape[2]) * ((i_shape[0] / strides[0]) * (i_shape[1] / strides[1])))
macc = flops / 2
t_macc += macc
t_flops += flops
t_flops = t_flops / factor
print('\nTotal FLOPS (x 10^6): %10.8f\n' % (t_flops))
print('\nTotal MACCs: %10.8f\n' % (t_macc))
return
##### Run function #######
net_flops(model,table=True)