-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrmsprop.py
57 lines (53 loc) · 2.73 KB
/
rmsprop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Optional, Callable
import torch
class RMSprop(torch.optim.Optimizer):
"""
Реализация метода RMSprop для нейронной сети.
Метод представляет собой доработку AdaGrad в виде добавления скользящего среднего
по параметру v
"""
def __init__(self, params, lr: float = 1e-2, alpha: float = 0.8, eps: float = 1e-6):
"""
Метод RMSprop: AdaGrad + Moving Average
:param params: параметры модели для пересчета градиента
:param lr: learning rate
:param alpha: коэффициент для скользящего среднего
:param eps: защита от деления на 0
"""
assert (alpha > 0) and (alpha < 1.0)
# поскольку params - это генератор, его нужно сохранить для обхода в нескольких вызовах
model_params = list(params)
opt_params = {
'lr': lr,
'alpha': alpha,
'eps': eps,
}
self.v = [torch.zeros_like(param) for param in model_params]
super().__init__(model_params, opt_params)
@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]:
"""
Расчет итеративного обновления весов по формуле:
v_(t+1) = alpha * v_(t) + (1 - alpha) * grad(L(w_(t)))**2
w_(t+1) = w_(t) - lr * grad(L(w_(t))) / sqrt(v_(t+1) + eps),
w_(t+1) - значения весов на следующей итерации (t+1)
w_(t) - значения весов на итерации t
v_(t+1) - нормировка темпа сходимости на шаге (t+1)
v_(0) - матрица из нулей размерности grad(L(w_(t)))
lr - learning rate
alpha - влияние предыдущего значения v
eps - малая константа для защиты от деления на нуль
grad(*) - градиент от *
L(w_(t)) - функция потерь от w_(t)
:param closure:
:return:
"""
for group in self.param_groups:
lr = group['lr']
alpha = group['alpha']
eps = group['eps']
for i, param in enumerate(group['params']):
if param.grad is not None:
self.v[i] = alpha * self.v[i] + (1 - alpha) * torch.square(param.grad)
param.data += (- lr * param.grad / torch.sqrt(self.v[i] + eps))
return