This is the easiest way to calculate GNS (Gradient Noise Scale) for your PyTorch models. No hooks, gradient accumulation, or multi-GPU setup needed. Just pass in your per-example losses and model.
GNS measures gradient noise in your training. See https://arxiv.org/pdf/1812.06162 and https://openreview.net/forum?id=xINTMAvPQA
pip install gns-pytorch
Simple usage:
from gns_pytorch import compute_gns
import torch
model = YourModel()
optimizer = torch.optim.Adam(model.parameters())
def training_step(batch):
x, y = batch
logits = model(x)
per_example_losses = torch.nn.functional.cross_entropy(logits, y, reduction='none')
if global_step % 100 == 0:
gns_value = compute_gns(per_example_losses, model)
gns_ema = 0.9 * gns_ema + 0.1 * gns_value
print(f"Current GNS (EMA): {gns_ema}")
loss = per_example_losses.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
With accurate GNS you can schedule your batch size (using gradient accumulation) to always be critical / optimal throughout training, massively boosting convergence and sample efficiency. This is similar to what deepseek-v3 did.
- Call
compute_gns
every N steps (like 100+) to avoid overhead - Use an EMA on the GNS values since they are very noisy
- The
param_percentage
param lets you sample a subset of model parameters for faster computation - Enable vmap with
use_vmap=True
to speed up computation by parallelizing per-example gradients (unfortunately, PyTorch's vmap isn't composable with flex attention and torch.compile yet) - GNS directly approximates critical batch size. For example, if GNS logger shows 64 and your global batch size is 32, you should double your gradient accumulation steps