From 31207553f566763615287c104d92f11821a3e672 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 5 Mar 2026 15:30:59 +0100 Subject: [PATCH] WIP some tweaks on adam and muon --- objective.py | 4 ++-- solvers/adam.py | 5 ++++- solvers/muon.py | 10 +++++----- solvers/scion.py | 2 +- solvers/soap.py | 7 +++++-- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/objective.py b/objective.py index a4ce567..4d46901 100644 --- a/objective.py +++ b/objective.py @@ -26,7 +26,7 @@ def set_data(self, train_dataloader, val_dataloader, model): def evaluate_result(self, model, dist=None): model.eval() - val_batch_size = 64 # Batch of 64 for validation + val_batch_size = 64 # Batch 64 for validation as in NanoGPT codebase if dist is not None: # In distributed mode, we use the distributed data generator rank, size = dist.get_rank(), dist.get_world_size() @@ -43,7 +43,7 @@ def evaluate_result(self, model, dist=None): # Compute the validation loss val_loss, n_batches = 0.0, 0 for data in val_loader: - loss, *_ = self.model(*data) + loss, *_ = model(*data) val_loss += loss.item() n_batches += 1 val_loss /= n_batches diff --git a/solvers/adam.py b/solvers/adam.py index 38cbc9c..b3b5835 100644 --- a/solvers/adam.py +++ b/solvers/adam.py @@ -17,7 +17,7 @@ class Solver(BaseSolver): parameters = { 'learning_rate': [1e-3], 'weight_decay': [1e-4], - 'num_steps': [6200], + 'num_steps': [8_000], 'batch_size': [64], "slurm_nodes": [1, 2], "sin_init": [True], @@ -63,7 +63,10 @@ def get_next(self, stop_val): return stop_val + 250 def warm_up(self): + n_iter = self.num_steps + self.num_steps = 10 self.run_once(stop_val=10) + self.num_steps = n_iter def run(self, cb): diff --git a/solvers/muon.py b/solvers/muon.py index 48d7520..bc46242 100644 --- a/solvers/muon.py +++ b/solvers/muon.py @@ -1,10 +1,10 @@ +from benchopt import BaseSolver from contextlib import nullcontext import torch from benchmark_utils.distributed_tools import setup_distributed from benchmark_utils.lr_scheduler import get_lr from benchmark_utils.optimizers.muon import Muon -from benchopt import BaseSolver from torch.optim import AdamW from tqdm.auto import tqdm @@ -13,11 +13,11 @@ class Solver(BaseSolver): name = "Muon" parameters = { - "muon_lr": [0.02], + "muon_lr": [3.6e-4], "muon_momentum": [0.95], - "adam_lr": [3e-4], - "adam_weight_decay": [0.0], - "num_steps": [6200], + "adam_lr": [0.0036], + "adam_weight_decay": [1e-4], + "num_steps": [8_000], "batch_size": [64], "slurm_nodes": [1, 2], } diff --git a/solvers/scion.py b/solvers/scion.py index 90c4f32..5731db6 100644 --- a/solvers/scion.py +++ b/solvers/scion.py @@ -24,7 +24,7 @@ class Solver(BaseSolver): "momentum": [0.1], "hidden_radius": [50.0], "lm_head_radius": [3000.0], - "num_steps": [6200], + "num_steps": [8_000], "batch_size": [64], "slurm_nodes": [1, 2], } diff --git a/solvers/soap.py b/solvers/soap.py index a6f27eb..0e23c66 100644 --- a/solvers/soap.py +++ b/solvers/soap.py @@ -15,9 +15,9 @@ class Solver(BaseSolver): name = "SOAP" parameters = { - "learning_rate": [3e-3], + "learning_rate": [0.0036], "weight_decay": [1e-4], - "num_steps": [6200], + "num_steps": [7_500], "batch_size": [64], "slurm_nodes": [1, 2], } @@ -53,7 +53,10 @@ def get_next(self, stop_val): return stop_val + 250 def warm_up(self): + n_iter = self.num_steps + self.num_steps = 10 self.run_once(stop_val=10) + self.num_steps = n_iter def run(self, cb): param_dict = {