Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Jan 19, 2025
0 parents commit bc2aa03
Show file tree
Hide file tree
Showing 34 changed files with 15,282 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
6 changes: 6 additions & 0 deletions .ipynb_checkpoints/Untitled-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
84 changes: 84 additions & 0 deletions .ipynb_checkpoints/domino_model-checkpoint.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions .ipynb_checkpoints/geometry_m-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
205 changes: 205 additions & 0 deletions .ipynb_checkpoints/geometry_model-checkpoint.ipynb

Large diffs are not rendered by default.

147 changes: 147 additions & 0 deletions .ipynb_checkpoints/resource_model-checkpoint.ipynb

Large diffs are not rendered by default.

84 changes: 84 additions & 0 deletions domino_model.ipynb

Large diffs are not rendered by default.

205 changes: 205 additions & 0 deletions geometry_model.ipynb

Large diffs are not rendered by default.

171 changes: 171 additions & 0 deletions resource_model.ipynb

Large diffs are not rendered by default.

905 changes: 905 additions & 0 deletions scripts/Fig10_nsl.ipynb

Large diffs are not rendered by default.

206 changes: 206 additions & 0 deletions scripts/Fig12_quadratic_loss.ipynb

Large diffs are not rendered by default.

525 changes: 525 additions & 0 deletions scripts/Fig13_grokking.ipynb

Large diffs are not rendered by default.

1,313 changes: 1,313 additions & 0 deletions scripts/Fig15_optimizer.ipynb

Large diffs are not rendered by default.

1,014 changes: 1,014 additions & 0 deletions scripts/Fig16_sparse_parity_experiment_theory_compare.ipynb

Large diffs are not rendered by default.

411 changes: 411 additions & 0 deletions scripts/Fig17_task_dependence.ipynb

Large diffs are not rendered by default.

944 changes: 944 additions & 0 deletions scripts/Fig18_nsl.ipynb

Large diffs are not rendered by default.

1,717 changes: 1,717 additions & 0 deletions scripts/Fig19_modularity.ipynb

Large diffs are not rendered by default.

229 changes: 229 additions & 0 deletions scripts/Fig2_bottom_sparse_parity.ipynb

Large diffs are not rendered by default.

1,014 changes: 1,014 additions & 0 deletions scripts/Fig2_top_sparse_parity.ipynb

Large diffs are not rendered by default.

1,647 changes: 1,647 additions & 0 deletions scripts/Fig4_geometry_2task.ipynb

Large diffs are not rendered by default.

782 changes: 782 additions & 0 deletions scripts/Fig5_geometry_ntask.ipynb

Large diffs are not rendered by default.

626 changes: 626 additions & 0 deletions scripts/Fig6_resource_N0.ipynb

Large diffs are not rendered by default.

551 changes: 551 additions & 0 deletions scripts/Fig7_bottom_N0_dependence.ipynb

Large diffs are not rendered by default.

666 changes: 666 additions & 0 deletions scripts/Fig7_top_geometry_resource_compare.ipynb

Large diffs are not rendered by default.

590 changes: 590 additions & 0 deletions scripts/Fig8_bottom_ce_underparam.ipynb

Large diffs are not rendered by default.

484 changes: 484 additions & 0 deletions scripts/Fig8_top_mse_underparam.ipynb

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions scripts/Fig9_domino.ipynb

Large diffs are not rendered by default.

149 changes: 149 additions & 0 deletions scripts/MLP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.cross_decomposition import CCA

seed = 0
torch.manual_seed(seed)


class MLP(nn.Module):

def __init__(self, width, act='relu', save_act=True, seed=0, device='cpu'):
super(MLP, self).__init__()

torch.manual_seed(seed)

linears = []
self.width = width
self.depth = depth = len(width) - 1
for i in range(depth):
layer = nn.Linear(width[i], width[i+1])
'''sm = sparse_mask(width[i], width[i+1]).T
layer.weight.data *= sm * torch.sqrt(torch.tensor(width[i],))'''
linears.append(layer)
self.linears = nn.ModuleList(linears)

if act == 'silu':
self.act_fun = torch.nn.SiLU()
elif act == 'relu':
self.act_fun = torch.nn.ReLU()
elif act == 'identity':
self.act_fun = torch.nn.Identity()
self.save_act = save_act
self.device = device

@property
def w(self):
return [self.linears[l].weight for l in range(self.depth)]

def forward(self, x):


for i in range(self.depth):

x = self.linears[i](x)
if i < self.depth - 1:
x = self.act_fun(x)

return x


def fit(self, dataset, opt="LBFGS", steps=100, log=1, mask=None, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1, metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None, save_ckpt=False, save_freq=1, save_folder='ckpt'):

pbar = tqdm(range(steps), desc='description', ncols=100)

if loss_fn == None:
if mask == None:
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
else:
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2 * mask)
else:
loss_fn = loss_fn_eval = loss_fn

if opt == "Adam":
optimizer = torch.optim.Adam(self.parameters(), lr=lr, betas=(0.9,0.999))
elif opt == "LBFGS":
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)


results = {}
results['train_loss'] = []
results['test_loss'] = []
if metrics != None:
for i in range(len(metrics)):
results[metrics[i].__name__] = []

if batch == -1 or batch > dataset['train_input'].shape[0] or batch > dataset['test_input'].shape[0]:
print('using full batch')
batch_size = dataset['train_input'].shape[0]
batch_size_test = dataset['test_input'].shape[0]
else:
batch_size = batch
batch_size_test = batch

global train_loss, reg_

def closure():
global train_loss, reg_
optimizer.zero_grad()
pred = self.forward(dataset['train_input'][train_id].to(self.device))
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
reg_ = torch.tensor(0.)
objective = train_loss + lamb * reg_
if opt == 'LBFGS':
objective.backward()
return objective

for _ in pbar:

if save_ckpt and _ % save_freq == 0:
torch.save(self.state_dict(), f'./{save_folder}/{_}')

train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)

if opt == "LBFGS":
optimizer.step(closure)

elif opt == "Adam":
pred = self.forward(dataset['train_input'][train_id].to(self.device))
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
reg_ = torch.tensor(0.)
loss = train_loss + lamb * reg_
optimizer.zero_grad()
loss.backward()
optimizer.step()

test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device))


if metrics != None:
for i in range(len(metrics)):
results[metrics[i].__name__].append(metrics[i]().item())

results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())

if _ % log == 0:
if display_metrics == None:
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
else:
string = ''
data = ()
for metric in display_metrics:
string += f' {metric}: %.2e |'
try:
results[metric]
except:
raise Exception(f'{metric} not recognized')
data += (results[metric][-1],)
pbar.set_description(string % data)

return results


165 changes: 165 additions & 0 deletions scripts/ademamix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Adapted from: https://pytorch.org/docs/1.6.0/_modules/torch/optim/adam.html
"""
import math
import torch
from torch.optim import Optimizer


def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1):
if step < warmup:
a = step / float(warmup)
return (1.0-a) * alpha_start + a * alpha_end
return alpha_end


def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):

def f(beta, eps=1e-8):
return math.log(0.5)/math.log(beta+eps)-1

def f_inv(t):
return math.pow(0.5, 1/(t+1))

if step < warmup:
a = step / float(warmup)
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
return beta_end


class AdEMAMix(Optimizer):
r"""Implements the AdEMAMix algorithm.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999, 0.9999))
corresponding to beta_1, beta_2, beta_3 in AdEMAMix
alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2)
beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None)
alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay as in AdamW (default: 0)
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9999), alpha=2.0,
beta3_warmup=None, alpha_warmup=None, eps=1e-8,
weight_decay=0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
alpha_warmup=alpha_warmup, weight_decay=weight_decay)
super(AdEMAMix, self).__init__(params, defaults)

def __setstate__(self, state):
super(AdEMAMix, self).__setstate__(state)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:

lr = group["lr"]
lmbda = group["weight_decay"]
eps = group["eps"]
beta1, beta2, beta3_final = group["betas"]
beta3_warmup = group["beta3_warmup"]
alpha_final = group["alpha"]
alpha_warmup = group["alpha_warmup"]

for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('AdEMAMix does not support sparse gradients.')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
if beta1 != 0.0: # save memory in case beta1 is 0.0
state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format)
else:
state['exp_avg_fast'] = None
state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq']

state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']

# Compute the effective alpha and beta3 in case warmup is used
if alpha_warmup is not None:
alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup)
else:
alpha = alpha_final

if beta3_warmup is not None:
beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup)
else:
beta3 = beta3_final

# Decay the first and second moment running average coefficient
if beta1 != 0.0:
exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1)
else:
exp_avg_fast = grad
exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom

# decay
update.add_(p, alpha=lmbda)

p.add_(-lr * update)

return loss


if __name__ == "__main__": # small dummy test

x = torch.randn((10,7))
model = torch.nn.Linear(7, 1, bias=False)
opt = AdEMAMix(params=model.parameters(), lr=1e-2, betas=(0.9, 0.999, 0.9999), alpha=2.0, beta3_warmup=45, alpha_warmup=45, weight_decay=0.1)
print(model.weight)
for itr in range(50):
y = model(x).mean()
opt.zero_grad()
y.backward()
opt.step()

print(model.weight)
Loading

0 comments on commit bc2aa03

Please sign in to comment.