Skip to content

elyxlz/gns-pytorch

Repository files navigation

GNS PyTorch

This is the easiest way to calculate GNS (Gradient Noise Scale) for your PyTorch models. No hooks, gradient accumulation, or multi-GPU setup needed. Just pass in your per-example losses and model.

What's GNS?

GNS measures gradient noise in your training. See https://arxiv.org/pdf/1812.06162 and https://openreview.net/forum?id=xINTMAvPQA

Install

pip install gns-pytorch

Usage

Simple usage:

from gns_pytorch import compute_gns
import torch

model = YourModel()
optimizer = torch.optim.Adam(model.parameters())

def training_step(batch):
    x, y = batch
    logits = model(x)
    per_example_losses = torch.nn.functional.cross_entropy(logits, y, reduction='none')
    
    if global_step % 100 == 0:
        gns_value = compute_gns(per_example_losses, model)
        gns_ema = 0.9 * gns_ema + 0.1 * gns_value
        print(f"Current GNS (EMA): {gns_ema}")
    
    loss = per_example_losses.mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Adaptive Batch Size Scheduling

With accurate GNS you can schedule your batch size (using gradient accumulation) to always be critical / optimal throughout training, massively boosting convergence and sample efficiency. This is similar to what deepseek-v3 did.

Tips

  • Call compute_gns every N steps (like 100+) to avoid overhead
  • Use an EMA on the GNS values since they are very noisy
  • The param_percentage param lets you sample a subset of model parameters for faster computation
  • Enable vmap with use_vmap=True to speed up computation by parallelizing per-example gradients (unfortunately, PyTorch's vmap isn't composable with flex attention and torch.compile yet)
  • GNS directly approximates critical batch size. For example, if GNS logger shows 64 and your global batch size is 32, you should double your gradient accumulation steps

About

easiest way to calculate gns in pytorch

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages