forked from hpi-xnor/BMXNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdorefa_ops.py
69 lines (61 loc) · 1.53 KB
/
dorefa_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import mxnet as mx
import numpy as np
from math_ops import *
def get_dorefa(nbit_w, nbit_a, nbit_g):
"""
implements a dorefa style quantization functions fw, fa, fg, for weights,
activations and gradients respectively
param:
nbit_w: bit of weights
nbit_a: bit of activation
nbit_g: bit of gradient
"""
def quantize(x, k):
"""
Implements k-bit quatization function
x: input tensor
k: k-bit quatization
"""
n = float(2**k-1)
x_q = mx.sym.Custom(data=x*n, op_type='around') / n
return x_q
def binary_sign(x):
"""
- clip input tensor to [0, 1]
- round it to {0, 1}
- convert to {1, -1}
"""
x_1_0 = mx.sym.Custom(data=x, op_type='clip_by_0_1')
x_round = mx.sym.Custom(data=x_1_0, op_type='around')
binary_w = x_round * 2 - 1
return binary_w
def qua_w(x):
"""
quantization function for weights
x: input tensor
"""
#32 bit
if nbit_w == 32:
return x
# 1 bit
if nbit_w == 1:
#with scaling factor E
E = mx.sym.Custom(data=mx.sym.abs(x), op_type='pro_channel_reduce_mean')
binary_w = binary_sign(x/E)*E
#binary_w = binary_sign(x)
#binary_w = mx.sym.Custom(data=binary_w, op_type='debug')
return binary_w
# otherwise
x = mx.sym.Activation(data=x, act_type="tanh")
x = x / mx.sym.Custom(data=mx.sym.abs(x), op_type='amax') * 0.5 + 0.5
return 2 * quantize(x, nbit_w) - 1
def qua_a(x):
if nbit_a == 32:
return x
if nbit_a == 1:
return binary_sign(x)
return quantize(x, nbit_a)
def qua_g(x):
#if nbit_g == 32:
return x
return qua_w, qua_a, qua_g