diff --git a/train.py b/train.py index e95a712..3c5dcda 100644 --- a/train.py +++ b/train.py @@ -160,7 +160,7 @@ def train( if is_main and (step % cfg.val_interval == 0) and "val" in dataloaders: val_loss = validate( cfg=cfg, - model=model, + model=accelerator.unwrap_model(model), ema_model=ema_model, val_dataloader=dataloaders["val"], step=step, @@ -187,7 +187,8 @@ def main(cfg): model, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"]) run = ddict() if acc.is_main_process: - ema.to(acc.device) + if ema is not None: + ema.to(acc.device) try: if cfg.meta.neptune: import neptune