-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"cells": [], | ||
"metadata": {}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"cells": [], | ||
"metadata": {}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import torch | ||
import torch.nn as nn | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from tqdm import tqdm | ||
from sklearn.decomposition import PCA | ||
from sklearn.linear_model import LinearRegression | ||
from sklearn.cross_decomposition import CCA | ||
|
||
seed = 0 | ||
torch.manual_seed(seed) | ||
|
||
|
||
class MLP(nn.Module): | ||
|
||
def __init__(self, width, act='relu', save_act=True, seed=0, device='cpu'): | ||
super(MLP, self).__init__() | ||
|
||
torch.manual_seed(seed) | ||
|
||
linears = [] | ||
self.width = width | ||
self.depth = depth = len(width) - 1 | ||
for i in range(depth): | ||
layer = nn.Linear(width[i], width[i+1]) | ||
'''sm = sparse_mask(width[i], width[i+1]).T | ||
layer.weight.data *= sm * torch.sqrt(torch.tensor(width[i],))''' | ||
linears.append(layer) | ||
self.linears = nn.ModuleList(linears) | ||
|
||
if act == 'silu': | ||
self.act_fun = torch.nn.SiLU() | ||
elif act == 'relu': | ||
self.act_fun = torch.nn.ReLU() | ||
elif act == 'identity': | ||
self.act_fun = torch.nn.Identity() | ||
self.save_act = save_act | ||
self.device = device | ||
|
||
@property | ||
def w(self): | ||
return [self.linears[l].weight for l in range(self.depth)] | ||
|
||
def forward(self, x): | ||
|
||
|
||
for i in range(self.depth): | ||
|
||
x = self.linears[i](x) | ||
if i < self.depth - 1: | ||
x = self.act_fun(x) | ||
|
||
return x | ||
|
||
|
||
def fit(self, dataset, opt="LBFGS", steps=100, log=1, mask=None, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1, metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None, save_ckpt=False, save_freq=1, save_folder='ckpt'): | ||
|
||
pbar = tqdm(range(steps), desc='description', ncols=100) | ||
|
||
if loss_fn == None: | ||
if mask == None: | ||
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2) | ||
else: | ||
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2 * mask) | ||
else: | ||
loss_fn = loss_fn_eval = loss_fn | ||
|
||
if opt == "Adam": | ||
optimizer = torch.optim.Adam(self.parameters(), lr=lr, betas=(0.9,0.999)) | ||
elif opt == "LBFGS": | ||
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) | ||
|
||
|
||
results = {} | ||
results['train_loss'] = [] | ||
results['test_loss'] = [] | ||
if metrics != None: | ||
for i in range(len(metrics)): | ||
results[metrics[i].__name__] = [] | ||
|
||
if batch == -1 or batch > dataset['train_input'].shape[0] or batch > dataset['test_input'].shape[0]: | ||
print('using full batch') | ||
batch_size = dataset['train_input'].shape[0] | ||
batch_size_test = dataset['test_input'].shape[0] | ||
else: | ||
batch_size = batch | ||
batch_size_test = batch | ||
|
||
global train_loss, reg_ | ||
|
||
def closure(): | ||
global train_loss, reg_ | ||
optimizer.zero_grad() | ||
pred = self.forward(dataset['train_input'][train_id].to(self.device)) | ||
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device)) | ||
reg_ = torch.tensor(0.) | ||
objective = train_loss + lamb * reg_ | ||
if opt == 'LBFGS': | ||
objective.backward() | ||
return objective | ||
|
||
for _ in pbar: | ||
|
||
if save_ckpt and _ % save_freq == 0: | ||
torch.save(self.state_dict(), f'./{save_folder}/{_}') | ||
|
||
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) | ||
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) | ||
|
||
if opt == "LBFGS": | ||
optimizer.step(closure) | ||
|
||
elif opt == "Adam": | ||
pred = self.forward(dataset['train_input'][train_id].to(self.device)) | ||
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device)) | ||
reg_ = torch.tensor(0.) | ||
loss = train_loss + lamb * reg_ | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device)) | ||
|
||
|
||
if metrics != None: | ||
for i in range(len(metrics)): | ||
results[metrics[i].__name__].append(metrics[i]().item()) | ||
|
||
results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy()) | ||
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy()) | ||
|
||
if _ % log == 0: | ||
if display_metrics == None: | ||
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy())) | ||
else: | ||
string = '' | ||
data = () | ||
for metric in display_metrics: | ||
string += f' {metric}: %.2e |' | ||
try: | ||
results[metric] | ||
except: | ||
raise Exception(f'{metric} not recognized') | ||
data += (results[metric][-1],) | ||
pbar.set_description(string % data) | ||
|
||
return results | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
""" | ||
Adapted from: https://pytorch.org/docs/1.6.0/_modules/torch/optim/adam.html | ||
""" | ||
import math | ||
import torch | ||
from torch.optim import Optimizer | ||
|
||
|
||
def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): | ||
if step < warmup: | ||
a = step / float(warmup) | ||
return (1.0-a) * alpha_start + a * alpha_end | ||
return alpha_end | ||
|
||
|
||
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): | ||
|
||
def f(beta, eps=1e-8): | ||
return math.log(0.5)/math.log(beta+eps)-1 | ||
|
||
def f_inv(t): | ||
return math.pow(0.5, 1/(t+1)) | ||
|
||
if step < warmup: | ||
a = step / float(warmup) | ||
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end)) | ||
return beta_end | ||
|
||
|
||
class AdEMAMix(Optimizer): | ||
r"""Implements the AdEMAMix algorithm. | ||
Arguments: | ||
params (iterable): iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr (float, optional): learning rate (default: 1e-3) | ||
betas (Tuple[float, float, float], optional): coefficients used for computing | ||
running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) | ||
corresponding to beta_1, beta_2, beta_3 in AdEMAMix | ||
alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2) | ||
beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None) | ||
alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None) | ||
eps (float, optional): term added to the denominator to improve | ||
numerical stability (default: 1e-8) | ||
weight_decay (float, optional): weight decay as in AdamW (default: 0) | ||
""" | ||
|
||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9999), alpha=2.0, | ||
beta3_warmup=None, alpha_warmup=None, eps=1e-8, | ||
weight_decay=0): | ||
if not 0.0 <= lr: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if not 0.0 <= eps: | ||
raise ValueError("Invalid epsilon value: {}".format(eps)) | ||
if not 0.0 <= betas[0] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||
if not 0.0 <= betas[1] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||
if not 0.0 <= betas[2] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) | ||
if not 0.0 <= weight_decay: | ||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | ||
if not 0.0 <= alpha: | ||
raise ValueError("Invalid alpha value: {}".format(alpha)) | ||
defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup, | ||
alpha_warmup=alpha_warmup, weight_decay=weight_decay) | ||
super(AdEMAMix, self).__init__(params, defaults) | ||
|
||
def __setstate__(self, state): | ||
super(AdEMAMix, self).__setstate__(state) | ||
|
||
@torch.no_grad() | ||
def step(self, closure=None): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
|
||
lr = group["lr"] | ||
lmbda = group["weight_decay"] | ||
eps = group["eps"] | ||
beta1, beta2, beta3_final = group["betas"] | ||
beta3_warmup = group["beta3_warmup"] | ||
alpha_final = group["alpha"] | ||
alpha_warmup = group["alpha_warmup"] | ||
|
||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
grad = p.grad | ||
if grad.is_sparse: | ||
raise RuntimeError('AdEMAMix does not support sparse gradients.') | ||
|
||
state = self.state[p] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
state['step'] = 0 | ||
# Exponential moving average of gradient values | ||
if beta1 != 0.0: # save memory in case beta1 is 0.0 | ||
state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||
else: | ||
state['exp_avg_fast'] = None | ||
state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||
# Exponential moving average of squared gradient values | ||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||
|
||
exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq'] | ||
|
||
state['step'] += 1 | ||
bias_correction1 = 1 - beta1 ** state['step'] | ||
bias_correction2 = 1 - beta2 ** state['step'] | ||
|
||
# Compute the effective alpha and beta3 in case warmup is used | ||
if alpha_warmup is not None: | ||
alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup) | ||
else: | ||
alpha = alpha_final | ||
|
||
if beta3_warmup is not None: | ||
beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup) | ||
else: | ||
beta3 = beta3_final | ||
|
||
# Decay the first and second moment running average coefficient | ||
if beta1 != 0.0: | ||
exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1) | ||
else: | ||
exp_avg_fast = grad | ||
exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3) | ||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | ||
|
||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) | ||
|
||
update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom | ||
|
||
# decay | ||
update.add_(p, alpha=lmbda) | ||
|
||
p.add_(-lr * update) | ||
|
||
return loss | ||
|
||
|
||
if __name__ == "__main__": # small dummy test | ||
|
||
x = torch.randn((10,7)) | ||
model = torch.nn.Linear(7, 1, bias=False) | ||
opt = AdEMAMix(params=model.parameters(), lr=1e-2, betas=(0.9, 0.999, 0.9999), alpha=2.0, beta3_warmup=45, alpha_warmup=45, weight_decay=0.1) | ||
print(model.weight) | ||
for itr in range(50): | ||
y = model(x).mean() | ||
opt.zero_grad() | ||
y.backward() | ||
opt.step() | ||
|
||
print(model.weight) |