diff --git a/apex/examples/imagenet/main_amp.py b/apex/examples/imagenet/main_amp.py index bddf24827e..2591992c35 100644 --- a/apex/examples/imagenet/main_amp.py +++ b/apex/examples/imagenet/main_amp.py @@ -140,11 +140,10 @@ def main(): model = model.cuda() # Scale learning rate based on global batch size - args.lr = args.lr*float(args.batch_size*args.world_size)/256. + args.lr = args.lr * float(args.batch_size * args.world_size) / 256. optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimizer = amp.initialize(model, optimizer,