Skip to content

Commit 6f627a8

Browse files
vandanavkterrytangyuan
authored andcommitted
MXNet distributed training (#122)
* MXNet distributed training * change apiVersion * Addressed some review comments Newline related comments * Revert "change apiVersion" This reverts commit 163aed7.
1 parent 59cdbae commit 6f627a8

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

examples/mxnet/Dockerfile

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
FROM horovod/horovod:0.16.2-tf1.12.0-torch1.1.0-mxnet1.4.1-py3.5 AS build
2+
3+
# Create a wrapper for OpenMPI to allow running as root by default
4+
RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \
5+
echo '#!/bin/bash' > /usr/local/bin/mpirun && \
6+
echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \
7+
chmod a+x /usr/local/bin/mpirun
8+
9+
# Configure OpenMPI to run good defaults:
10+
RUN echo "hwloc_base_binding_policy = none" >> /usr/local/etc/openmpi-mca-params.conf && \
11+
echo "rmaps_base_mapping_policy = slot" >> /usr/local/etc/openmpi-mca-params.conf && \
12+
echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf
13+
14+
# Set default NCCL parameters
15+
RUN echo NCCL_DEBUG=INFO >> /etc/nccl.conf && \
16+
echo NCCL_SOCKET_IFNAME=^docker0 >> /etc/nccl.conf
17+
18+
# --------------------------------------------------------------------
19+
20+
# Other packages needed for running examples
21+
RUN pip install gluoncv
22+
23+
# add the example script to examples folder
24+
ADD mxnet_mnist.py /examples/mxnet_mnist.py
25+
26+
WORKDIR "/"
27+
CMD ["bin/bash"]

examples/mxnet/mxnet-mnist.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
apiVersion: kubeflow.org/v1alpha1
2+
kind: MPIJob
3+
metadata:
4+
labels:
5+
ksonnet.io/component: mxnet-mnist-horovod-job
6+
name: mxnet-mnist-horovod-job
7+
namespace: default
8+
spec:
9+
replicas: 2
10+
template:
11+
spec:
12+
containers:
13+
- command:
14+
- mpirun
15+
- -mca
16+
- btl_tcp_if_exclude
17+
- lo
18+
- -mca
19+
- pml
20+
- ob1
21+
- -mca
22+
- btl
23+
- ^openib
24+
- --bind-to
25+
- none
26+
- -map-by
27+
- slot
28+
- -x
29+
- LD_LIBRARY_PATH
30+
- -x
31+
- PATH
32+
- -x
33+
- NCCL_SOCKET_IFNAME=eth0
34+
- -x
35+
- NCCL_DEBUG=INFO
36+
- -x
37+
- MXNET_CUDNN_AUTOTUNE_DEFAULT=0
38+
- python
39+
- /examples/mxnet_mnist.py
40+
- --save-frequency
41+
- "1"
42+
- --batch-size
43+
- "64"
44+
- --epochs
45+
- "5"
46+
image: mpioperator/mxnet-horovod:latest
47+
name: mxnet-mnist-horovod-job
48+
resources:
49+
limits:
50+
nvidia.com/gpu: 4

examples/mxnet/mxnet_mnist.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import argparse
2+
import logging
3+
import os
4+
import zipfile
5+
import time
6+
7+
import mxnet as mx
8+
import horovod.mxnet as hvd
9+
from mxnet import autograd, gluon, nd
10+
from mxnet.test_utils import download
11+
12+
# Training settings
13+
parser = argparse.ArgumentParser(description='Apache MXNet MNIST Example')
14+
15+
parser.add_argument('--batch-size', type=int, default=64,
16+
help='training batch size (default: 64)')
17+
parser.add_argument('--dtype', type=str, default='float32',
18+
help='training data type (default: float32)')
19+
parser.add_argument('--epochs', type=int, default=5,
20+
help='number of training epochs (default: 5)')
21+
parser.add_argument('--lr', type=float, default=0.01,
22+
help='learning rate (default: 0.01)')
23+
parser.add_argument('--momentum', type=float, default=0.9,
24+
help='SGD momentum (default: 0.9)')
25+
parser.add_argument('--no-cuda', action='store_true', default=False,
26+
help='disable training on GPU (default: False)')
27+
args = parser.parse_args()
28+
29+
if not args.no_cuda:
30+
# Disable CUDA if there are no GPUs.
31+
if mx.context.num_gpus() == 0:
32+
args.no_cuda = True
33+
34+
logging.basicConfig(level=logging.INFO)
35+
logging.info(args)
36+
37+
38+
# Function to get mnist iterator given a rank
39+
def get_mnist_iterator(rank):
40+
data_dir = "data-%d" % rank
41+
if not os.path.isdir(data_dir):
42+
os.makedirs(data_dir)
43+
zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
44+
dirname=data_dir)
45+
with zipfile.ZipFile(zip_file_path) as zf:
46+
zf.extractall(data_dir)
47+
48+
input_shape = (1, 28, 28)
49+
batch_size = args.batch_size
50+
51+
train_iter = mx.io.MNISTIter(
52+
image="%s/train-images-idx3-ubyte" % data_dir,
53+
label="%s/train-labels-idx1-ubyte" % data_dir,
54+
input_shape=input_shape,
55+
batch_size=batch_size,
56+
shuffle=True,
57+
flat=False,
58+
num_parts=hvd.size(),
59+
part_index=hvd.rank()
60+
)
61+
62+
val_iter = mx.io.MNISTIter(
63+
image="%s/t10k-images-idx3-ubyte" % data_dir,
64+
label="%s/t10k-labels-idx1-ubyte" % data_dir,
65+
input_shape=input_shape,
66+
batch_size=batch_size,
67+
flat=False,
68+
)
69+
70+
return train_iter, val_iter
71+
72+
73+
# Function to define neural network
74+
def conv_nets():
75+
net = gluon.nn.HybridSequential()
76+
with net.name_scope():
77+
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
78+
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
79+
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
80+
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
81+
net.add(gluon.nn.Flatten())
82+
net.add(gluon.nn.Dense(512, activation="relu"))
83+
net.add(gluon.nn.Dense(10))
84+
return net
85+
86+
87+
# Function to evaluate accuracy for a model
88+
def evaluate(model, data_iter, context):
89+
data_iter.reset()
90+
metric = mx.metric.Accuracy()
91+
for _, batch in enumerate(data_iter):
92+
data = batch.data[0].as_in_context(context)
93+
label = batch.label[0].as_in_context(context)
94+
output = model(data.astype(args.dtype, copy=False))
95+
metric.update([label], [output])
96+
return metric.get()
97+
98+
99+
# Initialize Horovod
100+
hvd.init()
101+
102+
# Horovod: pin context to local rank
103+
context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank())
104+
num_workers = hvd.size()
105+
106+
# Load training and validation data
107+
train_data, val_data = get_mnist_iterator(hvd.rank())
108+
109+
# Build model
110+
model = conv_nets()
111+
model.cast(args.dtype)
112+
model.hybridize()
113+
114+
# Create optimizer
115+
optimizer_params = {'momentum': args.momentum,
116+
'learning_rate': args.lr * hvd.size()}
117+
opt = mx.optimizer.create('sgd', **optimizer_params)
118+
119+
# Initialize parameters
120+
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
121+
magnitude=2)
122+
model.initialize(initializer, ctx=context)
123+
124+
# Horovod: fetch and broadcast parameters
125+
params = model.collect_params()
126+
if params is not None:
127+
hvd.broadcast_parameters(params, root_rank=0)
128+
129+
# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
130+
trainer = hvd.DistributedTrainer(params, opt)
131+
132+
# Create loss function and train metric
133+
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
134+
metric = mx.metric.Accuracy()
135+
136+
# Train model
137+
for epoch in range(args.epochs):
138+
tic = time.time()
139+
train_data.reset()
140+
metric.reset()
141+
for nbatch, batch in enumerate(train_data, start=1):
142+
data = batch.data[0].as_in_context(context)
143+
label = batch.label[0].as_in_context(context)
144+
with autograd.record():
145+
output = model(data.astype(args.dtype, copy=False))
146+
loss = loss_fn(output, label)
147+
loss.backward()
148+
trainer.step(args.batch_size)
149+
metric.update([label], [output])
150+
151+
if nbatch % 100 == 0:
152+
name, acc = metric.get()
153+
logging.info('[Epoch %d Batch %d] Training: %s=%f' %
154+
(epoch, nbatch, name, acc))
155+
156+
if hvd.rank() == 0:
157+
elapsed = time.time() - tic
158+
speed = nbatch * args.batch_size * hvd.size() / elapsed
159+
logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
160+
epoch, speed, elapsed)
161+
162+
# Evaluate model accuracy
163+
_, train_acc = metric.get()
164+
name, val_acc = evaluate(model, val_data, context)
165+
if hvd.rank() == 0:
166+
logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
167+
train_acc, name, val_acc)
168+
169+
if hvd.rank() == 0 and epoch == args.epochs - 1:
170+
assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\
171+
(0.96)" % val_acc

0 commit comments

Comments
 (0)