diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 22b9a840..aa188aeb 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -40,6 +40,7 @@ class train_config: # training spec batch_size: int = 2 + grad_accum_steps: int = 1 num_steps: int = 1000000 training_stage: str = "initial" learning_rate: float = 3e-4 diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 2d512cb2..78c6f512 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -80,6 +80,7 @@ def train( run["hparams"] = asdict(cfg) model.train() + optimizer.zero_grad() ddp_stats = torch.zeros(3).to(local_rank) start = time.time() @@ -91,20 +92,24 @@ def train( input = input.to(local_rank) label = label.to(local_rank) - optimizer.zero_grad() output = model(input) output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean() + loss = loss / cfg.grad_accum_steps loss.backward() - ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() - optimizer.step() + + if batch_idx % cfg.grad_accum_steps == 0: + ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() + optimizer.step() + optimizer.zero_grad() + ddp_stats[2] += 1 + scheduler.step() ddp_stats[0] += loss.item() - ddp_stats[2] += 1 if profiler: profiler.step()