From 8462b6c8366db92f442c3f7a283523c93f35734b Mon Sep 17 00:00:00 2001 From: Cade Gordon Date: Thu, 28 Jul 2022 13:18:19 -0500 Subject: [PATCH] Update logging style and add samples/s to to log_data --- src/training/train.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 8341be540..05dee34be 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -65,7 +65,6 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w loss_m = AverageMeter() batch_time_m = AverageMeter() data_time_m = AverageMeter() - samples_second_m = AverageMeter() end = time.time() for i, batch in enumerate(dataloader): step = num_batches_per_epoch * epoch + i @@ -101,7 +100,6 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w unwrap_model(model).logit_scale.clamp_(0, math.log(100)) batch_time_m.update(time.time() - end) - samples_second_m.update(args.batch_size * args.world_size / batch_time_m.val) end = time.time() batch_count = i + 1 if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): @@ -116,9 +114,8 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Samples/Second: {samples_second_m.avg:#g} " f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " f"LR: {optimizer.param_groups[0]['lr']:5f} " f"Logit Scale: {logit_scale_scalar:.3f}" ) @@ -128,6 +125,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w "loss": loss_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, + "samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, "scale": logit_scale_scalar, "lr": optimizer.param_groups[0]["lr"] } @@ -142,7 +140,6 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() - samples_second_m.reset() # end for