|
| 1 | +import argparse |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import zipfile |
| 5 | +import time |
| 6 | + |
| 7 | +import mxnet as mx |
| 8 | +import horovod.mxnet as hvd |
| 9 | +from mxnet import autograd, gluon, nd |
| 10 | +from mxnet.test_utils import download |
| 11 | + |
| 12 | +# Training settings |
| 13 | +parser = argparse.ArgumentParser(description='Apache MXNet MNIST Example') |
| 14 | + |
| 15 | +parser.add_argument('--batch-size', type=int, default=64, |
| 16 | + help='training batch size (default: 64)') |
| 17 | +parser.add_argument('--dtype', type=str, default='float32', |
| 18 | + help='training data type (default: float32)') |
| 19 | +parser.add_argument('--epochs', type=int, default=5, |
| 20 | + help='number of training epochs (default: 5)') |
| 21 | +parser.add_argument('--lr', type=float, default=0.01, |
| 22 | + help='learning rate (default: 0.01)') |
| 23 | +parser.add_argument('--momentum', type=float, default=0.9, |
| 24 | + help='SGD momentum (default: 0.9)') |
| 25 | +parser.add_argument('--no-cuda', action='store_true', default=False, |
| 26 | + help='disable training on GPU (default: False)') |
| 27 | +args = parser.parse_args() |
| 28 | + |
| 29 | +if not args.no_cuda: |
| 30 | + # Disable CUDA if there are no GPUs. |
| 31 | + if mx.context.num_gpus() == 0: |
| 32 | + args.no_cuda = True |
| 33 | + |
| 34 | +logging.basicConfig(level=logging.INFO) |
| 35 | +logging.info(args) |
| 36 | + |
| 37 | + |
| 38 | +# Function to get mnist iterator given a rank |
| 39 | +def get_mnist_iterator(rank): |
| 40 | + data_dir = "data-%d" % rank |
| 41 | + if not os.path.isdir(data_dir): |
| 42 | + os.makedirs(data_dir) |
| 43 | + zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip', |
| 44 | + dirname=data_dir) |
| 45 | + with zipfile.ZipFile(zip_file_path) as zf: |
| 46 | + zf.extractall(data_dir) |
| 47 | + |
| 48 | + input_shape = (1, 28, 28) |
| 49 | + batch_size = args.batch_size |
| 50 | + |
| 51 | + train_iter = mx.io.MNISTIter( |
| 52 | + image="%s/train-images-idx3-ubyte" % data_dir, |
| 53 | + label="%s/train-labels-idx1-ubyte" % data_dir, |
| 54 | + input_shape=input_shape, |
| 55 | + batch_size=batch_size, |
| 56 | + shuffle=True, |
| 57 | + flat=False, |
| 58 | + num_parts=hvd.size(), |
| 59 | + part_index=hvd.rank() |
| 60 | + ) |
| 61 | + |
| 62 | + val_iter = mx.io.MNISTIter( |
| 63 | + image="%s/t10k-images-idx3-ubyte" % data_dir, |
| 64 | + label="%s/t10k-labels-idx1-ubyte" % data_dir, |
| 65 | + input_shape=input_shape, |
| 66 | + batch_size=batch_size, |
| 67 | + flat=False, |
| 68 | + ) |
| 69 | + |
| 70 | + return train_iter, val_iter |
| 71 | + |
| 72 | + |
| 73 | +# Function to define neural network |
| 74 | +def conv_nets(): |
| 75 | + net = gluon.nn.HybridSequential() |
| 76 | + with net.name_scope(): |
| 77 | + net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu')) |
| 78 | + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) |
| 79 | + net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu')) |
| 80 | + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) |
| 81 | + net.add(gluon.nn.Flatten()) |
| 82 | + net.add(gluon.nn.Dense(512, activation="relu")) |
| 83 | + net.add(gluon.nn.Dense(10)) |
| 84 | + return net |
| 85 | + |
| 86 | + |
| 87 | +# Function to evaluate accuracy for a model |
| 88 | +def evaluate(model, data_iter, context): |
| 89 | + data_iter.reset() |
| 90 | + metric = mx.metric.Accuracy() |
| 91 | + for _, batch in enumerate(data_iter): |
| 92 | + data = batch.data[0].as_in_context(context) |
| 93 | + label = batch.label[0].as_in_context(context) |
| 94 | + output = model(data.astype(args.dtype, copy=False)) |
| 95 | + metric.update([label], [output]) |
| 96 | + return metric.get() |
| 97 | + |
| 98 | + |
| 99 | +# Initialize Horovod |
| 100 | +hvd.init() |
| 101 | + |
| 102 | +# Horovod: pin context to local rank |
| 103 | +context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank()) |
| 104 | +num_workers = hvd.size() |
| 105 | + |
| 106 | +# Load training and validation data |
| 107 | +train_data, val_data = get_mnist_iterator(hvd.rank()) |
| 108 | + |
| 109 | +# Build model |
| 110 | +model = conv_nets() |
| 111 | +model.cast(args.dtype) |
| 112 | +model.hybridize() |
| 113 | + |
| 114 | +# Create optimizer |
| 115 | +optimizer_params = {'momentum': args.momentum, |
| 116 | + 'learning_rate': args.lr * hvd.size()} |
| 117 | +opt = mx.optimizer.create('sgd', **optimizer_params) |
| 118 | + |
| 119 | +# Initialize parameters |
| 120 | +initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", |
| 121 | + magnitude=2) |
| 122 | +model.initialize(initializer, ctx=context) |
| 123 | + |
| 124 | +# Horovod: fetch and broadcast parameters |
| 125 | +params = model.collect_params() |
| 126 | +if params is not None: |
| 127 | + hvd.broadcast_parameters(params, root_rank=0) |
| 128 | + |
| 129 | +# Horovod: create DistributedTrainer, a subclass of gluon.Trainer |
| 130 | +trainer = hvd.DistributedTrainer(params, opt) |
| 131 | + |
| 132 | +# Create loss function and train metric |
| 133 | +loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() |
| 134 | +metric = mx.metric.Accuracy() |
| 135 | + |
| 136 | +# Train model |
| 137 | +for epoch in range(args.epochs): |
| 138 | + tic = time.time() |
| 139 | + train_data.reset() |
| 140 | + metric.reset() |
| 141 | + for nbatch, batch in enumerate(train_data, start=1): |
| 142 | + data = batch.data[0].as_in_context(context) |
| 143 | + label = batch.label[0].as_in_context(context) |
| 144 | + with autograd.record(): |
| 145 | + output = model(data.astype(args.dtype, copy=False)) |
| 146 | + loss = loss_fn(output, label) |
| 147 | + loss.backward() |
| 148 | + trainer.step(args.batch_size) |
| 149 | + metric.update([label], [output]) |
| 150 | + |
| 151 | + if nbatch % 100 == 0: |
| 152 | + name, acc = metric.get() |
| 153 | + logging.info('[Epoch %d Batch %d] Training: %s=%f' % |
| 154 | + (epoch, nbatch, name, acc)) |
| 155 | + |
| 156 | + if hvd.rank() == 0: |
| 157 | + elapsed = time.time() - tic |
| 158 | + speed = nbatch * args.batch_size * hvd.size() / elapsed |
| 159 | + logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', |
| 160 | + epoch, speed, elapsed) |
| 161 | + |
| 162 | + # Evaluate model accuracy |
| 163 | + _, train_acc = metric.get() |
| 164 | + name, val_acc = evaluate(model, val_data, context) |
| 165 | + if hvd.rank() == 0: |
| 166 | + logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name, |
| 167 | + train_acc, name, val_acc) |
| 168 | + |
| 169 | + if hvd.rank() == 0 and epoch == args.epochs - 1: |
| 170 | + assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\ |
| 171 | + (0.96)" % val_acc |
0 commit comments