-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgru.py
73 lines (61 loc) · 2.9 KB
/
gru.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
class GRU(nn.Module):
"""
Реализация GRU слоя нейронной сети.
Расчет скрытого состояния проводится по формуле
z_t = sigmoid(W_z @ x_t + b_iz + W_hz @ h_(t-1) + b_hz)
r_t = sigmoid(W_r @ x_t + b_ir + W_hr @ h_(t-1) + b_hr)
n_t = tanh(W_in @ x_t + b_in + r_t * (W_hh @ h_(t-1) + b_hh))
h_t = (1 - z_t) * n_t + z_t * h_(t-1)
h_t - скрытое состояние в момент времени t
x_t - вход в момент времени t
h_(t-1) - скрытое состояние предыдущего слоя в момент времени (t-1)
r_t - reset gate
z_t - update gate
n_t - new gate
Можно задать тип инициализации весов ('orthogonal', 'uniform').
"""
def __init__(self, input_dim: int, hidden_dim: int, init_type: str = 'orthogonal'):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.W_zh = nn.Parameter(torch.rand(hidden_dim, hidden_dim))
self.b_zh = nn.Parameter(torch.rand(1, hidden_dim))
self.W_zx = nn.Parameter(torch.rand(input_dim, hidden_dim))
self.b_zx = nn.Parameter(torch.rand(1, hidden_dim))
self.W_rh = nn.Parameter(torch.rand(hidden_dim, hidden_dim))
self.b_rh = nn.Parameter(torch.rand(1, hidden_dim))
self.W_rx = nn.Parameter(torch.rand(input_dim, hidden_dim))
self.b_rx = nn.Parameter(torch.rand(1, hidden_dim))
self.W_nh = nn.Parameter(torch.rand(hidden_dim, hidden_dim))
self.b_nh = nn.Parameter(torch.rand(1, hidden_dim))
self.W_nx = nn.Parameter(torch.rand(input_dim, hidden_dim))
self.b_nx = nn.Parameter(torch.rand(1, hidden_dim))
self.init_parameters(init_type)
def init_parameters(self, init_type: str):
std = 1.0 / self.hidden_dim**0.5
for param in self.parameters():
if init_type == 'uniform':
nn.init.uniform_(param, -std, std)
elif init_type == 'orthogonal':
nn.init.orthogonal_(param)
else:
raise NotImplementedError
def forward(self, x: torch.Tensor, hidden=None):
# x = [batch_size, seq_len, embed_dim]
device = x.device
if hidden is None:
hidden = torch.zeros((x.size(0), self.hidden_dim), device=device)
for idx in range(x.size(1)):
z = torch.sigmoid(
x[:, idx] @ self.W_zx + self.b_zx + hidden @ self.W_zh + self.b_zh
)
r = torch.sigmoid(
x[:, idx] @ self.W_rx + self.b_rx + hidden @ self.W_rh + self.b_rh
)
n = torch.tanh(
x[:, idx] @ self.W_nx + self.b_nx + r * (hidden @ self.W_nh + self.b_nh)
)
hidden = (1 - z) * n + z * hidden
return hidden