-
Notifications
You must be signed in to change notification settings - Fork 53
Expand file tree
/
Copy pathfast_activation.py
More file actions
executable file
·160 lines (124 loc) · 5.54 KB
/
fast_activation.py
File metadata and controls
executable file
·160 lines (124 loc) · 5.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
'''
Functions directly copied from e3nn library.
Speed up some special cases used in GIN and GAT.
'''
import torch
from e3nn import o3
from e3nn.math import normalize2mom
from e3nn.util.jit import compile_mode
@compile_mode('trace')
class Activation(torch.nn.Module):
'''
Directly apply activation when irreps is type-0.
'''
def __init__(self, irreps_in, acts):
super().__init__()
irreps_in = o3.Irreps(irreps_in)
assert len(irreps_in) == len(acts), (irreps_in, acts)
# normalize the second moment
acts = [normalize2mom(act) if act is not None else None for act in acts]
from e3nn.util._argtools import _get_device
irreps_out = []
for (mul, (l_in, p_in)), act in zip(irreps_in, acts):
if act is not None:
if l_in != 0:
raise ValueError("Activation: cannot apply an activation function to a non-scalar input.")
x = torch.linspace(0, 10, 256, device=_get_device(act))
a1, a2 = act(x), act(-x)
if (a1 - a2).abs().max() < 1e-5:
p_act = 1
elif (a1 + a2).abs().max() < 1e-5:
p_act = -1
else:
p_act = 0
p_out = p_act if p_in == -1 else p_in
irreps_out.append((mul, (0, p_out)))
if p_out == 0:
raise ValueError("Activation: the parity is violated! The input scalar is odd but the activation is neither even nor odd.")
else:
irreps_out.append((mul, (l_in, p_in)))
self.irreps_in = irreps_in
self.irreps_out = o3.Irreps(irreps_out)
self.acts = torch.nn.ModuleList(acts)
assert len(self.irreps_in) == len(self.acts)
#def __repr__(self):
# acts = "".join(["x" if a is not None else " " for a in self.acts])
# return f"{self.__class__.__name__} [{self.acts}] ({self.irreps_in} -> {self.irreps_out})"
def extra_repr(self):
output_str = super(Activation, self).extra_repr()
output_str = output_str + '{} -> {}, '.format(self.irreps_in, self.irreps_out)
return output_str
def forward(self, features, dim=-1):
# directly apply activation without narrow
if len(self.acts) == 1:
return self.acts[0](features)
output = []
index = 0
for (mul, ir), act in zip(self.irreps_in, self.acts):
if act is not None:
output.append(act(features.narrow(dim, index, mul)))
else:
output.append(features.narrow(dim, index, mul * ir.dim))
index += mul * ir.dim
if len(output) > 1:
return torch.cat(output, dim=dim)
elif len(output) == 1:
return output[0]
else:
return torch.zeros_like(features)
@compile_mode('script')
class Gate(torch.nn.Module):
'''
1. Use `narrow` to split tensor.
2. Use `Activation` in this file.
'''
def __init__(self, irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated):
super().__init__()
irreps_scalars = o3.Irreps(irreps_scalars)
irreps_gates = o3.Irreps(irreps_gates)
irreps_gated = o3.Irreps(irreps_gated)
if len(irreps_gates) > 0 and irreps_gates.lmax > 0:
raise ValueError(f"Gate scalars must be scalars, instead got irreps_gates = {irreps_gates}")
if len(irreps_scalars) > 0 and irreps_scalars.lmax > 0:
raise ValueError(f"Scalars must be scalars, instead got irreps_scalars = {irreps_scalars}")
if irreps_gates.num_irreps != irreps_gated.num_irreps:
raise ValueError(f"There are {irreps_gated.num_irreps} irreps in irreps_gated, but a different number ({irreps_gates.num_irreps}) of gate scalars in irreps_gates")
#assert len(irreps_scalars) == 1
#assert len(irreps_gates) == 1
self.irreps_scalars = irreps_scalars
self.irreps_gates = irreps_gates
self.irreps_gated = irreps_gated
self._irreps_in = (irreps_scalars + irreps_gates + irreps_gated).simplify()
self.act_scalars = Activation(irreps_scalars, act_scalars)
irreps_scalars = self.act_scalars.irreps_out
self.act_gates = Activation(irreps_gates, act_gates)
irreps_gates = self.act_gates.irreps_out
self.mul = o3.ElementwiseTensorProduct(irreps_gated, irreps_gates)
irreps_gated = self.mul.irreps_out
self._irreps_out = irreps_scalars + irreps_gated
def __repr__(self):
return f"{self.__class__.__name__} ({self.irreps_in} -> {self.irreps_out})"
def forward(self, features):
scalars_dim = self.irreps_scalars.dim
gates_dim = self.irreps_gates.dim
input_dim = self.irreps_in.dim
scalars = features.narrow(-1, 0, scalars_dim)
gates = features.narrow(-1, scalars_dim, gates_dim)
gated = features.narrow(-1, (scalars_dim + gates_dim),
(input_dim - scalars_dim - gates_dim))
scalars = self.act_scalars(scalars)
if gates.shape[-1]:
gates = self.act_gates(gates)
gated = self.mul(gated, gates)
features = torch.cat([scalars, gated], dim=-1)
else:
features = scalars
return features
@property
def irreps_in(self):
"""Input representations."""
return self._irreps_in
@property
def irreps_out(self):
"""Output representations."""
return self._irreps_out